# Imports

In [None]:
import os
from os.path import join
import time
import h5py
from sentence_transformers import SentenceTransformer
from PIL import Image
import torch
from torchvision import transforms

# Downloading Data

In [None]:
# flower dataset download link: http://www.robots.ox.ac.uk/~vgg/data/flowers/102/
# upload "jpg" folder to Drive

# caption dataset download link: https://drive.google.com/uc?export=download&confirm=l7Ld&id=0B0ywwgffWnLLcms2WWJQRFNSWXM
# upload "text_c10" folder to Drive

# Google Drive Setup

In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

In [None]:
datadir = "/content/drive/My Drive/CS444/Final_Project"
os.chdir(datadir)
!pwd

# Create Caption Embeddings

In [None]:
# caption saving code based on: https://github.com/paarthneekhara/text-to-image/blob/master/data_loader.py

def save_caption_vectors_flowers(datadir):
    # flower image directory
    img_dir = join(datadir, 'jpg')
    # get all jpgs
    image_files = [f for f in os.listdir(img_dir) if 'jpg' in f]
    # num_samples
    print(len(image_files))

    # initialize dictionary: key = img file, value = captions
    image_captions = {img_file: [] for img_file in image_files}

    caption_dir = join(datadir, 'text_c10')
    class_dirs = []
    # 102 class subdirectories (class_00001, ..., class_00102).
    for i in range(1, 103):
        class_dir_name = 'class_%.5d' % (i)
        class_dirs.append(join(caption_dir, class_dir_name))

    # read all .txt caption files for each folder
    for class_dir in class_dirs:
        caption_files = [f for f in os.listdir(class_dir) if 'txt' in f]
        for cap_file in caption_files:
            with open(join(class_dir, cap_file)) as f:
                captions = f.read().split('\n')
            # reconstruct image filename from caption file name
            img_file = cap_file[0:11] + ".jpg"
            # add 5 captions for each image
            image_captions[img_file] += [cap for cap in captions if len(cap) > 0][0:5]

    # confirm every image has captions
    print("images with captions:", len(image_captions))

    # load best text encoder based on evaluation
    # text_encoder = 'all-MiniLM-L6-v2'
    # text_encoder = 'multi-qa-mpnet-base-dot-v1'
    text_encoder = datadir + '/text_encoders/finetuned10_multi-qa-mpnet-base-dot-v1'
    model = SentenceTransformer(text_encoder)

    encoded_captions = {}

    # loop over every image and encode its captions
    for i, img in enumerate(image_captions):
        # model.encode takes a list of strings and returns (n_captions, embed_dim)
        encoded_captions[img] = model.encode(image_captions[img])

    # Save the encoded caption vectors to an HDF5 file.
    # h = h5py.File(join(datadir, 'basic_encoded_captions.hdf5'), 'w')
    # h = h5py.File(join(datadir, 'advanced_encoded_captions.hdf5'), 'w')
    h = h5py.File(join(datadir, 'finetuned_encoded_captions.hdf5'), 'w')
    for key in encoded_captions:
        h.create_dataset(key, data=encoded_captions[key])
    h.close()

save_caption_vectors_flowers(datadir)

8189
images with captions: 8189


# Save Resized Images To Disk In Tensor

In [None]:
image_dir = datadir + '/jpg'
cache_path = datadir + 'image_cache.pt'

to_tensor = transforms.ToTensor()

# load all images into a dict of tensors
tensor_cache = {}
for fname in os.listdir(image_dir):
    if not fname.lower().endswith('.jpg'):
        continue
    img = Image.open(os.path.join(image_dir, fname)).convert('RGB')
    # resize images before saving to reduce RAM use during computation + storage
    img = img.resize((64, 64), Image.BILINEAR)
    tensor_cache[fname] = to_tensor(img)

# save to drive
torch.save(tensor_cache, cache_path)