In [1]:
from transformers import CLIPTokenizerFast, CLIPModel
import torch
import torch.nn as nn
from torchvision.transforms import Compose, ToTensor,  Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset, BatchSampler
import torch.optim as optim

from sklearn.metrics import accuracy_score, f1_score

import numpy as np
import pandas as pd
from tqdm import tqdm

## Finetune CLIP for image classification
The idea is to re-train CLIP by matching images with their descriptions

### Parameter

In [2]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 100
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
df_train = pd.read_csv(f'{DIRECTROY}/reduced_train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/reduced_test.csv') 
num_classes = len(df_train['class'].unique())
classes = df_train['class'].unique().tolist()

In [4]:
class2label = {c:l for c, l in zip(df_train['newid'], df_train['label'])}

In [5]:
sorted_class2label = dict(sorted(class2label.items()))

In [6]:
labels = []
for k, v in sorted_class2label.items():
    labels.append(v)

In [7]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [8]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(), 
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

In [9]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.tokenizer =  CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['newid'].values).long()
        self.imgs = torch.cat([ self.transforms(self.resize_img(Image.open(f'{DIRECTROY}/{self.directory}/{x}')).convert('RGB')).half().reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])
        self.tokenized = self.tokenizer(df['label'].tolist(), padding=True, truncation=True, return_tensors="pt")
        self.input_ids = self.tokenized['input_ids']
        self.attention_mask = self.tokenized['attention_mask']
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        return img, label, input_ids, attention_mask

### CustomSampler for Dataloader

CLIP requires during each iteration, the classes must me different from each others. This custom sampler will always make sure that think happen.

The label just ensure that it does not larger than the dataloader length

It basicly place list of labels in random position, concatenate values of those position in a vector, adding the 0s and warp it in form of matrix (bs, x). With batchs that contain 0s (happen less than batch size), it will randomly filled with other labels not in the batchs 

