# Finetuning CLIP Model

In [1]:
import torch
import os
import numpy as np
# import clip
from PIL import Image
import json
import skimage
import torch.nn as nn

from transformers import AutoProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from torchvision import tv_tensors

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device is {device}")
# model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

Device is cuda


In [2]:
model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14')
processor = AutoProcessor.from_pretrained('openai/clip-vit-large-patch14')
model.to(device)

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05,

# Define Dataset and Dataloader

In [3]:
class HatefulMemesDataset(Dataset):
    def __init__(self, root_dir, jsonl_path, transforms=None):
        self.root_dir = root_dir
        with open(os.path.join(self.root_dir, jsonl_path), 'r') as f:
            self.jsonl = list(f)
        self.transforms = transforms
    
    def __len__(self):
        return len(self.jsonl)

    def __getitem__(self, i):
        json_str = self.jsonl[i]
        json_loaded = json.loads(json_str)
        img_path = os.path.join(self.root_dir, json_loaded['img'])
        label = torch.tensor(json_loaded['label'])
        # caption = json_loaded['text']
        # Directly use non-hateful and hateful as the text we are classifying image to
        # text = clip.tokenize(["non-hateful", "hateful"])
        img = np.array(Image.open(img_path).convert('RGB')).transpose((2, 0, 1))
        # img = preprocess(img)
        # Wrap in tv_tensors
        img = tv_tensors.Image(img)
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label

# Finetuning code

Finetune without changing any part of the model

## Finetune Functions

In [4]:
model

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05,

In [5]:
# Freeze layers
def prepare_model(model):
    # Freeze all
    for param in model.parameters():
        param.requires_grad = False
    # Unfreeze layers in image and text embedding transformer networks
    for param in model.vision_model.encoder.layers[-1].parameters():
        param.requires_grad = True
    for param in model.text_model.encoder.layers[-1].parameters():
        param.requires_grad = True
    # Final layers unfreeze
    for param in model.visual_projection.parameters():
        param.requires_grad = True
    for param in model.text_projection.parameters():
        param.requires_grad = True
    # Layer norm layers unfreeze
    for param in model.vision_model.post_layernorm.parameters():
        param.requires_grad = True
    for param in model.text_model.final_layer_norm.parameters():
        param.requires_grad = True
    # # Unfreeze layers in image and text embedding transformer networks
    # for param in model.visual.transformer.resblocks[11].parameters():
    #     param.requires_grad = True
    # for param in model.transformer.resblocks[11].parameters():
    #     param.requires_grad = True
    # # Unfreeze model token embedding layer
    # for param in model.token_embedding.parameters():
    #     param.requires_grad = True
    # # Layer norm layers unfreeze
    # for param in model.visual.ln_post.parameters():
    #     param.requires_grad = True
    # for param in model.ln_final.parameters():
    #     param.requires_grad = True

In [6]:
prepare_model(model)
# Model needs to be in float, otherwise NAN for loss and training does not work beyond first iteration
# model.float()

In [7]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr  = 1e-4)
loss = nn.CrossEntropyLoss()

In [8]:
def finetune_clip_one_epoch(model, processor, dataloader, optimizer, loss, device):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch_sample in tqdm(dataloader):
        img, label = batch_sample
        # Classify for non-hateful vs hateful
        inputs = processor(text=["non-hateful", "hateful"], images = img, return_tensors='pt', padding=True).to(device)
        label = label.to(device)
        # Get logits
        outputs = model(**inputs)
        logit_image = outputs.logits_per_image
        logit_text = outputs.logits_per_text
        # Get individual and total losses
        image_loss = loss(logit_image, label)
        text_loss = loss(torch.transpose(logit_text, 0, 1), label)
        total_loss = (image_loss + text_loss) / 2
        # Get the number of correct predictions
        correct = (torch.argmax(logit_image, dim = 1) == label).sum()
        acc = correct / img.shape[0]
        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        # Add loss
        epoch_loss += total_loss.item() * img.shape[0]
        epoch_acc += acc.item() * img.shape[0]
    final_epoch_loss = epoch_loss / len(dataloader.sampler)
    final_epoch_acc = epoch_acc / len(dataloader.sampler)
    return final_epoch_loss, final_epoch_acc, model

In [9]:
# Finetune on label directly, ignore caption
def finetune_model(model, processor, dataloader, optimizer, loss, epochs, device):
    model.to(device)
    # Keep track of train metrics
    train_loss_ls = []
    train_acc_ls = []
    # text = clip.tokenize(["non-hateful", "hateful"])
    for epoch in tqdm(range(1, epochs + 1)):
        # Train model
        epoch_loss, epoch_acc, model = finetune_clip_one_epoch(model, processor, dataloader, optimizer, loss, device)
        train_loss_ls.append(epoch_loss)
        train_acc_ls.append(epoch_acc)
        print(f"Epoch {epoch}\n")
        print(f"\ttrain_loss: {epoch_loss}\n")
        print(f"\ttrain_acc: {epoch_acc}\n")
    return model, train_loss_ls, train_acc_ls


In [10]:
train_dataset = HatefulMemesDataset('./dataset', 'train.jsonl', transforms=v2.Compose([v2.CenterCrop((800, 800))]))

In [11]:
len(train_dataset)

8500

In [12]:
train_dataloader = DataLoader(
    train_dataset, batch_size=64
)

In [13]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x21915db6c90>

## Finetune

