In [1]:
import open_clip

model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", 
    pretrained="laion2b_s34b_b79k"
)

tokenizer = open_clip.get_tokenizer("ViT-B-32")

labels = [
    "an airplane",
    "a bird",
    "a car",
    "a cat",
    "a deer",
    "a dog",
    "a horse",
    "a monkey",
    "a ship",
    "a truck"
]

text = tokenizer(labels)
text_embeddings = model.encode_text(text)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
len(text_embeddings)

10

Now, embeddings contains a 512 length vector for each text prompt. This vector has the same size as the vision encoder output. The dot product of the vector with vision features indicates the similarity, so we can determine the class probabilities for our dataset as follows

In [3]:
import torch.nn.functional as F

def embeddings_to_class_probs(vision_embeddings, text_embeddings):
    vision_embeddings = vision_embeddings / vision_embeddings.norm(dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
    logits = vision_embeddings @ text_embeddings.T
    class_probs = F.softmax(100. * logits, dim=-1)
    return class_probs

Now that we have the text embeddings for our target task, and a method to compare these against the image embeddings, all that's left to do is run the STL10 dataset through OpenCLIP vision encoder, compute the output class probabilities, and compare the result against the ground truth label.

In [4]:
import tqdm
import torch
from torchvision.datasets import STL10

dataset = STL10(
    root="./stl10",
    download=True,
    split="test"
)

num_correct = 0

for image, label in tqdm.tqdm(dataset):
    input_tensor = preprocess(image).unsqueeze(0)
    vision_embeddings = model.encode_image(input_tensor)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(dataset)

100%|██████████| 8000/8000 [04:51<00:00, 27.45it/s]


In [5]:
accuracy

96.675

And after this, out of the box, the OpenCLIP encoder, without any additional training, get's 96.68% accuracy on the STL10 test dataset! With no tricks, we achieved fairly competitive accuracy on the STL10 dataset, for comparison you can see other competitive results on the STL10 dataset here

Using linear head for classification

As shown, using the text prompts as class labels, we were able to achieve pretty good accuracy on the STL10 dataset without any training or ground truth labels. But what if we have ground truth labels available? Can we use this to improve the accuracy?

With this option, we'll explore how we can use some ground truth data to train a tiny logistic regression layer (linear layer followed by softmax) at the end of the OpenCLIP model and see if this improves the accuracy.

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import STL10
from torch.utils.data import DataLoader


device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

linear_probe = nn.Linear(512, len(labels))
linear_probe.to(device)

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4)

