In [1]:
from fomo.models.clip.clip_base import ClipBase
import torch
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


# CLIP Base Demo

In [2]:
clip = ClipBase()
clip.to_cpu()
clip.eval();

100%|███████████████████████████████████████| 335M/335M [02:33<00:00, 2.29MiB/s]


In [3]:
clip.precompute_prompt_features(["a picture of a cat", "a picture of kitten"])

In [4]:
img = Image.open("./cat.jpg")

In [5]:
img_tensor = clip.transform(img).view(1, 3, 224, 224)

In [7]:
with torch.no_grad():
    logits = clip.forward(img_tensor)
print(logits)

tensor([[29.0192, 26.0525]])


In [8]:
torch.functional.F.softmax(logits, dim=1)

tensor([[0.9510, 0.0490]])

# Datasets

In [1]:
from fomo.utils.data.datasets import DatasetInitializer
from fomo.utils.data import utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# loading zero_shot_dataset
zero_shot_dataset = DatasetInitializer.from_str("cifar10").value(train=True)

# zero_shot_dataset contains two properties: torch dataset and labels
print(zero_shot_dataset.dataset)
print(zero_shot_dataset.labels)

Files already downloaded and verified
Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [3]:
# you can also split the dataset into train and eval splits using utils

# using percentage split
train_dataset, val_dataset = utils.split_train_val(zero_shot_dataset.dataset, train_size=0.8)
print(len(train_dataset), len(val_dataset))

40000 10000


In [4]:
# using number of samples split
train_dataset, val_dataset = utils.split_train_val(zero_shot_dataset.dataset, train_eval_samples=[10, 20])
print(len(train_dataset), len(val_dataset))

10 20


# Learner

In [1]:
from fomo.pipelines.train import Learner
from fomo.pipelines.types.learner_args import LearnerArgs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
learner_args = LearnerArgs()
learner_args.device = "cpu"
learner_args.epochs = 2
learner_args.model_type = "clip_linear"
learner_args.train_eval_size = (10, 10)

learner = Learner(learner_args)

Files already downloaded and verified
Files already downloaded and verified
Turning off gradients in both the image and the text encoder
Parameters to be updated: {'image_linear.0.bias', 'image_linear.0.weight'}
Number of learnable paramms: 262656




In [3]:
learner.run()

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: [0][0/1]	Time  8.351 ( 8.351)	Data  7.830 ( 7.830)	Loss 3.1633e+00 (3.1633e+00)	Acc@1   0.00 (  0.00)


100%|██████████| 1/1 [00:14<00:00, 14.31s/it]
100%|██████████| 1/1 [00:08<00:00,  8.21s/it]

Validate: [0/1]	Time  8.214 ( 8.214)	Loss 2.3106e+00 (2.3106e+00)	Prompt Acc@1  10.00 ( 10.00)


100%|██████████| 1/1 [00:13<00:00, 13.22s/it]


 * Prompt Acc@1 10.000
saved best file


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: [1][0/1]	Time  8.372 ( 8.372)	Data  7.829 ( 7.829)	Loss 2.3008e+00 (2.3008e+00)	Acc@1  10.00 ( 10.00)


100%|██████████| 1/1 [00:13<00:00, 13.95s/it]
100%|██████████| 1/1 [00:08<00:00,  8.20s/it]

Validate: [0/1]	Time  8.202 ( 8.202)	Loss 2.3206e+00 (2.3206e+00)	Prompt Acc@1  10.00 ( 10.00)


100%|██████████| 1/1 [00:13<00:00, 13.21s/it]


 * Prompt Acc@1 10.000
There's no improvement for 1 epochs.


  1%|          | 1/157 [00:08<22:58,  8.84s/it]

Validate: [  0/157]	Time  8.839 ( 8.839)	Loss 2.3017e+00 (2.3017e+00)	Prompt Acc@1  14.06 ( 14.06)


  7%|▋         | 11/157 [00:38<07:29,  3.08s/it]

Validate: [ 10/157]	Time  3.080 ( 3.537)	Loss 2.3043e+00 (2.2982e+00)	Prompt Acc@1   7.81 ( 10.65)


 13%|█▎        | 21/157 [01:09<06:58,  3.08s/it]

Validate: [ 20/157]	Time  2.972 ( 3.311)	Loss 2.2972e+00 (2.2991e+00)	Prompt Acc@1  12.50 (  9.30)


 20%|█▉        | 31/157 [01:40<06:25,  3.06s/it]

Validate: [ 30/157]	Time  2.970 ( 3.241)	Loss 2.3166e+00 (2.3019e+00)	Prompt Acc@1   3.12 (  9.07)


 26%|██▌       | 41/157 [02:09<05:40,  2.94s/it]

