# 02 — Model Training

In this notebook we:
1. Load HAM10000 with our custom PyTorch Dataset
2. Apply data augmentation (albumentations)
3. Fine-tune **EfficientNet-B0** using transfer learning
4. Handle class imbalance with weighted loss
5. Train with early stopping and save the best model

> Training on **CPU** — using 128x128 images and a lightweight model to keep it manageable.

---

In [None]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('TkAgg')  # For interactive plots on Windows
import matplotlib.pyplot as plt
from pathlib import Path

from src.config import (
    DATA_DIR, MODELS_DIR, RESULTS_DIR, SEED,
    CLASS_NAMES, CLASS_LABELS, NUM_CLASSES,
    IMAGE_SIZE, BATCH_SIZE, LEARNING_RATE, NUM_EPOCHS,
    MODEL_NAME, EARLY_STOPPING_PATIENCE,
)

print(f'PyTorch: {torch.__version__}')
print(f'Device: {"cuda" if torch.cuda.is_available() else "cpu"}')
print(f'Model: {MODEL_NAME}')
print(f'Image size: {IMAGE_SIZE}x{IMAGE_SIZE}')
print(f'Batch size: {BATCH_SIZE}')
print(f'Epochs: {NUM_EPOCHS}')
print(f'Learning rate: {LEARNING_RATE}')
print()
print('Setup complete!')

---
## 1. Prepare the Data

We'll use our custom `HAM10000Dataset` class with:
- **Stratified splitting** — ensures all 7 classes appear in train/val/test
- **Data augmentation** — flips, rotations, color jitter (train only)
- **Weighted loss** — computed from class frequencies to handle imbalance

In [None]:
from src.dataset import prepare_dataloaders

# Build image path map (images are split across two folders)
# We need to consolidate the path for the dataset class
image_dirs = [
    DATA_DIR / 'HAM10000_images_part_1',
    DATA_DIR / 'HAM10000_images_part_2',
]

# Load metadata and add the full image path
metadata_path = DATA_DIR / 'HAM10000_metadata.csv'
df = pd.read_csv(metadata_path)

# Build path lookup
image_path_map = {}
for d in image_dirs:
    for f in d.iterdir():
        if f.suffix == '.jpg':
            image_path_map[f.stem] = f

print(f'Found {len(image_path_map)} images')
print(f'Metadata rows: {len(df)}')

In [None]:
from sklearn.model_selection import train_test_split
from src.dataset import HAM10000Dataset, get_transforms

# Stratified split: 70% train, 15% val, 15% test
train_val_df, test_df = train_test_split(
    df, test_size=0.15, stratify=df['dx'], random_state=SEED
)
train_df, val_df = train_test_split(
    train_val_df, test_size=0.176, stratify=train_val_df['dx'], random_state=SEED
    # 0.176 of 85% ≈ 15% of total
)

print(f'Train: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)')
print(f'Val:   {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)')
print(f'Test:  {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)')
print()

# Verify stratification
print('Class distribution check (should be similar across splits):')
for cls in CLASS_NAMES:
    tr = (train_df['dx'] == cls).sum() / len(train_df) * 100
    va = (val_df['dx'] == cls).sum() / len(val_df) * 100
    te = (test_df['dx'] == cls).sum() / len(test_df) * 100
    print(f'  {CLASS_LABELS[cls]:30s}  train:{tr:5.1f}%  val:{va:5.1f}%  test:{te:5.1f}%')

In [None]:
from torch.utils.data import DataLoader

# Create datasets with our custom class
# We pass the image_path_map so it can find images across both folders
train_dataset = HAM10000Dataset(train_df, image_dirs=image_dirs, transform=get_transforms('train'))
val_dataset = HAM10000Dataset(val_df, image_dirs=image_dirs, transform=get_transforms('val'))
test_dataset = HAM10000Dataset(test_df, image_dirs=image_dirs, transform=get_transforms('test'))

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Test a batch
images, labels = next(iter(train_loader))
print(f'Batch shape: {images.shape}')  # Should be [16, 3, 128, 128]
print(f'Labels shape: {labels.shape}')
print(f'Label values: {labels.tolist()}')

In [None]:
# Compute class weights for imbalanced data
label_counts = train_df['dx'].value_counts()
total = len(train_df)
class_weights = torch.tensor(
    [total / (NUM_CLASSES * label_counts.get(c, 1)) for c in CLASS_NAMES],
    dtype=torch.float32,
)

print('Class weights (higher = rarer class gets more attention):')
for cls, w in zip(CLASS_NAMES, class_weights):
    print(f'  {CLASS_LABELS[cls]:30s}: {w:.2f}')

### Visualize Augmented Samples

Let's see what the augmentation does to the images — this is important to verify it's working correctly.

In [None]:
from src.preprocessing import denormalize

