In [7]:
import open_clip

model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", 
    pretrained="laion2b_s34b_b79k"
)

tokenizer = open_clip.get_tokenizer("ViT-B-32")

labels = [
    "an airplane",
    "a bird",
    "a car",
    "a cat",
    "a deer",
    "a dog",
    "a horse",
    "a monkey",
    "a ship",
    "a truck"
]

text = tokenizer(labels)
text_embeddings = model.encode_text(text)

In [8]:
len(text_embeddings)

10

Now, embeddings contains a 512 length vector for each text prompt. This vector has the same size as the vision encoder output. The dot product of the vector with vision features indicates the similarity, so we can determine the class probabilities for our dataset as follows

In [9]:
import torch.nn.functional as F

def embeddings_to_class_probs(vision_embeddings, text_embeddings):
    vision_embeddings = vision_embeddings / vision_embeddings.norm(dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
    logits = vision_embeddings @ text_embeddings.T
    class_probs = F.softmax(100. * logits, dim=-1)
    return class_probs

Now that we have the text embeddings for our target task, and a method to compare these against the image embeddings, all that's left to do is run the STL10 dataset through OpenCLIP vision encoder, compute the output class probabilities, and compare the result against the ground truth label.

In [14]:
import tqdm
import torch
from torchvision.datasets import STL10

dataset = STL10(
    root="./stl10",
    download=True,
    split="test"
)

num_correct = 0

for image, label in tqdm.tqdm(dataset):
    input_tensor = preprocess(image).unsqueeze(0)
    vision_embeddings = model.encode_image(input_tensor)
    output_class_probs = embeddings_to_class_probs(vision_embeddings, text_embeddings)
    output_label = torch.argmax(output_class_probs, dim=-1)
    num_correct += int(torch.count_nonzero(output_label == label))

accuracy = 100. * num_correct / len(dataset)

100%|██████████| 8000/8000 [04:44<00:00, 28.11it/s]


In [20]:
accuracy

96.675

And after this, out of the box, the OpenCLIP encoder, without any additional training, get's 96.68% accuracy on the STL10 test dataset! With no tricks, we achieved fairly competitive accuracy on the STL10 dataset, for comparison you can see other competitive results on the STL10 dataset here

Using linear head for classification

As shown, using the text prompts as class labels, we were able to achieve pretty good accuracy on the STL10 dataset without any training or ground truth labels. But what if we have ground truth labels available? Can we use this to improve the accuracy?

With this option, we'll explore how we can use some ground truth data to train a tiny logistic regression layer (linear layer followed by softmax) at the end of the OpenCLIP model and see if this improves the accuracy.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import STL10
from torch.utils.data import DataLoader


device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

linear_probe = nn.Linear(512, len(labels))
linear_probe.to(device)

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=3e-4)

