# Vision Transformer (ViT) — Rice Diseases Classification

**Mục tiêu**: So sánh **công bằng** giữa ViT và Graphormer trên cùng một bộ dữ liệu bệnh lúa.  
Pipeline dataset (split, augmentation, seed) **giống hệt** `train_graphormer_simple.py`.

---
### Điều kiện so sánh công bằng đã đảm bảo
| Yếu tố | Graphormer | ViT (notebook này) |
|---|---|---|
| Splits | 70/15/15 stratified, seed=42 | **Giống hệt** (dùng cùng file `split_indices.pt`) |
| Augmentation | rotate90/180/270, flipH/V, zoom in/out | **Giống hệt** (áp dụng torchvision tương đương) |
| Optimizer | AdamW, lr=3e-4, wd=0.01 | **Giống hệt** |
| Scheduler | Polynomial decay + 1000 warmup | **Giống hệt** |
| Mixed precision | FP16 (`--fp16`) | AMP FP16 |
| Gradient clipping | 5.0 | **Giống hệt** |
| Seed | 42 | **Giống hệt** |
| Classes | 4 (BrownSpot/Healthy/Hispa/LeafBlast) | **Giống hệt** |

## 0. Cài đặt thư viện

In [None]:
# Chạy cell này một lần duy nhất trên Colab
!pip install timm -q

## 1. Siêu tham số (tùy chỉnh tại đây)

In [None]:
# ==========================================================================
#  CẤU HÌNH — Chỉnh các giá trị ở đây trước khi chạy
# ==========================================================================

# ── Đường dẫn dữ liệu ──────────────────────────────────────────────────────
# Trỏ đến thư mục chứa ảnh gốc (cùng cấu trúc mà process_images.py đọc).
# Ví dụ Colab: '/content/drive/MyDrive/rice_diseases_data'
IMAGE_DIR = '/content/rice_diseases_data'   # ← thay đổi nếu cần

# Thư mục chứa split_indices.pt và metadata.json đã được Graphormer tạo sẵn.
# Nếu không có (chưa từng chạy process_images.py), để None → tự tạo split mới.
GRAPHORMER_PROCESSED_DIR = '/content/Graphormer/examples/rice_diseases/rice_diseases_graphs/processed'

# ── Mô hình ViT ────────────────────────────────────────────────────────────
# 'vit_small_patch16_224'  → ~22M params  ← đề xuất cho Colab T4
# 'vit_tiny_patch16_224'   → ~5.7M params ← nhẹ hơn nếu RAM hạn chế
# 'vit_base_patch16_224'   → ~86M params  ← mạnh hơn, cần nhiều VRAM hơn
VIT_MODEL_NAME = 'vit_small_patch16_224'   # ← thay đổi nếu muốn

# Pretrained trên ImageNet?
# True  = fine-tune (thường tốt hơn, nhưng có thể không "công bằng" nếu Graphormer train từ đầu)
# False = train from scratch (công bằng nhất với Graphormer không dùng pretrained)
USE_PRETRAINED = False                     # ← thay đổi theo mục tiêu paper

# ── Training ───────────────────────────────────────────────────────────────
EPOCHS      = 50          # Graphormer simple: 50 epochs
BATCH_SIZE  = 32          # Graphormer simple: 32
LR          = 3e-4        # === Graphormer: 3e-4
WEIGHT_DECAY= 1e-2        # === Graphormer: 0.01
CLIP_NORM   = 5.0         # === Graphormer: --clip-norm 5.0
WARMUP_STEPS= 1000        # === Graphormer: --warmup-updates 1000
NUM_WORKERS = 2           # Số worker cho DataLoader
SEED        = 42          # === Graphormer seed

# ── Ảnh đầu vào ────────────────────────────────────────────────────────────
IMG_SIZE    = 224         # ViT patch16 cần 224×224

# ── Lưu checkpoint ─────────────────────────────────────────────────────────
SAVE_DIR    = './ckpts_vit_rice'

