## Imports and Setup

Setup here below assumes your `kaggle.json` key has been uploaded to the Colab before execute the below cell.

It installs all dependencies not previously available in Google Colab, downloads and unzips the `flickr8k.lance` dataset.

In [1]:
! mkdir /root/.kaggle/
! pip install -q pyarrow pylance transformers kaggle timm
! mv /content/kaggle.json /root/.kaggle/
! chmod +rwx /root/.kaggle/kaggle.json
! kaggle datasets download -d heyytanay/flickr-8k-lance
! unzip -qq flickr-8k-lance.zip
! rm flickr-8k-lance.zip

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m25.4/25.4 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m78.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m75.3 MB/s[0m eta [36m0:00:00[0m
Dataset URL: https://www.kaggle.com/datasets/heyytanay/flickr-8k-lance
License(s): CC0-1.0
Downloading flickr-8k-lance.zip to /content
100% 1.03G/1.04G [00:49<00:00, 23.3MB/s]
100% 1.04G/1.04G [00:49<00:00, 22.6MB/s]


In [29]:
import cv2
import lance

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import timm
from transformers import AutoModel, AutoTokenizer

import matplotlib.pyplot as plt

import itertools
from tqdm import tqdm

import warnings
warnings.simplefilter('ignore')

## Config, Utility and Transformations

Define a Config with hyperparameters, some utility functions to load the image and captions and image transformations to pre-process the images.

In [28]:
class Config:
    img_size = (128, 128)
    bs = 32
    head_lr = 1e-3
    img_enc_lr = 1e-4
    text_enc_lr = 1e-5
    max_len = 18
    img_embed_dim = 2048
    text_embed_dim = 768
    projection_dim = 256
    temperature = 1.0
    num_epochs = 2
    img_encoder_model = 'resnet50'
    text_encoder_model = 'bert-base-cased'

In [70]:
def load_image(ds, idx):
    # Utility function to load an image at an index and convert it from bytes format to img format
    raw_img = ds.take([idx], columns=['image']).to_pydict()
    raw_img = np.frombuffer(b''.join(raw_img['image']), dtype=np.uint8)
    img = cv2.imdecode(raw_img, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    return img

def load_caption(ds, idx):
    # Utility function to load an image's caption. Currently we return the longest caption of all
    captions = ds.take([idx], columns=['captions']).to_pydict()['captions'][0]
    return max(captions, key=len)

In [10]:
train_augments = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(Config.img_size),
        transforms.Normalize([0.5], [0.5]),
    ]
)

## CLIPLanceDataset

This is a custom pytorch dataset for loading the Flickr8k images and captions from the lance dataset, tokenizing / pre-processing them and then returning them in the proper format.

In [83]:
class CLIPLanceDataset(Dataset):
    """Custom Dataset to load images and their corresponding captions"""
    def __init__(self, lance_path, max_len=18, tokenizer=None, transforms=None):
        self.ds = lance.dataset(lance_path)
        self.max_len = max_len
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased') if not tokenizer else tokenizer
        self.transforms = transforms

    def __len__(self):
        return self.ds.count_rows()

    def __getitem__(self, idx):
        # Load the image
        img = load_image(self.ds, idx)
        caption = load_caption(self.ds, idx)

        # Tokenize the caption
        caption = self.tokenizer(
            caption,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        # Flatten the tokenized captions otherwise they cause errors during training
        for k, v in caption.items():
            caption[k] = v.flatten()

        # Apply transformations to the images
        if self.transforms:
            img = self.transforms(img)

        return img, caption

## Model

The Image Encoder, Text Encoder and the whole model architecture in general was adapted and inspired from Manan Goel's work, [Implementing CLIP with PyTorch Lightning](https://wandb.ai/manan-goel/coco-clip/reports/Implementing-CLIP-With-PyTorch-Lightning--VmlldzoyMzg4Njk1).

This specific implementation is for CLIP for natural language-based image search.

The part that is of importance to us in this example is how the data loading works which can be found in the earlier sample and can be used as a drop-in replacement for many use-cases.

In [84]:
class ImageEncoder(nn.Module):
    def __init__(self, model_name, pretrained = True):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0,
            global_pool="avg"
        )

        for param in self.backbone.parameters():
            param.requires_grad = True

    def forward(self, img):
        return self.backbone(img)

class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()

        self.backbone = AutoModel.from_pretrained(model_name)

        for param in self.backbone.parameters():
            param.requires_grad = True

    def forward(self, captions):
        output = self.backbone(**captions)
        return output.last_hidden_state[:, 0, :]

class Head(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)

        self.dropout = nn.Dropout(0.3)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x += projected

        return self.layer_norm(x)

## Train

We finally now define the model, optimizer, dataloader and train the model for 2 epochs.

In [85]:
# Define models, optimizer, etc
img_encoder = ImageEncoder(model_name=Config.img_encoder_model).to('cuda')
img_head = Head(Config.img_embed_dim, Config.projection_dim).to('cuda')

tokenizer = AutoTokenizer.from_pretrained(Config.text_encoder_model)
text_encoder = TextEncoder(model_name=Config.text_encoder_model).to('cuda')
text_head = Head(Config.text_embed_dim, Config.projection_dim).to('cuda')

parameters = [
    {"params": img_encoder.parameters(), "lr": Config.img_enc_lr},
    {"params": text_encoder.parameters(), "lr": Config.text_enc_lr},
    {
        "params": itertools.chain(
            img_head.parameters(),
            text_head.parameters(),
        ),
        "lr": Config.head_lr,
    },
]

optimizer = torch.optim.Adam(parameters)

Some utilities to simplify the training.

In [87]:
def loss_fn(img_embed, text_embed, temperature=0.2):
    logits = (text_embed @ img_embed.T) / temperature
    img_sim = img_embed @ img_embed.T
    text_sim = text_embed @ text_embed.T
    targets = F.softmax(
        (img_sim + text_sim) / 2 * temperature, dim=-1
    )
    img_loss = (-targets.T * nn.LogSoftmax(dim=-1)(logits.T)).sum(1)
    text_loss = (-targets * nn.LogSoftmax(dim=-1)(logits)).sum(1)
    return (img_loss + text_loss) / 2.0

def forward(img, caption):
    # Transfer to device
    img = img.to('cuda')
    for k, v in caption.items():
        caption[k] = v.to('cuda')

    # Get embeddings for both img and caption
    img_embed = img_head(img_encoder(img))
    text_embed = text_head(text_encoder(caption))

    return img_embed, text_embed

In [86]:
dataset = CLIPLanceDataset(
    lance_path="flickr8k.lance",
    max_len=Config.max_len,
    tokenizer=tokenizer,
    transforms=train_augments
)

dataloader = DataLoader(
    dataset,
    shuffle=False,
    batch_size=Config.bs,
    pin_memory=True
)

In [91]:
# Train!
img_encoder.train()
img_head.train()
text_encoder.train()
text_head.train()

for epoch in range(Config.num_epochs):
    print(f"{'='*20} Epoch: {epoch+1} / {Config.num_epochs} {'='*20}")

    prog_bar = tqdm(dataloader)
    for img, caption in prog_bar:
        optimizer.zero_grad(set_to_none=True)

        img_embed, text_embed = forward(img, caption)
        loss = loss_fn(img_embed, text_embed, temperature=Config.temperature).mean()

        loss.backward()
        optimizer.step()

        prog_bar.set_description(f"loss: {loss.item():.4f}")
    print()



loss: 2.0799: 100%|██████████| 253/253 [02:14<00:00,  1.88it/s]





loss: 1.3064: 100%|██████████| 253/253 [02:10<00:00,  1.94it/s]