train_dataset = STL10(
    root="./stl10",
    download=True,
    split="train",
    transform=preprocess
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

num_epochs = 10
for epoch in range(num_epochs):
    linear_probe.train()
    total_loss = 0
    for image, label in iter(train_loader):
        image = image.to(device)
        label = label.to(device)

        # Run open-clip to get vision embeddings
        with torch.no_grad():
            vision_embeddings = model.encode_image(image)
            if vision_embeddings.dtype == torch.float16:
                vision_embeddings = vision_embeddings.float()


        optimizer.zero_grad()
        output_logits = linear_probe(vision_embeddings)

        loss = F.cross_entropy(output_logits, label)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")


Epoch 1/10, Loss: 1.7709
Epoch 2/10, Loss: 0.8997
Epoch 3/10, Loss: 0.5078
Epoch 4/10, Loss: 0.3331
Epoch 5/10, Loss: 0.2411
Epoch 6/10, Loss: 0.1895
Epoch 7/10, Loss: 0.1553
Epoch 8/10, Loss: 0.1323
Epoch 9/10, Loss: 0.1152
Epoch 10/10, Loss: 0.1027


In [None]:
# Evaluate the linear probe on the test set
test_dataset = STL10(
    root="./stl10",
    download=True,
    split="test",
    transform=preprocess
)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

linear_probe.eval() 
model.eval() 

num_correct_test = 0
total_test_samples = 0

with torch.no_grad(): 
    for image, label in tqdm.tqdm(test_loader, desc="Evaluating on test set"):
        image = image.to(device)
        label = label.to(device)

        # Run open-clip to get vision embeddings
        vision_embeddings = model.encode_image(image)
        if vision_embeddings.dtype == torch.float16:
            vision_embeddings = vision_embeddings.float()
        
        output_logits = linear_probe(vision_embeddings)
        
        # Get predicted labels
        _, predicted_labels = torch.max(output_logits, 1)
        
        num_correct_test += (predicted_labels == label).sum().item()
        total_test_samples += label.size(0)

test_accuracy = 100. * num_correct_test / total_test_samples
print(f"Test Accuracy: {test_accuracy:.2f}%")

Evaluating on test set: 100%|██████████| 125/125 [01:25<00:00,  1.45it/s]

Test Accuracy: 98.42%





After training the linear probe, we evaluate it on the STL10 dataset, similar to before, and our accuracy is now 98.57!

Great! By using some labeled data, we were able to train a small logistic regression layer that improves the accuracy of OpenCLIP on the STL10 dataset by nearly +2%!

This improvement is likely because our text prompts, like "an airplane", might not perfectly match the labels as they appear in the STL10 dataset. But by seeing a few examples for each label, we can learn reference embeddings that more accurately represent the class labels.

## Training a student model to mimic OpenCLIP

We've now seen that using the large OpenCLIP model, we can achieve competetive results on the STL10 image classification dataset with little effort. But OpenCLIP is large, and is likely to have high memory consumption and latency compared to other model architectures. In addition, as a vision transformer model, OpenCLIP is less capable of exploiting the Deep Learning Accelerator (DLA) on Jetson AGX Orin, given the matrix multiplication in the attention layers. CNN models, like resnet18, on the other hand are highly optimized by both the GPU and DLA on Jetson, and allow us to run models at higher throughput and with less memory.

However, knowledge distillation can impact the accuracy of the model, so we'd like to better understand what factors are most important. To do this, we've run a few experiments that seek to answer a few questions:

How does distillation compare to training with ground truth labels?
How does the data distribution used for training impact model accuracy?
How does the distillation method impact model accuracy? Is it better to train on the class probabilities or internal features?
How does the student model architecture impact model accuracy? Will resnet50 obtain higher accuracy than resnet18?



In [1]:
from stl10_utils import (
    precompute_clip_stl10_train_image_embeddings,
    precompute_clip_stl10_test_image_embeddings,
    precompute_clip_stl10_text_embeddings,
    train_resnet18_from_scratch,
    train_resnet18_zero_shot_train_only,
    train_resnet18_linear_probe_train_only
)

precompute_clip_stl10_train_image_embeddings()
precompute_clip_stl10_test_image_embeddings()
precompute_clip_stl10_text_embeddings()
train_resnet18_from_scratch()
train_resnet18_zero_shot_train_only()
train_resnet18_linear_probe_train_only()

  from .autonotebook import tqdm as notebook_tqdm
5000it [00:00, 7526.99it/s]
8000it [00:00, 8100.05it/s]


Using device: mps for train_model_from_scratch




Epoch 1 - Test Loss: 1.7298, Test Accuracy: 34.62%
Epoch 2 - Test Loss: 1.5181, Test Accuracy: 44.05%
Epoch 3 - Test Loss: 1.4128, Test Accuracy: 47.83%
Epoch 4 - Test Loss: 1.2517, Test Accuracy: 53.14%
Epoch 5 - Test Loss: 1.3228, Test Accuracy: 51.08%
Epoch 6 - Test Loss: 1.2272, Test Accuracy: 55.34%
Epoch 7 - Test Loss: 1.3547, Test Accuracy: 53.99%
Epoch 8 - Test Loss: 1.3396, Test Accuracy: 55.64%
Epoch 9 - Test Loss: 1.4112, Test Accuracy: 54.89%
Epoch 10 - Test Loss: 1.4231, Test Accuracy: 56.30%
Finished Training
Using device: mps for train_student_classification_model


100%|██████████| 79/79 [00:48<00:00,  1.63it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]

| EPOCH 0 | TRAIN LOSS 0.00030684381184798863 | TEST ACC 19.837 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_0.pth



100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]

