Comparison of Zero-shot and Linear Probe Performance of Open CLIP-B-32 on CIFAR-10

#### Install packages

In [None]:
!pip install open_clip_torch
!pip install ftfy regex tqdm

In [2]:
import torch
import os
import numpy as np
from PIL import Image
import open_clip
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression

  from .autonotebook import tqdm as notebook_tqdm


#### Load CLIP ViT-B-32

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

A CLIP ViT-B/32 model trained with the LAION-2B English subset of LAION-5B

In [4]:
## Load CLIP Model with pretrained weight from Laion2B
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', device=device)


#### Load CIFAR10 and create DataLoader
root = os.path.expanduser("~/cache")
train_data = CIFAR10(root, download=True, train=True, transform=preprocess)
test_data = CIFAR10(root, download=True, train=False, transform=preprocess)

# Split the training dataset into 80% train and 20% validation (for Linear Probe Hyperparam Sweep)
train_size = int(0.8 * len(train_data))  # 80% for training
val_size = len(train_data) - train_size  # 20% for validation
train_data, val_data = random_split(train_data, [train_size, val_size])

batch_size = 2048


Files already downloaded and verified
Files already downloaded and verified


### Linear Probe Performance

#### Get features from Image Encoder

In [5]:
# Get features and labels for Linear Probe (input (n_sample, dim))
def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=batch_size)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

In [6]:
# Linear Probe Data
# Calculate features for train, validation, and test sets
train_features, train_labels = get_features(train_data)
val_features, val_labels = get_features(val_data)
test_features, test_labels = get_features(test_data)

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:29<00:00,  1.46s/it]
100%|██████████| 5/5 [00:06<00:00,  1.27s/it]
100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


#### Run Sweep to get the best L2 regularization

In [7]:
# Define hyperparameter sweep for regularization strength C (L2 regularization)
C_values = [10**-6, 10**-4, 10**-2, 1, 10**2, 10**4, 10**6]  # Logarithmic sweep

best_C = None
best_accuracy = 0

# Perform hyperparameter sweep over C
for C in C_values:
    print(f"Training with C = {C}")
    
    # Initialize the Logistic Regression classifier
    classifier = LogisticRegression(random_state=0, C=C, max_iter=1000, verbose=0)
    
    # Train the classifier on the training features
    classifier.fit(train_features, train_labels)
    
    # Evaluate the classifier on the validation set
    val_predictions = classifier.predict(val_features)
    val_accuracy = np.mean((val_labels == val_predictions).astype(float)) * 100.
    print(f"Validation Accuracy = {val_accuracy:.3f}%")
    
    # Track the best C value based on validation accuracy
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        best_C = C

print(f"Best C value: {best_C} with validation accuracy of {best_accuracy:.3f}%")

Training with C = 1e-06
Validation Accuracy = 94.070%
Training with C = 0.0001
Validation Accuracy = 95.510%
Training with C = 0.01
Validation Accuracy = 96.790%
Training with C = 1


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Validation Accuracy = 96.630%
Training with C = 100


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Validation Accuracy = 95.790%
Training with C = 10000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Validation Accuracy = 95.730%
Training with C = 1000000
Validation Accuracy = 95.710%
Best C value: 0.01 with validation accuracy of 96.790%


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Some C values output the "ConvergenceWarning: lbfgs failed to converge (status=1):STOP: TOTAL NO. of ITERATIONS REACHED LIMIT".
This project follows the original CLIP 

Test accuracy with the best L2 Regularization value

In [8]:
# Train the classifier on the entire training dataset with the best C
final_classifier = LogisticRegression(random_state=0, C=best_C, max_iter=1000, verbose=1)
final_classifier.fit(train_features, train_labels)

# Evaluate on the test set
test_predictions = final_classifier.predict(test_features)
test_accuracy = np.mean((test_labels == test_predictions).astype(float)) * 100.
print(f"Test Accuracy = {test_accuracy:.3f}%")

Test Accuracy = 96.790%


### Zero-shot Performance

Text Processing

In [9]:
## Zero-shot learning
model.eval()  # Set the model to evaluation mode
tokenizer = open_clip.get_tokenizer('ViT-B-32')

# Prepare text inputs (text descriptions for the CIFAR-10 classes)
text_inputs = torch.cat([tokenizer(f"a photo of a {c}") for c in test_data.classes]).to(device)
text_features = model.encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)  # Normalize


Zero-shot evaluation

In [10]:
correct_top1 = 0
correct_top5 = 0
total = len(test_data)

