In [1]:
# ViT multi-class (9) classification
# Dataset: PathMNIST MedMNIST (Yang et al., 2021)
# Model: Transformer ViTForImageClassification

!python -c "import medmnist" || pip install -q medmnist
!python -c "import evaluate" || pip install -q evaluate --no-deps

import torch
from torchvision import transforms
from torch.utils.data import Subset
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import ViTConfig, ViTForImageClassification, TrainingArguments, Trainer

import numpy as np
from PIL import Image
from medmnist import PathMNIST
from sklearn.metrics import accuracy_score, f1_score

# greyscale transform
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

# load datasets
train_dataset = PathMNIST(split="train", download=True, transform=transform)
val_dataset = PathMNIST(split="val", download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=64)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
ModuleNotFoundError: No module named 'medmnist'
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m81.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m62.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m37.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━

2025-09-29 20:11:39.551097: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759176699.747032      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759176699.800834      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
100%|██████████| 206M/206M [00:07<00:00, 29.2MB/s]


In [2]:
# initialize model
config = ViTConfig(
    image_size=28,
    patch_size=7,
    num_channels=1,
    hidden_size=128,
    num_hidden_layers=4,
    num_attention_heads=4,
    num_labels=9
)
model = ViTForImageClassification(config)
model = model.to(device)

In [3]:
# training using standard transformer pipeline
def collate_fn(batch):
    images, labels = zip(*batch)
    return {
        "pixel_values": torch.stack(images),
        "labels": torch.tensor(labels).squeeze().long()
    }

training_args = TrainingArguments(
    output_dir="./model",
    save_strategy="no",
    learning_rate=3e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    num_train_epochs=10,
    report_to="none", # disable wandb for kaggle
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)
trainer.train()

model.save_pretrained("./model")

  "labels": torch.tensor(labels).squeeze().long()


Step,Training Loss
500,1.7942
1000,1.4905
1500,1.4069
2000,1.3696
2500,1.3428
3000,1.3125
3500,1.2824
4000,1.2618
4500,1.2381
5000,1.2065


In [4]:
import evaluate
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "f1": f1.compute(predictions=preds, references=labels, average="weighted")["f1"],
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)
trainer.evaluate()

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

{'eval_loss': 0.9395607113838196,
 'eval_model_preparation_time': 0.0013,
 'eval_accuracy': 0.6499400239904038,
 'eval_f1': 0.6425615748102118,
 'eval_runtime': 2.0832,
 'eval_samples_per_second': 4802.291,
 'eval_steps_per_second': 75.366}