| EPOCH 1 | TRAIN LOSS 0.00012261361221538243 | TEST ACC 24.475 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_1.pth



100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.85it/s]


| EPOCH 2 | TRAIN LOSS 9.287601554931342e-05 | TEST ACC 28.387 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_2.pth


100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]

| EPOCH 3 | TRAIN LOSS 7.483997697944906e-05 | TEST ACC 33.388 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_3.pth



100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]

| EPOCH 4 | TRAIN LOSS 6.25165339903405e-05 | TEST ACC 29.837 |



100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.85it/s]

| EPOCH 5 | TRAIN LOSS 5.174225883866377e-05 | TEST ACC 32.325 |



100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]


| EPOCH 6 | TRAIN LOSS 4.2296258562842764e-05 | TEST ACC 38.35 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_6.pth


100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]


| EPOCH 7 | TRAIN LOSS 3.617563146966674e-05 | TEST ACC 40.2 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_7.pth


100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
100%|██████████| 125/125 [00:43<00:00,  2.85it/s]

| EPOCH 8 | TRAIN LOSS 3.096690251587012e-05 | TEST ACC 38.15 |



100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
100%|██████████| 125/125 [00:43<00:00,  2.86it/s]


| EPOCH 9 | TRAIN LOSS 2.5984479237459836e-05 | TEST ACC 42.788 |
Saving checkpoint to data/experiments/train_resnet18_zero_shot_train_only/checkpoint_9.pth


AssertionError: Torch not compiled with CUDA enabled

In [1]:
from stl10_utils import (
    train_resnet18_linear_probe_train_only
)
train_resnet18_linear_probe_train_only()

  from .autonotebook import tqdm as notebook_tqdm


Using device: mps for train_probe_model


100%|██████████| 79/79 [00:41<00:00,  1.90it/s]
100%|██████████| 125/125 [00:41<00:00,  3.00it/s] 

| EPOCH 0 | TRAIN LOSS 1.6864405704450003 | TEST ACC 96.7 |
Saving checkpoint to data/experiments/train_probe_model_linear/checkpoint_0.pth



100%|██████████| 79/79 [00:41<00:00,  1.92it/s]
100%|██████████| 125/125 [00:41<00:00,  3.02it/s] 

| EPOCH 1 | TRAIN LOSS 0.8551359893400458 | TEST ACC 97.825 |
Saving checkpoint to data/experiments/train_probe_model_linear/checkpoint_1.pth



100%|██████████| 79/79 [00:41<00:00,  1.92it/s]
100%|██████████| 125/125 [00:41<00:00,  2.99it/s] 

| EPOCH 2 | TRAIN LOSS 0.48272129327436036 | TEST ACC 98.125 |
Saving checkpoint to data/experiments/train_probe_model_linear/checkpoint_2.pth



100%|██████████| 79/79 [00:41<00:00,  1.91it/s]
100%|██████████| 125/125 [00:41<00:00,  3.01it/s] 

| EPOCH 3 | TRAIN LOSS 0.31865229542496837 | TEST ACC 98.237 |
Saving checkpoint to data/experiments/train_probe_model_linear/checkpoint_3.pth



100%|██████████| 79/79 [00:41<00:00,  1.92it/s]
100%|██████████| 125/125 [00:09<00:00, 12.67it/s] 


KeyboardInterrupt: 

In [None]:
from stl10 import (
    precompute_clip_stl10_train_image_embeddings,
    precompute_clip_stl10_unlabeled_image_embeddings,
    precompute_clip_stl10_test_image_embeddings,
    precompute_clip_stl10_text_embeddings,
    train_resnet18_zero_shot,
    train_resnet18_linear_probe
)

precompute_clip_stl10_train_image_embeddings()
precompute_clip_stl10_unlabeled_image_embeddings()
precompute_clip_stl10_test_image_embeddings()
precompute_clip_stl10_text_embeddings()
train_resnet18_zero_shot()
train_resnet18_linear_probe()