In [1]:
# ! git clone https://github.com/apple/ml-mobileclip.git /home/ml-mobileclip
# ! sed -i 's/torchvision==/torchvision>=/' /home/ml-mobileclip/requirements.txt
# ! pip install -e /home/ml-mobileclip

In [2]:
# !wget https://huggingface.co/apple/MobileCLIP-S2/resolve/main/mobileclip_s2.pt?download=true -O /home/mobileclip_s2.pt

# MobileCLIP FineTuning

In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
from glob import glob
import os
from tqdm import tqdm
import mobileclip
import datetime
tensorboard = False
if tensorboard:
    from torch.utils.tensorboard import SummaryWriter

device = "cuda" if torch.cuda.is_available() else "cpu"



## Hyper Params

In [4]:
# optim
opt_kargs = {
    'lr' : 1e-2, # TODO: Figure out LR decay, either step decay or warm restart
    'betas': (0.9, 0.95),
    'weight_decay' : 0.0,
    'eps': 1e-8,
}

num_epochs = 20
temperature = torch.tensor(0.1, device=device) # TODO: How to finetune this?
lit_mode = False
batch_size = 16
dataset_path = "/home/onion_clean"
b_test = False

checkpoint_path = "/home/mobile_clip_finetuned.pth"

# best_checkpoint_file = "mobile_clip_finetuned.pth" # If None will not load
best_checkpoint_file = "/home/mobile_clip_finetuned.pth"

## Load model and tokenizer

In [5]:
model, _, preprocess = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='/home/mobileclip_s2.pt')
tokenizer = mobileclip.get_tokenizer('mobileclip_s2')
# This should be called only for inference
model = model.to(device)
model

  chkpt = torch.load(pretrained)


