In [1]:
from fomo.models.clip.clip_base import ClipBase
import torch
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


# CLIP Base Demo

In [2]:
clip = ClipBase()
clip.to_cpu()
clip.eval();

In [3]:
clip.precompute_prompt_features(["a picture of a cat", "a picture of kitten"])

In [4]:
img = Image.open("./cat.jpg")

In [5]:
img_tensor = clip.transform(img).view(1, 3, 224, 224)

In [7]:
with torch.no_grad():
    logits = clip.forward(img_tensor)
print(logits)

tensor([[29.0192, 26.0525]])


In [8]:
torch.functional.F.softmax(logits, dim=1)

tensor([[0.9510, 0.0490]])

# Datasets

In [3]:
from fomo.utils.data.datasets import DatasetInitializer
from fomo.utils.data import utils

In [4]:
# loading zero_shot_dataset
zero_shot_dataset = DatasetInitializer.from_str("cifar10").value(train=True)

# zero_shot_dataset contains two properties: torch dataset and labels
print(zero_shot_dataset.dataset)
print(zero_shot_dataset.labels)

Files already downloaded and verified
Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [5]:
# you can also split the dataset into train and eval splits using utils

# using percentage split
train_dataset, val_dataset = utils.split_train_val(zero_shot_dataset.dataset, train_size=0.8)
print(len(train_dataset), len(val_dataset))

40000 10000


In [6]:
# using number of samples split

train_dataset, val_dataset = utils.split_train_val(zero_shot_dataset.dataset, train_eval_samples=[10, 20])
print(len(train_dataset), len(val_dataset))

10 20


In [11]:
# using subsample split
subsampled_dataset = utils.subsample_classes(zero_shot_dataset.dataset, "all")
print(subsampled_dataset.__len__())

subsampled_dataset = utils.subsample_classes(zero_shot_dataset.dataset, "base")
print(subsampled_dataset.__len__())

subsampled_dataset = utils.subsample_classes(zero_shot_dataset.dataset, "new")
print(subsampled_dataset.__len__())

50000
25000
25000


# Learner

In [1]:
from fomo.pipelines.train import Learner
from fomo.pipelines.types.learner_args import LearnerArgs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
learner_args = LearnerArgs()
learner_args.device = "cpu"
learner_args.epochs = 1
learner_args.model_type = "clip_linear"
learner_args.use_wandb = True
learner_args.train_subsample = "base"
learner_args.test_subsample = "new"
learner_args.train_eval_size = (10, 10)

learner = Learner(learner_args)

Files already downloaded and verified
Files already downloaded and verified
Turning off gradients in both the image and the text encoder
Parameters to be updated: {'image_linear.0.weight', 'image_linear.0.bias'}
Number of learnable paramms: 262656




