## CLIP ViT-B/32 Shape / Texture Bias

In [8]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import clip
import os


In [9]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: mps


### Custom Cue-Conflicted STL-10 Test Set

In [10]:
cc_dataset = datasets.ImageFolder(root='data/cue_conflict_dataset', transform=preprocess)
cc_dataset.samples.sort(key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x[0])))))

cc_loader = DataLoader(
    cc_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2
)

### Text Embeddings
Simple Template: 'a {class}'

In [11]:
stl10_classes = [
    "airplane", "bird", "car", "cat", "deer",
    "dog", "horse", "monkey","ship", "truck"
]

all_text_features = []
for classname in stl10_classes:
    text = f"a {classname}"
    tokenized = clip.tokenize(text).to(device)
    text_features = model.encode_text(tokenized)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    class_feature = text_features.mean(dim=0)
    class_feature /= class_feature.norm()
    all_text_features.append(class_feature)
text_features = torch.stack(all_text_features, dim=0)

In [12]:
def zero_shot_classification(loader, text_features, model, device, classes):
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}
    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Encode images
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            # Compute similarity with text features
            similarity = 100.0 * image_features @ text_features.to(device).T
            preds = similarity.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1

    accuracy = correct / total * 100

    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [13]:
cc_accuracy, cc_class_accuracies = zero_shot_classification(cc_loader, text_features, model, device, stl10_classes)

print(f"Zero-shot accuracy on cue-conflicted STL-10 test set: {cc_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in cc_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")


Zero-shot accuracy on cue-conflicted STL-10 test set: 75.11%
Per-class accuracy:
    airplane  : 68.25%
    bird      : 87.38%
    car       : 96.50%
    cat       : 62.12%
    deer      : 71.38%
    dog       : 72.62%
    horse     : 91.00%
    monkey    : 57.88%
    ship      : 68.62%
    truck     : 75.38%


In [15]:
stl_test_acc = 96.78
stl_class_accuraries = {
    "airplane"  : 99.12,
    "bird"      : 99.75,
    "car"       : 95.75,
    "cat"       : 84.25,
    "deer"      : 98.12,
    'dog'       : 96.12,
    "horse"     : 98.62,
    "monkey"    : 97.25,
    "ship"      : 99.88,
    "truck"     : 98.88
}

clip_shape_bias = (cc_accuracy/stl_test_acc)*100
clip_class_bias = {
    classname: (cc_class_accuracies[classname] / stl_class_accuraries[classname]) * 100
    for classname in cc_class_accuracies
}


print(f"CLIP Shape Bias Result: {clip_shape_bias:.2f}%")
print("Per-class shape bias:")
for classname, rob in clip_class_bias.items():
    print(f"    {classname:10s}: {rob:.2f}%")

CLIP Shape Bias Result: 77.61%
Per-class shape bias:
    airplane  : 68.86%
    bird      : 87.59%
    car       : 100.78%
    cat       : 73.74%
    deer      : 72.74%
    dog       : 75.56%
    horse     : 92.27%
    monkey    : 59.51%
    ship      : 68.71%
    truck     : 76.23%