CLIP(
  (image_encoder): MCi(
    (model): FastViT(
      (patch_embed): Sequential(
        (0): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
        (1): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=80)
        )
        (2): MobileOneBlock(
          (se): Identity()
          (activation): GELU(approximate='none')
          (reparam_conv): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (network): ModuleList(
        (0): Sequential(
          (0): RepMixerBlock(
            (token_mixer): RepMixer(
              (reparam_conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=80)
            )
            (convffn): ConvFFN(
              (con

## Load Dataset

In [6]:
labels = ['raw', 'translucent', 'golden brown']
texts = [ f"{label} chopped onions in a dark pan" for label in labels ]
text_tokenized = tokenizer(texts).numpy()
print(text_tokenized)

def generate_txt_tokenized_given_path(img_path):
    lbl_i = int(os.path.basename(os.path.dirname(img_path)))
    return text_tokenized[lbl_i]
    
def load_data(what):
    img_path_l = glob(os.path.join(dataset_path, what, "*/*.jpg"))
    assert len(img_path_l) > 0, "Invalid Dataset"
    txt_l = [ generate_txt_tokenized_given_path(img_path) for img_path in img_path_l ]
    assert_data(img_path_l, txt_l)
    return img_path_l, txt_l

def assert_data(img_path_l, txt_l):
    for i, (img_path, txt) in enumerate(zip(img_path_l, txt_l)):
        assert isinstance(img_path, str) and os.path.isfile(img_path), f"{i} {img_path} not present!"
        assert isinstance(txt, np.ndarray) and txt.shape == text_tokenized[0].shape

train_img_path_l, train_txt_l = load_data("train")
val_img_path_l, val_txt_l = load_data("valid")

[[49406  6323 22580 15255   530   320  3144  7437 49407     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0]
 [49406 49052 22580 15255   530   320  3144  7437 49407     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0]
 [49406  3878  2866 22580 15255   530   320  3

In [7]:
class DatasetLoader:
    
    def __init__(self, img_path_l, txt_l):
        assert len(img_path_l) == len(txt_l)
        self.img_path_l = img_path_l
        self.txt_l = txt_l
        
    def __len__(self):
        return b_test if b_test else len(self.img_path_l)

    def __getitem__(self, idx):
        # TODO: Add augmentation
        image = preprocess(Image.open(self.img_path_l[idx]))
        return image, self.txt_l[idx]

train_data_loader = DataLoader(DatasetLoader(train_img_path_l, train_txt_l), batch_size=b_test if b_test else batch_size, shuffle=True)
val_data_loader = DataLoader(DatasetLoader(val_img_path_l, val_txt_l), batch_size=b_test if b_test else batch_size, shuffle=False)

## Training

### CLIP Loss

In [8]:
# temperature = nn.Parameter(temperature)
def clip_loss(img_features, txt_features):
    img_features = img_features / img_features.norm(dim=-1, keepdim=True)
    txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)
    
    logits = (img_features @ txt_features.T) * torch.exp(temperature)
    
    labels = torch.arange(len(logits), device=logits.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    
    return (loss_i + loss_t) / 2, logits

opt = torch.optim.Adam(model.parameters(), **opt_kargs)

### Checkpoint save and reload

In [9]:
def save_checkpoint(test_loss, epoch):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': test_loss,
        'epoch': epoch
    }, checkpoint_path)

def load_checkpoint():
    checkpoint = torch.load(best_checkpoint_file, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    test_loss = checkpoint['loss']
    last_epoch = checkpoint['epoch']
    return last_epoch, test_loss

last_epoch, best_loss = 0, float('inf') # TODO: Use last_epoch this for lr decay restore
if best_checkpoint_file:
    last_epoch, best_loss = load_checkpoint()
    print(f"Loaded checkpoint {best_checkpoint_file} with test loss {best_loss}")

Loaded checkpoint /home/mobile_clip_finetuned.pth with test loss 2.2912741162184345


### Training Loop

In [10]:
if tensorboard:
    writer = SummaryWriter(log_dir=f"tb_logs/{datetime.now().strftime('%d%m%y_%H%M%S')}")
    
def tb_log_scaler(what, val, step):
    if tensorboard:
        writer.add_scalar(what, val, global_step=step)

# Train loop
train_step, val_step = 0, 0
for epoch in range(1, num_epochs + 1):
    print(f"\nEpoch {epoch}/{num_epochs}")

    # Training
    pbar = tqdm(train_data_loader, total=b_test if b_test else len(train_data_loader), desc="Training", ncols=100)
    model.train()
    running_loss = 0.0
    for img, txt in pbar:
        opt.zero_grad()
        img = img.to(device)
        txt = txt.to(device)
        
        # forward
        img_features = model.encode_image(img)
        txt_features = model.encode_text(txt)

        
        if lit_mode:
            img_features = img_features.detach()
        # loss
        loss, logits = clip_loss(img_features, txt_features)
        assert logits.size(0) == img.size(0)
        
        tb_log_scaler("train/loss_step", loss.item(), train_step)
        train_step += 1

        # backward
        loss.backward()
        opt.step()

        running_loss += loss.item() 
        # running_loss / (batch_idx + 1)

        pbar.set_postfix_str(f"Loss: {running_loss / (pbar.n + 1):.4f}")
    train_loss = running_loss / len(train_data_loader)
    print(f'Training Loss {train_loss:.4f}')
    tb_log_scaler("train/loss_epoch", train_loss, epoch)    

    # # Validation
    # model.eval()
    # running_loss = 0
    # correct, total = 0, 0
    # with torch.no_grad():
        
    #     pbar = tqdm(val_data_loader, total=b_test if b_test else len(val_data_loader), desc="Testing ", ncols=100)
    #     for img, txt in pbar:
    #         img = img.to(device)
    #         txt = txt.to(device)

    #         img_features = model.encode_image(img)
    #         txt_features = model.encode_text(txt)

    #         loss, logits = clip_loss(img_features, txt_features)
    #         assert logits.size(0) == img.size(0)
    #         tb_log_scaler("val/loss_step", loss.item(), val_step)
    #         val_step += 1

    #         lbl_hat = torch.argmax(logits, axis=1)
    #         lbl = torch.arange(len(lbl_hat), device=lbl_hat.device)

    #         correct += (lbl_hat == lbl).sum().cpu().item() 
    #         total += logits.size(0)

    #         running_loss += loss.item()
    #         pbar.set_postfix_str(f"Validation Loss: {running_loss / (pbar.n + 1):.4f} Accuracy: {correct / total:.2%}")
    # test_loss = running_loss / len(val_data_loader)
    # test_acc = correct / total
    # print(f"Test Loss:{test_loss:.4f} Test Acc:{test_acc:.2%}")
    # tb_log_scaler("val/loss_epoch", test_loss, epoch)
    # tb_log_scaler("val/accuracy", correct / total, epoch)
    if train_loss < best_loss:
        best_loss = train_loss
        save_checkpoint(best_loss, epoch)
        print(f"Saved checkpoint to {checkpoint_path} with Loss: {best_loss}")


Epoch 1/20


Training: 100%|██████████| 346/346 [04:40<00:00,  1.24it/s, Loss: 2.3259]


Training Loss 2.3259

Epoch 2/20


Training:  19%|█▉        | 66/346 [00:53<03:46,  1.23it/s, Loss: 2.3097]


KeyboardInterrupt: 