# Finetuning BLIP Model

In [1]:
import torch
import os
import numpy as np
# import clip
from PIL import Image
import json
import skimage
import torch.nn as nn

from transformers import AutoProcessor, BlipModel
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision import tv_tensors

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device is {device}")
# model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

Device is cuda


In [2]:
model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-large")
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model.to(device)

Some weights of BlipModel were not initialized from the model checkpoint at Salesforce/blip-image-captioning-large and are newly initialized: ['logit_scale', 'text_model.embeddings.LayerNorm.bias', 'text_model.embeddings.LayerNorm.weight', 'text_model.embeddings.position_embeddings.weight', 'text_model.embeddings.word_embeddings.weight', 'text_model.encoder.layer.0.attention.output.LayerNorm.bias', 'text_model.encoder.layer.0.attention.output.LayerNorm.weight', 'text_model.encoder.layer.0.attention.output.dense.bias', 'text_model.encoder.layer.0.attention.output.dense.weight', 'text_model.encoder.layer.0.attention.self.key.bias', 'text_model.encoder.layer.0.attention.self.key.weight', 'text_model.encoder.layer.0.attention.self.query.bias', 'text_model.encoder.layer.0.attention.self.query.weight', 'text_model.encoder.layer.0.attention.self.value.bias', 'text_model.encoder.layer.0.attention.self.value.weight', 'text_model.encoder.layer.0.crossattention.output.LayerNorm.bias', 'text_model

BlipModel(
  (text_model): BlipTextModel(
    (embeddings): BlipTextEmbeddings(
      (word_embeddings): Embedding(30524, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BlipTextEncoder(
      (layer): ModuleList(
        (0-11): 12 x BlipTextLayer(
          (attention): BlipTextAttention(
            (self): BlipTextSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BlipTextSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
         

# Define Dataset and Dataloader

In [3]:
class HatefulMemesDataset(Dataset):
    def __init__(self, root_dir, jsonl_path, transforms=None):
        self.root_dir = root_dir
        with open(os.path.join(self.root_dir, jsonl_path), 'r') as f:
            self.jsonl = list(f)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.jsonl)

    def __getitem__(self, i):
        json_str = self.jsonl[i]
        json_loaded = json.loads(json_str)
        img_path = os.path.join(self.root_dir, json_loaded['img'])
        label = torch.tensor(json_loaded['label'])
        # caption = json_loaded['text']
        # Directly use non-hateful and hateful as the text we are classifying image to
        # text = clip.tokenize(["non-hateful", "hateful"])
        img = np.array(Image.open(img_path).convert('RGB')).transpose((2, 0, 1))
        # img = preprocess(img)
        # Wrap in tv_tensors
        img = tv_tensors.Image(img)
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label

# Finetuning code

Finetune without changing any part of the model

## Finetune Functions

In [4]:
model

BlipModel(
  (text_model): BlipTextModel(
    (embeddings): BlipTextEmbeddings(
      (word_embeddings): Embedding(30524, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BlipTextEncoder(
      (layer): ModuleList(
        (0-11): 12 x BlipTextLayer(
          (attention): BlipTextAttention(
            (self): BlipTextSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BlipTextSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
         

In [5]:
# Freeze layers
def prepare_model(model):
    # Freeze all
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze layers in image and text embedding transformer networks
    for param in model.vision_model.encoder.layers[-1].parameters():
        param.requires_grad = True
    for param in model.text_model.encoder.layer[-1].parameters():
        param.requires_grad = True
    # Final layers unfreeze
    for param in model.visual_projection.parameters():
        param.requires_grad = True
    for param in model.text_projection.parameters():
        param.requires_grad = True
    # Layer norm layers unfreeze
    for param in model.vision_model.post_layernorm.parameters():
        param.requires_grad = True
    # Pooler layer in text_model
    for param in model.text_model.pooler.parameters():
        param.requires_grad = True
    # # Unfreeze layers in image and text embedding transformer networks
    # for param in model.visual.transformer.resblocks[11].parameters():
    #     param.requires_grad = True
    # for param in model.transformer.resblocks[11].parameters():
    #     param.requires_grad = True
    # # Unfreeze model token embedding layer
    # for param in model.token_embedding.parameters():
    #     param.requires_grad = True
    # # Layer norm layers unfreeze
    # for param in model.visual.ln_post.parameters():
    #     param.requires_grad = True
    # for param in model.ln_final.parameters():
    #     param.requires_grad = True

In [6]:
prepare_model(model)
# Model needs to be in float, otherwise NAN for loss and training does not work beyond first iteration
# model.float()

In [7]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr  = 1e-4)
loss = nn.CrossEntropyLoss()

In [8]:
def finetune_blip_one_epoch(model, processor, dataloader, optimizer, loss, device):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch_sample in tqdm(dataloader):
        img, label = batch_sample
        # Classify for non-hateful vs hateful
        inputs = processor(text=["non-hateful", "hateful"], images = img, return_tensors='pt', padding=True).to(device)
        label = label.to(device)
        # Get logits
        outputs = model(**inputs)
        logit_image = outputs.logits_per_image
        logit_text = outputs.logits_per_text
        # Get individual and total losses
        image_loss = loss(logit_image, label)
        text_loss = loss(torch.transpose(logit_text, 0, 1), label)
        total_loss = (image_loss + text_loss) / 2
        # Get the number of correct predictions
        correct = (torch.argmax(logit_image, dim = 1) == label).sum()
        acc = correct / img.shape[0]
        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        # Add loss
        epoch_loss += total_loss.item() * img.shape[0]
        epoch_acc += acc.item() * img.shape[0]
    final_epoch_loss = epoch_loss / len(dataloader.sampler)
    final_epoch_acc = epoch_acc / len(dataloader.sampler)
    return final_epoch_loss, final_epoch_acc, model

In [9]:
# Finetune on label directly, ignore caption
def finetune_model(model, processor, dataloader, optimizer, loss, epochs, device):
    model.to(device)
    # Keep track of train metrics
    train_loss_ls = []
    train_acc_ls = []
    # text = clip.tokenize(["non-hateful", "hateful"])
    for epoch in tqdm(range(1, epochs + 1)):
        # Train model
        epoch_loss, epoch_acc, model = finetune_blip_one_epoch(model, processor, dataloader, optimizer, loss, device)
        train_loss_ls.append(epoch_loss)
        train_acc_ls.append(epoch_acc)
        print(f"Epoch {epoch}\n")
        print(f"\ttrain_loss: {epoch_loss}\n")
        print(f"\ttrain_acc: {epoch_acc}\n")
    return model, train_loss_ls, train_acc_ls


In [10]:
train_dataset = HatefulMemesDataset('./dataset', 'train.jsonl', transforms=v2.Compose([v2.CenterCrop((800, 800))]))

In [11]:
len(train_dataset)

8500

In [12]:
train_dataloader = DataLoader(
    train_dataset, batch_size=32
)

In [13]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x2b46dc05d50>

## Finetune

In [14]:
finetuned_model, train_losses, train_accs = finetune_model(model, processor, train_dataloader, optimizer, loss, 10, device)

100%|██████████| 266/266 [13:19<00:00,  3.01s/it]
 10%|█         | 1/10 [13:19<1:59:54, 799.34s/it]

Epoch 1

	train_loss: 0.5621032547530006

	train_acc: 0.7118823529411765



100%|██████████| 266/266 [13:16<00:00,  2.99s/it]
 20%|██        | 2/10 [26:35<1:46:22, 797.75s/it]

Epoch 2

	train_loss: 0.5205558384867276

	train_acc: 0.7531764705882353



100%|██████████| 266/266 [13:32<00:00,  3.06s/it]
 30%|███       | 3/10 [40:08<1:33:52, 804.71s/it]

Epoch 3

	train_loss: 0.43619851579736263

	train_acc: 0.7978823529411765



100%|██████████| 266/266 [14:16<00:00,  3.22s/it]
 40%|████      | 4/10 [54:25<1:22:30, 825.05s/it]

Epoch 4

	train_loss: 0.31385113758900585

	train_acc: 0.8675294117647059



100%|██████████| 266/266 [13:46<00:00,  3.11s/it]
 50%|█████     | 5/10 [1:08:11<1:08:47, 825.50s/it]

Epoch 5

	train_loss: 0.21621458967994242

	train_acc: 0.9184705883194418



100%|██████████| 266/266 [13:28<00:00,  3.04s/it]
 60%|██████    | 6/10 [1:21:40<54:39, 819.80s/it]  

Epoch 6

	train_loss: 0.15716930064482285

	train_acc: 0.9441176470588235



100%|██████████| 266/266 [13:24<00:00,  3.03s/it]
 70%|███████   | 7/10 [1:35:05<40:44, 814.95s/it]

Epoch 7

	train_loss: 0.11142677786946296

	train_acc: 0.9615294117647059



100%|██████████| 266/266 [13:59<00:00,  3.16s/it]
 80%|████████  | 8/10 [1:49:05<27:25, 822.90s/it]

Epoch 8

	train_loss: 0.08980157093265477

	train_acc: 0.9694117647058823



100%|██████████| 266/266 [13:09<00:00,  2.97s/it]
 90%|█████████ | 9/10 [2:02:14<13:32, 812.54s/it]

Epoch 9

	train_loss: 0.08111194276037242

	train_acc: 0.9730588235294118



100%|██████████| 266/266 [12:54<00:00,  2.91s/it]
100%|██████████| 10/10 [2:15:09<00:00, 810.96s/it]

Epoch 10

	train_loss: 0.08041742061352466

	train_acc: 0.9736470588235294






# Assess on dev_unseen

In [15]:
from sklearn.metrics import roc_auc_score, accuracy_score

In [16]:
def test_model_zero_shot_capability(model, processor, device, jsonl_path):
    labels = []
    model_probs = []
    model_preds = []
    # 0 -> Non-hateful, 1 -> Hateful
    # text = clip.tokenize(["non-hateful", "hateful"]).to(device)
    with open(jsonl_path, 'r') as json_f:
        json_list = list(json_f)
    for json_str in tqdm(json_list):
        result = json.loads(json_str)
        img_path, label = result['img'], result['label']
        labels.append(label)
        # image = preprocess(Image.open(os.path.join('./dataset', img_path))).unsqueeze(0).to(device)
        with torch.no_grad():
            inputs = processor(text=['non-hateful', 'hateful'], images=Image.open(os.path.join('./dataset', img_path)), return_tensors='pt', padding=True).to(device)
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
            class_1_prob = probs[0][1]
            model_probs.append(class_1_prob)
            model_preds.append(np.argmax(probs))
    return roc_auc_score(labels, model_probs), accuracy_score(labels, model_preds)


In [17]:
# Dev ROC: 0.678470588235294
# Dev Accuracy: 0.6648148148148149
dev_roc, dev_acc = test_model_zero_shot_capability(finetuned_model, processor, device, './dataset/dev_unseen.jsonl')
print(f'Dev ROC: {dev_roc}')
print(f'Dev Accuracy: {dev_acc}')

  0%|          | 0/540 [00:00<?, ?it/s]

100%|██████████| 540/540 [00:45<00:00, 11.78it/s]

Dev ROC: 0.6740441176470588
Dev Accuracy: 0.6648148148148149





In [18]:

test_roc, test_acc = test_model_zero_shot_capability(finetuned_model, processor, device, './dataset/test_unseen.jsonl')
print(f'test ROC: {test_roc}')
print(f'test Accuracy: {test_acc}')

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [02:49<00:00, 11.79it/s]

test ROC: 0.6972655999999999
test Accuracy: 0.679





In [19]:
torch.save(finetuned_model.state_dict(), 'models/blip-image-captioning-large_10_epoch_1e-4.pt')