# Finetuning CLIP Model

In [1]:
import torch
import os
import numpy as np
import clip
from PIL import Image
import json
import skimage
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision import tv_tensors

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device is {device}")
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

Device is cuda


# Define Dataset and Dataloader

In [2]:
class HatefulMemesDataset(Dataset):
    def __init__(self, root_dir, jsonl_path, transforms=None):
        self.root_dir = root_dir
        with open(os.path.join(self.root_dir, jsonl_path), 'r') as f:
            self.jsonl = list(f)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.jsonl)

    def __getitem__(self, i):
        json_str = self.jsonl[i]
        json_loaded = json.loads(json_str)
        img_path = os.path.join(self.root_dir, json_loaded['img'])
        label = torch.tensor(json_loaded['label'])
        # caption = json_loaded['text']
        # Directly use non-hateful and hateful as the text we are classifying image to
        text = clip.tokenize(["non-hateful", "hateful"])
        img = Image.open(img_path)
        img = preprocess(img)
        # Wrap in tv_tensors
        img = tv_tensors.Image(img)
        if self.transforms is not None:
            img = self.transforms(img)
        return img, text, label

# Finetuning code

Finetune without changing any part of the model

## Finetune Functions

In [3]:
# Freeze layers
def prepare_model(model):
    # Freeze all
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze layers in image and text embedding transformer networks
    for param in model.visual.transformer.resblocks[11].parameters():
        param.requires_grad = True
    for param in model.transformer.resblocks[11].parameters():
        param.requires_grad = True
    # Unfreeze model token embedding layer
    for param in model.token_embedding.parameters():
        param.requires_grad = True
    # Layer norm laers unfreeze
    for param in model.visual.ln_post.parameters():
        param.requires_grad = True
    for param in model.ln_final.parameters():
        param.requires_grad = True

In [4]:
prepare_model(model)
# Model needs to be in float, otherwise NAN for loss and training does not work beyond first iteration
model.float()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [5]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr  = 1e-4)
loss = nn.CrossEntropyLoss()

In [6]:
def finetune_clip_one_epoch(model, dataloader, optimizer, loss, text, device):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch_sample in tqdm(dataloader):
        img, _, label = batch_sample
        img = img.to(device)
        text = text.to(device)
        label = label.to(device)
        # Get logits
        logit_image, logit_text = model(img, text)
        # Get individual and total losses
        image_loss = loss(logit_image, label)
        text_loss = loss(torch.transpose(logit_text, 0, 1), label)
        total_loss = (image_loss + text_loss) / 2
        # print(total_loss)
        # Get the number of correct predictions
        correct = (torch.argmax(logit_image, dim = 1) == label).sum()
        acc = correct / img.shape[0]
        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        # Add loss
        epoch_loss += total_loss.item() * img.shape[0]
        epoch_acc += acc.item() * img.shape[0]
    final_epoch_loss = epoch_loss / len(dataloader.sampler)
    final_epoch_acc = epoch_acc / len(dataloader.sampler)
    return final_epoch_loss, final_epoch_acc, model

In [7]:
# Finetune on label directly, ignore caption
def finetune_model(model, dataloader, optimizer, loss, epochs, device):
    model.to(device)
    # Keep track of train metrics
    train_loss_ls = []
    train_acc_ls = []
    text = clip.tokenize(["non-hateful", "hateful"])
    for epoch in tqdm(range(1, epochs + 1)):
        # Train model
        epoch_loss, epoch_acc, model = finetune_clip_one_epoch(model, dataloader, optimizer, loss, text, device)
        train_loss_ls.append(epoch_loss)
        train_acc_ls.append(epoch_acc)
        print(f"Epoch {epoch}\n")
        print(f"\ttrain_loss: {epoch_loss}\n")
        print(f"\ttrain_acc: {epoch_acc}\n")
    return model, train_loss_ls, train_acc_ls


In [8]:
train_dataset = HatefulMemesDataset('./dataset', 'train.jsonl')

In [9]:
len(train_dataset)

8500

In [10]:
train_dataloader = DataLoader(
    train_dataset, batch_size=128
)

In [11]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x2561f261890>

## Finetune

In [12]:
finetuned_model, train_losses, train_accs = finetune_model(model, train_dataloader, optimizer, loss, 10, device)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
100%|██████████| 67/67 [01:54<00:00,  1.71s/it]
 10%|█         | 1/10 [01:54<17:11, 114.66s/it]

Epoch 1

	train_loss: 0.620481564802282

	train_acc: 0.6847058825492859



100%|██████████| 67/67 [01:51<00:00,  1.66s/it]
 20%|██        | 2/10 [03:46<15:02, 112.78s/it]

