# VILT Model Zero Shot Assessment

In [20]:
from PIL import Image
# ViltModel is a raw model with no heads. Can use to define heads
from transformers import ViltProcessor, ViltForImageAndTextRetrieval
from torchvision.transforms.v2 import functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors

import skimage
import torch
import os
import numpy as np
import json
import torch.nn as nn

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from ipywidgets import FloatProgress
from sklearn.metrics import roc_auc_score, accuracy_score

# Load Model

In [21]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [23]:
model.to(device)

ViltForImageAndTextRetrieval(
  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_feature

# Run Inference and Obtain Results

In [24]:
def test_vilt_model_zero_shot_capability(model, jsonl_path):
    labels = []
    model_probs = []
    model_preds = []
    # 0 -> Non-hateful, 1 -> Hateful
    with open(jsonl_path, 'r') as json_f:
        json_list = list(json_f)
    for json_str in tqdm(json_list):
        result = json.loads(json_str)
        img_path, label = result['img'], result['label']
        labels.append(label)
        # Read image
        image = Image.open(os.path.join('./dataset', img_path)).convert('RGB')
        hateful_encoding = processor(image, 'hateful', return_tensors='pt').to('cuda')
        non_hateful_encoding = processor(image, 'non-hateful', return_tensors='pt').to('cuda')
        with torch.no_grad():
            hateful_outputs = model(**hateful_encoding)
            non_hateful_outputs = model(**non_hateful_encoding)
            # print(hateful_outputs)
            hateful_logit = hateful_outputs.logits[0,:].item()
            non_hateful_logit = non_hateful_outputs.logits[0, :].item()
            # Non-hateful -> 0 , hateful -> 1
            logits = torch.tensor([non_hateful_logit, hateful_logit])
            softmaxed_logits = nn.functional.softmax(logits, dim=0).cpu().numpy()
            # print(softmaxed_logits)
            # print(np.argmax(softmaxed_logits))
            class_1_prob = softmaxed_logits[1]
            model_probs.append(class_1_prob)
            model_preds.append(np.argmax(softmaxed_logits))
    return roc_auc_score(labels, model_probs), accuracy_score(labels, model_preds)


In [25]:
dev_roc, dev_acc = test_vilt_model_zero_shot_capability(model, './dataset/dev_unseen.jsonl')
print(f'Dev ROC: {dev_roc}')
print(f'Dev Accuracy: {dev_acc}')

  0%|          | 0/540 [00:00<?, ?it/s]

100%|██████████| 540/540 [00:37<00:00, 14.28it/s]

Dev ROC: 0.5244117647058824
Dev Accuracy: 0.5574074074074075





In [26]:
test_roc, test_acc = test_vilt_model_zero_shot_capability(model, './dataset/test_unseen.jsonl')
print(f'Test ROC: {test_roc}')
print(f'Test Accuracy: {test_acc}')

100%|██████████| 2000/2000 [02:12<00:00, 15.10it/s]

Test ROC: 0.4672501333333333
Test Accuracy: 0.514