# Show a few augmented training samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(10):
    img, label = train_dataset[i]
    img_display = denormalize(img)  # Undo normalization for display
    ax = axes[i // 5, i % 5]
    ax.imshow(img_display)
    ax.set_title(CLASS_NAMES[label], fontsize=10)
    ax.axis('off')

plt.suptitle('Augmented Training Samples', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---
## 2. Build the Model

We use **EfficientNet-B0** pretrained on ImageNet:
1. First, **freeze** the base and only train the classifier head (fast)
2. Then, **unfreeze** everything and fine-tune with a lower learning rate

In [None]:
from src.model import create_model, freeze_base, unfreeze_all, get_model_summary

# Create the model
model = create_model(MODEL_NAME, NUM_CLASSES, pretrained=True)
print(f'Model: {MODEL_NAME}')
print(f'Output classes: {NUM_CLASSES}')
print()

# Freeze base layers — only train the classifier head first
model = freeze_base(model)
print('After freezing base:')
get_model_summary(model)

---
## 3. Train — Phase 1: Classifier Head Only

First we train just the classification head with the base frozen. This is fast and teaches the head to use the pretrained features.

In [None]:
from src.train import train_model, set_seed

set_seed(SEED)

print('=== PHASE 1: Training classifier head (base frozen) ===')
print('This should be fast even on CPU...\n')

model, history_phase1 = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights,
    num_epochs=5,           # Just a few epochs for the head
    lr=1e-3,                # Higher LR since only training head
    patience=3,
    save_name='phase1_head.pth',
)

print('\nPhase 1 complete!')

---
## 4. Train — Phase 2: Full Fine-tuning

Now we unfreeze everything and fine-tune the entire model with a lower learning rate.

In [None]:
# Unfreeze all layers
model = unfreeze_all(model)
print('After unfreezing all layers:')
get_model_summary(model)

print()
print('=== PHASE 2: Full fine-tuning ===')
print('This will take longer on CPU — grab a coffee...\n')

model, history_phase2 = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,        # Lower LR for fine-tuning
    patience=EARLY_STOPPING_PATIENCE,
    save_name='best_model.pth',
)

print('\nPhase 2 complete!')

---
## 5. Training Curves

Let's visualize how training went across both phases.

In [None]:
# Combine histories from both phases
history = {
    'train_loss': history_phase1['train_loss'] + history_phase2['train_loss'],
    'val_loss': history_phase1['val_loss'] + history_phase2['val_loss'],
    'train_acc': history_phase1['train_acc'] + history_phase2['train_acc'],
    'val_acc': history_phase1['val_acc'] + history_phase2['val_acc'],
}

from src.evaluate import plot_training_history

plot_training_history(history, save=True)
print(f'\nBest val accuracy: {max(history["val_acc"]):.4f}')
print(f'Best val loss: {min(history["val_loss"]):.4f}')

---
## 6. Quick Test Set Evaluation

Let's see how the trained model performs on the held-out test set.

In [None]:
from src.evaluate import get_predictions, print_classification_report

device = torch.device('cpu')
model = model.to(device)

# Get predictions on test set
y_true, y_pred, y_probs = get_predictions(model, test_loader, device)

print('=== TEST SET RESULTS ===')
print()
print_classification_report(y_true, y_pred)

# Overall accuracy
accuracy = (y_true == y_pred).mean()
print(f'\nOverall test accuracy: {accuracy:.4f} ({accuracy*100:.1f}%)')

In [None]:
from src.evaluate import plot_confusion_matrix

plot_confusion_matrix(y_true, y_pred, save=True)

In [None]:
# Save test predictions for later analysis
import json

test_results = {
    'accuracy': float(accuracy),
    'best_val_loss': float(min(history['val_loss'])),
    'best_val_acc': float(max(history['val_acc'])),
    'total_epochs': len(history['train_loss']),
    'model': MODEL_NAME,
    'image_size': IMAGE_SIZE,
}

with open(RESULTS_DIR / 'training_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print('Results saved to results/training_results.json')
print(json.dumps(test_results, indent=2))

---
## Summary

What we did:
1. Loaded 10,015 images with stratified 70/15/15 split
2. Applied augmentation (flips, rotations, color jitter) to training data
3. Used weighted cross-entropy loss to handle class imbalance
4. **Phase 1**: Trained classifier head only (base frozen) — fast warm-up
5. **Phase 2**: Fine-tuned entire model with lower learning rate
6. Early stopping saved the best model by validation loss

### Saved files:
- `models/best_model.pth` — best model weights
- `results/training_curves.png` — loss and accuracy plots
- `results/confusion_matrix.png` — test set confusion matrix
- `results/training_results.json` — metrics summary

### Next Steps

-> **03_evaluation.ipynb** — deeper evaluation with ROC curves

-> **04_gradcam.ipynb** — visualize what the model is looking at