In [None]:
!gdown 1JJjMiNieTz7xYs6UeVqd02M3DW4fnEfU
!unzip cpvr2016_flowers.zip

In [None]:
import os

def load_captions(captions_folder, image_folder):
    captions = {}
    image_files = os.listdir(image_folder)
    for image_file in image_files:
        image_name = image_file.split(".")[0]
        caption_file = os.path.join(captions_folder, image_name + ".txt")
        with open(caption_file, "r") as f:
            caption = f.readline()[0].strip()
        if image_name not in captions:
            captions[image_name] = caption
    
    return caption

captions_folder = "./cpvr2016_flowers/captions"
image_folder = "./cpvr2016_flowers/images"

captions = load_captions(captions_folder, image_folder)
captions

In [None]:
import torch
import numpy as np
from sentence_transformers import SentenceTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = SentenceTransformer("all-mpet-base-v2").to(device)

def encode_caption(captions):
    encoded_captions = {}
    for image_name in captions.keys():
        caption = captions[image_name]
        encoded_captions[image_name] = {
            "embed": torch.tensor(bert_model.encode(caption)),
            "text": caption
        }
    return encoded_captions

encoded_captions = encode_caption(captions)

In [None]:
from PIL import Image
from torch.utils.data import Dataset

class FlowerDataset(Dataset):
    def __init__(self, img_dir, captions, transform=None):
        self.img_dir = img_dir
        self.captions = captions
        self.transform = transform
        self.img_names = list(self.captions.keys())
    
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name+".jpg")
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        encoded_captions = self.captions[img_name]["embed"]
        caption = self.captions[img_name]["text"]
        
        return {
            "image": image,
            "embed_caption": encoded_captions,
            "text": caption
        }