In [4]:
# import ClipBase
from fomo.models.clip.clip_base import ClipBase
from torch import nn
import torch
from PIL import Image

In [6]:
model = ClipBase()
model.to_cpu()
model.eval();

In [7]:
img1 = Image.open("./cat.jpg")
img2 = Image.open("./dog.jpg")
img1_tensor = model.transform(img1)
img2_tensor = model.transform(img2)

img_input = torch.stack([img1_tensor, img2_tensor])

img_emb = model.encode_images(img_input).unsqueeze(0)
print(img_emb.shape)

torch.Size([1, 2, 512])


In [113]:
prompts = ["a photo of a cat", "a photo of a dog", "a photo of a bird"]
prompt_emb = model.encode_text(prompts)
prompt_emb = prompt_emb.unsqueeze(1).expand(-1, 2, -1)

In [115]:
input = torch.cat([img_emb, prompt_emb], dim=0)
input.shape

torch.Size([4, 2, 512])

In [125]:
class MiniTransformer(nn.Module):
    def __init__(self, num_classes: int = 5) -> None:
        super(MiniTransformer, self).__init__()

        self.wq = nn.Sequential(
            nn.Linear(512, 2),
            nn.Linear(2, 512),
        )
        self.wk = nn.Sequential(
            nn.Linear(512, 2),
            nn.Linear(2, 512),
        )
        self.wv = nn.Sequential(
            nn.Linear(512, 2),
            nn.Linear(2, 512),
        )
        self.transformer = nn.MultiheadAttention(embed_dim=512, num_heads=2)
        self._attn_mask = self._init_attn_mask(num_classes, 1)

    @staticmethod
    def _init_attn_mask(num_prompts: int, num_images: int) -> torch.Tensor:
        num_total = num_prompts + num_images
        mask = torch.zeros((num_total, num_total))

        for i in range(num_prompts):
            for j in range(num_prompts, num_total):
                mask[i, j] = 1

        for i in range(num_prompts, num_total):
            for j in range(num_prompts):
                mask[i, j] = 1

        return mask

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        batch_size = inputs.shape[1]
        query = self.wq(inputs)
        key = self.wk(inputs)
        value = self.wv(inputs)

        mask = self._attn_mask.clone().detach().unsqueeze(0).repeat(batch_size*2, 1, 1)

        return self.transformer(query, key, value, attn_mask=mask)

In [163]:
from torch.functional import F

In [177]:
class ClipExtension(ClipBase):
    def __init__(self, backbone: str = "ViT-B/16", root: str = "./data", num_classes: int = 10) -> None:
        # pass default arguments to the parent class
        super(ClipExtension, self).__init__(backbone, root=root)

        self._num_classes = num_classes
        self.transformer = MiniTransformer(num_classes=self._num_classes)

    @property
    def learnable_param_names(self) -> set[str]:
         # IMPORTANT: Add the name of the learnable parameters in the model
        return set(["transformer"])

    # If needed you can override the to_cpu and to_cuda methods
    def to_cpu(self) -> None:
        self._clip.to(torch.device("cpu"))
        self.transformer.to(torch.device("cpu"))
        self._clip.float()

    def to_cuda(self) -> None:
        self.transformer.to(torch.device("cuda"))
        self._clip.to(torch.device("cuda"))

    def forward(self, images: torch.Tensor, prompts: list[str] | None = None) -> torch.Tensor:
        # Change the forward method to include the visual_mlp
        if prompts:
            text_features = self.encode_text(prompts)
        elif self._precomputed_prompt_features is not None:
            text_features = self._precomputed_prompt_features
        else:
            raise ValueError("At least one prompts or pre-computed promt features has to be present.")

        image_features = self.encode_images(images).unsqueeze(0)
        text_features = text_features.unsqueeze(1).expand(-1, image_features.shape[1], -1)

        inputs =  torch.cat([text_features, image_features], dim=0)
        tr_outputs = self.transformer(inputs)[0]

        inputs += tr_outputs

        image_features = inputs[self._num_classes:]
        text_features = inputs[:self._num_classes]

        images = image_features.permute(1, 0, 2)
        prompts = text_features.permute(1, 0, 2)

        logits_per_image: torch.Tensor = torch.bmm(prompts, images.transpose(1, 2))

        return logits_per_image

In [169]:
model = ClipExtension()
model.to_cpu()
model.eval();

In [170]:
model.precompute_prompt_features(["a photo of a cat", "a photo of a dog", "a photo of a bird"])

In [173]:
out = model.forward(img_input).squeeze(-1)

In [174]:
F.softmax(out, dim=-1)

