#### Import required libraries
___

In [9]:
import numpy as np
import torch
import clip
from clip import SimpleTokenizer

from src import FoodDataModule, KPerClassSampler
from src import CLIP_Captions, CLIP_Linear
from src import TextTransformer

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

#### Load the backbone model
___

Print all available models

In [14]:
clip.available_models()

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']

In [None]:
# Load model
clip_backbone = "ViT-B/32"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
assert device=='cuda', "No GPU detected in your machine"

model, preprocess = clip.load(clip_backbone, device=device, jit=False)
input_resolution = 224
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

#### Prepare dataset
___

In [6]:
dataset_root = "data/food-101/images"
dataset_root = '/Users/fabio/Desktop/dl4cv/data.nosync/'
dm = FoodDataModule(folder=dataset_root, batch_size=32, image_transform=preprocess)

In [18]:
# List all caption templates
templates = [
    "a photo of {}, a type of food."
    ]

text_transformer = TextTransformer(
    tokenizer = SimpleTokenizer(), 
    templates=templates,
    context_length=context_length
)

In [8]:
#Init tensor of captions
num_classes = len(dm.dataset.class_to_idx)
num_captions = len(templates)

tokenized_captions = torch.zeros((num_classes, num_captions, context_length), dtype=torch.int).to(device)

for idx, class_name in dm.dataset.idx_to_class.items():
    class_captions = text_transformer(class_name)
    tokenized_captions[idx] = class_captions

tokenized_captions = tokenized_captions.to(device)
tokenized_captions.shape

torch.Size([2, 1, 77])

#### Training
___

`clip_type` refers to the finetuning mode. If type 'captions' then both text and image features are used to train the model else if 'linear' CLIP will be used as an ordinary image classifier

In [None]:
clip_type = 'captions'

if clip_type == 'captions':
    clip_model = CLIP_Captions(model.to(device), tokenized_captions, out_features=512)
else:
    clip_model = CLIP_Linear(model.to(device), num_classes, out_features=512)

Start training

In [None]:
log_dir = f'logs/CLIP_{clip_type}_{clip_backbone}'
logger = TensorBoardLogger(log_dir)
checkpoint = ModelCheckpoint(log_dir, monitor='val/accuracy', mode='max')

trainer = pl.Trainer(gpus=1,
                     gradient_clip_val=1,
                     auto_lr_find = True,
                     logger=logger,
                     callbacks=[checkpoint, EarlyStopping(monitor='val/loss')],
                     max_epochs = 10
                     )

# Tune the hyperams, i.e find the best initial lr
trainer.tune(clip_wrapper, datamodule = dm)

# Start trainig
trainer.fit(clip_wrapper, datamodule = dm)

# Load best saved model
clip_wrapper.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])

# Test
trainer.test(clip_wrapper, datamodule = dm)

Load tensorboard and observe performance

In [None]:
%reload_ext tensorboard
%tensorboard --logdir logs