In [5]:
import os
import math
import random
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets

from transformers import (
    ViTForImageClassification,
    AutoImageProcessor,
    get_cosine_schedule_with_warmup,
)



In [6]:
class SafeImageFolder(datasets.ImageFolder):
    """SafeImageFolder that ignores hidden classes like `.ipynb_checkpoints`."""
    def find_classes(self, directory: str):
        classes, _ = super().find_classes(directory)
        # drop any class directory that starts with a dot
        classes = [c for c in classes if not c.startswith('.')]
        classes = sorted(classes)
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

In [2]:
# -----------------------------
# Config
# -----------------------------
DATA_ROOT = "/data/imagenet100-224/train"  # class-subfolders inside here
MODEL_NAME = "google/vit-base-patch16-224"
OUTPUT_DIR = "./vit-imagenet100-head-only"
VAL_SPLIT = 0.1
SEED = 42

EPOCHS = 10
BATCH_SIZE = 32            # adjust to your GPU memory
LR = 3e-4                  # a bit higher since we're training only the head
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.05
GRAD_ACCUM_STEPS = 1       # increase if you need effective larger batch
AMP = True                 # mixed precision

NUM_WORKERS = min(8, os.cpu_count() or 2)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [7]:
# -----------------------------
# Data
# -----------------------------
# We’ll use torchvision’s ImageFolder and apply the ViT processor in the collate_fn.
print("Indexing dataset…")
full_ds = SafeImageFolder(DATA_ROOT)
num_classes = len(full_ds.classes)

# Build label maps from folder names
label2id = {cls_name: i for cls_name, i in full_ds.class_to_idx.items()}
id2label = {i: cls_name for cls_name, i in full_ds.class_to_idx.items()}

# Train/val split from the single folder
random.seed(SEED)
N = len(full_ds)
n_val = int(N * VAL_SPLIT)
n_train = N - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val], generator=torch.Generator().manual_seed(SEED))
print(f"Train: {n_train:,} | Val: {n_val:,} | Classes: {num_classes}")

# Image processor handles resize, to-tensor, normalization for ViT
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

def collate_fn(batch):
    # batch: list of (PIL_image, label)
    images, labels = zip(*batch)
    enc = processor(images=list(images), return_tensors="pt")
    enc["labels"] = torch.tensor(labels, dtype=torch.long)
    return enc

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)



Indexing dataset…
Train: 117,000 | Val: 13,000 | Classes: 100


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [8]:
# -----------------------------
# Model: head-only fine-tuning
# -----------------------------
print("Loading model…")
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_classes,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # replaces the classification head if shape differs
)
model.to(device)

# Freeze everything except the classifier head
for p in model.parameters():
    p.requires_grad = False
for p in model.classifier.parameters():
    p.requires_grad = True

# Verify which params train
trainable_params = [n for n, p in model.named_parameters() if p.requires_grad]
print("Trainable params:", trainable_params)

# Optimizer/scheduler
head_params = (p for p in model.classifier.parameters() if p.requires_grad)
optimizer = torch.optim.AdamW(head_params, lr=LR, weight_decay=WEIGHT_DECAY)

num_update_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM_STEPS)
max_train_steps = EPOCHS * num_update_steps_per_epoch
num_warmup_steps = int(WARMUP_RATIO * max_train_steps)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=max_train_steps
)

scaler = torch.cuda.amp.GradScaler(enabled=AMP)

# -----------------------------
# Training / Evaluation loops
# -----------------------------
def evaluate(model, loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            outputs = model(pixel_values=batch["pixel_values"])
            loss = loss_fn(outputs.logits, batch["labels"])
            loss_sum += loss.item() * batch["labels"].size(0)

            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == batch["labels"]).sum().item()
            total += batch["labels"].size(0)
    avg_loss = loss_sum / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc

os.makedirs(OUTPUT_DIR, exist_ok=True)
best_val_acc = 0.0

global_step = 0
for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad(set_to_none=True)

    for step, batch in enumerate(train_loader, start=1):
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        with torch.cuda.amp.autocast(enabled=AMP):
            outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
            loss = outputs.loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if step % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            lr_scheduler.step()
            global_step += 1

        running_loss += loss.item() * GRAD_ACCUM_STEPS

        if step % 50 == 0 or step == len(train_loader):
            current_lr = optimizer.param_groups[0]["lr"]
            print(f"Epoch {epoch} | Step {step}/{len(train_loader)} | "
                  f"Loss {running_loss/step:.4f} | LR {current_lr:.6f}")

    # --- Validation ---
    val_loss, val_acc = evaluate(model, val_loader)
    print(f"Epoch {epoch} done -> Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        save_path = os.path.join(OUTPUT_DIR, "best")
        model.save_pretrained(save_path)
        processor.save_pretrained(save_path)
        # Also save just the head weights if you want:
        torch.save(model.classifier.state_dict(), os.path.join(OUTPUT_DIR, "classifier_head.pt"))
        print(f"New best ({best_val_acc*100:.2f}%). Saved to {save_path}")

# Final save
model.save_pretrained(os.path.join(OUTPUT_DIR, "last"))
processor.save_pretrained(os.path.join(OUTPUT_DIR, "last"))

print("Training complete.")
print(f"Best Val Acc: {best_val_acc*100:.2f}%")


Loading model…


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable params: ['classifier.weight', 'classifier.bias']


  scaler = torch.cuda.amp.GradScaler(enabled=AMP)
  with torch.cuda.amp.autocast(enabled=AMP):


Epoch 1 | Step 50/3657 | Loss 4.7320 | LR 0.000008
Epoch 1 | Step 100/3657 | Loss 4.6889 | LR 0.000016
Epoch 1 | Step 150/3657 | Loss 4.6406 | LR 0.000025
Epoch 1 | Step 200/3657 | Loss 4.5647 | LR 0.000033
Epoch 1 | Step 250/3657 | Loss 4.4698 | LR 0.000041
Epoch 1 | Step 300/3657 | Loss 4.3560 | LR 0.000049
Epoch 1 | Step 350/3657 | Loss 4.2217 | LR 0.000057
Epoch 1 | Step 400/3657 | Loss 4.0700 | LR 0.000066
Epoch 1 | Step 450/3657 | Loss 3.9054 | LR 0.000074
Epoch 1 | Step 500/3657 | Loss 3.7347 | LR 0.000082
Epoch 1 | Step 550/3657 | Loss 3.5598 | LR 0.000090
Epoch 1 | Step 600/3657 | Loss 3.3842 | LR 0.000098
Epoch 1 | Step 650/3657 | Loss 3.2119 | LR 0.000107
Epoch 1 | Step 700/3657 | Loss 3.0527 | LR 0.000115
Epoch 1 | Step 750/3657 | Loss 2.9050 | LR 0.000123
Epoch 1 | Step 800/3657 | Loss 2.7702 | LR 0.000131
Epoch 1 | Step 850/3657 | Loss 2.6463 | LR 0.000139
Epoch 1 | Step 900/3657 | Loss 2.5305 | LR 0.000148
Epoch 1 | Step 950/3657 | Loss 2.4249 | LR 0.000156
Epoch 1 | Ste