In [None]:
from transformers import CLIPTokenizerFast, CLIPModel
import torch
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

from sklearn.metrics import accuracy_score,f1_score

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 100
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
df_train = pd.read_csv(f'{DIRECTROY}/train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/test_kaggletest.csv') 
num_classes = len(df_train['class'].unique())
classes = df_train['class'].unique().values.tolist()

In [None]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [None]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.df = df
        self.transforms = transforms
        self.directory = directory

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = Image.open(f'{DIRECTROY}/{self.directory}/{self.df.iloc[idx, 0]}')
        img = self.transforms(img)
        label = self.df.iloc[idx, 1]
        return img, label

In [None]:
train_dataset = CustomDataset(df_train, image_transforms, 'train')
test_dataset = CustomDataset(df_test_public, image_transforms, 'test')

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")

prompts = tokenizer(classes, return_tensors="pt", padding=True, truncation=True)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
    for i in range(18):
        train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_{i}.pth')
        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        for inputs, labels in tqdm(train_dataset):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        len_train += len(train_dataset)
        
    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    with torch.no_grad():
        len_test = 0
        for i in range(2):
            test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_dataset_{i}.pth')
            test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            for inputs, labels in tqdm(test_dataset):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs.logits, labels)
                eval_loss += loss.item()
                
                outputs = torch.argmax(outputs.logits, 1).flatten().cpu().numpy()
                labels = labels.flatten().cpu().numpy()
                
                true_labels.extend(labels)
                pred_labels.extend(outputs)
            len_test += len(test_dataset)
        
        print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
        print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
        print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
        print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
            