## CLIP ViT-B/32 Zero-Shot Analysis - CIFAR-10

In [5]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset
import clip


In [6]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: mps


### CIFAR-10 Dataloader (test set)

In [7]:
test_dataset = datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=preprocess
    )

loader = DataLoader(
        test_dataset,
        batch_size= 128,
        shuffle=False,
        num_workers=2
    )

100%|██████████| 170M/170M [02:04<00:00, 1.37MB/s] 


### Generating text embeddings of prompts with ensembling

In [8]:
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Prompts used (for CIFAR-10) in the paper "Learning Transferable Visual Models From Natural Language Supervision"
templates = [
    'a photo of a {}.',
    'a blurry photo of a {}.',
    'a black and white photo of a {}.',
    'a low contrast photo of a {}.',
    'a high contrast photo of a {}.',
    'a bad photo of a {}.',
    'a good photo of a {}.',
    'a photo of a small {}.',
    'a photo of a big {}.',
    'a photo of the {}.',
    'a blurry photo of the {}.',
    'a black and white photo of the {}.',
    'a low contrast photo of the {}.',
    'a high contrast photo of the {}.',
    'a bad photo of the {}.',
    'a good photo of the {}.',
    'a photo of the small {}.',
    'a photo of the big {}.',
]

def get_text_features(templates, classes, model, device):
    all_text_features = []
    for classname in classes:
        if templates is not None:
            texts = [template.format(classname) for template in templates]
        else:
            texts = [classname]
        tokenized = clip.tokenize(texts).to(device)
        text_features = model.encode_text(tokenized)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        class_feature = text_features.mean(dim=0)
        class_feature /= class_feature.norm()
        all_text_features.append(class_feature)
    text_features = torch.stack(all_text_features, dim=0)
    return text_features

plain_text_features = get_text_features(None, cifar10_classes, model, device).cpu()
ensemble_text_features = get_text_features(templates, cifar10_classes, model, device).cpu()

### Zero-shot classification

In [9]:
def zero_shot_classification(loader, text_features, model, device, classes):
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}
    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Encode images
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            # Compute similarity with text features
            similarity = 100.0 * image_features @ text_features.to(device).T
            preds = similarity.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1

    accuracy = correct / total * 100

    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [10]:
plain_accuracy, plain_class_accuracies = zero_shot_classification(loader, plain_text_features, model, device, cifar10_classes)
ensemble_accuracy, ensemble_class_accuracies = zero_shot_classification(loader, ensemble_text_features, model, device, cifar10_classes)

print(f"Zero-shot accuracy on CIFAR-10 (plain labels): {plain_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in plain_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

print(f"\nZero-shot accuracy on CIFAR-10 (label ensembling): {ensemble_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in ensemble_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

Zero-shot accuracy on CIFAR-10 (plain labels): 87.35%
Per-class accuracy:
    airplane  : 81.20%
    automobile: 89.80%
    bird      : 91.70%
    cat       : 79.20%
    deer      : 82.10%
    dog       : 85.20%
    frog      : 76.80%
    horse     : 97.10%
    ship      : 96.40%
    truck     : 94.00%

Zero-shot accuracy on CIFAR-10 (label ensembling): 89.85%
Per-class accuracy:
    airplane  : 91.60%
    automobile: 94.40%
    bird      : 89.10%
    cat       : 85.30%
    deer      : 82.40%
    dog       : 87.50%
    frog      : 81.50%
    horse     : 97.50%
    ship      : 95.40%
    truck     : 93.80%
