# 02 — Model Training

In this notebook we:
1. Load HAM10000 with our custom PyTorch Dataset
2. Apply data augmentation (albumentations)
3. Fine-tune **EfficientNet-B0** using transfer learning
4. Handle class imbalance with weighted loss
5. Train with early stopping and save the best model

> Training on **CPU** — using 128x128 images and a lightweight model to keep it manageable.

---

In [None]:
import sys
sys.path.append('..')

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from src.config import (
    DATA_DIR, MODELS_DIR, RESULTS_DIR, SEED,
    CLASS_NAMES, CLASS_LABELS, NUM_CLASSES,
    IMAGE_SIZE, BATCH_SIZE, LEARNING_RATE, NUM_EPOCHS,
    MODEL_NAME, EARLY_STOPPING_PATIENCE,
)

print(f'PyTorch: {torch.__version__}')
print(f'Device: {"cuda" if torch.cuda.is_available() else "cpu"}')
print(f'Model: {MODEL_NAME}')
print(f'Image size: {IMAGE_SIZE}x{IMAGE_SIZE}')
print(f'Batch size: {BATCH_SIZE}')
print(f'Epochs: {NUM_EPOCHS}')
print(f'Learning rate: {LEARNING_RATE}')
print()
print('Setup complete!')

---
## 1. Prepare the Data

We need to do three things before training:

1. **Find all images** — HAM10000 splits images across two folders (`part_1` and `part_2`), so we build a lookup dictionary to find any image by its ID.

2. **Split the data** — We split into 70% train / 15% validation / 15% test, using **stratified splitting** (ensures each split has the same proportion of each skin condition).

3. **Handle class imbalance** — Since Melanocytic Nevus is 67% of the data, we compute **class weights** so the model pays more attention to rare conditions.

In [None]:
# The images are split across two folders — let's find them all
image_dirs = [
    DATA_DIR / 'HAM10000_images_part_1',
    DATA_DIR / 'HAM10000_images_part_2',
]

# Load the metadata CSV (contains image_id, diagnosis, age, sex, etc.)
metadata_path = DATA_DIR / 'HAM10000_metadata.csv'
df = pd.read_csv(metadata_path)

# Count images across both folders
total_images = sum(
    len([f for f in d.iterdir() if f.suffix == '.jpg'])
    for d in image_dirs if d.exists()
)

print(f'Found {total_images} images')
print(f'Metadata rows: {len(df)}')

### Stratified Train / Val / Test Split

**Why stratified?** If we split randomly, rare classes like Dermatofibroma (1.1%) might end up with 0 samples in the test set. Stratified splitting guarantees each class keeps its proportion in every split.

**Why 70/15/15?**
- **Train (70%)** — the model learns from these images
- **Validation (15%)** — we check performance after each epoch to detect overfitting
- **Test (15%)** — final evaluation on data the model has never seen

In [None]:
from sklearn.model_selection import train_test_split
from src.dataset import HAM10000Dataset, get_transforms

# Stratified split: 70% train, 15% val, 15% test
train_val_df, test_df = train_test_split(
    df, test_size=0.15, stratify=df['dx'], random_state=SEED
)
train_df, val_df = train_test_split(
    train_val_df, test_size=0.176, stratify=train_val_df['dx'], random_state=SEED
    # 0.176 of 85% = 15% of total
)

print(f'Train: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)')
print(f'Val:   {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)')
print(f'Test:  {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)')
print()

# Verify stratification — percentages should be similar across all splits
print('Class distribution check (should be similar across splits):')
for cls in CLASS_NAMES:
    tr = (train_df['dx'] == cls).sum() / len(train_df) * 100
    va = (val_df['dx'] == cls).sum() / len(val_df) * 100
    te = (test_df['dx'] == cls).sum() / len(test_df) * 100
    print(f'  {CLASS_LABELS[cls]:30s}  train:{tr:5.1f}%  val:{va:5.1f}%  test:{te:5.1f}%')

### Create PyTorch Datasets and DataLoaders

Our custom `HAM10000Dataset` class does three things:
1. **Loads an image** from disk by its ID
2. **Applies augmentation** (only for training — flips, rotations, color changes)
3. **Returns a tensor** (numbers the model can process) + the label (which condition it is)

A **DataLoader** wraps the dataset and feeds images to the model in batches of 16.

In [None]:
from torch.utils.data import DataLoader

# Create datasets — training gets augmentation, val/test don't
train_dataset = HAM10000Dataset(train_df, image_dirs=image_dirs, transform=get_transforms('train'))
val_dataset = HAM10000Dataset(val_df, image_dirs=image_dirs, transform=get_transforms('val'))
test_dataset = HAM10000Dataset(test_df, image_dirs=image_dirs, transform=get_transforms('test'))

