In [1]:
from transformers import CLIPTokenizerFast, CLIPModel
import torch
import torch.nn as nn
from torchvision.transforms import Compose, ToTensor,  Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset, BatchSampler
import torch.optim as optim

from sklearn.metrics import accuracy_score, f1_score

import numpy as np
import pandas as pd
from tqdm import tqdm

## Finetune CLIP for image classification
The idea is to re-train CLIP by matching images with their descriptions

### Parameter

In [2]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 100
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
df_train = pd.read_csv(f'{DIRECTROY}/reduced_train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/reduced_test.csv') 
num_classes = len(df_train['class'].unique())
classes = df_train['class'].unique().tolist()

In [4]:
class2label = {c:l for c, l in zip(df_train['newid'], df_train['label'])}

In [5]:
sorted_class2label = dict(sorted(class2label.items()))

In [6]:
labels = []
for k, v in sorted_class2label.items():
    labels.append(v)

In [7]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [7]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(), 
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

In [8]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.tokenizer =  CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['newid'].values).long()
        self.imgs = torch.cat([ self.transforms(self.resize_img(Image.open(f'{DIRECTROY}/{self.directory}/{x}')).convert('RGB')).half().reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])
        self.tokenized = self.tokenizer(df['label'].tolist(), padding=True, truncation=True, return_tensors="pt")
        self.input_ids = self.tokenized['input_ids']
        self.attention_mask = self.tokenized['attention_mask']
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        return img, label, input_ids, attention_mask

### CustomSampler for Dataloader

CLIP requires during each iteration, the classes must me different from each others. This custom sampler will always make sure that think happen.

The label just ensure that it does not larger than the dataloader length

It basicly place list of labels in random position, concatenate values of those position in a vector, adding the 0s and warp it in form of matrix (bs, x). With batchs that contain 0s (happen less than batch size), it will randomly filled with other labels not in the batchs 

In [9]:
class CustomSampler(BatchSampler):
    def __init__(self, labels, n_samples, n_classes):
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
        self.labels = torch.IntTensor(labels.values).to(self.device)
        self.n_samples = n_samples
        self.n_classes = n_classes
        self.num_labels = len(self.labels.unique())
        self.labels_idx = dict()
        for i in range(self.num_labels):
            self.labels_idx[i] = torch.where(self.labels == i)[0]
            
        self.bs = n_samples*n_classes
        
        self.extra = (self.bs - len(self.labels)%self.bs)%self.bs
        
    @staticmethod
    def random_mix(ts):
        order = torch.randperm(len(ts))
        return ts[order] 

    def __iter__(self):
        order = np.random.choice(self.num_labels, self.num_labels, replace=False)
        idxs = torch.cat([self.random_mix(self.labels_idx[l]) for l in order]).to(self.device) + 1
        idxs = torch.cat([idxs, torch.zeros(self.extra).long().to(self.device)]).to(self.device)
        idxs = idxs.view(self.bs,len(idxs)//self.bs).T
        
        for bs in idxs:
            get = torch.nonzero(bs)
            re = bs[get].squeeze() - 1
            re = re.to(self.device)
            count = 0
            i = 0
            # print(self.random_mix(self.labels_idx[order[i]])[count])
            if len(re) < self.bs:
                if count>len(self.labels_idx[order[i]]):
                    count = 0
                    i += 1
                lucky = torch.ones(1, device=self.device)*self.random_mix(self.labels_idx[order[i]])[count]
                re = torch.cat([re, lucky.long().to(self.device)])
                count += 1
            # print(re.tolist())
     
            yield re
            

### Loading the dataset from drive

In [11]:
train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_reduced_all.pth')
train_dataloader = DataLoader(train_dataset, batch_sampler=CustomSampler(df_train['newid'],  BATCH_SIZE,1))
test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Loading CLIP

In [14]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model = model.to(device)
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")

### Passing every prompts during testing

In [15]:
all_prompts = tokenizer([f'a photo of {x}' for x in labels], return_tensors="pt", padding=True, truncation=True)
all_prompts['input_ids'] = all_prompts['input_ids'].to(device)
all_prompts['attention_mask'] = all_prompts['attention_mask'].to(device)

### Training process

In [16]:
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

In [17]:
test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [18]:
max_accuracy = 0
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
    
    for i in range(2):
        train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_reduced_aug_{i}.pth')
        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        for inputs, labels, input_ids, attention_mask  in tqdm(train_dataloader):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            logits = model(pixel_values=inputs, input_ids=input_ids, attention_mask=attention_mask)
            logits_per_image = logits.logits_per_image
            logits_per_text = logits.logits_per_text
            
            ground_truth = torch.arange(BATCH_SIZE).to(device)
            
            loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        len_train += len(train_dataset)
        del train_dataset
        del train_dataloader

    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    with torch.no_grad():
        len_test = 0

        
        for inputs, labels, input_ids, attention_mask  in tqdm(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            
            logits = model(pixel_values=inputs, input_ids=all_prompts['input_ids'], attention_mask=all_prompts['attention_mask'])
            logits_per_image = logits.logits_per_image
            
            outputs = torch.argmax(logits_per_image, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        len_test += len(test_dataset)
        
        print(f'Epoch {epoch+1}/{EPOCHS},')
        print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
        print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
        print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
        if accuracy_score(true_labels, pred_labels) > max_accuracy:
            max_accuracy = accuracy_score(true_labels, pred_labels)
            torch.save(model, f'{MODEL_PATH}/model_clip_{epoch+1}.pth')
            torch.save(optimizer, f'{MODEL_PATH}/optimizer/optimizer_clip_{epoch+1}.pth')
            
            

Training epoch: 1


100%|██████████| 1137/1137 [05:19<00:00,  3.56it/s]


NameError: name 'train_dataset' is not defined