In [14]:
finetuned_model, train_losses, train_accs = finetune_model(model, processor, train_dataloader, optimizer, loss, 10, device)

  0%|          | 0/10 [00:00<?, ?it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Un

Epoch 1

	train_loss: 0.5101710465445238

	train_acc: 0.764117647339316



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 2

	train_loss: 0.3131133459911627

	train_acc: 0.8605882354904624



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 3

	train_loss: 0.13155179732950295

	train_acc: 0.951764705882353



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 4

	train_loss: 0.10794556440763614

	train_acc: 0.9609411764705882



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 5

	train_loss: 0.10262114322886748

	train_acc: 0.9627058824651381



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 6

	train_loss: 0.08013549749290241

	train_acc: 0.971411764902227



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 7

	train_loss: 0.01968815822296721

	train_acc: 0.9936470588235294



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 8

	train_loss: 0.011856869134175427

	train_acc: 0.9974117647058823



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 9

	train_loss: 0.003607394662877435

	train_acc: 0.9992941176470588



Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.


Epoch 10

	train_loss: 0.0015606611585929331

	train_acc: 0.9996470588235294






# Assess on dev_unseen

In [15]:
from sklearn.metrics import roc_auc_score, accuracy_score

In [16]:
def test_model_zero_shot_capability(model, processor, device, jsonl_path):
    labels = []
    model_probs = []
    model_preds = []
    # 0 -> Non-hateful, 1 -> Hateful
    # text = clip.tokenize(["non-hateful", "hateful"]).to(device)
    with open(jsonl_path, 'r') as json_f:
        json_list = list(json_f)
    for json_str in tqdm(json_list):
        result = json.loads(json_str)
        img_path, label = result['img'], result['label']
        labels.append(label)
        # image = preprocess(Image.open(os.path.join('./dataset', img_path))).unsqueeze(0).to(device)
        with torch.no_grad():
            inputs = processor(text=['non-hateful', 'hateful'], images=Image.open(os.path.join('./dataset', img_path)), return_tensors='pt', padding=True).to(device)
            outputs = model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()
            class_1_prob = probs[0][1]
            model_probs.append(class_1_prob)
            model_preds.append(np.argmax(probs))
    return roc_auc_score(labels, model_probs), accuracy_score(labels, model_preds)


In [18]:
# Dev ROC: 0.678470588235294
# Dev Accuracy: 0.6648148148148149
dev_roc, dev_acc = test_model_zero_shot_capability(finetuned_model, processor, device, './dataset/dev_unseen.jsonl')
print(f'Dev ROC: {dev_roc}')
print(f'Dev Accuracy: {dev_acc}')

  0%|          | 0/540 [00:00<?, ?it/s]Unused or unrecognized kwargs: padding.
  0%|          | 1/540 [00:00<03:17,  2.73it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 3/540 [00:00<01:11,  7.53it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 5/540 [00:00<00:50, 10.61it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|▏         | 7/540 [00:00<00:42, 12.64it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 9/540 [00:00<00:37, 14.06it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 11/540 [00:00<00:34, 15.18it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 13/540 [00:01<00:32, 16.38it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  3%|▎         | 15/540 [00:01

Dev ROC: 0.7433088235294117
Dev Accuracy: 0.7055555555555556





In [19]:
# Dev ROC: 0.700788913603559
# Dev Accuracy: 0.636
dev_roc, dev_acc = test_model_zero_shot_capability(finetuned_model, processor, device, './dataset/dev_seen.jsonl')
print(f'Dev ROC: {dev_roc}')
print(f'Dev Accuracy: {dev_acc}')

  0%|          | 0/500 [00:00<?, ?it/s]Unused or unrecognized kwargs: padding.


  0%|          | 1/500 [00:00<02:33,  3.25it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 3/500 [00:00<00:58,  8.44it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 6/500 [00:00<00:37, 13.29it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 9/500 [00:00<00:30, 15.93it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 12/500 [00:00<00:26, 18.25it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  3%|▎         | 15/500 [00:00<00:25, 19.32it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  4%|▎         | 18/500 [00:01<00:24

Dev ROC: 0.7416908034757005
Dev Accuracy: 0.634





In [20]:
# test ROC: 0.682078831532613
# test Accuracy: 0.615
test_roc, test_acc = test_model_zero_shot_capability(finetuned_model, processor, device, './dataset/test_seen.jsonl')
print(f'test ROC: {test_roc}')
print(f'test Accuracy: {test_acc}')

  0%|          | 0/1000 [00:00<?, ?it/s]Unused or unrecognized kwargs: padding.


  0%|          | 1/1000 [00:00<05:13,  3.19it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  0%|          | 3/1000 [00:00<02:02,  8.14it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  0%|          | 5/1000 [00:00<01:28, 11.22it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 7/1000 [00:00<01:15, 13.17it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 9/1000 [00:00<01:05, 15.11it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 12/1000 [00:00<00:57, 17.10it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 15/1000 [00:01<00:54, 18.20it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  2%|▏         | 17/10

test ROC: 0.7694037615046019
test Accuracy: 0.64





In [21]:
# test ROC: 0.7010842666666666
# test Accuracy: 0.675
test_roc, test_acc = test_model_zero_shot_capability(finetuned_model, processor, device, './dataset/test_unseen.jsonl')
print(f'test ROC: {test_roc}')
print(f'test Accuracy: {test_acc}')

  0%|          | 0/2000 [00:00<?, ?it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  0%|          | 3/2000 [00:00<01:35, 20.87it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  0%|          | 6/2000 [00:00<01:46, 18.79it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  0%|          | 8/2000 [00:00<01:45, 18.80it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  0%|          | 10/2000 [00:00<01:46, 18.65it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 12/2000 [00:00<01:48, 18.30it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 14/2000 [00:00<01:46, 18.61it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
  1%|          | 16/2000 [00:

test ROC: 0.7747930666666667
test Accuracy: 0.707



