In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
from glob import glob
from matplotlib import pyplot as plt
import os
from tqdm.notebook import tqdm
import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model
from sklearn.metrics import accuracy_score

device = "cuda" if torch.cuda.is_available() else "cpu"



# Hyper Params

In [2]:
lr = 1e-5
num_epochs = 10
temperature = 0.1 # TODO: How to finetune this?
batch_size = 128
beta_1 = 0.9
beta_2 = 0.99 # Change to 0.95 if unstable
decay = 0
dataset_path = "/home/mnjm/workspace/clip/dataset/onion"
b_test = 5

# Load model and tokenizer

In [3]:
model, _, preprocess = open_clip.create_model_and_transforms('MobileCLIP-S2', pretrained='datacompdr')
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
# This should be called only for inference
# model = reparameterize_model(model)
model

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): FastVit(
      (stem): Sequential(
        (0): MobileOneBlock(
          (se): Identity()
          (conv_kxk): ModuleList(
            (0): ConvNormAct(
              (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (bn): BatchNormAct2d(
                80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
                (drop): Identity()
                (act): Identity()
              )
            )
          )
          (conv_scale): ConvNormAct(
            (conv): Conv2d(3, 80, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNormAct2d(
              80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): Identity()
            )
          )
          (act): GELU(approximate='none')
        )
        (1): MobileOneBlock(
          (se): Identity()
          (conv_kxk): ModuleList

In [4]:
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')

### (Optional) Save the model and tokenizer to train on Windows AI Server (cause no open_clip)

In [5]:
b_save = False
save_dir = "./mobile_clip_save"

os.makedirs(save_dir, exist_ok=True)

if b_save:
    if model is not None:
        # Save Model
        torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth'))
        # Save model config

        # TODO: Complete it

# Load Dataset

In [6]:
labels = ['raw', 'translucent', 'golden brown']
texts = [ f"{label} chopped onions in a dark pan" for label in labels ]
text_tokenized = tokenizer(texts).numpy()
# print(text_tokenized.shape)

def generate_txt_tokenized_given_path(img_path):
    lbl_i = int(os.path.basename(os.path.dirname(img_path)))
    return text_tokenized[lbl_i]
    
def load_data(what):
    img_path_l = glob(os.path.join(dataset_path, what, "*/*.jpg"))
    txt_l = [ generate_txt_tokenized_given_path(img_path) for img_path in img_path_l ]
    assert_data(img_path_l, txt_l)
    return img_path_l, txt_l

def assert_data(img_path_l, txt_l):
    for i, (img_path, txt) in enumerate(zip(img_path_l, txt_l)):
        assert isinstance(img_path, str) and os.path.isfile(img_path), f"{i} {img_path} not present!"
        assert isinstance(txt, np.ndarray) and txt.shape == text_tokenized[0].shape

train_img_path_l, train_txt_l = load_data("train")
val_img_path_l, val_txt_l = load_data("valid")

In [7]:
class DatasetLoader:
    
    def __init__(self, img_path_l, txt_l):
        assert len(img_path_l) == len(txt_l)
        self.img_path_l = img_path_l
        self.txt_l = txt_l
        
    def __len__(self):
        return b_test if b_test else len(self.img_path_l)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.img_path_l[idx]))
        return image, self.txt_l[idx]

train_data_loader = DataLoader(DatasetLoader(train_img_path_l, train_txt_l), batch_size=b_test if b_test else batch_size, shuffle=True)
val_data_loader = DataLoader(DatasetLoader(val_img_path_l, val_txt_l), batch_size=b_test if b_test else batch_size, shuffle=True)

# Training

In [8]:
def clip_loss(img_features, txt_features):
    img_features = img_features / img_features.norm(dim=-1, keepdim=True)
    txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)
    
    logits = (img_features @ txt_features.T) * torch.exp(torch.tensor(temperature))
    
    labels = torch.arange(len(logits), device=logits.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    
    return (loss_i + loss_t) / 2

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)

# Train the model
for epoch in range(num_epochs):
    progress_bar = tqdm(train_data_loader, total=b_test if b_test else len(train_data_loader))

    # Training
    model.train()
    for batch in progress_bar:
        optimizer.zero_grad()
        img, txt = batch
        img = img.to(device)
        txt = txt.to(device)
        
        # forward
        img_features = model.encode_image(img)
        txt_features = model.encode_text(txt)

        # loss
        loss = clip_loss(img_features, txt_features)

        # backward
        loss.backward()
        optimizer.step()

        progress_bar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {loss.item():.4f}")

    # Validation
    model.eval()
    val_gt, val_pred = [], []
    for batch in val_data_loader:
        img, txt = batch
        img = img.to(device)
        txt = txt.to(device)

        img_features = model.encode_image(img)
        txt_features = model.encode_text(txt)

        img_features = img_features / img_features.norm(dim=-1, keepdim=True)
        txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)

        logits = (img_features @ txt_features.T) * torch.exp(torch.tensor(temperature))
        logits = torch.argmax(logits, axis=1)
        labels = torch.arange(len(logits), device=logits.device)
        
        val_pred.extend(logits)
        val_gt.extend(labels)
    val_accuracy = accuracy_score(val_gt, val_pred)
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {val_accuracy:.4f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1/10, Validation Accuracy: 0.2000


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2/10, Validation Accuracy: 0.2000


  0%|          | 0/5 [00:00<?, ?it/s]