## CLIP ViT-B/32 Zero-Shot Analysis

In [21]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import clip
from PIL import Image


In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: cuda


### CIFAR-10 Dataloader (full set)

In [23]:
transform = preprocess

train_dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )
test_dataset = datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )

full_dataset = ConcatDataset([train_dataset, test_dataset])

loader = DataLoader(
        full_dataset,
        batch_size= 128,
        shuffle=False,
        num_workers=2
    )

### Generating text embeddings of prompts with ensembling

In [24]:
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Prompts used in the paper "Learning Transferable Visual Models From Natural Language Supervision"
templates = [
    'a photo of a {}.',
    'a blurry photo of a {}.',
    'a black and white photo of a {}.',
    'a low contrast photo of a {}.',
    'a high contrast photo of a {}.',
    'a bad photo of a {}.',
    'a good photo of a {}.',
    'a photo of a small {}.',
    'a photo of a big {}.',
    'a photo of the {}.',
    'a blurry photo of the {}.',
    'a black and white photo of the {}.',
    'a low contrast photo of the {}.',
    'a high contrast photo of the {}.',
    'a bad photo of the {}.',
    'a good photo of the {}.',
    'a photo of the small {}.',
    'a photo of the big {}.',
]

with torch.no_grad():
    all_text_features = []
    for classname in cifar10_classes:
        texts = [template.format(classname) for template in templates]
        tokenized = clip.tokenize(texts).to(device)
        text_features = model.encode_text(tokenized)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        class_feature = text_features.mean(dim=0)
        class_feature /= class_feature.norm()
        all_text_features.append(class_feature)
        text_features = torch.stack(all_text_features, dim=0)

### Zero-shot classification

In [25]:
correct = 0
total = 0

with torch.no_grad():
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        # Encode images
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # Compute similarity
        similarity = 100.0 * image_features @ text_features.T
        preds = similarity.argmax(dim=1)

        # Count correct predictions
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total * 100
print(f"Zero-shot CLIP accuracy on CIFAR-10: {accuracy:.2f}%")

Zero-shot CLIP accuracy on CIFAR-10: 89.68%