# Loop through all the test samples in batches
with torch.no_grad():
    for images, labels in tqdm(DataLoader(test_data, batch_size=batch_size)):
        image_features = model.encode_image(images.to(device))
        image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize

        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

        # Get the top 5 predictions for each image in the batch
        values, indices = similarity.topk(5, dim=-1)  # Get top 5 predictions for all images in the batch

        # Calculate top-1 and top-5 accuracy for each image in the batch
        for i in range(len(labels)):
            # Check if the true class is in the top 1 prediction
            if indices[i, 0].item() == labels[i]:
                correct_top1 += 1

            # Check if the true class is in the top 5 predictions
            if labels[i] in indices[i].tolist():
                correct_top5 += 1


100%|██████████| 5/5 [00:09<00:00,  1.92s/it]


Calculate Top-1 and Top-5 Accuracy

In [11]:
# Calculate top-1 and top-5 accuracy
top1_accuracy = correct_top1 / total * 100
top5_accuracy = correct_top5 / total * 100

print(f"Top-1 accuracy: {top1_accuracy:.2f}%")
print(f"Top-5 accuracy: {top5_accuracy:.2f}%")


Top-1 accuracy: 93.65%
Top-5 accuracy: 99.83%


### Influence of Class name on Zero-shot performance

In [12]:
# Define multiple variations of text descriptions for CIFAR-10 classes
text_variations = [
    # Simple description
    ("Simple description", lambda c: f"a photo of a {c}"),      # 93.65%
    ("Image description", lambda c: f"an image of a {c}"),      # 93.62%

    # Detailed description
    # ("Detailed description", lambda c: f"a detailed photo of a {c} with fine details"),       92.94%
    ("Detailed description", lambda c: f"a photo with the main subject of a {c}"),              # 94.14%


    # Adding adjectives
    # ("Large description", lambda c: f"a large photo of a {c}"),       91.98%
    # ("Small description", lambda c: f"a small photo of a {c}"),       93.06%

    # Sentence structure variations
    ("Sentence structure 1", lambda c: f"this is a photo of a {c}"),        # 93.39%
    ("Sentence structure 2", lambda c: f"a beautiful photo of a {c}"),      # 93.66%
]

accuracies = {}

# Iterate over the different text input variations
for variation_name, description in text_variations:
    # Prepare text inputs
    text_inputs = torch.cat([tokenizer(description(c)) for c in test_data.classes]).to(device)
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)  # Normalize

    correct_top1 = 0
    correct_top5 = 0
    total = len(test_data)

    # Loop through all the test samples in batches
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(test_data, batch_size=batch_size)):
            image_features = model.encode_image(images.to(device))
            image_features /= image_features.norm(dim=-1, keepdim=True)  # Normalize

            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

            # Get the top 5 predictions for each image in the batch
            values, indices = similarity.topk(5, dim=-1)

            # Calculate top-1 and top-5 accuracy for each image in the batch
            for i in range(len(labels)):
                # Check if the true class is in the top 1 prediction
                if indices[i, 0].item() == labels[i]:
                    correct_top1 += 1

                # Check if the true class is in the top 5 predictions
                if labels[i] in indices[i].tolist():
                    correct_top5 += 1

    # Calculate top-1 and top-5 accuracy
    top1_accuracy = correct_top1 / total * 100
    top5_accuracy = correct_top5 / total * 100

    # Store the results for this variation
    accuracies[variation_name] = (top1_accuracy, top5_accuracy)

# Output the accuracies for all variations
for variation_name, (top1, top5) in accuracies.items():
    print(f"Text variation: {variation_name}")
    print(f"  Top-1 accuracy: {top1:.2f}%")
    print(f"  Top-5 accuracy: {top5:.2f}%")

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:09<00:00,  1.90s/it]
100%|██████████| 5/5 [00:09<00:00,  2.00s/it]
100%|██████████| 5/5 [00:10<00:00,  2.04s/it]
100%|██████████| 5/5 [00:09<00:00,  1.90s/it]
100%|██████████| 5/5 [00:09<00:00,  1.99s/it]

Text variation: Simple description
  Top-1 accuracy: 93.65%
  Top-5 accuracy: 99.83%
Text variation: Image description
  Top-1 accuracy: 93.62%
  Top-5 accuracy: 99.83%
Text variation: Detailed description
  Top-1 accuracy: 94.14%
  Top-5 accuracy: 99.77%
Text variation: Sentence structure 1
  Top-1 accuracy: 93.39%
  Top-5 accuracy: 99.79%
Text variation: Sentence structure 2
  Top-1 accuracy: 93.66%
  Top-5 accuracy: 99.72%



