# ViLT Model Exploration

In [1]:
from PIL import Image
# ViltModel is a raw model with no heads. Can use to define heads
from transformers import ViltProcessor, ViltModel
from torchvision.transforms.v2 import functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors

import skimage
import torch
import os
import numpy as np
import json
import torch.nn as nn

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from ipywidgets import FloatProgress
from sklearn.metrics import roc_auc_score, accuracy_score

# Define and Load Model and Processor

In [2]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")

In [3]:
processor.image_processor

ViltImageProcessor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "size_divisor",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "do_pad",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "do_normalize": true,
  "do_pad": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViltImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 384
  },
  "size_divisor": 32
}

In [4]:
test_processed = processor(Image.open('./dataset/img/98724.png').convert('RGB'),'funny meme')
test_processed['input_ids']

[101, 6057, 2033, 4168, 102]

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
# No separate text and visual encoders, combbined into one transformer encoder that processes both text and image patches in ViT style
model

ViltModel(
  (embeddings): ViltEmbeddings(
    (text_embeddings): TextEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(40, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (patch_embeddings): ViltPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    )
    (token_type_embeddings): Embedding(2, 768)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViltEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViltLayer(
        (attention): ViltAttention(
          (attention): ViltSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=Fa

# Define Model with Classification Head

In [7]:
class ViltForClassification(nn.Module):
    def __init__(self, vilt_base_model, device):
        super().__init__()
        self.vilt = vilt_base_model
        self.device = device
        self.vilt.to(self.device)
        # Update all model parameters to be frozen, except for last few layers of certain parts
        for param in self.vilt.parameters():
            param.requires_grad = False
        # Unfreeze for last layer of encoder
        for param in self.vilt.encoder.layer[11].parameters():
            param.requires_grad = True
        # Unfreeze for final linear pooler layer
        for param in self.vilt.pooler.parameters():
            param.requires_grad = True
        # First linear layer
        self.relu1 = nn.ReLU()
        self.linear = nn.Linear(768, 38)
        self.relu2 = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)
        # Final layer, 2 classes hateful or non-hateful
        self.linear2 = nn.Linear(38, 2)
    
    def forward(self, batch):
        # Unpack the processed image and text and feed to VILT model
        input_ids = batch['input_ids'].to(self.device)
        token_type_ids = batch['token_type_ids'].to(self.device)
        attn_mask = batch['attention_mask'].to(self.device)
        pixel_vals = batch['pixel_values'].to(self.device)
        pixel_mask = batch['pixel_mask'].to(self.device)
        # Remove additional dimension at 1, specify to not remove batch during inference
        model_outs = self.vilt(input_ids.squeeze(1), attn_mask.squeeze(1), token_type_ids.squeeze(1), pixel_vals.squeeze(1), pixel_mask.squeeze(1))
        # Feed model outs to ReLU
        relu_model_outs = self.relu1(model_outs.pooler_output)
        # Linear layer and ReLU
        output_1 = self.linear(relu_model_outs)
        relu_output_1 = self.relu2(output_1)
        output = self.linear2(relu_output_1)
        return output
        

# Define Dataset

In [8]:
class HatefulMemesDataset(Dataset):
    def __init__(self, root_dir, jsonl_path, transforms=None):
        self.root_dir = root_dir
        with open(os.path.join(self.root_dir, jsonl_path), 'r') as f:
            self.jsonl = list(f)
        self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        self.transforms = transforms
    
    def __len__(self):
        return len(self.jsonl)

    def __getitem__(self, i):
        json_str = self.jsonl[i]
        json_loaded = json.loads(json_str)
        img_path = os.path.join(self.root_dir, json_loaded['img'])
        label = torch.tensor(json_loaded['label'])
        caption = json_loaded['text']
        img = np.array(Image.open(img_path).convert('RGB')).transpose((2, 0, 1))
        img = tv_tensors.Image(img)
        if self.transforms is not None:
            img = self.transforms(img)
        img = np.array(img).transpose((1,2,0))
        # Need to truncate, pad to max length to ensure that no dataloader issues happen!
        processed = processor(img, caption, truncation=True, padding='max_length', return_tensors='pt')
        return {
            'input_ids': processed['input_ids'],
            'token_type_ids': processed['token_type_ids'],
            'attention_mask': processed['attention_mask'],
            'pixel_values': processed['pixel_values'],
            'pixel_mask': processed['pixel_mask'],
            'label':  label
        }

# Define Training Code

In [9]:
def finetune_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch_sample in tqdm(dataloader):
        label = batch_sample['label'].to(device)
        pred = model(batch_sample)
        loss = loss_fn(pred, label)
        # print(total_loss)
        correct = (torch.argmax(pred, dim = 1) == label).sum()
        acc = correct / label.shape[0]
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Add loss
        epoch_loss += loss.item() * label.shape[0]
        epoch_acc += acc.item() * label.shape[0]
    final_epoch_loss = epoch_loss / len(dataloader.sampler)
    final_epoch_acc = epoch_acc / len(dataloader.sampler)
    return final_epoch_loss, final_epoch_acc, model

In [10]:
# Finetune on label directly, ignore caption
def finetune_model(model, dataloader, optimizer, loss, epochs, device):
    model.to(device)
    # Keep track of train metrics
    train_loss_ls = []
    train_acc_ls = []
    for epoch in tqdm(range(1, epochs + 1)):
        # Train model
        epoch_loss, epoch_acc, model = finetune_one_epoch(model, dataloader, optimizer, loss, device)
        train_loss_ls.append(epoch_loss)
        train_acc_ls.append(epoch_acc)
        print(f"Epoch {epoch}\n")
        print(f"\ttrain_loss: {epoch_loss}\n")
        print(f"\ttrain_acc: {epoch_acc}\n")
    return model, train_loss_ls, train_acc_ls

# Define Model and Dataloaders

In [11]:
classification_model = ViltForClassification(model, 'cuda')
classification_model

ViltForClassification(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, 

In [12]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr  = 5e-5)
loss = nn.CrossEntropyLoss()

In [13]:
train_dataset = HatefulMemesDataset('./dataset', 'train.jsonl', transforms=v2.Compose([v2.CenterCrop((800, 800))]))
train_dataloader = DataLoader(
    train_dataset, batch_size=128
)

In [14]:
finetuned_model, train_losses, train_accs = finetune_model(classification_model, train_dataloader, optimizer, loss, 10, device)

100%|██████████| 67/67 [06:08<00:00,  5.50s/it]
 10%|█         | 1/10 [06:08<55:14, 368.31s/it]

Epoch 1

	train_loss: 0.6290790331784417

	train_acc: 0.655058823557461



100%|██████████| 67/67 [05:10<00:00,  4.63s/it]
 20%|██        | 2/10 [11:18<44:33, 334.22s/it]

Epoch 2

	train_loss: 0.5641836188260246

	train_acc: 0.7121176471710206



100%|██████████| 67/67 [05:16<00:00,  4.73s/it]
 30%|███       | 3/10 [16:35<38:03, 326.18s/it]

Epoch 3

	train_loss: 0.5215061726009145

	train_acc: 0.7436470588235294



100%|██████████| 67/67 [05:53<00:00,  5.27s/it]
 40%|████      | 4/10 [22:28<33:41, 336.88s/it]

Epoch 4

	train_loss: 0.48792184880200556

	train_acc: 0.7732941177312066



100%|██████████| 67/67 [05:13<00:00,  4.68s/it]
 50%|█████     | 5/10 [27:42<27:22, 328.59s/it]

Epoch 5

	train_loss: 0.45569926343244666

	train_acc: 0.7967058826895321



100%|██████████| 67/67 [05:16<00:00,  4.72s/it]
 60%|██████    | 6/10 [32:58<21:37, 324.38s/it]

Epoch 6

	train_loss: 0.42062176327144396

	train_acc: 0.8238823531655705



100%|██████████| 67/67 [05:02<00:00,  4.51s/it]
 70%|███████   | 7/10 [38:00<15:51, 317.15s/it]

Epoch 7

	train_loss: 0.38380383689263287

	train_acc: 0.8494117649302763



100%|██████████| 67/67 [04:55<00:00,  4.41s/it]
 80%|████████  | 8/10 [42:56<10:20, 310.36s/it]

Epoch 8

	train_loss: 0.34378311850042903

	train_acc: 0.8743529412886676



100%|██████████| 67/67 [04:58<00:00,  4.45s/it]
 90%|█████████ | 9/10 [47:54<05:06, 306.52s/it]

Epoch 9

	train_loss: 0.3059117962893318

	train_acc: 0.8961176472551683



100%|██████████| 67/67 [04:58<00:00,  4.46s/it]
100%|██████████| 10/10 [52:53<00:00, 317.37s/it]

Epoch 10

	train_loss: 0.2688689673157299

	train_acc: 0.9167058826334337






In [15]:
def test_model_capability(model, jsonl_path):
    labels = []
    model_probs = []
    model_preds = []
    # 0 -> Non-hateful, 1 -> Hateful
    with open(jsonl_path, 'r') as json_f:
        json_list = list(json_f)
    for json_str in tqdm(json_list):
        result = json.loads(json_str)
        img_path, label, caption = result['img'], result['label'], result['text']
        labels.append(label)
        # Read image
        batch = processor(Image.open(os.path.join('./dataset', img_path)).convert('RGB'), caption, truncation=True, padding='max_length', return_tensors='pt')
        with torch.no_grad():
            out = model(batch)
            probs = out.softmax(dim=-1).cpu().numpy()
            class_1_prob = probs[0][1]
            model_probs.append(class_1_prob)
            model_preds.append(np.argmax(probs))
    return roc_auc_score(labels, model_probs), accuracy_score(labels, model_preds)

In [16]:
test_model_capability(finetuned_model, './dataset/dev_seen.jsonl')

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:20<00:00, 24.62it/s]


(0.6708005952857211, 0.602)