In [3]:
import torch
from transformers import CLIPProcessor
from src.datasets import CIFAR100FSCIL, CUB200FSCIL, MiniImageNetFSCIL

### Load Dataset along with Image and Text Preprocessing

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
img_preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16").feature_extractor
text_preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16").tokenizer
dataset = CIFAR100FSCIL(transform=img_preprocess)



### 🧷 Model Parameters

Model initialized as:  
`CLIPForPromptTuning(L_g=4, L_s=4, D_g=10, D_s=2)`

- `L_g = 4`: Number of **General Prompts (G-prompts)** per modality (vision & language).
- `L_s = 4`: Number of **Shared Prompts (S-prompts)** across modalities.
- `D_g = 10`: Number of layers/depth for which **G-prompts** are inserted.
- `D_s = 2`: Number of layers/depth for which **S-prompts** are inserted.
- Total layers in CLIP text and image encoder -> 12

In [5]:
from src.models.clip_models import CLIPForPromptTuning
model = CLIPForPromptTuning(L_g=4, L_s=4, D_g=10, D_s=2).to(device)

In [6]:
text_label_mapping = dataset.text_label_mapping.values()
prompt_labels = ['[]'.replace('[]', i) for i in text_label_mapping]
print(f"Number of classes: {len(prompt_labels)}")
print(f"Model Classes: {prompt_labels}")

Number of classes: 100
Model Classes: ['apple', 'aquarium fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak tree', 'orange', 'orchid', 'otter', 'palm tree', 'pear', 'pickup truck', 'pine tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle

In [7]:
out_text_tokens = text_preprocess(prompt_labels, padding=True, return_tensors="pt")
text_tokens = out_text_tokens["input_ids"].to(device)
attn_mask = out_text_tokens["attention_mask"].to(device)

In [None]:
session_no = 0 # Base Session
test_dataset = dataset.eval_stream[session_no]._datasets[0]
test_length = len(test_dataset)
print(f"Test dataset length: {test_length}")

Test dataset length: 6000


In [45]:
import pandas as pd
import random

rand = random.randint(0, test_length - 1)
img, x, y = test_dataset[rand]

In [46]:
print(f"True label: {prompt_labels[y]}")

True label: bear


In [47]:
from PIL import Image
img.show()

In [48]:
x_new = torch.tensor(x, device=device).unsqueeze(0)
logits, text_out = model(x_new, text_tokens, attn_mask)
logits

tensor([[21.9442, 21.1817, 22.2317, 25.6748, 24.9217, 19.4653, 22.0614, 19.9956,
         19.9834, 20.8015, 21.1082, 21.2212, 22.0211, 19.2919, 21.6196, 25.7606,
         22.1121, 19.9578, 20.5058, 25.8859, 20.7491, 24.8579, 21.5507, 21.9781,
         18.8758, 20.9095, 20.5628, 21.2551, 20.1599, 22.5594, 22.4197, 23.4722,
         22.3390, 23.7674, 23.0566, 20.6650, 23.0290, 19.6536, 24.5385, 21.4802,
         19.5188, 22.9245, 21.9093, 24.3556, 19.8583, 20.8884, 22.3249, 21.9323,
         19.5393, 22.4188, 22.5605, 21.9058, 20.1038, 21.6178, 22.7976, 23.1640,
         20.0658, 20.4957, 20.9667, 21.8960, 21.7744, 20.0843, 21.0496, 24.9927,
         23.3907, 22.7302, 22.9145, 21.1126, 21.8598, 21.1380, 22.1159, 22.6109,
         22.3848, 21.5915, 22.7176, 23.7071, 21.0948, 20.7128, 20.4113, 19.8888,
         24.2704, 20.1246, 21.4347, 22.8881, 21.9425, 20.9563, 22.3376, 23.3204,
         23.4490, 21.8377, 21.2983, 22.6431, 21.7847, 21.0763, 19.7935, 22.6815,
         22.2401, 23.9674, 2

In [49]:
logits.shape, text_out.shape

(torch.Size([1, 100]), torch.Size([100, 512]))

In [50]:
output_idx = torch.argmax(logits, dim=1).item()
print(f"True: {prompt_labels[y]}")
print(f"Predicted: {prompt_labels[output_idx]}")

True: bear
Predicted: cattle