tensor([[0.3497, 0.3201, 0.3302],
        [0.3273, 0.3431, 0.3296]], grad_fn=<SoftmaxBackward0>)

In [1]:
from fomo.models.clip.clip_transformer import ClipTransformer
from PIL import Image
import torch

torch.autograd.set_detect_anomaly(True)

clip = ClipTransformer()
clip.to_cpu()
clip.precompute_prompt_features(["a photo of a cat", "a photo of a dog", "a photo of a bird"])

for name, param in clip.named_parameters():
    param.requires_grad = False
    for learnable_param_name in clip.learnable_param_names:
        if learnable_param_name in name:
            param.requires_grad = True

# print the learnable parameters
for name, param in clip.named_parameters():
    if param.requires_grad:
        print(name)



img1 = Image.open("./cat.jpg")
img2 = Image.open("./dog.jpg")
img1_tensor = clip.transform(img1)
img2_tensor = clip.transform(img2)

img_input = torch.stack([img1_tensor, img2_tensor])

out = clip.forward(img_input).squeeze(-1)

loss = torch.functional.F.cross_entropy(out, torch.tensor([0, 1]))
loss.backward()

  from .autonotebook import tqdm as notebook_tqdm


mmha.mha.in_proj_weight
mmha.mha.in_proj_bias
mmha.mha.out_proj.weight
mmha.mha.out_proj.bias


In [2]:
from fomo.pipelines.train import Learner
from fomo.pipelines.types.learner_args import LearnerArgs
import torch
from PIL import Image

In [4]:
learner_args = LearnerArgs()
learner_args.device = "cpu"
learner_args.epochs = 100
learner_args.model_type = "clip_transformer"
learner_args.train_eval_size = (100, 200)

learner = Learner(learner_args)

