## CLIP ViT-B/32 Zero-Shot Analysis - PACS Sketch

In [None]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import clip
from PIL import Image


In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: cuda


### PACS Sketch Dataloader (full set)

In [14]:
transform = preprocess

data_dir = "data/pacs_data/pacs_data/sketch"

sketch_dataset = datasets.ImageFolder(root=data_dir, transform=preprocess)

sketch_loader = DataLoader(
    sketch_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)


### Generating text embeddings of prompts with ensembling

In [None]:
PACS_classes = sketch_dataset.classes

# Prompts adapted from the paper "Learning Transferable Visual Models From Natural Language Supervision"
# Similar to the ones used for CIFAR
templates = [
    'a sketch of a {}.',
    'a blurry sketch of a {}.',
    'a black and white sketch of a {}.',
    'a bad sketch of a {}.',
    'a good sketch of a {}.',
    'a sketch of a small {}.',
    'a sketch of a big {}.',
    'a sketch of the {}.',
    'a blurry sketch of the {}.',
    'a black and white drawing of the {}.',
    'a bad drawing of the {}.',
    'a good drawing of the {}.',
    'a drawing of the small {}.',
    'a drawing of the big {}.',
]

with torch.no_grad():
    all_text_features = []
    for classname in PACS_classes:
        texts = [template.format(classname) for template in templates]
        tokenized = clip.tokenize(texts).to(device)
        text_features = model.encode_text(tokenized)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        class_feature = text_features.mean(dim=0)
        class_feature /= class_feature.norm()
        all_text_features.append(class_feature)
        text_features = torch.stack(all_text_features, dim=0)

### Zero-shot classification

In [16]:
correct = 0
total = 0

with torch.no_grad():
    for images, labels in sketch_loader:
        images, labels = images.to(device), labels.to(device)

        # Encode images
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # Compute similarity
        similarity = 100.0 * image_features @ text_features.T
        preds = similarity.argmax(dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"Zero-shot CLIP accuracy on PACS Sketch: {accuracy:.2f}%")

Zero-shot CLIP accuracy on PACS Sketch: 86.59%
