In [49]:
import os
import re
import random
from PIL import Image
import torch
from tqdm.notebook import tqdm
from transformers import CLIPModel,CLIPProcessor

In [50]:
def load_split_file(file_path):
    with open(file_path,'r') as f:
        lines= f.readlines()
    samples = []
    for line in lines:
        path, label= line.strip().split()
        samples.append((path, int(label)))
    return samples

def load_class_names(file_path):
    id_to_name = {}
    with open(file_path, 'r') as f:
        for line in f:
            idx, name = line.split()[0],re.sub(r'\d+', '', line).strip()

            id_to_name[int(idx)] = name
    return id_to_name

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [52]:
dataset_root = '../ip102_v1.1/'
train_txt = os.path.join(dataset_root, "train.txt")
val_txt = os.path.join(dataset_root, "val.txt")
test_txt = os.path.join(dataset_root, "test.txt")
images_root=os.path.join(dataset_root, "images")


train_data = load_split_file(train_txt)
test_data = load_split_file(test_txt)
val_data = load_split_file(val_txt)

In [53]:
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
model.eval()

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [None]:
def test_model_few_shot(path, k_shots=0):
    correct = 0
    total = 0

    id_to_name = load_class_names(path)
    num_classes = len(id_to_name)

    if k_shots == 0:
        text_prompts = [f"a photo of a {name}" for name in list(id_to_name.values())]
        text_inputs = processor(text=text_prompts, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            try:
                text_features = model.get_text_features(**text_inputs)
            except:
                text_inputs = processor(text=text_prompts,return_tensors="pt",padding="max_length", 
                            truncation=True,max_length=processor.tokenizer.model_max_length ).to(device)
        text_features /= text_features.norm(p=2, dim=-1, keepdim=True)

    else:
        print(f"Preparing {k_shots}-shot prototypes...")
        class_features_list = [[] for _ in range(num_classes)]
        shot_counts = {i: 0 for i in range(num_classes)}
        random.shuffle(train_data)

        for image_path, true_label in train_data:
            if shot_counts[true_label] < k_shots:
                full_path = os.path.join(images_root, image_path)
                image = Image.open(full_path).convert("RGB")
                inputs = processor(images=image, return_tensors="pt").to(device)
                
                with torch.no_grad():
                    image_features = model.get_image_features(**inputs)
                    image_features /= image_features.norm(p=2, dim=-1, keepdim=True)
                    class_features_list[true_label].append(image_features)
                    shot_counts[true_label] += 1
            
            if all(count == k_shots for count in shot_counts.values()):
                break
        
        class_features = torch.zeros(num_classes, model.config.projection_dim).to(device)
        for i in range(num_classes):
            if class_features_list[i]:
                class_features[i] = torch.stack(class_features_list[i]).mean(dim=0).squeeze(0)
            else:
                print(f"Warning: Not enough shots collected")
        class_features /= class_features.norm(p=2, dim=-1, keepdim=True)

    print(f"Starting evaluation with {k_shots}-shot learning...")
    
    for image_path, true_label in tqdm(val_data): 
        full_path = os.path.join(images_root, image_path)
        image = Image.open(full_path).convert("RGB")
        
        inputs = processor(images=image, return_tensors="pt").to(device)
        
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
            image_features /= image_features.norm(p=2, dim=-1, keepdim=True)
            
            similarity = (image_features @ class_features.T).squeeze(0)  
            pred_label = similarity.argmax().item()
            
            if pred_label == true_label:
                correct += 1
            total += 1

    accuracy = correct / total * 100
    print(f"Test Accuracy ({k_shots}-shot): {accuracy:.2f}%")

In [None]:
test_model_few_shot(os.path.join(dataset_root, "../large-multi-modal/caption generation/gemini-small.txt"), k_shots=0)


--- 5-shot Evaluation ---
Starting evaluation with 0-shot learning...


  0%|          | 0/7508 [00:00<?, ?it/s]

Test Accuracy (0-shot): 22.11%


In [55]:
test_model_few_shot(os.path.join(dataset_root, "../large-multi-modal/caption generation/gemini-small.txt"), k_shots=10)

Preparing 10-shot prototypes...
Starting evaluation with 10-shot learning...


  0%|          | 0/7508 [00:00<?, ?it/s]

Test Accuracy (10-shot): 27.85%