Validate: [ 40/157]	Time  2.909 ( 3.170)	Loss 2.3203e+00 (2.3009e+00)	Prompt Acc@1   4.69 (  9.30)


 32%|███▏      | 51/157 [02:39<05:13,  2.96s/it]

Validate: [ 50/157]	Time  2.960 ( 3.128)	Loss 2.3281e+00 (2.3030e+00)	Prompt Acc@1   9.38 (  9.34)


 39%|███▉      | 61/157 [03:09<04:44,  2.96s/it]

Validate: [ 60/157]	Time  2.931 ( 3.101)	Loss 2.2982e+00 (2.3038e+00)	Prompt Acc@1  14.06 (  9.45)


 45%|████▌     | 71/157 [03:38<04:14,  2.96s/it]

Validate: [ 70/157]	Time  3.012 ( 3.082)	Loss 2.3027e+00 (2.3037e+00)	Prompt Acc@1   9.38 (  9.40)


 52%|█████▏    | 81/157 [04:08<03:44,  2.96s/it]

Validate: [ 80/157]	Time  2.932 ( 3.067)	Loss 2.2922e+00 (2.3044e+00)	Prompt Acc@1  10.94 (  9.45)


 58%|█████▊    | 91/157 [04:37<03:14,  2.94s/it]

Validate: [ 90/157]	Time  2.958 ( 3.053)	Loss 2.3166e+00 (2.3061e+00)	Prompt Acc@1   6.25 (  9.13)


 64%|██████▍   | 101/157 [05:07<02:45,  2.95s/it]

Validate: [100/157]	Time  2.947 ( 3.043)	Loss 2.2895e+00 (2.3067e+00)	Prompt Acc@1   7.81 (  9.24)


 71%|███████   | 111/157 [05:36<02:16,  2.96s/it]

Validate: [110/157]	Time  2.988 ( 3.035)	Loss 2.3122e+00 (2.3063e+00)	Prompt Acc@1   9.38 (  9.22)


 77%|███████▋  | 121/157 [06:06<01:45,  2.93s/it]

Validate: [120/157]	Time  2.895 ( 3.027)	Loss 2.2964e+00 (2.3063e+00)	Prompt Acc@1   9.38 (  9.25)


 83%|████████▎ | 131/157 [06:35<01:16,  2.96s/it]

Validate: [130/157]	Time  2.976 ( 3.021)	Loss 2.3266e+00 (2.3060e+00)	Prompt Acc@1   6.25 (  9.30)


 90%|████████▉ | 141/157 [07:05<00:47,  2.96s/it]

Validate: [140/157]	Time  2.973 ( 3.017)	Loss 2.2835e+00 (2.3052e+00)	Prompt Acc@1   9.38 (  9.30)


 96%|█████████▌| 151/157 [07:34<00:17,  2.94s/it]

Validate: [150/157]	Time  2.913 ( 3.013)	Loss 2.2895e+00 (2.3050e+00)	Prompt Acc@1   6.25 (  9.41)


100%|██████████| 157/157 [08:10<00:00,  3.12s/it]

 * Prompt Acc@1 9.430





## Extending ClipBase

In [10]:
# import ClipBase
from fomo.models.clip.clip_base import ClipBase
from torch import nn
import torch

In [13]:
class ClipExtension(ClipBase):
    def __init__(self, backbone: str = "ViT-B/16", root: str = "./data") -> None:
        # pass default arguments to the parent class
        super(ClipExtension, self).__init__(backbone, root=root)

        # add additional blocks to the model

        self.visual_mlp = nn.Sequential(
            nn.Linear(self._clip.visual.output_dim, 12),
            nn.Linear(12, self._clip.visual.output_dim)
        )

    @property
    def learnable_param_names(self) -> set[str]:
         # IMPORTANT: Add the name of the learnable parameters in the model
        return set(["image_linear"])

    # If needed you can override the to_cpu and to_cuda methods
    def to_cpu(self) -> None:
        self._clip.to(torch.device("cpu"))
        self.image_linear.to(torch.device("cpu"))
        self._clip.float()

    def to_cuda(self) -> None:
        self.image_linear.to(torch.device("cuda"))
        self._clip.to(torch.device("cuda"))

    def forward(self, images: torch.Tensor, prompts: list[str] | None = None) -> torch.Tensor:
        # Change the forward method to include the visual_mlp
        if prompts:
            text_features = self.encode_text(prompts)
        elif self._precomputed_prompt_features is not None:
            text_features = self._precomputed_prompt_features
        else:
            raise ValueError("At least one prompts or pre-computed promt features has to be present.")

        image_features = self.encode_images(images)

        image_features = self.image_linear(image_features)

        logits_per_image: torch.Tensor = self.logit_scale * image_features @ text_features.t()

        return logits_per_image

In [15]:
model = ClipExtension()

#print learnable parameters
print(model.learnable_param_names)

{'image_linear'}
