## CLIP ViT-B/32 Zero-Shot Analysis - STL-10

In [11]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset
import clip


In [12]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: mps


### STL-10 Dataloader (test set)

In [13]:
test_stl = datasets.STL10(
        root="./data",
        split='test',
        download=True,
        transform=preprocess
    )

loader = DataLoader(
        test_stl,
        batch_size= 128,
        shuffle=False,
        num_workers=2
    )

### Generating text embeddings of prompts with ensembling

In [14]:
stl10_classes = test_stl.classes

# Prompts used (for STL-10) in the paper "Learning Transferable Visual Models From Natural Language Supervision"
templates = [
    "a photo of the {}.",
    "a photo of the {}."
]

def get_text_features(templates, classes, model, device):
    all_text_features = []
    for classname in classes:
        if templates is not None:
            texts = [template.format(classname) for template in templates]
        else:
            texts = [classname]
        tokenized = clip.tokenize(texts).to(device)
        text_features = model.encode_text(tokenized)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        class_feature = text_features.mean(dim=0)
        class_feature /= class_feature.norm()
        all_text_features.append(class_feature)
    text_features = torch.stack(all_text_features, dim=0)
    return text_features

plain_text_features = get_text_features(None, stl10_classes, model, device).cpu()
ensemble_text_features = get_text_features(templates, stl10_classes, model, device).cpu()

### Zero-shot classification

In [15]:
def zero_shot_classification(loader, text_features, model, device, classes):
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}
    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Encode images
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            # Compute similarity with text features
            similarity = 100.0 * image_features @ text_features.to(device).T
            preds = similarity.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1

    accuracy = correct / total * 100

    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [16]:
plain_accuracy, plain_class_accuracies = zero_shot_classification(loader, plain_text_features, model, device, stl10_classes)
ensemble_accuracy, ensemble_class_accuracies = zero_shot_classification(loader, ensemble_text_features, model, device, stl10_classes)

print(f"Zero-shot accuracy on STL-10 (plain labels): {plain_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in plain_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

print(f"\nZero-shot accuracy on STL-10 (label ensembling): {ensemble_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in ensemble_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

Zero-shot accuracy on STL-10 (plain labels): 96.25%
Per-class accuracy:
    airplane  : 95.12%
    bird      : 99.75%
    car       : 94.75%
    cat       : 84.38%
    deer      : 98.50%
    dog       : 96.62%
    horse     : 99.00%
    monkey    : 94.88%
    ship      : 99.88%
    truck     : 99.62%

Zero-shot accuracy on STL-10 (label ensembling): 96.78%
Per-class accuracy:
    airplane  : 99.12%
    bird      : 99.75%
    car       : 95.75%
    cat       : 84.25%
    deer      : 98.12%
    dog       : 96.12%
    horse     : 98.62%
    monkey    : 97.25%
    ship      : 99.88%
    truck     : 98.88%
