In [None]:
!pip install transformers datasets evaluate torch torchvision -q

import torch
from datasets import load_dataset
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
import evaluate
import numpy as np

# 1. Load dataset (CIFAR-10)
dataset = load_dataset("cifar10")

# 2. Preprocessing
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

def transform(example_batch):
    # The processor expects a list of images, so we wrap the single image in a list
    inputs = processor([x for x in example_batch['img']], return_tensors="pt")
    inputs['labels'] = example_batch['label']
    return inputs

# Apply the transform directly to the dataset
processed_dataset = dataset.map(transform, batched=True)


def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

# 3. Load pre-trained ViT
id2label = {str(i): name for i, name in enumerate(dataset['train'].features['label'].names)}
label2id = {name: str(i) for i, name in enumerate(dataset['train'].features['label'].names)}

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=10,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True


)

# 4. Training setup
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return accuracy.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(
    output_dir="./vit-cifar10",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    push_to_hub=False,
    report_to="none"  # Disable wandb logging

)

# 5. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["test"],
    tokenizer=processor,
    data_collator=collate_fn,
    compute_metrics=compute_metrics
)

# 6. Train model
trainer.train()

# 7. Evaluate
results = trainer.evaluate()
print("Test Accuracy:", results["eval_accuracy"])

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
import matplotlib.pyplot as plt

# Get a few test images
test_batch = next(iter(prepared_ds["test"].with_format("torch").batch(5)))

# Make predictions
with torch.no_grad():
    outputs = model(test_batch["pixel_values"])
    preds = outputs.logits.argmax(-1)

# Show images + predicted labels
for i in range(5):
    plt.imshow(test_batch["pixel_values"][i].permute(1,2,0))
    plt.title(f"Predicted: {id2label[str(preds[i].item())]} | True: {id2label[str(test_batch['labels'][i].item())]}")
    plt.axis("off")
    plt.show()