Epoch 2

	train_loss: 0.5174343039007748

	train_acc: 0.7423529412325691



100%|██████████| 67/67 [01:46<00:00,  1.59s/it]
 30%|███       | 3/10 [05:32<12:49, 109.89s/it]

Epoch 3

	train_loss: 0.3814533473954481

	train_acc: 0.8278823532216689



100%|██████████| 67/67 [01:45<00:00,  1.58s/it]
 40%|████      | 4/10 [07:18<10:49, 108.23s/it]

Epoch 4

	train_loss: 0.2134525564137627

	train_acc: 0.9084705885157865



100%|██████████| 67/67 [01:48<00:00,  1.61s/it]
 50%|█████     | 5/10 [09:06<09:00, 108.18s/it]

Epoch 5

	train_loss: 0.17561840916381163

	train_acc: 0.9287058826334337



100%|██████████| 67/67 [01:45<00:00,  1.57s/it]
 60%|██████    | 6/10 [10:51<07:08, 107.19s/it]

Epoch 6

	train_loss: 0.15109234586401898

	train_acc: 0.9381176470588235



100%|██████████| 67/67 [01:45<00:00,  1.57s/it]
 70%|███████   | 7/10 [12:37<05:19, 106.60s/it]

Epoch 7

	train_loss: 0.11414949060976505

	train_acc: 0.9543529411764706



100%|██████████| 67/67 [01:46<00:00,  1.59s/it]
 80%|████████  | 8/10 [14:23<03:33, 106.53s/it]

Epoch 8

	train_loss: 0.07431371414310792

	train_acc: 0.9689411764705882



100%|██████████| 67/67 [01:53<00:00,  1.70s/it]
 90%|█████████ | 9/10 [16:17<01:48, 108.78s/it]

Epoch 9

	train_loss: 0.05464837867372176

	train_acc: 0.9747058823529412



100%|██████████| 67/67 [01:59<00:00,  1.78s/it]
100%|██████████| 10/10 [18:16<00:00, 109.62s/it]

Epoch 10

	train_loss: 0.04104952521622181

	train_acc: 0.9781176470588235






# Assess on dev_unseen

In [13]:
from sklearn.metrics import roc_auc_score, accuracy_score

In [21]:
def test_model_zero_shot_capability(model, jsonl_path):
    labels = []
    model_probs = []
    model_preds = []
    # 0 -> Non-hateful, 1 -> Hateful
    text = clip.tokenize(["non-hateful", "hateful"]).to(device)
    with open(jsonl_path, 'r') as json_f:
        json_list = list(json_f)
    for json_str in tqdm(json_list):
        result = json.loads(json_str)
        img_path, label = result['img'], result['label']
        labels.append(label)
        # Read image
        image = preprocess(Image.open(os.path.join('./dataset', img_path))).unsqueeze(0).to(device)
        with torch.no_grad():
            logits_per_image, _ = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
            # Get larger label probability
            class_1_prob = probs[0][1]
            model_probs.append(class_1_prob)
            model_preds.append(np.argmax(probs))
    return roc_auc_score(labels, model_probs), accuracy_score(labels, model_preds)


In [22]:
dev_roc, dev_acc = test_model_zero_shot_capability(finetuned_model, './dataset/dev_unseen.jsonl')
print(f'Dev ROC: {dev_roc}')
print(f'Dev Accuracy: {dev_acc}')

100%|██████████| 540/540 [00:14<00:00, 37.68it/s]

Dev ROC: 0.678470588235294
Dev Accuracy: 0.6648148148148149





In [16]:
dev_roc, dev_acc = test_model_zero_shot_capability(finetuned_model, './dataset/dev_seen.jsonl')
print(f'Dev ROC: {dev_roc}')
print(f'Dev Accuracy: {dev_acc}')

100%|██████████| 500/500 [00:13<00:00, 36.51it/s]

Dev ROC: 0.700788913603559
Dev Accuracy: 0.636





In [17]:
test_roc, test_acc = test_model_zero_shot_capability(finetuned_model, './dataset/test_seen.jsonl')
print(f'test ROC: {test_roc}')
print(f'test Accuracy: {test_acc}')

100%|██████████| 1000/1000 [00:27<00:00, 36.54it/s]

test ROC: 0.682078831532613
test Accuracy: 0.615





In [18]:
test_roc, test_acc = test_model_zero_shot_capability(finetuned_model, './dataset/test_unseen.jsonl')
print(f'test ROC: {test_roc}')
print(f'test Accuracy: {test_acc}')

100%|██████████| 2000/2000 [00:50<00:00, 39.48it/s]

test ROC: 0.7010842666666666
test Accuracy: 0.675



