In [1]:
import numpy as np
from tqdm.auto import tqdm
import torch
from sklearn.metrics import accuracy_score, classification_report, f1_score
from transformers import ViTForImageClassification, ViTImageProcessor
from datasets import load_dataset

In [2]:
cifar_data = load_dataset("uoft-cs/cifar10")

In [None]:
NUM_CLASSES = 10 

baseline_model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)
baseline_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline_model.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

def preprocess_function(examples):
    images = [img.convert("RGB") for img in examples["img"]]
    inputs = processor(images, return_tensors="pt")
    inputs["labels"] = examples["label"]
    
    return inputs

tokenized_cifar = cifar_data.map(
    preprocess_function, 
    batched=True, 
    num_proc=32 #128 cpu count
)

tokenized_cifar = tokenized_cifar.rename_column("label", "labels")

In [None]:
all_preds = []
all_labels = []

for batch in tqdm(test_dataloader, desc="Calculating Baseline Metrics"):
    pixel_values = batch['pixel_values'].to(device)
    
    labels = batch['labels'] 
    
    with torch.no_grad():
        outputs = teacher_model(pixel_values)
        logits = outputs.logits

    predictions = torch.argmax(logits, dim=-1).cpu().numpy()
    
    all_preds.extend(predictions)
    all_labels.extend(labels.numpy())

y_true = np.array(all_labels)
y_pred = np.array(all_preds)

baseline_accuracy = accuracy_score(y_true, y_pred)
print(f"Baseline Accuracy: {baseline_accuracy:.4f}")

baseline_f1_weighted = f1_score(y_true, y_pred, average='weighted')
print(f"Baseline F1-Score (Weighted): {baseline_f1_weighted:.4f}")

target_names = [f'Class {i}' for i in range(10)]

report = classification_report(
    y_true, 
    y_pred, 
    target_names=target_names,
    zero_division=0 
)
print("\nClassification Report:\n", report)