In [9]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import clip
from PIL import Image


In [10]:
def load_stl10_images(path, n_images):
    with open(path, "rb") as f:
        data = np.fromfile(f, dtype=np.uint8)
        data = data.reshape(n_images, 3, 96, 96)
        data = np.transpose(data, (0, 2, 3, 1))
    return data

def load_stl10_labels(path, n_images):
    with open(path, "rb") as f:
        labels = np.fromfile(f, dtype=np.uint8)
    return labels

class STL10Dataset(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.fromarray(self.images[idx].astype(np.uint8))
        label = int(self.labels[idx]) - 1
        img = self.transform(img)
        return img, label

if __name__ == "__main__":
    # Load data
    n_train, n_test = 5000, 8000
    X_train = load_stl10_images("data/stl10_binary/train_X.bin", n_train)
    y_train = load_stl10_labels("data/stl10_binary/train_y.bin", n_train)
    X_test  = load_stl10_images("data/stl10_binary/test_X.bin", n_test)
    y_test  = load_stl10_labels("data/stl10_binary/test_y.bin", n_test)

    # Load CLIP
    device = "mps" if torch.mps.is_available() else "cpu"
    print("Using device:", device)
    model, preprocess = clip.load("ViT-B/32", device=device)

    # Datasets
    train_dataset = STL10Dataset(X_train, y_train, transform=preprocess)
    test_dataset  = STL10Dataset(X_test, y_test, transform=preprocess)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)  # safer on macOS
    test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

    # Class names
    class_names = [
        "airplane", "bird", "car", "cat", "deer",
        "dog", "horse", "monkey", "ship", "truck"
    ]

    # Encode text
    text_prompts = [name for name in class_names]
    text_tokens = clip.tokenize(text_prompts).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    # Evaluation
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            logits = 100.0 * image_features @ text_features.T
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Zero-shot accuracy on test set (Plain labels): {acc*100:.2f}")

    # Encode text
    text_prompts = [f"A photo of a {name}" for name in class_names]
    text_tokens = clip.tokenize(text_prompts).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    # Evaluation
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            logits = 100.0 * image_features @ text_features.T
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Zero-shot accuracy on test set (Prompted text): {acc*100:.2f}")

    # Encode text
    text_prompts = [f"An image containing a {name}" for name in class_names]
    text_tokens = clip.tokenize(text_prompts).to(device)

    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    # Evaluation
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            logits = 100.0 * image_features @ text_features.T
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Zero-shot accuracy on test set (Descriptive text): {acc*100:.2f}")

Using device: mps
Zero-shot accuracy on test set (Plain labels): 84.86
Zero-shot accuracy on test set (Prompted text): 87.85
Zero-shot accuracy on test set (Descriptive text): 87.09
