In [None]:
%%capture
!pip install -U sentence-transformers

In [None]:
from os import path
import numpy as np
import pandas as pd

import torch
import pickle
import pathlib
import collections
import urllib
import zipfile

# Download the dataset

In [None]:
dataset_img_path = pathlib.Path('flickr8k') / 'Flicker8k_Dataset'

In [None]:
# Reference: https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/image_captioning.ipynb#scrollTo=kaNy_l7tGuAZ&line=1&uniqifier=1

def flickr8k(path='flickr8k'):
  path = pathlib.Path(path)
  path = pathlib.Path(path)
  dataset_path = path / 'Flicker8k_Dataset'

  if not dataset_path.exists():
    url = 'https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip'
    file_path, _ = urllib.request.urlretrieve(url)
    zip_ref = zipfile.ZipFile(file_path, 'r')
    zip_ref.extractall(path)
    zip_ref.close()

    url = 'https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip'
    file_path, _ = urllib.request.urlretrieve(url)
    zip_ref = zipfile.ZipFile(file_path, 'r')
    zip_ref.extractall(path)
    zip_ref.close()
    
  captions = (path/"Flickr8k.token.txt").read_text().splitlines()
  captions = (line.split('\t') for line in captions)
  captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)

  cap_dict = collections.defaultdict(list)
  for fname, cap in captions:
    cap_dict[fname].append(cap)

  train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()
  train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]

  dev_files = (path/'Flickr_8k.devImages.txt').read_text().splitlines()
  dev_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in dev_files]

  test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()
  test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]

  return train_captions, dev_captions, test_captions

In [None]:
train_raw, dev_raw, test_raw = flickr8k()

In [None]:
train_captions = [cap for _, captions in train_raw for cap in captions]
dev_captions = [cap for _, captions in dev_raw for cap in captions]
test_captions = [cap for _, captions in test_raw for cap in captions]

# Calculate Features

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Caption Embeddings

In [None]:
from sentence_transformers import SentenceTransformer

In [None]:
model = SentenceTransformer('nreimers/MiniLM-L6-H384-uncased', device=device)

In [None]:
train_embeddings = model.encode(train_captions, batch_size=256)
dev_embeddings = model.encode(dev_captions, batch_size=256)
test_embeddings = model.encode(test_captions, batch_size=256)

## Image vectors

In [None]:
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

weights = torchvision.models.ResNet18_Weights.DEFAULT

### Prepare The Dataset

In [None]:
class CustomImageDataset(Dataset):
  def __init__(self, image_paths, transform=None, target_transform=None):
    self.image_paths = image_paths
    self.transform = transform
    self.target_transform = target_transform

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    img_path = self.image_paths[idx]
    image = read_image(str(img_path))
    if self.transform:
      image = self.transform(image)
    return image

In [None]:
train_img_paths = [path for path, _ in train_raw]
dev_img_paths = [path for path, _ in dev_raw]
test_img_paths = [path for path, _ in test_raw]

In [None]:
transforms = weights.transforms()

In [None]:
img_train_ds = CustomImageDataset(train_img_paths, transform=transforms)
img_dev_ds = CustomImageDataset(dev_img_paths, transform=transforms)
img_test_ds = CustomImageDataset(test_img_paths, transform=transforms)

In [None]:
img_train_dataloader = DataLoader(img_train_ds, batch_size=64, shuffle=False)
img_dev_dataloader = DataLoader(img_dev_ds, batch_size=64, shuffle=False)
img_test_dataloader = DataLoader(img_test_ds, batch_size=64, shuffle=False)

### Extract and save features

In [None]:
def extract_vectors(model, dataloader, device):
  vectors = []

  for imgs in dataloader:
    hiddens = model(imgs.to(device))
    cpu_hiddens = hiddens.cpu().detach().numpy().copy()
    cpu_hiddens = np.squeeze(cpu_hiddens)

    vectors.append(cpu_hiddens)

  vectors = np.concatenate(vectors, axis=0)
  return vectors

def save_vectors(vectors, name, base_path):
  vectors_path = pathlib.Path(base_path).joinpath(f'{name}_vectors.pkl')
  with open(vectors_path, 'wb') as f:
    pickle.dump(vectors, f)

In [None]:
model = torchvision.models.resnet18(weights)
model = torch.nn.Sequential(*list(model.children())[:-1])
model = model.to(device)
model = model.eval()

In [None]:
train_vectors = extract_vectors(model, img_train_dataloader, device)
dev_vectors = extract_vectors(model, img_dev_dataloader, device)
test_vectors = extract_vectors(model, img_test_dataloader, device)

In [None]:
base_path = '/content/drive/MyDrive/collective_learning/px-multimodal-repr/binaries/flickr8k/'

save_vectors(train_embeddings, 'train_text', base_path)
save_vectors(dev_embeddings, 'dev_text', base_path)
save_vectors(test_embeddings, 'test_text', base_path)
save_vectors(train_vectors, 'train_img', base_path)
save_vectors(dev_vectors, 'dev_img', base_path)
save_vectors(test_vectors, 'test_img', base_path)