In [3]:
learner.run()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdqmiss[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: [0][0/1]	Time 11.054 (11.054)	Data 10.446 (10.446)	Loss 2.6379e+00 (2.6379e+00)	Acc@1   0.00 (  0.00)


100%|██████████| 1/1 [00:17<00:00, 17.17s/it]
100%|██████████| 1/1 [00:10<00:00, 10.45s/it]

Validate: [0/1]	Time 10.451 (10.451)	Loss 1.8925e+00 (1.8925e+00)	Prompt Acc@1  20.00 ( 20.00)


100%|██████████| 1/1 [00:15<00:00, 15.45s/it]


 * Prompt Acc@1 20.000
saved best file
Files already downloaded and verified


  1%|          | 1/157 [00:11<30:40, 11.80s/it]

Validate: [  0/157]	Time 11.799 (11.799)	Loss 2.5814e+00 (2.5814e+00)	Prompt Acc@1   9.38 (  9.38)


  7%|▋         | 11/157 [00:46<08:40,  3.57s/it]

Validate: [ 10/157]	Time  3.714 ( 4.184)	Loss 2.5792e+00 (2.4580e+00)	Prompt Acc@1   7.81 ( 10.37)


 13%|█▎        | 21/157 [01:19<07:32,  3.33s/it]

Validate: [ 20/157]	Time  3.238 ( 3.795)	Loss 2.4487e+00 (2.4532e+00)	Prompt Acc@1   9.38 ( 10.49)


 20%|█▉        | 31/157 [01:52<06:54,  3.29s/it]

Validate: [ 30/157]	Time  3.239 ( 3.638)	Loss 2.5285e+00 (2.4703e+00)	Prompt Acc@1   9.38 (  9.93)


 26%|██▌       | 41/157 [02:25<06:23,  3.31s/it]

Validate: [ 40/157]	Time  3.341 ( 3.560)	Loss 2.5206e+00 (2.4615e+00)	Prompt Acc@1  12.50 ( 10.02)


 32%|███▏      | 51/157 [03:00<06:01,  3.41s/it]

Validate: [ 50/157]	Time  3.375 ( 3.536)	Loss 2.6653e+00 (2.4694e+00)	Prompt Acc@1   9.38 (  9.99)


 39%|███▉      | 61/157 [03:33<05:20,  3.34s/it]

Validate: [ 60/157]	Time  3.338 ( 3.506)	Loss 2.4688e+00 (2.4721e+00)	Prompt Acc@1   7.81 (  9.84)


 45%|████▌     | 71/157 [04:07<04:48,  3.36s/it]

Validate: [ 70/157]	Time  3.312 ( 3.482)	Loss 2.5490e+00 (2.4694e+00)	Prompt Acc@1   7.81 (  9.95)


 52%|█████▏    | 81/157 [04:40<04:11,  3.31s/it]

Validate: [ 80/157]	Time  3.228 ( 3.466)	Loss 2.4381e+00 (2.4770e+00)	Prompt Acc@1  10.94 (  9.82)


 58%|█████▊    | 91/157 [05:14<03:42,  3.38s/it]

Validate: [ 90/157]	Time  3.420 ( 3.461)	Loss 2.4622e+00 (2.4848e+00)	Prompt Acc@1  12.50 (  9.91)


 64%|██████▍   | 101/157 [05:48<03:06,  3.32s/it]

Validate: [100/157]	Time  3.279 ( 3.449)	Loss 2.3963e+00 (2.4881e+00)	Prompt Acc@1  12.50 (  9.99)


 71%|███████   | 111/157 [06:21<02:32,  3.32s/it]

Validate: [110/157]	Time  3.291 ( 3.438)	Loss 2.4520e+00 (2.4841e+00)	Prompt Acc@1   7.81 ( 10.08)


 77%|███████▋  | 121/157 [06:54<01:58,  3.29s/it]

Validate: [120/157]	Time  3.280 ( 3.429)	Loss 2.3733e+00 (2.4836e+00)	Prompt Acc@1  18.75 ( 10.05)


 83%|████████▎ | 131/157 [07:28<01:26,  3.33s/it]

Validate: [130/157]	Time  3.369 ( 3.422)	Loss 2.6873e+00 (2.4849e+00)	Prompt Acc@1   4.69 (  9.97)


 90%|████████▉ | 141/157 [08:02<00:54,  3.39s/it]

Validate: [140/157]	Time  3.443 ( 3.420)	Loss 2.3139e+00 (2.4820e+00)	Prompt Acc@1  15.62 (  9.95)


 96%|█████████▌| 151/157 [08:35<00:19,  3.28s/it]

Validate: [150/157]	Time  3.193 ( 3.413)	Loss 2.3879e+00 (2.4802e+00)	Prompt Acc@1   9.38 (  9.96)


100%|██████████| 157/157 [09:13<00:00,  3.52s/it]


 * Prompt Acc@1 10.040
Files already downloaded and verified


  1%|▏         | 1/79 [00:11<15:01, 11.55s/it]

Validate: [ 0/79]	Time 11.553 (11.553)	Loss 1.9260e+00 (1.9260e+00)	Prompt Acc@1  20.31 ( 20.31)


 14%|█▍        | 11/79 [00:45<03:54,  3.45s/it]

Validate: [10/79]	Time  3.283 ( 4.146)	Loss 1.9104e+00 (1.8901e+00)	Prompt Acc@1  14.06 ( 20.60)


 27%|██▋       | 21/79 [01:19<03:17,  3.41s/it]

Validate: [20/79]	Time  3.359 ( 3.784)	Loss 1.9115e+00 (1.8893e+00)	Prompt Acc@1  20.31 ( 20.24)


 39%|███▉      | 31/79 [01:53<02:42,  3.38s/it]

Validate: [30/79]	Time  3.381 ( 3.654)	Loss 1.9061e+00 (1.9018e+00)	Prompt Acc@1  15.62 ( 19.56)


 52%|█████▏    | 41/79 [02:26<02:07,  3.37s/it]

Validate: [40/79]	Time  3.471 ( 3.577)	Loss 2.0188e+00 (1.9094e+00)	Prompt Acc@1  20.31 ( 19.63)


 65%|██████▍   | 51/79 [03:00<01:35,  3.40s/it]

Validate: [50/79]	Time  3.448 ( 3.540)	Loss 1.9069e+00 (1.9186e+00)	Prompt Acc@1  21.88 ( 19.94)


 77%|███████▋  | 61/79 [03:34<01:00,  3.36s/it]

Validate: [60/79]	Time  3.301 ( 3.515)	Loss 1.6355e+00 (1.9096e+00)	Prompt Acc@1  28.12 ( 19.98)


 90%|████████▉ | 71/79 [04:07<00:26,  3.35s/it]

Validate: [70/79]	Time  3.310 ( 3.491)	Loss 1.6977e+00 (1.9096e+00)	Prompt Acc@1  26.56 ( 19.83)


100%|██████████| 79/79 [04:51<00:00,  3.69s/it]


 * Prompt Acc@1 20.080
Files already downloaded and verified


  1%|▏         | 1/79 [00:11<15:27, 11.89s/it]

Validate: [ 0/79]	Time 11.888 (11.888)	Loss 1.6161e+00 (1.6161e+00)	Prompt Acc@1  25.00 ( 25.00)


 14%|█▍        | 11/79 [00:46<03:55,  3.47s/it]

Validate: [10/79]	Time  3.369 ( 4.197)	Loss 1.5206e+00 (1.6149e+00)	Prompt Acc@1  40.62 ( 27.98)


 27%|██▋       | 21/79 [01:19<03:14,  3.35s/it]

Validate: [20/79]	Time  3.240 ( 3.805)	Loss 1.6797e+00 (1.6312e+00)	Prompt Acc@1  25.00 ( 26.86)


 39%|███▉      | 31/79 [01:53<02:40,  3.35s/it]

Validate: [30/79]	Time  3.237 ( 3.660)	Loss 1.6802e+00 (1.6406e+00)	Prompt Acc@1  23.44 ( 25.71)


 52%|█████▏    | 41/79 [02:27<02:09,  3.40s/it]

Validate: [40/79]	Time  3.443 ( 3.591)	Loss 1.6760e+00 (1.6457e+00)	Prompt Acc@1  17.19 ( 25.04)


 65%|██████▍   | 51/79 [03:00<01:34,  3.37s/it]

Validate: [50/79]	Time  3.477 ( 3.538)	Loss 1.6201e+00 (1.6538e+00)	Prompt Acc@1  32.81 ( 24.69)


 77%|███████▋  | 61/79 [03:34<01:01,  3.43s/it]

Validate: [60/79]	Time  3.536 ( 3.518)	Loss 1.6810e+00 (1.6547e+00)	Prompt Acc@1  18.75 ( 24.59)


 90%|████████▉ | 71/79 [04:08<00:27,  3.39s/it]

Validate: [70/79]	Time  3.356 ( 3.506)	Loss 1.5504e+00 (1.6482e+00)	Prompt Acc@1  32.81 ( 25.13)


100%|██████████| 79/79 [04:53<00:00,  3.71s/it]

 * Prompt Acc@1 24.960





## Extending ClipBase

In [10]:
# import ClipBase
from fomo.models.clip.clip_base import ClipBase
from torch import nn
import torch

In [13]:
class ClipExtension(ClipBase):
    def __init__(self, backbone: str = "ViT-B/16", root: str = "./data") -> None:
        # pass default arguments to the parent class
        super(ClipExtension, self).__init__(backbone, root=root)

        # add additional blocks to the model

        self.visual_mlp = nn.Sequential(
            nn.Linear(self._clip.visual.output_dim, 12),
            nn.Linear(12, self._clip.visual.output_dim)
        )

    @property
    def learnable_param_names(self) -> set[str]:
         # IMPORTANT: Add the name of the learnable parameters in the model
        return set(["image_linear"])

    # If needed you can override the to_cpu and to_cuda methods
    def to_cpu(self) -> None:
        self._clip.to(torch.device("cpu"))
        self.image_linear.to(torch.device("cpu"))
        self._clip.float()

    def to_cuda(self) -> None:
        self.image_linear.to(torch.device("cuda"))
        self._clip.to(torch.device("cuda"))

    def forward(self, images: torch.Tensor, prompts: list[str] | None = None) -> torch.Tensor:
        # Change the forward method to include the visual_mlp
        if prompts:
            text_features = self.encode_text(prompts)
        elif self._precomputed_prompt_features is not None:
            text_features = self._precomputed_prompt_features
        else:
            raise ValueError("At least one prompts or pre-computed promt features has to be present.")

        image_features = self.encode_images(images)

        image_features = self.image_linear(image_features)

        logits_per_image: torch.Tensor = self.logit_scale * image_features @ text_features.t()

        return logits_per_image

In [15]:
model = ClipExtension()

#print learnable parameters
print(model.learnable_param_names)

{'image_linear'}
