In [None]:
import numpy as np
import torch
import clip
from clip.simple_tokenizer import SimpleTokenizer

from src import FoodDataModule, KPerClassSampler
from src import CLIP_Contrastive
from src import TextTransformer

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
# Load the model
clip_backbone = "ViT-B/16"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, image_transform = clip.load(clip_backbone, jit=False)
model = model.to(dtype=torch.float32)

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

In [5]:
# Prepare dataset and loaders
dataset_root = "data/food-101/images"
datamodule = FoodDataModule(folder=dataset_root, 
                            batch_size=32, 
                            image_transform=image_transform)
datamodule.setup()

In [6]:
# Templates and tokenization
templates = [
    "a photo of {}, a type of food.",
]

text_transformer = TextTransformer(
    tokenizer = SimpleTokenizer(), 
    templates = templates,
    context_length = context_length
)

num_classes = len(datamodule.dataset.class_to_idx.keys())
num_captions = len(templates)
tokenized_captions = torch.zeros(
    (num_classes, num_captions, context_length),
    dtype=torch.int)

for idx, class_name in datamodule.dataset.idx_to_class.items():
    class_captions = text_transformer(class_name)
    tokenized_captions[idx] = class_captions
    
tokenized_captions = tokenized_captions.to(device)
tokenized_captions.shape

torch.Size([101, 1, 77])

In [None]:
ks = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]

for k in ks:
    print(f'\nK: {k}')
    train_sampler = KPerClassSampler(dataset=datamodule.train_dataset, k=k, seed=42)
    
    train_loader = datamodule.train_dataloader(train_sampler)
    val_loader = datamodule.val_dataloader()
    
    clip_model = CLIP_Contrastive(model.to(device), tokenized_captions, out_features=512)

    log_dir = f'logs/CLIP_contrastive_few_shot_{clip_backbone}_{k}'
    logger = TensorBoardLogger(log_dir)
    checkpoint = ModelCheckpoint(log_dir, monitor='val/accuracy', mode='max')
    epochs = 5

    trainer = pl.Trainer(
        gpus=1,
        max_epochs=epochs,
        gradient_clip_val=1,
        amp_backend='native',
        auto_lr_find=True,
        logger=logger,
        callbacks=[checkpoint]
    )

    trainer.tune(clip_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    trainer.fit(clip_model, train_loader, val_loader)
    clip_model.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])
    accuracy = trainer.test(clip_model, datamodule=datamodule)
    
    k_results[k] = accuracy[0]['test/accuracy']

In [3]:
print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Allocated: 0.6 GB
Cached:    0.6 GB