# Create dataloaders (feeds batches of 16 images to the model)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Test: grab one batch and check the shapes
images, labels = next(iter(train_loader))
print(f'Batch shape: {images.shape}')   # [16, 3, 128, 128] = 16 images, 3 color channels, 128x128 pixels
print(f'Labels shape: {labels.shape}')  # [16] = one label per image
print(f'Label values: {labels.tolist()}')  # Numbers 0-6 representing the 7 conditions

### Class Weights for Imbalanced Data

Remember from EDA: Melanocytic Nevus is 67% of the data while Dermatofibroma is only 1%.

Without weights, the model would just predict "mole" for everything and get 67% accuracy.

**Class weights** tell the loss function: *"if you get a rare condition wrong, it counts much more."*

Formula: `weight = total_samples / (num_classes * class_count)`
- Common class (nv, 6705 images) -> low weight (~0.2)
- Rare class (df, 115 images) -> high weight (~12.5)

In [None]:
# Compute class weights for imbalanced data
label_counts = train_df['dx'].value_counts()
total = len(train_df)
class_weights = torch.tensor(
    [total / (NUM_CLASSES * label_counts.get(c, 1)) for c in CLASS_NAMES],
    dtype=torch.float32,
)

print('Class weights (higher = rarer class gets more attention):')
for cls, w in zip(CLASS_NAMES, class_weights):
    print(f'  {CLASS_LABELS[cls]:30s}: {w:.2f}')

### Visualize Augmented Samples

Data augmentation creates variations of training images (flips, rotations, color shifts).

**Why?** It artificially increases dataset size and teaches the model that a mole is still a mole even if the image is flipped or slightly darker. This prevents overfitting (memorizing training images instead of learning patterns).

In [None]:
from src.preprocessing import denormalize

# Show a few augmented training samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(10):
    img, label = train_dataset[i]
    img_display = denormalize(img)  # Undo normalization so we can see the actual colors
    ax = axes[i // 5, i % 5]
    ax.imshow(img_display)
    ax.set_title(CLASS_NAMES[label], fontsize=10)
    ax.axis('off')

plt.suptitle('Augmented Training Samples (flipped, rotated, color-shifted)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(RESULTS_DIR / 'augmented_samples.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 2. Build the Model

### What is Transfer Learning?

**EfficientNet-B0** was pre-trained on **ImageNet** (14 million images, 1000 categories like cats, dogs, cars).

It already knows how to see edges, textures, shapes, and colors. We don't need to teach it that from scratch — we just need to teach it *our specific task*: "which of these 7 skin conditions is this?"

### Two-Phase Training Strategy:

1. **Phase 1 — Head only** (fast): Freeze the pretrained layers, only train the new classifier head. This teaches the head to use the existing features for our 7 classes.

2. **Phase 2 — Full fine-tuning** (slower): Unfreeze everything and fine-tune with a very small learning rate. This gently adjusts the pretrained features to be even better for skin lesions.

In [None]:
from src.model import create_model, freeze_base, unfreeze_all, get_model_summary

# Create the model — downloads pretrained weights from ImageNet
model = create_model(MODEL_NAME, NUM_CLASSES, pretrained=True)
print(f'Model: {MODEL_NAME}')
print(f'Output classes: {NUM_CLASSES}')
print()

# Phase 1: Freeze base layers — only the classifier head will train
model = freeze_base(model)
print('After freezing base:')
get_model_summary(model)
print()
print('Only ~9K parameters will train in Phase 1 (the head).')
print('The other 4M parameters stay frozen (pretrained ImageNet features).')

---
## 3. Train — Phase 1: Classifier Head Only

We train only the last layer (classifier head) for 5 epochs with a **higher learning rate (1e-3)**.

This is fast because we're only updating ~9,000 parameters out of 4 million. Think of it as: the model already knows how to see, we're just teaching it what our 7 labels mean.

**What to watch for:**
- Training loss should drop quickly
- Validation accuracy should improve
- This phase takes just a few minutes on CPU

In [None]:
from src.train import train_model, set_seed

set_seed(SEED)

print('=== PHASE 1: Training classifier head (base frozen) ===')
print('This should be fast even on CPU...\n')

model, history_phase1 = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights,
    num_epochs=5,           # Just a few epochs for the head
    lr=1e-3,                # Higher LR since only training the head
    patience=3,
    save_name='phase1_head.pth',
)

print('\nPhase 1 complete!')

---
## 4. Train — Phase 2: Full Fine-tuning

Now we unfreeze ALL layers and train the entire model with a **much lower learning rate (1e-4)**.

**Why a lower learning rate?** The pretrained features are already good — we don't want to destroy them with large updates. Small adjustments make the features slightly better for skin lesions specifically.

**Early stopping** watches validation loss. If it doesn't improve for 4 epochs in a row, training stops automatically to prevent overfitting.

**What to watch for:**
- Val accuracy should gradually improve beyond Phase 1
- If val loss starts going UP while train loss goes DOWN = overfitting (early stopping catches this)
- This phase takes longer (~20-30 min on CPU)

In [None]:
# Unfreeze all layers for full fine-tuning
model = unfreeze_all(model)
print('After unfreezing all layers:')
get_model_summary(model)

print()
print('=== PHASE 2: Full fine-tuning ===')
print('This will take longer on CPU — grab a coffee...\n')

model, history_phase2 = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights,
    num_epochs=NUM_EPOCHS,
    lr=LEARNING_RATE,        # Lower LR for fine-tuning
    patience=EARLY_STOPPING_PATIENCE,
    save_name='best_model.pth',
)