Files already downloaded and verified
Files already downloaded and verified
Turning off gradients in both the image and the text encoder
['_clip.positional_embedding', '_clip.text_projection', '_clip.logit_scale', '_clip.visual.class_embedding', '_clip.visual.positional_embedding', '_clip.visual.proj', '_clip.visual.conv1.weight', '_clip.visual.ln_pre.weight', '_clip.visual.ln_pre.bias', '_clip.visual.transformer.resblocks.0.attn.in_proj_weight', '_clip.visual.transformer.resblocks.0.attn.in_proj_bias', '_clip.visual.transformer.resblocks.0.attn.out_proj.weight', '_clip.visual.transformer.resblocks.0.attn.out_proj.bias', '_clip.visual.transformer.resblocks.0.ln_1.weight', '_clip.visual.transformer.resblocks.0.ln_1.bias', '_clip.visual.transformer.resblocks.0.mlp.c_fc.weight', '_clip.visual.transformer.resblocks.0.mlp.c_fc.bias', '_clip.visual.transformer.resblocks.0.mlp.c_proj.weight', '_clip.visual.transformer.resblocks.0.mlp.c_proj.bias', '_clip.visual.transformer.resblocks.0.ln_2.we



In [5]:
learner.run()

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [0][0/2]	Time 13.127 (13.127)	Data  8.908 ( 8.908)	Loss 2.2597e+00 (2.2597e+00)	Acc@1  78.12 ( 78.12)


100%|██████████| 2/2 [00:26<00:00, 13.19s/it]
 25%|██▌       | 1/4 [00:11<00:34, 11.40s/it]

Validate: [0/4]	Time 11.398 (11.398)	Loss 2.2532e+00 (2.2532e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.51s/it]


 * Prompt Acc@1 95.500
saved best file


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [1][0/2]	Time 12.346 (12.346)	Data  8.873 ( 8.873)	Loss 2.2531e+00 (2.2531e+00)	Acc@1  89.06 ( 89.06)


100%|██████████| 2/2 [00:25<00:00, 12.54s/it]
 25%|██▌       | 1/4 [00:11<00:33, 11.23s/it]

Validate: [0/4]	Time 11.228 (11.228)	Loss 2.2532e+00 (2.2532e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.53s/it]


 * Prompt Acc@1 95.500
There's no improvement for 1 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [2][0/2]	Time 12.083 (12.083)	Data  8.604 ( 8.604)	Loss 2.2549e+00 (2.2549e+00)	Acc@1  85.94 ( 85.94)


100%|██████████| 2/2 [00:24<00:00, 12.42s/it]
 25%|██▌       | 1/4 [00:11<00:34, 11.50s/it]

Validate: [0/4]	Time 11.499 (11.499)	Loss 2.2532e+00 (2.2532e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.58s/it]


 * Prompt Acc@1 95.500
There's no improvement for 2 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [3][0/2]	Time 14.047 (14.047)	Data 10.476 (10.476)	Loss 2.2537e+00 (2.2537e+00)	Acc@1  87.50 ( 87.50)


100%|██████████| 2/2 [00:27<00:00, 13.64s/it]
 25%|██▌       | 1/4 [00:11<00:35, 11.75s/it]

Validate: [0/4]	Time 11.753 (11.753)	Loss 2.2532e+00 (2.2532e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.66s/it]


 * Prompt Acc@1 95.500
There's no improvement for 3 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [4][0/2]	Time 12.318 (12.318)	Data  8.715 ( 8.715)	Loss 2.2554e+00 (2.2554e+00)	Acc@1  84.38 ( 84.38)


100%|██████████| 2/2 [00:25<00:00, 12.66s/it]
 25%|██▌       | 1/4 [00:11<00:35, 11.80s/it]

Validate: [0/4]	Time 11.806 (11.806)	Loss 2.2531e+00 (2.2531e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:39<00:00,  9.81s/it]


 * Prompt Acc@1 95.500
There's no improvement for 4 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [5][0/2]	Time 12.356 (12.356)	Data  8.892 ( 8.892)	Loss 2.2531e+00 (2.2531e+00)	Acc@1  87.50 ( 87.50)


100%|██████████| 2/2 [00:25<00:00, 12.54s/it]
 25%|██▌       | 1/4 [00:11<00:35, 11.90s/it]

Validate: [0/4]	Time 11.901 (11.901)	Loss 2.2531e+00 (2.2531e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.74s/it]


 * Prompt Acc@1 95.500
There's no improvement for 5 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [6][0/2]	Time 12.173 (12.173)	Data  8.696 ( 8.696)	Loss 2.2557e+00 (2.2557e+00)	Acc@1  82.81 ( 82.81)


100%|██████████| 2/2 [00:24<00:00, 12.48s/it]
 25%|██▌       | 1/4 [00:12<00:37, 12.41s/it]

Validate: [0/4]	Time 12.409 (12.409)	Loss 2.2531e+00 (2.2531e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:40<00:00, 10.02s/it]


 * Prompt Acc@1 95.500
There's no improvement for 6 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [7][0/2]	Time 12.159 (12.159)	Data  8.807 ( 8.807)	Loss 2.2579e+00 (2.2579e+00)	Acc@1  81.25 ( 81.25)


100%|██████████| 2/2 [00:24<00:00, 12.47s/it]
 25%|██▌       | 1/4 [00:11<00:34, 11.57s/it]

Validate: [0/4]	Time 11.572 (11.572)	Loss 2.2531e+00 (2.2531e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.60s/it]


 * Prompt Acc@1 95.500
There's no improvement for 7 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [8][0/2]	Time 12.401 (12.401)	Data  8.706 ( 8.706)	Loss 2.2545e+00 (2.2545e+00)	Acc@1  84.38 ( 84.38)


100%|██████████| 2/2 [00:25<00:00, 12.68s/it]
 25%|██▌       | 1/4 [00:11<00:33, 11.28s/it]

Validate: [0/4]	Time 11.285 (11.285)	Loss 2.2531e+00 (2.2531e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.64s/it]


 * Prompt Acc@1 95.500
There's no improvement for 8 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [9][0/2]	Time 11.355 (11.355)	Data  8.137 ( 8.137)	Loss 2.2530e+00 (2.2530e+00)	Acc@1  89.06 ( 89.06)


100%|██████████| 2/2 [00:23<00:00, 11.98s/it]
 25%|██▌       | 1/4 [00:11<00:33, 11.09s/it]

Validate: [0/4]	Time 11.089 (11.089)	Loss 2.2530e+00 (2.2530e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:37<00:00,  9.42s/it]


 * Prompt Acc@1 95.500
There's no improvement for 9 epochs.


  0%|          | 0/2 [00:00<?, ?it/s]

Epoch: [10][0/2]	Time 11.310 (11.310)	Data  8.120 ( 8.120)	Loss 2.2550e+00 (2.2550e+00)	Acc@1  85.94 ( 85.94)


100%|██████████| 2/2 [00:24<00:00, 12.01s/it]
 25%|██▌       | 1/4 [00:11<00:34, 11.35s/it]

Validate: [0/4]	Time 11.347 (11.347)	Loss 2.2530e+00 (2.2530e+00)	Prompt Acc@1  92.19 ( 92.19)


100%|██████████| 4/4 [00:38<00:00,  9.68s/it]


 * Prompt Acc@1 95.500
There's no improvement for 10 epochs.
The training halted by early stopping criterion.


  1%|          | 1/157 [00:09<25:11,  9.69s/it]

Validate: [  0/157]	Time  9.689 ( 9.689)	Loss 2.2538e+00 (2.2538e+00)	Prompt Acc@1  82.81 ( 82.81)


  7%|▋         | 11/157 [00:41<07:54,  3.25s/it]

Validate: [ 10/157]	Time  3.304 ( 3.753)	Loss 2.2549e+00 (2.2547e+00)	Prompt Acc@1  92.19 ( 90.77)


 13%|█▎        | 21/157 [01:12<07:05,  3.13s/it]

Validate: [ 20/157]	Time  3.076 ( 3.468)	Loss 2.2566e+00 (2.2544e+00)	Prompt Acc@1  92.19 ( 90.92)


 20%|█▉        | 31/157 [01:44<06:35,  3.14s/it]

Validate: [ 30/157]	Time  3.057 ( 3.369)	Loss 2.2573e+00 (2.2542e+00)	Prompt Acc@1  92.19 ( 91.18)


 26%|██▌       | 41/157 [02:17<06:14,  3.23s/it]

Validate: [ 40/157]	Time  3.277 ( 3.357)	Loss 2.2543e+00 (2.2545e+00)	Prompt Acc@1  90.62 ( 90.62)


 32%|███▏      | 51/157 [02:51<05:58,  3.38s/it]

Validate: [ 50/157]	Time  3.557 ( 3.366)	Loss 2.2529e+00 (2.2544e+00)	Prompt Acc@1  89.06 ( 90.59)


 39%|███▉      | 61/157 [03:24<05:07,  3.21s/it]

Validate: [ 60/157]	Time  3.133 ( 3.349)	Loss 2.2530e+00 (2.2544e+00)	Prompt Acc@1  90.62 ( 90.57)


 45%|████▌     | 71/157 [03:57<04:40,  3.27s/it]

Validate: [ 70/157]	Time  3.362 ( 3.339)	Loss 2.2542e+00 (2.2543e+00)	Prompt Acc@1  92.19 ( 90.69)


 52%|█████▏    | 81/157 [04:29<04:13,  3.33s/it]

Validate: [ 80/157]	Time  3.254 ( 3.333)	Loss 2.2562e+00 (2.2545e+00)	Prompt Acc@1  89.06 ( 90.64)


 58%|█████▊    | 91/157 [05:01<03:29,  3.17s/it]

Validate: [ 90/157]	Time  3.154 ( 3.317)	Loss 2.2528e+00 (2.2546e+00)	Prompt Acc@1  87.50 ( 90.45)


 64%|██████▍   | 101/157 [05:33<02:59,  3.20s/it]

Validate: [100/157]	Time  3.168 ( 3.306)	Loss 2.2578e+00 (2.2548e+00)	Prompt Acc@1  92.19 ( 90.39)


 71%|███████   | 111/157 [06:05<02:25,  3.16s/it]

Validate: [110/157]	Time  3.173 ( 3.291)	Loss 2.2531e+00 (2.2549e+00)	Prompt Acc@1  93.75 ( 90.30)


 77%|███████▋  | 121/157 [06:37<01:55,  3.20s/it]

Validate: [120/157]	Time  3.168 ( 3.287)	Loss 2.2539e+00 (2.2548e+00)	Prompt Acc@1  92.19 ( 90.50)


 83%|████████▎ | 131/157 [07:09<01:21,  3.14s/it]

Validate: [130/157]	Time  3.171 ( 3.276)	Loss 2.2542e+00 (2.2549e+00)	Prompt Acc@1  90.62 ( 90.46)


 90%|████████▉ | 141/157 [07:41<00:53,  3.32s/it]

Validate: [140/157]	Time  3.338 ( 3.275)	Loss 2.2595e+00 (2.2550e+00)	Prompt Acc@1  87.50 ( 90.39)


 96%|█████████▌| 151/157 [08:14<00:20,  3.33s/it]

Validate: [150/157]	Time  3.241 ( 3.275)	Loss 2.2578e+00 (2.2549e+00)	Prompt Acc@1  92.19 ( 90.50)


100%|██████████| 157/157 [08:51<00:00,  3.39s/it]

 * Prompt Acc@1 90.510





In [7]:
img1 = Image.open("./cat.jpg")
img2 = Image.open("./dog.jpg")
img1_tensor = learner.model.transform(img1)
img2_tensor = learner.model.transform(img2)

img_input = img1_tensor.unsqueeze(0)

learner.model.forward(img_input)

RuntimeError: Tensors must have same number of dimensions: got 3 and 2

In [3]:
learner.run()

  0%|          | 0/1 [00:14<?, ?it/s]


AttributeError: 'function' object has no attribute 'dim'