In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from transformers import AutoImageProcessor, ViTModel
from torchvision import datasets, transforms
from PIL import Image
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
model_name = "google/vit-base-patch16-224"
image_processor = AutoImageProcessor.from_pretrained(model_name, do_rescale=False)
model = ViTModel.from_pretrained(model_name)
model = model.to(device)
model.eval()
for p in model.parameters():
    p.requires_grad = False

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_dataset = Subset(datasets.CIFAR10(root="./data", train=True, download=True, transform=transform), range(10000))
test_dataset  = Subset(datasets.CIFAR10(root="./data", train=False, download=True, transform=transform), range(2000))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
def get_representations(dataloader):
    cls_list, mean_list, labels_list = [], [], []
    with torch.no_grad():
        for images, labels in dataloader:
            pixel_values = image_processor(images=list(images), return_tensors="pt")["pixel_values"].to(device)

            outputs = model(pixel_values)
            hidden = outputs.last_hidden_state  # [B, num_tokens, hidden_dim]

            cls_repr = hidden[:, 0, :].cpu().numpy()
            mean_repr = hidden[:, 1:, :].mean(1).cpu().numpy()

            cls_list.append(cls_repr)
            mean_list.append(mean_repr)
            labels_list.append(labels.numpy())

    return (np.vstack(cls_list),
            np.vstack(mean_list),
            np.hstack(labels_list))

In [5]:
print("Extracting train features...")
train_cls, train_mean, y_train = get_representations(train_loader)

print("Extracting test features...")
test_cls, test_mean, y_test = get_representations(test_loader)

Extracting train features...
Extracting test features...


In [6]:
clf_cls = LogisticRegression(max_iter=2000).fit(train_cls, y_train)
clf_mean = LogisticRegression(max_iter=2000).fit(train_mean, y_train)

acc_cls = accuracy_score(y_test, clf_cls.predict(test_cls))
acc_mean = accuracy_score(y_test, clf_mean.predict(test_mean))

print(f"CLS linear probe accuracy:  {acc_cls:.4f}")
print(f"Mean linear probe accuracy: {acc_mean:.4f}")

CLS linear probe accuracy:  0.9670
Mean linear probe accuracy: 0.9710