train_dataset = STL10(
    root="./stl10",
    download=True,
    split="train",
    transform=preprocess
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

num_epochs = 10
for epoch in range(num_epochs):
    linear_probe.train()
    total_loss = 0
    for image, label in iter(train_loader):
        image = image.to(device)
        label = label.to(device)

        # Run open-clip to get vision embeddings
        with torch.no_grad():
            vision_embeddings = model.encode_image(image)
            if vision_embeddings.dtype == torch.float16:
                vision_embeddings = vision_embeddings.float()


        optimizer.zero_grad()
        output_logits = linear_probe(vision_embeddings)

        loss = F.cross_entropy(output_logits, label)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")


Epoch 1/10, Loss: 1.7450
Epoch 2/10, Loss: 0.8890
Epoch 3/10, Loss: 0.4993
Epoch 4/10, Loss: 0.3268
Epoch 5/10, Loss: 0.2386
Epoch 6/10, Loss: 0.1871
Epoch 7/10, Loss: 0.1549
Epoch 8/10, Loss: 0.1337
Epoch 9/10, Loss: 0.1159
Epoch 10/10, Loss: 0.1016


In [7]:
# Evaluate the linear probe on the test set
test_dataset = STL10(
    root="./stl10",
    download=True,
    split="test",
    transform=preprocess
)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

linear_probe.eval() 
model.eval() 

num_correct_test = 0
total_test_samples = 0

with torch.no_grad(): 
    for image, label in tqdm.tqdm(test_loader, desc="Evaluating on test set"):
        image = image.to(device)
        label = label.to(device)

        # Run open-clip to get vision embeddings
        vision_embeddings = model.encode_image(image)
        if vision_embeddings.dtype == torch.float16:
            vision_embeddings = vision_embeddings.float()
        
        output_logits = linear_probe(vision_embeddings)
        
        # Get predicted labels
        _, predicted_labels = torch.max(output_logits, 1)
        
        num_correct_test += (predicted_labels == label).sum().item()
        total_test_samples += label.size(0)

test_accuracy = 100. * num_correct_test / total_test_samples
print(f"Test Accuracy: {test_accuracy:.2f}%")

Evaluating on test set: 100%|██████████| 125/125 [01:31<00:00,  1.37it/s]

Test Accuracy: 98.42%





After training the linear probe, we evaluate it on the STL10 dataset, similar to before, and our accuracy is now 98.57!

Great! By using some labeled data, we were able to train a small logistic regression layer that improves the accuracy of OpenCLIP on the STL10 dataset by nearly +2%!

This improvement is likely because our text prompts, like "an airplane", might not perfectly match the labels as they appear in the STL10 dataset. But by seeing a few examples for each label, we can learn reference embeddings that more accurately represent the class labels.

## Training a student model to mimic OpenCLIP

We've now seen that using the large OpenCLIP model, we can achieve competetive results on the STL10 image classification dataset with little effort. But OpenCLIP is large, and is likely to have high memory consumption and latency compared to other model architectures. In addition, as a vision transformer model, OpenCLIP is less capable of exploiting the Deep Learning Accelerator (DLA) on Jetson AGX Orin, given the matrix multiplication in the attention layers. CNN models, like resnet18, on the other hand are highly optimized by both the GPU and DLA on Jetson, and allow us to run models at higher throughput and with less memory.

However, knowledge distillation can impact the accuracy of the model, so we'd like to better understand what factors are most important. To do this, we've run a few experiments that seek to answer a few questions:

How does distillation compare to training with ground truth labels?
How does the data distribution used for training impact model accuracy?
How does the distillation method impact model accuracy? Is it better to train on the class probabilities or internal features?
How does the student model architecture impact model accuracy? Will resnet50 obtain higher accuracy than resnet18?



In [1]:
from stl10_utils import (
    precompute_clip_stl10_train_image_embeddings,
    precompute_clip_stl10_test_image_embeddings,
    precompute_clip_stl10_text_embeddings,
    train_resnet18_from_scratch,
    train_resnet18_linear_probe_train_only
)

precompute_clip_stl10_train_image_embeddings()
precompute_clip_stl10_test_image_embeddings()
precompute_clip_stl10_text_embeddings()
train_resnet18_from_scratch()


  from .autonotebook import tqdm as notebook_tqdm
5000it [00:00, 8787.94it/s]
8000it [00:00, 9257.65it/s]


Using device: mps for train_model_from_scratch




Epoch 1 - Test Loss: 1.7259, Test Accuracy: 34.29%
Epoch 2 - Test Loss: 1.6783, Test Accuracy: 39.52%
Epoch 3 - Test Loss: 1.4484, Test Accuracy: 46.79%
Epoch 4 - Test Loss: 1.3064, Test Accuracy: 51.29%
Epoch 5 - Test Loss: 1.2840, Test Accuracy: 52.86%
Epoch 6 - Test Loss: 1.2543, Test Accuracy: 53.89%
Epoch 7 - Test Loss: 1.3261, Test Accuracy: 55.14%
Epoch 8 - Test Loss: 1.2820, Test Accuracy: 57.17%
Epoch 9 - Test Loss: 1.3840, Test Accuracy: 56.45%
Epoch 10 - Test Loss: 1.3339, Test Accuracy: 57.70%
Finished Training


In [3]:
train_resnet18_linear_probe_train_only()

Using device: mps for probe model in train_student_linear_probe
Probe checkpoint data/experiments/train_probe_model_linear/checkpoint_9.pth not found. Attempting to train probe model...
Using device: mps for train_probe_model_linear
Starting training for linear probe model. Output will be in data/experiments/train_probe_model_linear
Using device: mps for train_probe_model


100%|██████████| 79/79 [00:12<00:00,  6.20it/s]
100%|██████████| 125/125 [00:12<00:00,  9.82it/s]

| EPOCH 0 | TRAIN LOSS 3.7658857170836875 | TEST ACC 97.138 |
Saving checkpoint for epoch 0 to data/experiments/train_probe_model_linear/checkpoint_0.pth



100%|██████████| 79/79 [00:11<00:00,  6.69it/s]
100%|██████████| 125/125 [00:12<00:00,  9.80it/s]

| EPOCH 1 | TRAIN LOSS 0.2560871586723612 | TEST ACC 97.388 |
Saving checkpoint for epoch 1 to data/experiments/train_probe_model_linear/checkpoint_1.pth



100%|██████████| 79/79 [00:11<00:00,  6.71it/s]
100%|██████████| 125/125 [00:12<00:00,  9.84it/s]

| EPOCH 2 | TRAIN LOSS 0.1538377963898914 | TEST ACC 97.412 |
Saving checkpoint for epoch 2 to data/experiments/train_probe_model_linear/checkpoint_2.pth



100%|██████████| 79/79 [00:11<00:00,  6.65it/s]
100%|██████████| 125/125 [00:12<00:00,  9.78it/s]

| EPOCH 3 | TRAIN LOSS 0.07637745670479079 | TEST ACC 97.537 |
Saving checkpoint for epoch 3 to data/experiments/train_probe_model_linear/checkpoint_3.pth



100%|██████████| 79/79 [00:11<00:00,  6.72it/s]
100%|██████████| 125/125 [00:12<00:00,  9.85it/s]

| EPOCH 4 | TRAIN LOSS 0.048721451878915974 | TEST ACC 97.688 |
Saving checkpoint for epoch 4 to data/experiments/train_probe_model_linear/checkpoint_4.pth



100%|██████████| 79/79 [00:11<00:00,  6.69it/s]
100%|██████████| 125/125 [00:12<00:00,  9.81it/s]

| EPOCH 5 | TRAIN LOSS 0.02843519240430204 | TEST ACC 97.4 |
Saving checkpoint for epoch 5 to data/experiments/train_probe_model_linear/checkpoint_5.pth



100%|██████████| 79/79 [00:11<00:00,  6.70it/s]
100%|██████████| 125/125 [00:12<00:00,  9.81it/s]

| EPOCH 6 | TRAIN LOSS 0.02611771381016021 | TEST ACC 97.5 |
Saving checkpoint for epoch 6 to data/experiments/train_probe_model_linear/checkpoint_6.pth



100%|██████████| 79/79 [00:11<00:00,  6.72it/s]
100%|██████████| 125/125 [00:12<00:00,  9.81it/s]

| EPOCH 7 | TRAIN LOSS 0.026856589847502027 | TEST ACC 97.6 |
Saving checkpoint for epoch 7 to data/experiments/train_probe_model_linear/checkpoint_7.pth



100%|██████████| 79/79 [00:11<00:00,  6.73it/s]
100%|██████████| 125/125 [00:12<00:00,  9.81it/s]

| EPOCH 8 | TRAIN LOSS 0.014373054670230478 | TEST ACC 97.638 |
Saving checkpoint for epoch 8 to data/experiments/train_probe_model_linear/checkpoint_8.pth



100%|██████████| 79/79 [00:11<00:00,  6.64it/s]
100%|██████████| 125/125 [00:12<00:00,  9.82it/s]


| EPOCH 9 | TRAIN LOSS 0.008312782411593696 | TEST ACC 97.812 |
Saving checkpoint for epoch 9 to data/experiments/train_probe_model_linear/checkpoint_9.pth
Finished training for linear probe model.
Loading probe model weights from data/experiments/train_probe_model_linear/checkpoint_9.pth
Starting student model training (arch: resnet18) using linear probe.
Using device: mps for train_student_classification_model


TypeError: object of type 'NoneType' has no len()

In [None]:
from stl10 import (
    precompute_clip_stl10_train_image_embeddings,
    precompute_clip_stl10_unlabeled_image_embeddings,
    precompute_clip_stl10_test_image_embeddings,
    precompute_clip_stl10_text_embeddings,
    train_resnet18_linear_probe
)

precompute_clip_stl10_train_image_embeddings()
precompute_clip_stl10_unlabeled_image_embeddings()
precompute_clip_stl10_test_image_embeddings()
precompute_clip_stl10_text_embeddings()
train_resnet18_linear_probe()