### Contrastive Language-Image Pretraining
This notebooks implements CLIP as described in https://arxiv.org/abs/2103.00020.

### Importing libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models import resnet18, ResNet18_Weights
from transformers import AlbertModel

### Defining the model

In [None]:
class CLIP(nn.Module):
  def __init__(self):
    super(CLIP, self).__init__()
    self.image_encoder = resnet18(weights=ResNet18_Weights.DEFAULT) # 11M params
    self.image_encoder.fc = nn.Linear(512, 512)

    self.text_encoder = AlbertModel.from_pretrained('albert-base-v2') # 11M params
    self.text_proj = nn.Linear(768, 512)

    self.t = nn.Parameter(torch.tensor(0.07), requires_grad = True)

  def forward(self, image, input_ids, attention_mask):
    image_embedding = self.image_encoder(image)
    image_embedding = F.normalize(image_embedding)

    text_encoding = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, -1]
    text_embedding = self.text_proj(text_encoding)
    text_embedding = F.normalize(text_embedding)

    logits = image_embedding @ text_embedding.T * (torch.clamp(torch.exp(self.t), 0.01, 100))
    return logits

In [None]:
class SCELoss(nn.Module):
    def __init__(self, alpha, beta, num_classes=32):
        super(SCELoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.num_classes = num_classes
        self.cross_entropy = torch.nn.CrossEntropyLoss()

    def forward(self, pred, labels):
        # CCE
        ce = self.cross_entropy(pred, labels)

        # RCE
        pred = F.softmax(pred, dim=1)
        pred = torch.clamp(pred, min=1e-7, max=1.0)
        label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float()
        label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
        rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))

        # Loss
        loss = self.alpha * ce + self.beta * rce.mean()
        return loss

#### Training hyperparameters

In [None]:
batch_size = 32

### Loading the data

In [None]:
!pip install -q kaggle

In [None]:
from google.colab import files

files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"mliu2023","key":"9d23dec1bc5e435650a779bc3d7be4d2"}'}

In [None]:
!mkdir ~/.kaggle

!cp kaggle.json ~/.kaggle/

In [None]:
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d adityajn105/flickr8k

Dataset URL: https://www.kaggle.com/datasets/adityajn105/flickr8k
License(s): CC0-1.0
Downloading flickr8k.zip to /content
100% 1.03G/1.04G [00:09<00:00, 203MB/s]
100% 1.04G/1.04G [00:09<00:00, 114MB/s]


In [None]:
!unzip -qq /content/flickr8k.zip

In [None]:
import pandas as pd
image_path = '/content/Images'
data = pd.read_csv("/content/captions.txt")

In [None]:
captions = data['caption'].tolist()

In [None]:
from transformers import AlbertTokenizer
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import Resize
import os

class FlickrDataset(Dataset):

    def __init__(self, image_path, captions):
      self.image_paths = [f for f in os.listdir(image_path) if os.path.isfile(os.path.join(image_path, f))]
      self.captions = captions
      self.resize = Resize((224, 224))
      self.tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')

    def __len__(self):
      return len(self.image_paths)

    def __getitem__(self, idx):
      inputs = self.tokenizer(
            self.captions[idx],
            max_length=64,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
      )
      input_ids = inputs['input_ids']
      attention_mask = inputs['attention_mask']
      return self.resize(read_image(os.path.join(image_path, self.image_paths[idx]))), input_ids.squeeze(0), attention_mask.squeeze(0)

dataset = FlickrDataset(image_path, captions)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

### Training loop

In [None]:
model = CLIP().cuda()

In [None]:
sceloss = SCELoss(0.5, 0.5, batch_size)

In [None]:
optimizer = torch.optim.AdamW(model.parameters())

In [None]:
for (images, input_ids, attention_mask) in dataloader:
  images = images.cuda().float()
  input_ids = input_ids.cuda()
  attention_mask = attention_mask.cuda()

  labels = torch.arange(len(images))
  outputs = model(images, input_ids, attention_mask).detach().cpu()
  loss = sceloss(outputs, labels)

  # backward pass
  optimizer.zero_grad()
  loss = torch.autograd.Variable(loss, requires_grad = True)
  loss.backward()
  # update weights
  optimizer.step()

  print(loss.item())