## CLIP ViT-B/32 Shape / Texture Bias

In [33]:
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import clip
import os


In [34]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print("Device:", device)
model, preprocess = clip.load("ViT-B/32", device=device)

Device: mps


### Custom Cue-Conflicted STL-10 Test Set

In [None]:
transform = preprocess

stl_testset = datasets.STL10(root='data', split='test', download=True, transform=preprocess)

test_loader = DataLoader(
    stl_testset,
    batch_size=128,
    shuffle=False,
    num_workers=2
)

cc_dataset = datasets.ImageFolder(root='data/cue_conflict_dataset', transform=preprocess)
cc_dataset.samples.sort(key=lambda x: int(''.join(filter(str.isdigit, os.path.basename(x[0])))))

cc_loader = DataLoader(
    cc_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2
)

  0%|          | 295k/2.64G [00:18<46:25:13, 15.8kB/s] 


KeyboardInterrupt: 

### Text Embeddings
Simple Template: 'a {class}'

In [None]:
stl10_classes = [
    "airplane", "bird", "car", "cat", "deer",
    "dog", "horse", "monkey","ship", "truck"
]

all_text_features = []
for classname in stl10_classes:
    text = f"a {classname}"
    tokenized = clip.tokenize(text).to(device)
    text_features = model.encode_text(tokenized)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    class_feature = text_features.mean(dim=0)
    class_feature /= class_feature.norm()
    all_text_features.append(class_feature)
text_features = torch.stack(all_text_features, dim=0)

In [37]:
def zero_shot_classification(loader, text_features, model, device, classes):
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}
    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Encode images
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            # Compute similarity with text features
            similarity = 100.0 * image_features @ text_features.to(device).T
            preds = similarity.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1

    accuracy = correct / total * 100

    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [None]:
test_accuracy, test_class_accuracies = zero_shot_classification(test_loader, text_features, model, device, stl10_classes)

print(f"Zero-shot accuracy on STL-10 test set: {test_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in test_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

cc_accuracy, cc_class_accuracies = zero_shot_classification(cc_loader, text_features, model, device, stl10_classes)

print(f"Zero-shot accuracy on cue-conflicted STL-10 test set: {cc_accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in cc_class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")


# Do Shape bias calcs

Zero-shot accuracy on cue-conflicted STL-10 subset: 77.58%
Per-class accuracy:
    airplane  : 73.88%
    bird      : 88.00%
    car       : 96.50%
    cat       : 65.00%
    deer      : 71.62%
    dog       : 73.38%
    horse     : 91.38%
    monkey    : 59.50%
    ship      : 76.62%
    truck     : 79.88%
