In [None]:
from google.colab import drive 
drive.mount('/content/gdrive', force_remount=True)

In [None]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
!wget http://images.cocodataset.org/zips/train2014.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip

In [None]:
!unzip train2014.zip
!unzip annotations_trainval2014.zip

In [None]:
import json, collections

annotation_file = "./annotations/captions_train2014.json"
PATH = './train2014/'

with open(annotation_file, 'r') as f:
    annotations = json.load(f)
image_path_to_caption = collections.defaultdict(list)
for val in annotations['annotations']:
  caption = val['caption']
  image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (val['image_id'])
  image_path_to_caption[image_path].append(caption)
image_paths = list(image_path_to_caption.keys())

In [None]:
print(image_paths[0])

In [None]:
import pickle
with open("coco_filenames.txt", "wb") as f:
  pickle.dump(image_paths, f)
with open('coco_captions.txt', 'wb') as f:
  pickle.dump(image_path_to_caption, f)

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class Coco2014Dataset(Dataset):
  def __init__(self, fnames, captions, transform=None):
    self.fnames = fnames
    self.captions = captions
    self.transform = transform

  def __len__(self):
    return len(self.fnames)
  
  def __getitem__(self, i):
    img, captions = Image.open(self.fnames[i]), self.captions[self.fnames[i]]
    if self.transform is not None:
      img = self.transform(img)
    return img, captions 

def collate(batch):
  imgs, captions = torch.stack([x[0] for x in batch]), [x[1] for x in batch]
  return imgs, captions

In [None]:
!pip install ipywidgets

In [None]:
import clip
import torch
from tqdm.notebook import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

dataset = Coco2014Dataset(image_paths, image_path_to_caption, preprocess)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=8, collate_fn=collate, shuffle=False)

In [None]:
# embed images 
fin_features = []
with torch.no_grad():
  for idx, (images, captions) in tqdm(enumerate(dataloader)):
    images = images.to(device)
    image_features = model.encode_image(images)
    fin_features.append(image_features)
fin_features = torch.cat(fin_features)

In [None]:
torch.save(fin_features, 'coco_features.pt')

In [None]:
all_captions = []
for idx, (images, captions) in tqdm(enumerate(dataloader)):
  all_captions.extend(captions)
with open('all_captions.txt', 'wb') as f:
  pickle.dump(all_captions, f)