## CLIP ViT-B/32 Zero-Shot Analysis - Domain-Shifted (PACS)

In [1]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import clip


  from pkg_resources import packaging


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: cuda


### PACS Dataloaders for each domain

In [3]:
sketch_dir = "data/pacs_data/pacs_data/sketch"
art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"


sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=preprocess)
art_dataset = datasets.ImageFolder(root=art_dir, transform=preprocess)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=preprocess)

sketch_loader = DataLoader(
    sketch_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
art_loader = DataLoader(
    art_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
cartoon_loader = DataLoader(
    cartoon_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)


### Generating text embeddings of prompts with ensembling for each domain

In [4]:
PACS_classes = sketch_dataset.classes

# Prompts adapted from the paper "Learning Transferable Visual Models From Natural Language Supervision"
# Similar to the ones used for CIFAR
templates = [
    'a {} of a {}.',
    'a blurry {} of a {}.',
    'a black and white {} of a {}.',
    'a bad {} of a {}.',
    'a good {} of a {}.',
    'a {} of a small {}.',
    'a {} of a big {}.',
    'a {} of the {}.',
    'a blurry {} of the {}.',
    'a black and white {} of the {}.',
    'a bad {} of the {}.',
    'a good {} of the {}.',
    'a {} of the small {}.',
    'a {} of the big {}.',
]

def get_text_features(templates, domain, classes, model, device):
    all_text_features = []
    for classname in classes:
        texts = [template.format(domain, classname) for template in templates]
        tokenized = clip.tokenize(texts).to(device)
        text_features = model.encode_text(tokenized)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        class_feature = text_features.mean(dim=0)
        class_feature /= class_feature.norm()
        all_text_features.append(class_feature)
    text_features = torch.stack(all_text_features, dim=0)
    return text_features

sketch_text_features = get_text_features(templates, 'sketch', PACS_classes, model, device)
art_text_features = get_text_features(templates, 'painting', PACS_classes, model, device)
cartoon_text_features = get_text_features(templates, 'cartoon', PACS_classes, model, device)


### Zero-shot classification

In [5]:
def zero_shot_classification(loader, text_features, model, device, classes):
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Encode images
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            # Compute similarity with text features
            similarity = 100.0 * image_features @ text_features.T
            preds = similarity.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1

    accuracy = correct / total * 100

    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [6]:
sketch_accuracy, sketch_class_accuracies = zero_shot_classification(sketch_loader, sketch_text_features, model, device, PACS_classes)
art_accuracy, art_class_accuracies = zero_shot_classification(art_loader, art_text_features, model, device, PACS_classes)
cartoon_accuracy, cartoon_class_accuracies = zero_shot_classification(cartoon_loader, cartoon_text_features, model, device, PACS_classes)

print(f"Zero-shot CLIP accuracy on PACS Sketch:         {sketch_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in sketch_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

print(f"\nZero-shot CLIP accuracy on PACS Art/Painting: {art_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in art_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

print(f"\nZero-shot CLIP accuracy on PACS Cartoon:      {cartoon_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in cartoon_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")


Zero-shot CLIP accuracy on PACS Sketch:         86.56%
Per-class accuracy:
    dog       : 72.67%
    elephant  : 92.57%
    giraffe   : 70.92%
    guitar    : 99.01%
    horse     : 98.41%
    house     : 98.75%
    person    : 85.62%

Zero-shot CLIP accuracy on PACS Art/Painting: 94.87%
Per-class accuracy:
    dog       : 93.93%
    elephant  : 95.69%
    giraffe   : 98.25%
    guitar    : 95.11%
    horse     : 97.51%
    house     : 100.00%
    person    : 88.42%

Zero-shot CLIP accuracy on PACS Cartoon:      97.53%
Per-class accuracy:
    dog       : 96.66%
    elephant  : 99.78%
    giraffe   : 99.13%
    guitar    : 100.00%
    horse     : 96.30%
    house     : 99.65%
    person    : 93.09%
