In [1]:
from transformers import CLIPTokenizerFast, CLIPModel
import torch
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import accuracy_score,f1_score

import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
BATCH_SIZE = 32
IMG_SIZE = 224

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
df_train = pd.read_csv(f'{DIRECTROY}/reduced_train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/reduced_test.csv') 
num_classes = len(df_train['newid'].unique())

In [4]:
class2label = {c:l for c, l in zip(df_train['newid'], df_train['label'])}

In [19]:
sorted_class2label = dict(sorted(class2label.items()))
labels = []
for k, v in sorted_class2label.items():
    labels.append(v)

In [6]:
len(class2label), num_classes

(649, 649)

In [7]:
len(df_train['label'].unique())

649

In [8]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [9]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.tokenizer =  CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['newid'].values).long()
        self.imgs = torch.cat([ self.transforms(self.resize_img(Image.open(f'{DIRECTROY}/{self.directory}/{x}')).convert('RGB')).half().reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])
        self.tokenized = self.tokenizer(df['label'].tolist(), padding=True, truncation=True, return_tensors="pt")
        self.input_ids = self.tokenized['input_ids']
        self.attention_mask = self.tokenized['attention_mask']
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        return img, label, input_ids, attention_mask

In [14]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model = model.to(device)
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")

In [15]:
model.parameters

<bound method Module.parameters of CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_n

In [21]:
prompts = tokenizer(labels, return_tensors="pt", padding=True, truncation=True)
prompts_input_ids = prompts['input_ids'].to(device)
prompts_attention_mask = prompts['attention_mask'].to(device)

In [22]:
test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [26]:
true_labels = []
pred_labels = []
test_loss = 0
len_dataset = 0
with torch.no_grad():
        len_test = 0
        model.eval()
        for inputs, labels, input_ids, attention_mask  in tqdm(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            logits = model(pixel_values=inputs, input_ids=prompts_input_ids, attention_mask=prompts_attention_mask)
            
            outputs = torch.argmax(logits.logits_per_image, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        len_test += len(test_dataset)
        
        
        print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
        print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
        print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
    

100%|██████████| 190/190 [00:26<00:00,  7.18it/s]

Accuracy: 0.34501480750246794
F1 Score Weighted: 0.32035752586803135
F1 Score Macro: 0.31367785955833266





### Zero-shot can only get 34.5% in accuracy