# ViLT Model Exploration

In [1]:
from PIL import Image
# ViltModel is a raw model with no heads. Can use to define heads
from transformers import ViltProcessor, ViltModel
from torchvision.transforms.v2 import functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors

import skimage
import torch
import os
import numpy as np
import json
import torch.nn as nn

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from ipywidgets import FloatProgress
from sklearn.metrics import roc_auc_score, accuracy_score

# Define and Load Model and Processor

In [2]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")

In [3]:
processor.image_processor

ViltImageProcessor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "size_divisor",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "do_pad",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "do_normalize": true,
  "do_pad": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViltImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 384
  },
  "size_divisor": 32
}

In [4]:
test_processed = processor(Image.open('./dataset/img/98724.png').convert('RGB'),'funny meme')
test_processed['input_ids']

[101, 6057, 2033, 4168, 102]

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
# No separate text and visual encoders, combbined into one transformer encoder that processes both text and image patches in ViT style
model

ViltModel(
  (embeddings): ViltEmbeddings(
    (text_embeddings): TextEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(40, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (patch_embeddings): ViltPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    )
    (token_type_embeddings): Embedding(2, 768)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViltEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViltLayer(
        (attention): ViltAttention(
          (attention): ViltSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=Fa

# Define Model with Classification Head

In [7]:
class ViltForClassification(nn.Module):
    def __init__(self, vilt_base_model, mapping_num, map_dim, device):
        super().__init__()
        self.vilt = vilt_base_model
        self.mapping_num = mapping_num
        self.map_dim = map_dim
        self.device = device
        self.vilt.to(self.device)
        # Update all model parameters to be frozen, except for last few layers of certain parts
        for param in self.vilt.parameters():
            param.requires_grad = False
        # Unfreeze for last layer of encoder
        for param in self.vilt.encoder.layer[11].parameters():
            param.requires_grad = True
        # Unfreeze for final linear pooler layer
        for param in self.vilt.pooler.parameters():
            param.requires_grad = True
        # Series of Relu -> Linear -> Dropout layers
        map_layers = [nn.Linear(768, self.map_dim), nn.Dropout(0.2)]
        for _ in range(mapping_num):
            map_layers.extend([nn.ReLU(), nn.Linear(self.map_dim, self.map_dim), nn.Dropout(0.2)])
        self.map_layers = nn.Sequential(*map_layers)
        # Final layer, 2 classes hateful or non-hateful
        self.final_linear = nn.Linear(self.map_dim, 2)
    
    def forward(self, batch):
        # Unpack the processed image and text and feed to VILT model
        input_ids = batch['input_ids'].to(self.device)
        token_type_ids = batch['token_type_ids'].to(self.device)
        attn_mask = batch['attention_mask'].to(self.device)
        pixel_vals = batch['pixel_values'].to(self.device)
        pixel_mask = batch['pixel_mask'].to(self.device)
        # Remove additional dimension at 1, specify to not remove batch during inference
        model_outs = self.vilt(input_ids.squeeze(1), attn_mask.squeeze(1), token_type_ids.squeeze(1), pixel_vals.squeeze(1), pixel_mask.squeeze(1))
        # Feed model outs to  map_layers
        map_outs = self.map_layers(model_outs.pooler_output)
        # Final output
        output = self.final_linear(map_outs)
        return output
        

In [8]:
# Freeze layers
def prepare_model(model):
    # Freeze all
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze final layers in ViLT
    for param in model.vilt.encoder.layer[-1].parameters():
        param.requires_grad = True
    for param in model.vilt.layernorm.parameters():
        param.requires_grad = True
    for param in model.vilt.pooler.parameters():
        param.requires_grad = True
    # Unfreeze for map layers and final linear layer
    for param in model.map_layers.parameters():
        param.requires_grad = True
    for param in model.final_linear.parameters():
        param.requires_grad = True

# Define Dataset

In [9]:
class HatefulMemesDataset(Dataset):
    def __init__(self, root_dir, jsonl_path, transforms=None):
        self.root_dir = root_dir
        with open(os.path.join(self.root_dir, jsonl_path), 'r') as f:
            self.jsonl = list(f)
        self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        self.transforms = transforms
    
    def __len__(self):
        return len(self.jsonl)

    def __getitem__(self, i):
        json_str = self.jsonl[i]
        json_loaded = json.loads(json_str)
        img_path = os.path.join(self.root_dir, json_loaded['img'])
        label = torch.tensor(json_loaded['label'])
        caption = json_loaded['text']
        img = np.array(Image.open(img_path).convert('RGB')).transpose((2, 0, 1))
        img = tv_tensors.Image(img)
        if self.transforms is not None:
            img = self.transforms(img)
        img = np.array(img).transpose((1,2,0))
        # Need to truncate, pad to max length to ensure that no dataloader issues happen!
        processed = processor(img, caption, truncation=True, padding='max_length', return_tensors='pt')
        return {
            'input_ids': processed['input_ids'],
            'token_type_ids': processed['token_type_ids'],
            'attention_mask': processed['attention_mask'],
            'pixel_values': processed['pixel_values'],
            'pixel_mask': processed['pixel_mask'],
            'label':  label
        }

# Define Training Code

In [10]:
def finetune_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch_sample in tqdm(dataloader):
        label = batch_sample['label'].to(device)
        pred = model(batch_sample)
        loss = loss_fn(pred, label)
        # print(total_loss)
        correct = (torch.argmax(pred, dim = 1) == label).sum()
        acc = correct / label.shape[0]
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Add loss
        epoch_loss += loss.item() * label.shape[0]
        epoch_acc += acc.item() * label.shape[0]
    final_epoch_loss = epoch_loss / len(dataloader.sampler)
    final_epoch_acc = epoch_acc / len(dataloader.sampler)
    return final_epoch_loss, final_epoch_acc, model

In [11]:
# Finetune on label directly, ignore caption
def finetune_model(model, dataloader, optimizer, loss, epochs, device):
    model.to(device)
    # Keep track of train metrics
    train_loss_ls = []
    train_acc_ls = []
    for epoch in tqdm(range(1, epochs + 1)):
        # Train model
        epoch_loss, epoch_acc, model = finetune_one_epoch(model, dataloader, optimizer, loss, device)
        train_loss_ls.append(epoch_loss)
        train_acc_ls.append(epoch_acc)
        print(f"Epoch {epoch}\n")
        print(f"\ttrain_loss: {epoch_loss}\n")
        print(f"\ttrain_acc: {epoch_acc}\n")
    return model, train_loss_ls, train_acc_ls

# Define Model and Dataloaders

In [12]:
classification_model = ViltForClassification(model, 3, 38, 'cuda')
classification_model

ViltForClassification(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, 

In [13]:
prepare_model(classification_model)

In [14]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classification_model.parameters()), lr  = 5e-5)
loss = nn.CrossEntropyLoss()

In [15]:
train_dataset = HatefulMemesDataset('./dataset', 'train.jsonl', transforms=v2.Compose([v2.CenterCrop((800, 800))]))
train_dataloader = DataLoader(
    train_dataset, batch_size=128
)

In [16]:
finetuned_model, train_losses, train_accs = finetune_model(classification_model, train_dataloader, optimizer, loss, 10, device)

100%|██████████| 67/67 [04:46<00:00,  4.28s/it]
 10%|█         | 1/10 [04:46<43:00, 286.69s/it]

Epoch 1

	train_loss: 0.6553485418487998

	train_acc: 0.6444705882352941



100%|██████████| 67/67 [04:40<00:00,  4.19s/it]
 20%|██        | 2/10 [09:27<37:45, 283.16s/it]

Epoch 2

	train_loss: 0.5996144534559811

	train_acc: 0.6472941176558242



100%|██████████| 67/67 [04:48<00:00,  4.31s/it]
 30%|███       | 3/10 [14:16<33:19, 285.71s/it]

Epoch 3

	train_loss: 0.5617582945543177

	train_acc: 0.6782352943139918



100%|██████████| 67/67 [04:38<00:00,  4.16s/it]
 40%|████      | 4/10 [18:54<28:17, 282.90s/it]

Epoch 4

	train_loss: 0.5322473987972035

	train_acc: 0.7272941176751081



100%|██████████| 67/67 [04:41<00:00,  4.20s/it]
 50%|█████     | 5/10 [23:35<23:31, 282.28s/it]

Epoch 5

	train_loss: 0.4926264838330886

	train_acc: 0.7638823532777674



100%|██████████| 67/67 [04:43<00:00,  4.22s/it]
 60%|██████    | 6/10 [28:18<18:50, 282.54s/it]

Epoch 6

	train_loss: 0.45623255095762366

	train_acc: 0.7917647059384514



100%|██████████| 67/67 [04:55<00:00,  4.41s/it]
 70%|███████   | 7/10 [33:14<14:20, 286.85s/it]

Epoch 7

	train_loss: 0.42330831261242136

	train_acc: 0.82035294156916



100%|██████████| 67/67 [05:12<00:00,  4.66s/it]
 80%|████████  | 8/10 [38:26<09:49, 294.88s/it]

Epoch 8

	train_loss: 0.3896473267779631

	train_acc: 0.8443529413728152



100%|██████████| 67/67 [04:53<00:00,  4.38s/it]
 90%|█████████ | 9/10 [43:19<04:54, 294.38s/it]

Epoch 9

	train_loss: 0.35273694586753845

	train_acc: 0.867411764902227



100%|██████████| 67/67 [05:18<00:00,  4.75s/it]
100%|██████████| 10/10 [48:38<00:00, 291.86s/it]

Epoch 10

	train_loss: 0.3234019802458146

	train_acc: 0.8827058826334336






# Evaluate

In [17]:
def test_model_capability(model, jsonl_path):
    labels = []
    model_probs = []
    model_preds = []
    # 0 -> Non-hateful, 1 -> Hateful
    with open(jsonl_path, 'r') as json_f:
        json_list = list(json_f)
    for json_str in tqdm(json_list):
        result = json.loads(json_str)
        img_path, label, caption = result['img'], result['label'], result['text']
        labels.append(label)
        # Read image
        batch = processor(Image.open(os.path.join('./dataset', img_path)).convert('RGB'), caption, truncation=True, padding='max_length', return_tensors='pt')
        with torch.no_grad():
            out = model(batch)
            probs = out.softmax(dim=-1).cpu().numpy()
            class_1_prob = probs[0][1]
            model_probs.append(class_1_prob)
            model_preds.append(np.argmax(probs))
    return roc_auc_score(labels, model_probs), accuracy_score(labels, model_preds)

In [18]:
# test_model_capability(finetuned_model, './dataset/dev_seen.jsonl')

In [19]:
test_model_capability(finetuned_model, './dataset/dev_unseen.jsonl')

100%|██████████| 540/540 [00:17<00:00, 31.20it/s]


(0.6508088235294118, 0.6592592592592592)

In [20]:
test_model_capability(finetuned_model, './dataset/test_unseen.jsonl')

100%|██████████| 2000/2000 [01:00<00:00, 33.27it/s]


(0.6675093333333333, 0.6815)

In [21]:
torch.save(finetuned_model.state_dict(), 'models/vilt-b32-mlm_map_num_3_10_epoch_5e-5.pt')