# ── Augmentation (giống process_images.py của Graphormer) ────────────────
USE_AUGMENTATION = True   # Bật augmentation như Graphormer

# ── Mixed Precision ────────────────────────────────────────────────────────
USE_AMP     = True        # === Graphormer: --fp16

# ==========================================================================
print("Configuration loaded.")
print(f"  Model         : {VIT_MODEL_NAME}")
print(f"  Pretrained    : {USE_PRETRAINED}")
print(f"  Epochs        : {EPOCHS}")
print(f"  Batch size    : {BATCH_SIZE}")
print(f"  LR            : {LR}")
print(f"  Weight decay  : {WEIGHT_DECAY}")
print(f"  AMP (FP16)    : {USE_AMP}")
print(f"  Augmentation  : {USE_AUGMENTATION}")
print(f"  Seed          : {SEED}")

## 2. Imports & Reproducibility

In [None]:
import os
import json
import time
import random
import warnings
from pathlib import Path
from collections import Counter

import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as T
import timm

warnings.filterwarnings('ignore')


def set_seed(seed: int):
    """Fix all random seeds for full reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device : {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU    : {torch.cuda.get_device_name(0)}")
    print(f"VRAM   : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 3. Dataset — Cùng split và augmentation với Graphormer

In [None]:
# ============================================================
# 3.1  Tìm ảnh và đọc split indices (giống hệt Graphormer)
# ============================================================

CLASS_NAMES = ['BrownSpot', 'Healthy', 'Hispa', 'LeafBlast']
CLASS_TO_IDX = {n: i for i, n in enumerate(CLASS_NAMES)}


def find_images(image_dir: str) -> dict:
    """
    Tìm ảnh theo đúng cấu trúc thư mục của process_images.py.
    Hỗ trợ cả2 cấu trúc: LabelledRice/Labelled và RiceDiseaseDataset/.
    """
    data_dict = {c: [] for c in CLASS_NAMES}
    base = Path(image_dir)

    labelled_dir = base / 'LabelledRice' / 'Labelled'
    if not labelled_dir.exists():
        labelled_dir = base / 'Labelled'

    if labelled_dir.exists():
        for cls in CLASS_NAMES:
            cls_dir = labelled_dir / cls
            if cls_dir.exists():
                imgs = (list(cls_dir.glob('*.jpg')) +
                        list(cls_dir.glob('*.jpeg')) +
                        list(cls_dir.glob('*.png')))
                data_dict[cls] = [str(p) for p in imgs]
    else:
        for split_name in ['train', 'validation']:
            split_dir = base / 'RiceDiseaseDataset' / split_name
            if split_dir.exists():
                for cls in CLASS_NAMES:
                    cls_dir = split_dir / cls
                    if cls_dir.exists():
                        imgs = (list(cls_dir.glob('*.jpg')) +
                                list(cls_dir.glob('*.jpeg')) +
                                list(cls_dir.glob('*.png')))
                        data_dict[cls].extend([str(p) for p in imgs])
    return data_dict


# Tìm ảnh
data_dict = find_images(IMAGE_DIR)

# Tạo danh sách phẳng (path, label)
all_paths, all_labels = [], []
for cls in CLASS_NAMES:
    for p in data_dict[cls]:
        all_paths.append(p)
        all_labels.append(CLASS_TO_IDX[cls])

# Sắp xếp để đảm bảo thứ tự ổn định
all_paths  = np.array(all_paths)
all_labels = np.array(all_labels)

print(f"Total images found: {len(all_paths)}")
for cls in CLASS_NAMES:
    n = (all_labels == CLASS_TO_IDX[cls]).sum()
    print(f"  {cls:12s}: {n}")

In [None]:
# ============================================================
# 3.2  Tạo split — ưu tiên đọc file split_indices.pt của Graphormer
#       để đảm bảo CÙNG TẬP TRAIN/VAL/TEST
# ============================================================

split_file = Path(GRAPHORMER_PROCESSED_DIR) / 'split_indices.pt' if GRAPHORMER_PROCESSED_DIR else None

def create_stratified_splits(n, labels, seed=42):
    """Tạo split 70/15/15 stratified — giống hệt rice_diseases_dataset.py."""
    indices = np.arange(n)
    train_idx, temp_idx = train_test_split(
        indices, test_size=0.3, stratify=labels, random_state=seed
    )
    temp_labels = labels[temp_idx]
    val_idx, test_idx = train_test_split(
        temp_idx, test_size=0.5, stratify=temp_labels, random_state=seed
    )
    return (
        torch.from_numpy(train_idx),
        torch.from_numpy(val_idx),
        torch.from_numpy(test_idx)
    )


if split_file and split_file.exists():
    # ── TRƯỜNG HỢP 1: Dùng split sẵn có của Graphormer ──────────────────
    # Lưu ý: split_indices.pt của Graphormer được tạo trên danh sách ảnh
    # được xử lý theo thứ tự CLASS_NAMES -> file index trong process_images.py.
    # Để mapping chính xác, ta cần đọc metadata.json để biết order.
    print(f"Found Graphormer split file: {split_file}")

    meta_file = Path(GRAPHORMER_PROCESSED_DIR) / 'metadata.json'
    if meta_file.exists():
        with open(meta_file) as f:
            meta = json.load(f)
        # image_paths trong metadata đúng thứ tự tạo .pt files
        graphormer_paths = np.array(meta['image_paths'])
        graphormer_labels = np.array(meta['labels'])

        splits = torch.load(split_file)
        train_idx = splits['train_idx']
        val_idx   = splits['val_idx']
        test_idx  = splits['test_idx']

        # Mapping từ Graphormer indices → ViT dataset
        # ViT dùng graphormer_paths/labels trực tiếp
        vit_paths  = graphormer_paths
        vit_labels = graphormer_labels
        print("  ✓ Using Graphormer's exact image ordering and splits (strongest guarantee).")

    else:
        print("  WARNING: metadata.json not found. Cannot map Graphormer splits to image paths.")
        print("  Falling back to fresh stratified split (same algorithm, may differ in exact samples).")
        vit_paths, vit_labels = all_paths, all_labels
        train_idx, val_idx, test_idx = create_stratified_splits(len(vit_paths), vit_labels, SEED)

else:
    # ── TRƯỜNG HỢP 2: Tạo split mới với cùng thuật toán ─────────────────
    print("No Graphormer split file found. Creating fresh 70/15/15 stratified split.")
    print("Algorithm is identical to rice_diseases_dataset.py.")
    vit_paths, vit_labels = all_paths, all_labels
    train_idx, val_idx, test_idx = create_stratified_splits(len(vit_paths), vit_labels, SEED)


print(f"\nSplit sizes:")
print(f"  Train : {len(train_idx)}")
print(f"  Val   : {len(val_idx)}")
print(f"  Test  : {len(test_idx)}")

# Per-class distribution in train set
train_labels_arr = vit_labels[train_idx.numpy()]
print("\nTrain class distribution:")
for cls in CLASS_NAMES:
    n = (train_labels_arr == CLASS_TO_IDX[cls]).sum()
    print(f"  {cls:12s}: {n}")

In [None]:
# ============================================================
# 3.3  Transforms — Augmentation giống hệt process_images.py
# ============================================================

# Graphormer augmentations: rotate_90, rotate_180, rotate_270,
#   flip_horizontal, flip_vertical, zoom_in (crop 80%), zoom_out (pad)
# → tương đương với torchvision RandomApply / RandomRotation

IMAGENET_MEAN = [0.485, 0.456, 0.406]   # ImageNet stats
IMAGENET_STD  = [0.229, 0.224, 0.225]


def build_train_transform(img_size: int, augment: bool) -> T.Compose:
    ops = [T.Resize((img_size, img_size))]

    if augment:
        ops += [
            # Rotate 0/90/180/270
            T.RandomApply([T.RandomRotation(degrees=(90, 90))],  p=0.33),
            T.RandomApply([T.RandomRotation(degrees=(180, 180))], p=0.33),
            T.RandomApply([T.RandomRotation(degrees=(270, 270))], p=0.33),
            # Flip H / V
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            # Zoom-in: random crop (80% of area) → resize back (≈ zoom-in)
            T.RandomApply([T.RandomResizedCrop(img_size, scale=(0.64, 0.80))], p=0.5),
            # Color jitter (optional — Graphormer doesn't use, so disabled by default)
            # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        ]

    ops += [
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
    return T.Compose(ops)


def build_eval_transform(img_size: int) -> T.Compose:
    return T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])


train_transform = build_train_transform(IMG_SIZE, USE_AUGMENTATION)
eval_transform  = build_eval_transform(IMG_SIZE)

print("Train transform:")
print(train_transform)
print("\nEval transform:")
print(eval_transform)

In [None]:
# ============================================================
# 3.4  Dataset class
# ============================================================

class RiceDiseaseImageDataset(Dataset):
    """
    Loads raw images from disk and applies torchvision transforms.
    Uses the same train/val/test indices as the Graphormer pipeline.
    """

    def __init__(self, paths: np.ndarray, labels: np.ndarray, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = int(self.labels[idx])
        return img, label


# Build full dataset with train transform (will slice per-split)
train_dataset = RiceDiseaseImageDataset(
    vit_paths[train_idx.numpy()], vit_labels[train_idx.numpy()], transform=train_transform
)
val_dataset = RiceDiseaseImageDataset(
    vit_paths[val_idx.numpy()], vit_labels[val_idx.numpy()], transform=eval_transform
)
test_dataset = RiceDiseaseImageDataset(
    vit_paths[test_idx.numpy()], vit_labels[test_idx.numpy()], transform=eval_transform
)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True, drop_last=False
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | Test batches: {len(test_loader)}")

## 4. Model — ViT via `timm`

In [None]:
# ============================================================
#  Tại sao chọn vit_small_patch16_224?
# ============================================================
#
#  Graphormer slim (fairseq): 12 layers, embed=256, ffn=512, heads=8
#   → ước tính ~6–10M param
#
#  ViT-Small patch16:  12 layers, embed=384, heads=6  → ~22M params (pretrain=False)
#  ViT-Tiny  patch16:  12 layers, embed=192, heads=3  →  ~5.7M params
#
#  Đề xuất:
#  • USE_PRETRAINED=False, VIT_MODEL_NAME='vit_small_patch16_224'
#    → công bằng nhất (cùng kiến trúc từ scratch)
#  • Nếu muốn thể hiện sức mạnh transfer learning:
#    USE_PRETRAINED=True và ghi rõ trong paper

model = timm.create_model(
    VIT_MODEL_NAME,
    pretrained=USE_PRETRAINED,
    num_classes=len(CLASS_NAMES),
    img_size=IMG_SIZE,
)
model = model.to(DEVICE)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model      : {VIT_MODEL_NAME}")
print(f"Pretrained : {USE_PRETRAINED}")
print(f"Parameters : {n_params:,}")
print(f"Num classes: {len(CLASS_NAMES)}: {CLASS_NAMES}")

## 5. Optimizer & Scheduler — Giống hệt Graphormer

In [None]:
# AdamW — giống Graphormer: adam-betas=(0.9,0.999), adam-eps=1e-8
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=WEIGHT_DECAY,
)

# Polynomial decay with linear warmup
# Graphormer: --lr-scheduler polynomial_decay --power 1
#             --warmup-updates 1000 --total-num-update 50000
total_steps  = EPOCHS * len(train_loader)
warmup_steps = min(WARMUP_STEPS, total_steps // 10)


def lr_lambda(current_step: int):
    """Linear warmup then polynomial (linear) decay — mirrors Graphormer."""
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    progress = (current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return max(0.0, 1.0 - progress)   # power=1 → linear decay


scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# AMP scaler
scaler = GradScaler(enabled=USE_AMP)

print(f"Optimizer    : AdamW(lr={LR}, wd={WEIGHT_DECAY}, betas=(0.9, 0.999))")
print(f"Scheduler    : Polynomial decay (linear), warmup={warmup_steps} steps")
print(f"Total steps  : {total_steps}")
print(f"AMP enabled  : {USE_AMP}")

## 6. Training & Evaluation Functions

In [None]:
def train_one_epoch(model, loader, optimizer, scheduler, scaler, device):
    model.train()
    total_loss = total_correct = total = 0

    for imgs, labels in loader:
        imgs   = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=USE_AMP):
            logits = model(imgs)                      # (B, num_classes)
            loss   = F.cross_entropy(logits, labels)

        scaler.scale(loss).backward()
        # Gradient clipping — giống Graphormer: --clip-norm 5.0
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        preds = logits.argmax(dim=-1)
        total_loss    += loss.item() * labels.size(0)
        total_correct += (preds == labels).sum().item()
        total         += labels.size(0)

    return total_loss / total, total_correct / total


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = total_correct = total = 0

    for imgs, labels in loader:
        imgs   = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        with autocast(enabled=USE_AMP):
            logits = model(imgs)
            loss   = F.cross_entropy(logits, labels)

        preds = logits.argmax(dim=-1)
        total_loss    += loss.item() * labels.size(0)
        total_correct += (preds == labels).sum().item()
        total         += labels.size(0)

    return total_loss / total, total_correct / total


print("Training functions ready.")

## 7. Training Loop

In [None]:
import os
os.makedirs(SAVE_DIR, exist_ok=True)

best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}

header = f"{'Epoch':>6} {'TrainLoss':>10} {'TrainAcc':>9} {'ValLoss':>9} {'ValAcc':>8} {'LR':>10} {'Time':>7}"
sep    = '─' * 66
print(sep)
print(header)
print(sep)

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, scheduler, scaler, DEVICE)
    val_loss, val_acc = evaluate(model, val_loader, DEVICE)

    current_lr = scheduler.get_last_lr()[0]
    elapsed = time.time() - t0

    # Log
    history['train_loss'].append(tr_loss)
    history['train_acc'].append(tr_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)

    print(f"{epoch:6d}  {tr_loss:10.4f}  {tr_acc*100:8.2f}%  {val_loss:9.4f}  {val_acc*100:7.2f}%  "
          f"{current_lr:.2e}  {elapsed:5.1f}s")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        ckpt = {
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_acc': val_acc,
            'config': {
                'model_name': VIT_MODEL_NAME,
                'pretrained': USE_PRETRAINED,
                'seed': SEED,
                'epochs': EPOCHS,
                'batch_size': BATCH_SIZE,
                'lr': LR,
                'weight_decay': WEIGHT_DECAY,
                'augmentation': USE_AUGMENTATION,
                'amp': USE_AMP,
            }
        }
        torch.save(ckpt, os.path.join(SAVE_DIR, 'best_model.pt'))
        print(f"         ✓ Best checkpoint saved (val_acc={val_acc*100:.2f}%)")

print(sep)
print(f"Training complete. Best val acc: {best_val_acc*100:.2f}%")

## 8. Final Test Evaluation

In [None]:
# Load best checkpoint
best_path = os.path.join(SAVE_DIR, 'best_model.pt')
if os.path.exists(best_path):
    ckpt = torch.load(best_path, map_location=DEVICE)
    model.load_state_dict(ckpt['model_state'])
    print(f"Loaded best model from epoch {ckpt['epoch']} (val_acc={ckpt['val_acc']*100:.2f}%)")

test_loss, test_acc = evaluate(model, test_loader, DEVICE)

print(f"\n{'='*50}")
print(f"  FINAL TEST RESULTS")
print(f"{'='*50}")
print(f"  Test Loss    : {test_loss:.4f}")
print(f"  Test Acc     : {test_acc*100:.2f}%")
print(f"  Best Val Acc : {best_val_acc*100:.2f}%")

# Per-class accuracy & confusion
model.eval()
all_preds, all_targets = [], []

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(DEVICE, non_blocking=True)
        with autocast(enabled=USE_AMP):
            logits = model(imgs)
        preds = logits.argmax(dim=-1).cpu()
        all_preds.extend(preds.tolist())
        all_targets.extend(labels.tolist())

all_preds   = np.array(all_preds)
all_targets = np.array(all_targets)

print(f"\n  Per-class accuracy:")
print(f"  {'Class':15s}  {'Correct':>7}  {'Total':>7}  {'Acc':>8}")
print(f"  {'-'*42}")
for c, cls_name in enumerate(CLASS_NAMES):
    mask = all_targets == c
    correct = (all_preds[mask] == c).sum()
    total_c = mask.sum()
    acc_c   = correct / total_c * 100 if total_c > 0 else 0.0
    print(f"  {cls_name:15s}  {correct:7d}  {total_c:7d}  {acc_c:7.2f}%")
print(f"{'='*50}")

## 9. Confusion Matrix

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

cm = confusion_matrix(all_targets, all_preds)
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'ViT Confusion Matrix — Test Set\n({VIT_MODEL_NAME}, pretrained={USE_PRETRAINED})')
plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'confusion_matrix.png'), dpi=150)
plt.show()

print("\nClassification Report:")
print(classification_report(all_targets, all_preds, target_names=CLASS_NAMES, digits=4))

## 10. Learning Curves

In [None]:
epochs_range = range(1, EPOCHS + 1)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(epochs_range, history['train_loss'], label='Train', color='steelblue')
axes[0].plot(epochs_range, history['val_loss'],   label='Val',   color='orangered')
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Cross-Entropy Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(epochs_range, [a*100 for a in history['train_acc']], label='Train', color='steelblue')
axes[1].plot(epochs_range, [a*100 for a in history['val_acc']],   label='Val',   color='orangered')
axes[1].axhline(y=best_val_acc*100, color='green', linestyle='--', alpha=0.7, label=f'Best Val={best_val_acc*100:.1f}%')
axes[1].set_title('Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(epochs_range, history['lr'], color='purple')
axes[2].set_title('Learning Rate (per epoch)')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('LR')
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'ViT Training — {VIT_MODEL_NAME} | pretrained={USE_PRETRAINED} | aug={USE_AUGMENTATION}', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'learning_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

## 11. Summary — Thông tin cho Paper

In [None]:
print("=" * 60)
print("  EXPERIMENT SUMMARY (for paper)")
print("=" * 60)
print(f"  Model             : {VIT_MODEL_NAME}")
print(f"  Pretrained        : {USE_PRETRAINED}")
print(f"  Parameters        : {n_params:,}")
print(f"  Image size        : {IMG_SIZE}×{IMG_SIZE}")
print(f"  Batch size        : {BATCH_SIZE}")
print(f"  Learning rate     : {LR}")
print(f"  Weight decay      : {WEIGHT_DECAY}")
print(f"  LR schedule       : Linear warmup ({warmup_steps} steps) + Polynomial decay")
print(f"  Clip norm         : {CLIP_NORM}")
print(f"  Mixed precision   : {USE_AMP}")
print(f"  Augmentation      : {USE_AUGMENTATION}")
print(f"  Epochs            : {EPOCHS}")
print(f"  Seed              : {SEED}")
print(f"  Dataset split     : 70/15/15 stratified")
print(f"  Train / Val / Test: {len(train_dataset)} / {len(val_dataset)} / {len(test_dataset)}")
print(f"  Classes           : {CLASS_NAMES}")
print("─" * 60)
print(f"  Best Val Acc      : {best_val_acc*100:.2f}%")
print(f"  Test Acc          : {test_acc*100:.2f}%")
print(f"  Test Loss         : {test_loss:.4f}")
print("=" * 60)