# CLIP Modeling

## 0. imports

In [1]:
%load_ext jupyter_black

In [2]:
import sys

sys.path.append("..")

In [4]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import timm

import transformers
from transformers import AutoModel

In [5]:
from src.dataset.datamodule import CLIPDataModule

## 1. CLIP DataModule

In [6]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

dm_params = {
    "data_path": "../data/captions.csv",
    "img_dir": "../data/images/",
    "tokenizer_name": "bert-base-uncased",
    "img_size": 224,
    "txt_max_length": 200,
    "val_size": 0.2,
    "test_size": 0.2,
    "batch_size": 32,
    "num_workers": 4,
}

dm = CLIPDataModule(**dm_params)

In [7]:
batch = next(iter(dm.train_dataloader()))

## 2. Encoder

### 2.1 Image Encoder

In [8]:
class ImageEncoder(nn.Module):
    def __init__(
        self, model_name: str, use_pretrained: bool = True, is_trainable: bool = True
    ):
        super().__init__()

        self.model_name = model_name
        self.use_pretrained = use_pretrained
        self.is_trainable = is_trainable

        # img encoer init
        self.model = timm.create_model(
            model_name, num_classes=0, global_pool="avg", pretrained=use_pretrained
        )

        if not self.is_trainable:
            for parameter in self.model.parameters():
                parameter.requires_grad = self.is_trainable

    def forward(self, img: torch.Tensor):
        return self.model(img)

### 2.2 Text Encoder

In [9]:
class TextEncoder(nn.Module):
    def __init__(
        self, model_name: str, use_pretrained: bool = True, is_trainable: bool = True
    ):
        super().__init__()

        self.model_name = model_name
        self.use_pretrained = use_pretrained
        self.is_trainable = is_trainable
        self.cls_token_idx = 0

        if use_pretrained:
            self.model = AutoModel.from_pretrained(model_name)
        else:
            raise NotImplementedError

        if not self.is_trainable:
            for parameter in self.model.parameters():
                parameter.requires_grad = self.is_trainable

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.cls_token_idx, :]

### 2.3 Projection Head

In [10]:
class ProjectionHead(nn.Module):
    """
    ref:  https://github.com/h-albert-lee/G-CLIP/blob/master/modules.py

    TODO: img encoder에 layer_norm 적절한지?
    """

    def __init__(self, embedding_dim: int, projection_dim: int, dropout: float):
        super().__init__()

        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

## 3. CLIP Model

### 3.1 Line by Line

In [19]:
is_trainable = True
use_pretrained = True

# img encoder
img_model_name = "vit_base_patch16_224"
img_embedding = 768

# text encoder
text_model_name = "bert-base-uncased"
text_embedding = 768

# projection head
projection_dim = 256
dropout = 0.1

# clip
temperature = 1.0

In [20]:
img_encoder = ImageEncoder(img_model_name)
text_encoder = TextEncoder(text_model_name)

img_projection = ProjectionHead(
    embedding_dim=img_embedding, projection_dim=projection_dim, dropout=dropout
)
text_projection = ProjectionHead(
    embedding_dim=text_embedding, projection_dim=projection_dim, dropout=dropout
)

In [21]:
# img & text features from encoder
img_features = img_encoder(batch["image"])
text_features = text_encoder(batch["input_ids"], batch["attention_mask"])

In [22]:
# img & text embedding from projection head
img_embeddings = img_projection(img_features)
text_embeddings = text_projection(text_features)

In [23]:
# Calculating the Loss
logits = (text_embeddings @ img_embeddings.T) / temperature
images_similarity = img_embeddings @ img_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T

In [24]:
targets = F.softmax((images_similarity + texts_similarity) / 2 * temperature, dim=-1)
texts_loss = F.cross_entropy(logits, targets, reduction="none")
images_loss = F.cross_entropy(logits.T, targets.T, reduction="none")
loss = (images_loss + texts_loss) / 2.0  # shape: (batch_size)

In [25]:
loss.mean()

tensor(13.3602, grad_fn=<MeanBackward0>)

### 3.2 CLIP Model

In [26]:
class CLIP(nn.Module):
    def __init__(
        self,
        img_model_name: str,
        text_model_name: str,
        temperature: float,
        img_embedding: int,
        text_embedding: int,
        projection_dim: int,
        dropout: float,
        is_trainable: bool = True,
        use_pretrained: bool = True,
    ):
        super().__init__()

        self.img_model_name = img_model_name
        self.text_model_name = text_model_name
        self.temperature = temperature

        self.img_encoder = ImageEncoder(img_model_name, use_pretrained, is_trainable)
        self.text_encoder = TextEncoder(text_model_name, use_pretrained, is_trainable)

        self.img_projection = ProjectionHead(
            embedding_dim=img_embedding, projection_dim=projection_dim, dropout=dropout
        )
        self.text_projection = ProjectionHead(
            embedding_dim=text_embedding, projection_dim=projection_dim, dropout=dropout
        )

    def forward(self, batch: dict[str, torch.Tensor]):
        # img & text features from encoder
        img_features = self.img_encoder(batch["image"])
        text_features = self.text_encoder(batch["input_ids"], batch["attention_mask"])

        # img & text embedding from projection head
        img_embeddings = self.img_projection(img_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ img_embeddings.T) / self.temperature
        imgs_similarity = img_embeddings @ img_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax(
            (imgs_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = F.cross_entropy(logits, targets, reduction="none")
        imgs_loss = F.cross_entropy(logits.T, targets.T, reduction="none")
        loss = (imgs_loss + texts_loss) / 2.0

        return {
            "loss": loss.mean(),
            "img_embeddings": img_embeddings,
            "text_embeddings": text_embeddings,
        }

In [27]:
model_params = {
    "is_trainable": True,
    "use_pretrained": True,
    # img encoder
    "img_model_name": "vit_base_patch16_224",
    "img_embedding": 768,
    # text encoder
    "text_model_name": "bert-base-uncased",
    "text_embedding": 768,
    # projection head
    "projection_dim": 256,
    "dropout": 0.1,
    # clip
    "temperature": 1.0,
}

model = CLIP(**model_params)

In [28]:
outputs = model(batch)

In [29]:
outputs["loss"]

tensor(19.7848, grad_fn=<MeanBackward0>)

In [30]:
outputs["img_embeddings"].shape

torch.Size([32, 256])

In [31]:
outputs["text_embeddings"].shape

torch.Size([32, 256])