In [None]:
from fomo.models.clip.clip_base import ClipBase
import torch
from PIL import Image

# CLIP Base Demo

In [None]:
clip = ClipBase()
clip.to_cpu()
clip.eval();

In [None]:
clip.precompute_prompt_features(["a picture of a cat", "a picture of kitten"])

In [None]:
img = Image.open("./cat.jpg")

In [None]:
img_tensor = clip.transform(img).view(1, 3, 224, 224)

In [None]:
with torch.no_grad():
    logits = clip.forward(img_tensor)
print(logits)

In [None]:
torch.functional.F.softmax(logits, dim=1)

# Datasets

In [None]:
from fomo.utils.data.datasets import DatasetInitializer
from fomo.utils.data import utils

In [None]:
# loading zero_shot_dataset
zero_shot_dataset = DatasetInitializer.from_str("cifar10").value(train=True)

# zero_shot_dataset contains two properties: torch dataset and labels
print(zero_shot_dataset.dataset)
print(zero_shot_dataset.labels)

In [None]:
# you can also split the dataset into train and eval splits using utils

# using percentage split
train_dataset, val_dataset = utils.split_train_val(zero_shot_dataset.dataset, train_size=0.8)
print(len(train_dataset), len(val_dataset))

In [None]:
# using number of samples split

train_dataset, val_dataset = utils.split_train_val(zero_shot_dataset.dataset, train_eval_samples=[10, 20])
print(len(train_dataset), len(val_dataset))

In [11]:
# using subsample split
subsampled_dataset = utils.subsample_classes(zero_shot_dataset.dataset, "all")
print(subsampled_dataset.__len__())

subsampled_dataset = utils.subsample_classes(zero_shot_dataset.dataset, "base")
print(subsampled_dataset.__len__())

subsampled_dataset = utils.subsample_classes(zero_shot_dataset.dataset, "new")
print(subsampled_dataset.__len__())

50000
25000
25000


# Learner

In [None]:
from fomo.pipelines.train import Learner
from fomo.pipelines.types.learner_args import LearnerArgs

In [None]:
learner_args = LearnerArgs()
learner_args.device = "cpu"
learner_args.epochs = 1
learner_args.model_type = "clip_linear"
learner_args.use_wandb = True
learner_args.train_subsample = "base"
learner_args.test_subsample = "new"
learner_args.train_eval_size = (10, 10)

learner = Learner(learner_args)

In [None]:
learner.run()

## Extending ClipBase

In [None]:
# import ClipBase
from fomo.models.clip.clip_base import ClipBase
from torch import nn
import torch

In [None]:
class ClipExtension(ClipBase):
    def __init__(self, backbone: str = "ViT-B/16", root: str = "./data") -> None:
        # pass default arguments to the parent class
        super(ClipExtension, self).__init__(backbone, root=root)

        # add additional blocks to the model

        self.visual_mlp = nn.Sequential(
            nn.Linear(self._clip.visual.output_dim, 12),
            nn.Linear(12, self._clip.visual.output_dim)
        )

    @property
    def learnable_param_names(self) -> set[str]:
         # IMPORTANT: Add the name of the learnable parameters in the model
        return set(["image_linear"])

    # If needed you can override the to_cpu and to_cuda methods
    def to_cpu(self) -> None:
        self._clip.to(torch.device("cpu"))
        self.image_linear.to(torch.device("cpu"))
        self._clip.float()

    def to_cuda(self) -> None:
        self.image_linear.to(torch.device("cuda"))
        self._clip.to(torch.device("cuda"))

    def forward(self, images: torch.Tensor, prompts: list[str] | None = None) -> torch.Tensor:
        # Change the forward method to include the visual_mlp
        if prompts:
            text_features = self.encode_text(prompts)
        elif self._precomputed_prompt_features is not None:
            text_features = self._precomputed_prompt_features
        else:
            raise ValueError("At least one prompts or pre-computed promt features has to be present.")

        image_features = self.encode_images(images)

        image_features = self.image_linear(image_features)

        logits_per_image: torch.Tensor = self.logit_scale * image_features @ text_features.t()

        return logits_per_image

In [None]:
model = ClipExtension()

#print learnable parameters
print(model.learnable_param_names)

### N-class K-shot dataloader

In [None]:

from fomo.pipelines.train import Learner
from fomo.pipelines.types.learner_args import LearnerArgs
from fomo.pipelines.utils.initializers import initalize_datasets, initalize_n_class_k_shot_dataloaders, intialize_model

learner_args = LearnerArgs()
learner_args.device = "mps"
# learner_args.epochs = 300
# learner_args.patience = 10
# learner_args.print_freq = 20
# learner_args.save_freq = 300

learner_args.model_type = "clip_transformer"
learner_args.dataloder_type = 'n_class_k_shot'
learner_args.n = 5
learner_args.k = 16
# learner_args.train_size = 0.8
learner_args.train_eval_size = [100,20]

dataset = DatasetInitializer.from_str("cifar10").value(train=True)

# Load clip image transformation
model = intialize_model(
    learner_args.model_type, learner_args.model_backbone, learner_args.device
)
transforms = model.transforms

(train_dataset, test_dataset), labels = initalize_datasets(learner_args.dataset, transforms)

train_loader, val_loader, test_loader = initalize_n_class_k_shot_dataloaders(
    train_dataset, test_dataset, learner_args
)

### Test

In [None]:
from fomo.pipelines.train import Learner
from fomo.pipelines.types.learner_args import LearnerArgs

learner_args = LearnerArgs()
learner_args.epochs = 100
learner_args.patience = 20
learner_args.print_freq = 50
learner_args.save_freq = 200
learner_args.learning_rate = 0.02
learner_args.momentum = 0.9
learner_args.weight_decay =  0.002
learner_args.use_wandb = True
learner_args.dataset = 'oxford_pets'

learner_args.model_type = "clip_mm_mlp_adapter"
# learner_args.dataloder_type = 'n_class_k_shot'
# learner_args.n = 5
# learner_args.k = 16
learner_args.train_size = None # 0.8
learner_args.train_eval_size = [592,592]
learner_args.batch_size = 64


learner = Learner(learner_args)

learner.run()

Turning off gradients in both the image and the text encoder
Parameters to be updated: {'mm_to_text_mlp.0.weight', 'mm_to_visual_mlp.2.weight', 'mm_to_text_mlp.2.weight', 'mm_to_visual_mlp.0.weight'}
Number of learnable paramms: 98304