print('\nPhase 2 complete!')

---
## 5. Training Curves

These charts show how the model learned over time:

- **Loss chart (left)** — should go down for both train and val. If train keeps dropping but val goes up, the model is overfitting.
- **Accuracy chart (right)** — should go up. The gap between train and val accuracy tells us about overfitting.

The vertical "jump" between Phase 1 and Phase 2 is normal — we went from training 9K params to 4M params.

In [None]:
# Combine histories from both phases into one timeline
history = {
    'train_loss': history_phase1['train_loss'] + history_phase2['train_loss'],
    'val_loss': history_phase1['val_loss'] + history_phase2['val_loss'],
    'train_acc': history_phase1['train_acc'] + history_phase2['train_acc'],
    'val_acc': history_phase1['val_acc'] + history_phase2['val_acc'],
}

from src.evaluate import plot_training_history

plot_training_history(history, save=True)
print(f'\nBest val accuracy: {max(history["val_acc"]):.4f} ({max(history["val_acc"])*100:.1f}%)')
print(f'Best val loss: {min(history["val_loss"]):.4f}')
print(f'Total epochs trained: {len(history["train_loss"])}')

---
## 6. Test Set Evaluation

Now the real test — how does the model perform on images it has **never seen during training**?

The **classification report** shows per-class metrics:
- **Precision** — "When the model says melanoma, how often is it right?"
- **Recall** — "Of all actual melanomas, how many did the model catch?"
- **F1-score** — The balance between precision and recall

For medical applications, **recall is more important than precision** — it's worse to miss a melanoma (false negative) than to flag a mole for review (false positive).

In [None]:
from src.evaluate import get_predictions, print_classification_report

device = torch.device('cpu')
model = model.to(device)

# Get predictions on the test set
y_true, y_pred, y_probs = get_predictions(model, test_loader, device)

print('=== TEST SET RESULTS ===')
print()
print_classification_report(y_true, y_pred)

# Overall accuracy
accuracy = (y_true == y_pred).mean()
print(f'\nOverall test accuracy: {accuracy:.4f} ({accuracy*100:.1f}%)')

### Confusion Matrix

The confusion matrix shows exactly where the model gets confused.

- **Diagonal** (top-left to bottom-right) = correct predictions
- **Off-diagonal** = mistakes
- For example, if the model often confuses Melanoma with Melanocytic Nevus, you'll see a high number in that cell.

In [None]:
from src.evaluate import plot_confusion_matrix

plot_confusion_matrix(y_true, y_pred, save=True)

In [None]:
# Save results to a JSON file for the README and future reference
import json

test_results = {
    'accuracy': float(accuracy),
    'best_val_loss': float(min(history['val_loss'])),
    'best_val_acc': float(max(history['val_acc'])),
    'total_epochs': len(history['train_loss']),
    'model': MODEL_NAME,
    'image_size': IMAGE_SIZE,
}

with open(RESULTS_DIR / 'training_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print('Results saved to results/training_results.json')
print(json.dumps(test_results, indent=2))

---
## Summary

### What we did:
1. **Loaded** 10,015 images with stratified 70/15/15 split
2. **Augmented** training data (flips, rotations, color jitter) to prevent overfitting
3. **Weighted** the loss function so rare conditions get more attention
4. **Phase 1**: Trained only the classifier head (base frozen) — fast warm-up
5. **Phase 2**: Fine-tuned entire model with lower learning rate — full optimization
6. **Early stopping** saved the best model automatically

### Key concepts learned:
- **Transfer learning** — reusing knowledge from ImageNet instead of training from scratch
- **Two-phase training** — head first, then fine-tune all
- **Class imbalance handling** — weighted loss function
- **Early stopping** — prevents overfitting by stopping when validation loss stops improving
- **Stratified splitting** — preserves class proportions in train/val/test

### Saved files:
- `models/best_model.pth` — best model weights
- `results/training_curves.png` — loss and accuracy plots
- `results/confusion_matrix.png` — test set confusion matrix
- `results/training_results.json` — metrics summary

### Next Steps

-> **03_evaluation.ipynb** — deeper evaluation with ROC curves and per-class analysis

-> **04_gradcam.ipynb** — visualize which regions of the image the model focuses on