In [None]:
!pip install open_clip_torch tensorboard

[0m

# MobileCLIP FineTuning

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
from glob import glob
import os
from tqdm.notebook import tqdm
import open_clip
# from mobileclip.modules.common.mobileone import reparameterize_model
# from torch.utils.tensorboard import SummaryWriter

device = "cuda" if torch.cuda.is_available() else "cpu"

## Hyper Params

In [3]:
# optim
opt_kargs = {
    'lr' : 1e-2, # TODO: Figure out LR decay, either step decay or warm restart
    'betas': (0.9, 0.95),
    'weight_decay' : 0.0,
    'eps': 1e-8,
}

num_epochs = 20
temperature = torch.tensor(0.1, device=device) # TODO: How to finetune this?
batch_size = 16
dataset_path = "./onion_dataset"
b_test = False

checkpoint_path = "./mobile_clip_finetuned.pth"

# best_checkpoint_file = "mobile_clip_finetuned.pth" # If None will not load
best_checkpoint_file = None

## Load model and tokenizer

In [4]:
model, _, preprocess = open_clip.create_model_and_transforms('MobileCLIP-S2', pretrained='datacompdr')
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
# This should be called only for inference
# model = reparameterize_model(model)
model = model.to(device)
model

CustomTextCLIP(
  (visual): TimmModel(
    (trunk): FastVit(
      (stem): Sequential(
        (0): MobileOneBlock(
          (se): Identity()
          (conv_kxk): ModuleList(
            (0): ConvNormAct(
              (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (bn): BatchNormAct2d(
                80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
                (drop): Identity()
                (act): Identity()
              )
            )
          )
          (conv_scale): ConvNormAct(
            (conv): Conv2d(3, 80, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNormAct2d(
              80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): Identity()
            )
          )
          (act): GELU(approximate='none')
        )
        (1): MobileOneBlock(
          (se): Identity()
          (conv_kxk): ModuleList

In [5]:
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')

## Load Dataset

In [6]:
labels = ['raw', 'translucent', 'golden brown']
texts = [ f"{label} chopped onions in a dark pan" for label in labels ]
text_tokenized = tokenizer(texts).numpy()
# print(text_tokenized.shape)

def generate_txt_tokenized_given_path(img_path):
    lbl_i = int(os.path.basename(os.path.dirname(img_path)))
    return text_tokenized[lbl_i]
    
def load_data(what):
    img_path_l = glob(os.path.join(dataset_path, what, "*/*.jpg"))
    txt_l = [ generate_txt_tokenized_given_path(img_path) for img_path in img_path_l ]
    assert_data(img_path_l, txt_l)
    return img_path_l, txt_l

def assert_data(img_path_l, txt_l):
    for i, (img_path, txt) in enumerate(zip(img_path_l, txt_l)):
        assert isinstance(img_path, str) and os.path.isfile(img_path), f"{i} {img_path} not present!"
        assert isinstance(txt, np.ndarray) and txt.shape == text_tokenized[0].shape

train_img_path_l, train_txt_l = load_data("train")
val_img_path_l, val_txt_l = load_data("valid")

In [7]:
class DatasetLoader:
    
    def __init__(self, img_path_l, txt_l):
        assert len(img_path_l) == len(txt_l)
        self.img_path_l = img_path_l
        self.txt_l = txt_l
        
    def __len__(self):
        return b_test if b_test else len(self.img_path_l)

    def __getitem__(self, idx):
        # TODO: Add augmentation
        image = preprocess(Image.open(self.img_path_l[idx]))
        return image, self.txt_l[idx]

train_data_loader = DataLoader(DatasetLoader(train_img_path_l, train_txt_l), batch_size=b_test if b_test else batch_size, shuffle=True)
val_data_loader = DataLoader(DatasetLoader(val_img_path_l, val_txt_l), batch_size=b_test if b_test else batch_size, shuffle=True)

## Training

### CLIP Loss

In [8]:
temperature = nn.Parameter(temperature)
def clip_loss(img_features, txt_features):
    img_features = img_features / img_features.norm(dim=-1, keepdim=True)
    txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)
    
    logits = (img_features @ txt_features.T) * torch.exp(temperature)
    
    labels = torch.arange(len(logits), device=logits.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    
    return (loss_i + loss_t) / 2, logits

opt = torch.optim.Adam(model.parameters(), **opt_kargs)

### Checkpoint save and reload

In [9]:
def save_checkpoint(test_loss, epoch):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': test_loss,
        'epoch': epoch
    }, checkpoint_path)

def load_checkpoint():
    checkpoint = torch.load(best_checkpoint_file, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    test_loss = checkpoint['loss']
    last_epoch = checkpoint['epoch']
    return last_epoch, test_loss

last_epoch, test_loss = 0, 0.0 # TODO: Use last_epoch this for lr decay restore
if best_checkpoint_file:
    last_epoch, test_loss = load_checkpoint()
    print(f"Loaded checkpoint {best_checkpoint_file} with test loss {test_loss}")

### Training Loop

In [None]:
# Train the model
best_test_loss = torch.inf
for epoch in range(1, num_epochs + 1):
    print(f"\nEpoch {epoch}/{num_epochs}")

    # Training
    pbar = tqdm(train_data_loader, total=b_test if b_test else len(train_data_loader), desc="Training")
    model.train()
    running_loss = 0.0
    for batch_idx, batch in enumerate(pbar):
        opt.zero_grad()
        img, txt = batch
        img = img.to(device)
        txt = txt.to(device)
        
        # forward
        img_features = model.encode_image(img).detach()
        txt_features = model.encode_text(txt)

        # loss
        loss, _ = clip_loss(img_features, txt_features)

        # backward
        loss.backward()
        opt.step()

        running_loss += loss.item() 
        # running_loss / (batch_idx + 1)

        pbar.set_postfix_str(f"Loss: {running_loss / (batch_idx + 1):.4f}")

    # Validation
    model.eval()
    running_loss = 0
    with torch.no_grad():
        correct, total = 0, 0
        pbar = tqdm(val_data_loader, total=b_test if b_test else len(val_data_loader), desc="Testing ")
        for batch_idx, batch in enumerate(pbar):
            img, txt = batch
            img = img.to(device)
            txt = txt.to(device)

            img_features = model.encode_image(img)
            txt_features = model.encode_text(txt)

            img_features = img_features / img_features.norm(dim=-1, keepdim=True)
            txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)

            loss, logits = clip_loss(img_features, txt_features)

            lbl_ = torch.argmax(logits, axis=1)
            lbl = torch.arange(len(logits), device=logits.device)

            correct += (lbl_ == lbl).sum().cpu().item() 
            total += logits.size(0)

            running_loss += loss.item()
            pbar.set_postfix_str(f"Validation Loss: {running_loss / (batch_idx + 1):.4f} Accuracy: {correct / total:.2%}")
    test_loss = running_loss / len(val_data_loader)
    if test_loss < best_test_loss:
        save_checkpoint(test_loss, epoch)
        best_test_loss = test_loss
        print(f"Saved checkpoint to {checkpoint_path} with Loss: {test_loss}")


Epoch 1/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]

Saved checkpoint to ./mobile_clip_finetuned.pth with Loss: 2.647172151093787

Epoch 2/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]


Epoch 3/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]

Saved checkpoint to ./mobile_clip_finetuned.pth with Loss: 2.6465915194217193

Epoch 4/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]


Epoch 5/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]

Saved checkpoint to ./mobile_clip_finetuned.pth with Loss: 2.644302803151151

Epoch 6/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]


Epoch 7/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]

Saved checkpoint to ./mobile_clip_finetuned.pth with Loss: 2.6396697414048176

Epoch 8/20


Training:   0%|          | 0/346 [00:00<?, ?it/s]

Testing :   0%|          | 0/94 [00:00<?, ?it/s]