In [1]:
# Data is download from https://www.kaggle.com/datasets/adityajn105/flickr8k?resource=download-directory

In [2]:
DATA_FLODER = "../Data"
PATH_captions = "../Data/captions.txt"
PATH_IMAGES = "../Data/Images"

In [3]:
import pandas as pd

df = pd.read_csv(PATH_captions)
df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]
df.to_csv("captions.csv", index=False)
df = pd.read_csv("captions.csv")

df.sample(5)

Unnamed: 0,image,caption,id
7496,2308108566_2cba6bca53.jpg,A person biking through the woods .,1499
21565,3169394115_2193158cee.jpg,A dog stands on a bench in the snow .,4313
16429,2892395757_0a1b0eedd2.jpg,Man wearing helmet and racing gear on a bicycle,3285
11996,260231029_966e2f1727.jpg,A black dog plays with a brown dog on the sand .,2399
9632,2448793019_5881c025f9.jpg,a young child in blue jeans sliding down a blu...,1926


# Config

In [4]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
batch_size = 32
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 200

In [7]:
pretrained = True   # for both image encoder and text encoder
trainable = True    # for both image encoder and text encoder
temperature = 1.0
size = 224

In [8]:
num_projection_layers = 1
projection_dim = 256 
dropout = 0.1

In [9]:
debug = False

# Utils

In [10]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

# Dataset

In [13]:
import cv2
import np
import tqdm
import albumentations as A          # a fast data aug libary
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

ModuleNotFoundError: No module named 'albumentations'

In [None]:
def make_train_valid_dfs():
    dataframe = pd.read_csv("captions.csv")

    max_id = dataframe["id"].max() + 1 if not debug else 100

    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe

In [None]:
train_df, valid_df = make_train_valid_dfs()

In [None]:
train_df.head()

In [None]:
valid_df.head()

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{PATH_IMAGES}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]

        return item

    def __len__(self):
        return len(self.captions)


def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(size, size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(size, size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

train_dataset = CLIPDataset(train_df["image"].values, train_df["caption"].values,
                            tokenizer=tokenizer, transforms=get_transforms(mode="train"))

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                         num_workers=0, shuffle=True)

In [None]:
tqdm_object = tqdm(train_dl, total=len(train_dl))
for batch in tqdm_object:
    batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
    print(batch["image"].size())
    print(batch["input_ids"].size())
    break

In [None]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=temperature,
        image_embedding=image_embedding,
        text_embedding=text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

train_dataset = CLIPDataset(train_df["image"].values, train_df["caption"].values,
                            tokenizer=tokenizer, transforms=get_transforms(mode="train"))

val_dataset = CLIPDataset(valid_df["image"].values, valid_df["caption"].values, 
                            tokenizer=tokenizer, transforms=get_transforms(mode="val"))

train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                         num_workers=0, shuffle=True)

val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                         num_workers=0, shuffle=False)

In [None]:
model = CLIPModel().to(CFG.device)

params = [
        {"params": model.image_encoder.parameters(), "lr": image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(), 
            model.text_projection.parameters()
        ), "lr": head_lr, "weight_decay": weight_decay}
    ]

optimizer = torch.optim.AdamW(params, weight_decay=0.)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=patience, factor=factor)

In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}
        
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter

In [None]:
for epoch in range(epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train_epoch(model, train_dl, optimizer, lr_scheduler, step)