In [10]:
class CustomSampler(BatchSampler):
    def __init__(self, labels, n_samples, n_classes):
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
        self.labels = torch.IntTensor(labels.values).to(self.device)
        self.n_samples = n_samples
        self.n_classes = n_classes
        self.num_labels = len(self.labels.unique())
        self.labels_idx = dict()
        for i in range(self.num_labels):
            self.labels_idx[i] = torch.where(self.labels == i)[0]
            
        self.bs = n_samples*n_classes
        
        self.extra = (self.bs - len(self.labels)%self.bs)%self.bs
        
    @staticmethod
    def random_mix(ts):
        order = torch.randperm(len(ts))
        return ts[order] 

    def __iter__(self):
        order = np.random.choice(self.num_labels, self.num_labels, replace=False)
        idxs = torch.cat([self.random_mix(self.labels_idx[l]) for l in order]).to(self.device) + 1
        idxs = torch.cat([idxs, torch.zeros(self.extra).long().to(self.device)]).to(self.device)
        idxs = idxs.view(self.bs,len(idxs)//self.bs).T
        
        for bs in idxs:
            get = torch.nonzero(bs)
            re = bs[get].squeeze() - 1
            re = re.to(self.device)
            count = 0
            i = 0
            # print(self.random_mix(self.labels_idx[order[i]])[count])
            if len(re) < self.bs:
                if count>len(self.labels_idx[order[i]]):
                    count = 0
                    i += 1
                lucky = torch.ones(1, device=self.device)*self.random_mix(self.labels_idx[order[i]])[count]
                re = torch.cat([re, lucky.long().to(self.device)])
                count += 1
            # print(re.tolist())
     
            yield re
            

### Loading the dataset from drive

In [11]:
train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_reduced_all.pth')
train_dataloader = DataLoader(train_dataset, batch_sampler=CustomSampler(df_train['newid'],  BATCH_SIZE,1))
test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Loading CLIP

In [12]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model = model.to(device)
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")

### Passing every prompts during testing

In [13]:
all_prompts = tokenizer(labels, return_tensors="pt", padding=True, truncation=True)
all_prompts['input_ids'] = all_prompts['input_ids'].to(device)
all_prompts['attention_mask'] = all_prompts['attention_mask'].to(device)

### Training process

In [14]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

In [15]:
max_accuracy = 0
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
    
    for inputs, labels, input_ids, attention_mask  in tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        logits = model(pixel_values=inputs, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = logits.logits_per_image
        logits_per_text = logits.logits_per_text
        
        ground_truth = torch.arange(BATCH_SIZE).to(device)
        
        loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    len_train += len(train_dataset)
        
    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    with torch.no_grad():
        len_test = 0

        test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        for inputs, labels, input_ids, attention_mask  in tqdm(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            logits = model(pixel_values=inputs, input_ids=all_prompts['input_ids'], attention_mask=all_prompts['attention_mask'])
            logits_per_image = logits.logits_per_image
            
            outputs = torch.argmax(logits_per_image, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        len_test += len(test_dataset)
        
        print(f'Epoch {epoch+1}/{EPOCHS},')
        print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
        print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
        print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
        if accuracy_score(true_labels, pred_labels) > max_accuracy:
            max_accuracy = accuracy_score(true_labels, pred_labels)
            torch.save(model, f'{MODEL_PATH}/model_clip_{epoch+1}.pth')
            torch.save(optimizer, f'{MODEL_PATH}/optimizer/optimizer_clip_{epoch+1}.pth')
            
            

Training epoch: 1


914it [04:14,  3.60it/s]


Epoch 1/100, Loss: 0.05026475822608008


100%|██████████| 190/190 [00:27<00:00,  6.98it/s]


Epoch 1/100,
Accuracy: 0.24827245804540968
F1 Score Weighted: 0.2289517041726222
F1 Score Macro: 0.21129561703577748
Training epoch: 2


914it [04:15,  3.58it/s]


Epoch 2/100, Loss: 0.024940774320141544


100%|██████████| 190/190 [00:26<00:00,  7.13it/s]


Epoch 2/100,
Accuracy: 0.33316880552813427
F1 Score Weighted: 0.319435814334516
F1 Score Macro: 0.30072761958069333
Training epoch: 3


914it [04:14,  3.60it/s]


Epoch 3/100, Loss: 0.018754650398697004


100%|██████████| 190/190 [00:26<00:00,  7.11it/s]


Epoch 3/100,
Accuracy: 0.3627838104639684
F1 Score Weighted: 0.3415101007243461
F1 Score Macro: 0.3263892358732765
Training epoch: 4


914it [04:14,  3.60it/s]


Epoch 4/100, Loss: 0.015605918177215385


100%|██████████| 190/190 [00:26<00:00,  7.12it/s]


Epoch 4/100,
Accuracy: 0.35620269825600526
F1 Score Weighted: 0.3417744027665416
F1 Score Macro: 0.3229343693293088
Training epoch: 5


914it [04:13,  3.60it/s]


Epoch 5/100, Loss: 0.013527242615026606


100%|██████████| 190/190 [00:26<00:00,  7.13it/s]


Epoch 5/100,
Accuracy: 0.39568937150378414
F1 Score Weighted: 0.37998050034055114
F1 Score Macro: 0.358172986620073
Training epoch: 6


914it [04:14,  3.59it/s]


Epoch 6/100, Loss: 0.012097978652531507


100%|██████████| 190/190 [00:26<00:00,  7.12it/s]


Epoch 6/100,
Accuracy: 0.38153998025666336
F1 Score Weighted: 0.37466491695530524
F1 Score Macro: 0.34890379605895255
Training epoch: 7


914it [04:13,  3.60it/s]


Epoch 7/100, Loss: 0.010900929995428679


100%|██████████| 190/190 [00:26<00:00,  7.12it/s]


Epoch 7/100,
Accuracy: 0.40506745640013164
F1 Score Weighted: 0.39854325262860224
F1 Score Macro: 0.3727206339406033
Training epoch: 8


914it [04:13,  3.60it/s]


Epoch 8/100, Loss: 0.009894512348773603


100%|██████████| 190/190 [00:26<00:00,  7.12it/s]


Epoch 8/100,
Accuracy: 0.40095426127015465
F1 Score Weighted: 0.3920460002074479
F1 Score Macro: 0.37295352625674283
Training epoch: 9


914it [04:12,  3.62it/s]


Epoch 9/100, Loss: 0.009384287150981623


100%|██████████| 190/190 [00:26<00:00,  7.13it/s]


Epoch 9/100,
Accuracy: 0.4068772622573215
F1 Score Weighted: 0.40370535780081307
F1 Score Macro: 0.38218794730931327
Training epoch: 10


914it [04:15,  3.58it/s]


Epoch 10/100, Loss: 0.008824919744688069


100%|██████████| 190/190 [00:26<00:00,  7.08it/s]


Epoch 10/100,
Accuracy: 0.4218492925304376
F1 Score Weighted: 0.41998524507432883
F1 Score Macro: 0.39984093969378853
Training epoch: 11


914it [04:17,  3.55it/s]


Epoch 11/100, Loss: 0.008015973949133871


100%|██████████| 190/190 [00:27<00:00,  6.96it/s]


Epoch 11/100,
Accuracy: 0.46709443896018427
F1 Score Weighted: 0.46004539820145923
F1 Score Macro: 0.4456557834231198
Training epoch: 12


914it [04:17,  3.54it/s]


Epoch 12/100, Loss: 0.007371614060604764


100%|██████████| 190/190 [00:26<00:00,  7.07it/s]


Epoch 12/100,
Accuracy: 0.4407699901283317
F1 Score Weighted: 0.4326607075352418
F1 Score Macro: 0.40601741422804855
Training epoch: 13


914it [04:16,  3.56it/s]


Epoch 13/100, Loss: 0.007256390984828972


100%|██████████| 190/190 [00:26<00:00,  7.06it/s]


Epoch 13/100,
Accuracy: 0.4369858506087529
F1 Score Weighted: 0.4295050951957478
F1 Score Macro: 0.41074062605603556
Training epoch: 14


914it [04:16,  3.57it/s]


Epoch 14/100, Loss: 0.00695214707052649


100%|██████████| 190/190 [00:26<00:00,  7.07it/s]


Epoch 14/100,
Accuracy: 0.4600197433366239
F1 Score Weighted: 0.4559433918596189
F1 Score Macro: 0.43795551916665026
Training epoch: 15


914it [04:15,  3.58it/s]


Epoch 15/100, Loss: 0.00637879008563889


100%|██████████| 190/190 [00:26<00:00,  7.07it/s]


Epoch 15/100,
Accuracy: 0.4453767686739059
F1 Score Weighted: 0.44350688946794037
F1 Score Macro: 0.4222820709442925
Training epoch: 16


914it [04:14,  3.59it/s]


Epoch 16/100, Loss: 0.006483036291028869


100%|██████████| 190/190 [00:27<00:00,  7.02it/s]


Epoch 16/100,
Accuracy: 0.44784468575189207
F1 Score Weighted: 0.44826464060914883
F1 Score Macro: 0.43778798199947205
Training epoch: 17


914it [04:21,  3.49it/s]


Epoch 17/100, Loss: 0.0060839180611370796


100%|██████████| 190/190 [00:27<00:00,  7.00it/s]


Epoch 17/100,
Accuracy: 0.470714050674564
F1 Score Weighted: 0.46565870103825385
F1 Score Macro: 0.4458167530847786
Training epoch: 18


914it [04:18,  3.54it/s]


Epoch 18/100, Loss: 0.005633364566912806


100%|██████████| 190/190 [00:26<00:00,  7.04it/s]


Epoch 18/100,
Accuracy: 0.4429088515959197
F1 Score Weighted: 0.4435700611092661
F1 Score Macro: 0.4239850103735334
Training epoch: 19


914it [04:16,  3.57it/s]


Epoch 19/100, Loss: 0.0053187219390653075


100%|██████████| 190/190 [00:26<00:00,  7.07it/s]


Epoch 19/100,
Accuracy: 0.46281671602500823
F1 Score Weighted: 0.4640599818219774
F1 Score Macro: 0.4444222772069494
Training epoch: 20


914it [04:17,  3.56it/s]


Epoch 20/100, Loss: 0.005589253092317468


100%|██████████| 190/190 [00:26<00:00,  7.06it/s]


Epoch 20/100,
Accuracy: 0.4053965120105298
F1 Score Weighted: 0.40319679491072385
F1 Score Macro: 0.38667813461994316
Training epoch: 21


914it [04:17,  3.55it/s]


Epoch 21/100, Loss: 0.0053112239807644665


100%|██████████| 190/190 [00:27<00:00,  6.94it/s]


Epoch 21/100,
Accuracy: 0.5065811122079631
F1 Score Weighted: 0.500850261867146
F1 Score Macro: 0.4825533817494047
Training epoch: 22


914it [04:18,  3.53it/s]


Epoch 22/100, Loss: 0.005003435726808464


100%|██████████| 190/190 [00:27<00:00,  6.91it/s]


Epoch 22/100,
Accuracy: 0.45788088186903586
F1 Score Weighted: 0.45520092069605617
F1 Score Macro: 0.43847979265932646
Training epoch: 23


914it [04:17,  3.55it/s]


Epoch 23/100, Loss: 0.005018313817362169


100%|██████████| 190/190 [00:26<00:00,  7.05it/s]


Epoch 23/100,
Accuracy: 0.4876604146100691
F1 Score Weighted: 0.4879339357178343
F1 Score Macro: 0.4679232354982869
Training epoch: 24


914it [04:17,  3.55it/s]


Epoch 24/100, Loss: 0.00482497355694276


100%|██████████| 190/190 [00:27<00:00,  7.03it/s]


Epoch 24/100,
Accuracy: 0.4878249424152682
F1 Score Weighted: 0.486866481933182
F1 Score Macro: 0.4545111450811121
Training epoch: 25


914it [04:18,  3.54it/s]


Epoch 25/100, Loss: 0.004615203781699359


100%|██████████| 190/190 [00:27<00:00,  6.96it/s]


Epoch 25/100,
Accuracy: 0.493747943402435
F1 Score Weighted: 0.49363553665539334
F1 Score Macro: 0.4684114201536851
Training epoch: 26


914it [04:20,  3.51it/s]


Epoch 26/100, Loss: 0.0044116396486416


100%|██████████| 190/190 [00:27<00:00,  7.00it/s]


Epoch 26/100,
Accuracy: 0.4861796643632774
F1 Score Weighted: 0.4872390833817201
F1 Score Macro: 0.46535834156916506
Training epoch: 27


914it [04:18,  3.54it/s]


Epoch 27/100, Loss: 0.004483442939639559


100%|██████████| 190/190 [00:27<00:00,  6.93it/s]


Epoch 27/100,
Accuracy: 0.4717012175057585
F1 Score Weighted: 0.47344314939487453
F1 Score Macro: 0.44905268418257765
Training epoch: 28


914it [04:20,  3.51it/s]


Epoch 28/100, Loss: 0.0042842511000297755


100%|██████████| 190/190 [00:27<00:00,  6.87it/s]


Epoch 28/100,
Accuracy: 0.5138203356367226
F1 Score Weighted: 0.5117450174180641
F1 Score Macro: 0.4871522028485476
Training epoch: 29


914it [04:20,  3.51it/s]


Epoch 29/100, Loss: 0.004335257926216252


100%|██████████| 190/190 [00:27<00:00,  6.86it/s]


Epoch 29/100,
Accuracy: 0.5133267522211253
F1 Score Weighted: 0.515395244120553
F1 Score Macro: 0.4885324280640756
Training epoch: 30


914it [04:19,  3.52it/s]


Epoch 30/100, Loss: 0.0040343458505110235


100%|██████████| 190/190 [00:27<00:00,  6.99it/s]


Epoch 30/100,
Accuracy: 0.5156301414939125
F1 Score Weighted: 0.5162043499569426
F1 Score Macro: 0.49085779229435217
Training epoch: 31


914it [04:18,  3.54it/s]


Epoch 31/100, Loss: 0.003979459120031782


100%|██████████| 190/190 [00:27<00:00,  6.92it/s]


Epoch 31/100,
Accuracy: 0.49868377755840737
F1 Score Weighted: 0.5018961302555547
F1 Score Macro: 0.47639093989243797
Training epoch: 32


914it [04:22,  3.48it/s]


Epoch 32/100, Loss: 0.003932189766119641


100%|██████████| 190/190 [00:27<00:00,  6.91it/s]


Epoch 32/100,
Accuracy: 0.5070746956235603
F1 Score Weighted: 0.5105567897498453
F1 Score Macro: 0.4888627938496638
Training epoch: 33


461it [02:09,  3.55it/s]


KeyboardInterrupt: 