In [None]:
from transformers import CLIPTokenizerFast, CLIPModel
import torch
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import accuracy_score,f1_score

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
BATCH_SIZE = 32
IMG_SIZE = 224

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
df_train = pd.read_csv(f'{DIRECTROY}/train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/test_kaggletest.csv') 
num_classes = len(df_train['class'].unique())
classes = df_train['class'].unique().values.tolist()

In [None]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")

In [None]:
model.parameters

In [None]:
prompts = tokenizer(classes, return_tensors="pt", padding=True, truncation=True)
criteria = torch.nn.CrossEntropyLoss()

In [None]:
true_labels = []
pred_labels = []
test_loss = 0
len_dataset = 0
for i in range(3):
    dataset = torch.load(f'{DIRECTROY}/test_public_dataset/train_dataset_{i}.pth')
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    for i, (inputs, labels) in tqdm(enumerate(dataset)):
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            logis_per_image, logis_per_text = model(inputs, labels=prompts.input_ids)
            
            loss = criteria(logis_per_image, labels)
            
            pred = torch.argmax(logis_per_image,1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            true_labels.extend(pred)
            pred_labels.extend(labels)
        
            test_loss += loss.item()
    len_dataset += len(dataset)
        
       
print(f'Loss: {test_loss/len_dataset}')
print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
    