## Notebook 3: Callbacks, Hooks & Training

Objective: Learn about PyTorch Lightning Callbacks, implement built-in and
           custom callbacks, set up and run a full training loop using the
           Trainer, and export the model using TorchScript.

In [None]:

# --- Imports ---
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import os

# Assume DataModule and LightningModule can be imported from previous scripts
# You might need to adjust sys.path or package structure for this to work
# Example (adjust paths as needed):
# from pathlib import Path
# import sys
# script_dir = Path(__file__).parent
# sys.path.append(str(script_dir.parent)) # Add project root to path
# from src.01_Dataset_Creation_Augmentation import PatchDataModule, train_transforms, val_test_transforms, class_to_idx # Need to refactor these out or pass config
# from src.02_model_training import HistologyClassifier

# --- Configuration (Placeholders - Load these properly) ---
# These would typically come from config files, CLI args, or imported modules
CHECKPOINT_DIR = "./lightning_checkpoints"
LOG_DIR = "./lightning_logs"
NUM_CLASSES = 7
LEARNING_RATE = 1e-4

# Example: Assume these are loaded or defined
# class_to_idx = {'TER': 0, 'Necrotic': 1, ...} # Load from previous step or config
# train_transforms = ... # Load from previous step or config
# val_test_transforms = ... # Load from previous step or config
# DATA_DIR = '../data/raw/patch_classification_dataset'

## ðŸ“ž Introduction to Callbacks

Callbacks are self-contained programs that can be added to your PyTorch Lightning
`Trainer`. They allow you to add custom logic at various stages of the training
process (e.g., at the beginning/end of an epoch, before/after a batch) without
cluttering your `LightningModule`.

**Why use Callbacks?**
- **Modularity:** Keep training logic separate from model definition.
- **Reusability:** Easily reuse common logic like checkpointing or early stopping across projects.
- **Extensibility:** Hook into specific points in the training loop for monitoring, logging, or other actions.

Lightning provides several useful built-in callbacks, and you can easily create
your own.

### `ModelCheckpoint`

This callback saves your model's weights periodically during training.
Key parameters:
- `dirpath`: Directory to save checkpoints.
- `filename`: Naming pattern for checkpoint files (can include metrics).
- `monitor`: Metric to monitor for saving the 'best' model (e.g., 'val_loss_epoch').
- `mode`: 'min' or 'max' depending on whether the monitored metric should be minimized or maximized.
- `save_top_k`: Save the top 'k' best models according to the monitored metric.
- `save_last`: Save the latest model checkpoint at the end of every epoch.


In [None]:

# Example configuration for ModelCheckpoint
# Saves the best model based on validation accuracy (higher is better)
checkpoint_callback_acc = ModelCheckpoint(
    dirpath=CHECKPOINT_DIR,
    filename='best-model-acc-{epoch:02d}-{val_acc:.2f}',
    monitor='val_acc', # Assuming 'val_acc' is logged in LightningModule
    mode='max',
    save_top_k=1, # Save only the single best model
    save_last=True, # Also save the latest model state
    verbose=True
)

# Example configuration for ModelCheckpoint
# Saves the best model based on validation loss (lower is better)
checkpoint_callback_loss = ModelCheckpoint(
    dirpath=CHECKPOINT_DIR,
    filename='best-model-loss-{epoch:02d}-{val_loss_epoch:.2f}',
    monitor='val_loss_epoch', # Monitor the epoch validation loss
    mode='min',
    save_top_k=1,
    save_last=True,
    verbose=True
)

print("ModelCheckpoint callbacks configured.")

### `EarlyStopping`

This callback stops training early if a monitored metric stops improving,
preventing overfitting and saving computation time.
Key parameters:
- `monitor`: Metric to monitor (e.g., 'val_loss_epoch').
- `mode`: 'min' or 'max'.
- `patience`: Number of epochs to wait for improvement before stopping.
- `min_delta`: Minimum change in the monitored quantity to qualify as an improvement.
- `verbose`: Print messages when stopping.

In [None]:

# Example configuration for EarlyStopping
# Stops training if validation loss doesn't improve for 5 consecutive epochs
early_stopping_callback = EarlyStopping(
    monitor='val_loss_epoch',
    mode='min',
    patience=10, # Increase patience for potentially noisy validation loss
    min_delta=0.001,
    verbose=True
)

print("EarlyStopping callback configured.")

### `LearningRateMonitor`

Automatically logs the learning rate used by the optimizer(s) at each step or epoch.
Very useful when using learning rate schedulers.
Key parameters:
- `logging_interval`: 'step' or 'epoch'.

In [None]:

lr_monitor_callback = LearningRateMonitor(logging_interval='epoch')

print("LearningRateMonitor callback configured.")