# EfficientNet B3 Training for Periapical Index (PAI) Classification

**Author:** Gerald Torgersen  
**Date:** 2025  
**GitHub:** [github.com/geraldOslo/PAI-meets-AI](https://github.com/geraldOslo/PAI-meets-AI)  

**License**  
SPDX-License-Identifier: MIT  
Copyright (c) 2025 Gerald Torgersen


## Overview

This notebook orchestrates the training of a deep learning model (default: EfficientNet-B3) to classify dental radiographs based on the Periapical Index (PAI) scale. The PAI scores (originally 1-5) are mapped to classes 0-4 for training purposes.

## Workflow Summary

This notebook serves as the main script to run the training process. Core functionalities are modularized into separate Python files (`.py`) for better organization and reusability:

-   **`config.py`**: Manages all hyperparameters, data paths, augmentation settings, and other configurations.
-   **`data_utils.py`**: Handles dataset loading (`CustomDataset`), data statistics display, and the creation of data augmentation pipelines.
-   **`model_utils.py`**: Defines the function (`get_model`) to load pretrained models (like EfficientNet) and adapt them for the PAI classification task (e.g., modify classifier head, apply fine-tuning strategy).
-   **`train_utils.py`**: Contains core training helper functions, including loss function definitions (CrossEntropy, FocalLoss), optimizer/scheduler creation (`get_optimizer`, `get_scheduler`), gradient clipping, Mixup (if enabled), checkpointing, and saving results (history JSON, summary YAML).
-   **`visualization_utils.py`**: Provides functions to generate and save plots for learning rates, GPU memory usage, and training/validation metrics.

The notebook imports these modules and executes the following steps:
1.  Load configuration from `config.py`.
2.  Load dataset metadata and perform train/validation split.
3.  Apply oversampling to the training set.
4.  Create PyTorch Datasets and DataLoaders.
5.  Initialize the model, loss function, optimizer, and scheduler using utility functions.
6.  Run the main training and validation loop.
7.  Save the best model checkpoint, training history, plots, and a final summary YAML file.

## Configuration

All settings are managed externally in `config.py`. Before running this notebook:

1.  Ensure `config.py`, `data_utils.py`, `model_utils.py`, `train_utils.py`, and `visualization_utils.py` are in the same directory as this notebook or accessible via Python's path.
2.  If starting fresh, copy `config_template.py` (if available) to `config.py`.
3.  Edit `config.py` to match your environment (data paths, desired hyperparameters, etc.).

## Key Training Techniques Implemented

-   Fine-tuning of models pre-trained on ImageNet (e.g., EfficientNet-B3).
-   Configurable loss functions (e.g., CrossEntropy with Label Smoothing, Focal Loss).
-   Mixed Precision Training (AMP) via `torch.amp` for faster training and reduced memory usage on compatible GPUs.
-   AdamW Optimizer.
-   OneCycleLR Learning Rate Scheduling for efficient convergence.
-   Gradient Clipping to prevent exploding gradients.
-   Early Stopping based on validation F1-score to prevent overfitting and save time.
-   Optional Mixup data augmentation (configurable in `config.py`).
-   Comprehensive logging (training history saved to JSON, run summary saved to YAML) and visualization (metrics, LR, GPU memory plots saved as PNG).

## Data Handling

-   Loads image paths and PAI labels (originally 1-5) from CSV file(s) specified in `config.py`.
-   Maps PAI labels to a 0-based index (0-4) for classification.
-   Checks for the existence of image files and filters the dataset accordingly.
-   Splits data into training and validation sets using stratified sampling based on PAI class.
-   Applies `RandomOverSampler` from `imblearn` to balance the class distribution **only in the training dataset**. The validation set retains its original (potentially imbalanced) distribution for realistic evaluation.
-   Applies configurable data augmentation (defined in `config.py` and implemented in `data_utils.py`) to the training set during training.
-   Applies standard resizing, center cropping, and normalization to the validation set.
-   Displays class distribution statistics for the initial, split, and resampled datasets.

## Requirements

-   PyTorch (1.8+ recommended)
-   torchvision
-   scikit-learn
-   imblearn (for `RandomOverSampler`)
-   pandas
-   numpy
-   matplotlib
-   tqdm
-   PyYAML (for saving summary file)
-   CUDA-capable GPU (highly recommended for reasonable training times)

## Usage

1.  Configure `config.py` for your dataset paths and desired hyperparameters.
2.  Ensure all required `.py` utility files are present.
3.  Run the cells of this notebook sequentially.
4.  The best model checkpoint (`_best.pth`), plots, history (`_history.json`), and summary (`_summary.yaml`) will be saved to the directory specified in `config.py` (`CHECKPOINT_DIR`) with a timestamped filename.

## Model Application

The trained model (`_best.pth` file) can be loaded and used with separate inference scripts to classify new dental radiographs based on the PAI scale.

In [1]:
# Standard Library Imports
import os
import json
import math
import datetime
import time
import traceback
import gc
import warnings
import sys
import copy

# Third-party Library Imports
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
from torch.amp import GradScaler
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from imblearn.over_sampling import RandomOverSampler
from tqdm.notebook import tqdm
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image # Typically used by data_utils for image loading
import yaml

# Type Hinting Imports
from typing import Dict, List, Optional, Tuple, Union

# Local Project Utility Imports
from config import * # Imports all variables from config.py
import data_utils
import model_utils
import train_utils
import visualization_utils

# For inline plotting in Jupyter Notebooks
%matplotlib inline

# Suppress specific warnings, e.g., from OneCycleLR verbose or AMP anomaly detection
warnings.filterwarnings("ignore", message="The verbose parameter is deprecated.*", category=UserWarning)
warnings.filterwarnings("ignore", message="Anomaly Detection has been enabled.*", category=UserWarning)

## Configuration Setup

This section loads parameters from `config.py`, detects the available compute device (CPU/GPU), and sets up all necessary file paths for saving model checkpoints, training history, and plots. The `training_config` dictionary is assembled to consolidate all relevant settings for easy access and later summarization.

In [None]:
# --- Configuration Loading ---
# Data Paths loaded from config.py
root_dirs = DATA_PATHS["root_dirs"]
csv_files = DATA_PATHS["csv_files"]

# Combine relevant configs into a single dictionary for training script use
training_config = {}
training_config.update(MODEL_CONFIG)
training_config.update(NORMALIZATION) # Add normalization info
training_config.update(DATALOADER_CONFIG)
training_config.update(DATA_TRANSFORM_CONFIG)

# --- Runtime Device Detection ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name_detected = "cpu"
if device.type == 'cuda':
    try:
        # Get GPU name, sanitize it for filenames
        gpu_name_detected = torch.cuda.get_device_name(0).replace(" ", "_").replace("-", "_")
    except Exception as e:
        print(f"Could not get GPU name: {e}")
        gpu_name_detected = "gpu_unknown"
print(f"Runtime Detected Device: {device} ({gpu_name_detected})")
training_config['device_info'] = f"{device.type} ({gpu_name_detected})" # Store for summary
print(f"Pytorch version: {torch.__version__}")

model_name = training_config['model']
print(f"Model name: {model_name}")

# --- Timestamp and Save Paths ---
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
training_config['timestamp'] = timestamp  # Store for summary

# Create a unique file prefix for all saved assets based on model, device, and timestamp
file_prefix = f"{model_name}_{gpu_name_detected}_{timestamp}"
training_config['file_prefix'] = file_prefix  # Store for summary

# Define all output file paths within the CHECKPOINT_DIR
best_model_save_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_best.pth")
history_json_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_history.json")
summary_yaml_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_summary.yaml")
gpu_plot_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_gpu_memory.png")
lr_plot_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_lr_schedule.png")
metrics_plot_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_metrics.png")
validation_report_csv_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_validation_report.csv")

# --- Create Checkpoint Directory if it doesn't exist ---
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Display Effective Training Configuration ---
print("\n--- Effective Training Configuration ---")
for key, value in training_config.items():
     # Optionally shorten long lists/dicts for display to keep console output tidy
    if isinstance(value, list) and len(value) > 10:
        print(f"  {key}: [List with {len(value)} items]")
    elif isinstance(value, dict) and len(value) > 5:
         print(f"  {key}: {{Dict with {len(value)} keys}}")
    else:
         print(f"  {key}: {value}")
print("-------------------------------------")
print(f"Best model will be saved to: {best_model_save_path}")
print(f"Results and logs base path: {CHECKPOINT_DIR}")

## Data Loading and Preparation

This section handles the loading of dataset metadata from CSV files, performs train-validation splitting with stratification, applies data augmentation transforms, and sets up PyTorch DataLoaders. Oversampling is applied to the training set to address class imbalance, ensuring a more robust training process.

In [None]:
print("\n--- Loading Data & Creating Datasets using data_utils ---")

# 1. Create Data Transforms using the configuration loaded from config.py
print("Creating data transforms...")
data_transforms = data_utils.create_data_transforms(DATA_TRANSFORM_CONFIG, NORMALIZATION)
print("Data transforms creation complete.")

# 2. Load Train/Validation Datasets using the `data_utils.load_datasets` function.
# This function handles metadata loading, image file checking, PAI mapping (1-indexed to 0-indexed),
# and stratified train/validation splitting.
print("Loading train/validation datasets...")
train_dataset, val_dataset, _, class_names = data_utils.load_datasets(
    data_paths=DATA_PATHS,          # Dictionary of data root directories and CSV files
    transforms=data_transforms,     # Dictionary of training and validation transforms
    split_test_size=0.2,            # Proportion of data for the validation set
    split_random_state=42           # Random seed for reproducible data splitting
)

# Verify that datasets were successfully loaded
if train_dataset is None or val_dataset is None:
    raise RuntimeError("Failed to load train or validation datasets. Please check data paths and CSV integrity.")

# Store validation set size and class names in the training_config for later summary
training_config['val_set_size'] = len(val_dataset)
training_config['class_names'] = class_names

print(f"\nDataset loading and splitting complete via data_utils.")

# --- Create DataLoaders ---
# Sets up PyTorch DataLoaders for efficient batch processing during training and validation.
# `use_oversampler` flag from config.py controls whether `RandomOverSampler` is applied to the training set.
print("\n--- Creating DataLoaders ---")
use_oversampler_flag = training_config.get('use_oversampler', True)
print(f"Creating dataloaders (use_oversampler={use_oversampler_flag})...")

train_loader, val_loader, _ = data_utils.create_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=None,                  # No test loader needed in this training phase
    dataloader_config=DATALOADER_CONFIG, # Dataloader specific settings (batch size, workers, etc.)
    use_oversampler=use_oversampler_flag
)

# Verify DataLoaders were created
if train_loader is None or val_loader is None:
    raise RuntimeError("Failed to create train or validation dataloaders.")

# Store dataset sizes in training_config for final summary
training_config['original_train_set_size'] = len(train_dataset)
try:
    # Attempt to get the actual number of samples in the resampled training loader
    if hasattr(train_loader, 'sampler') and train_loader.sampler is not None:
        if hasattr(train_loader.sampler, 'indices'):
            training_config['resampled_train_set_size'] = len(train_loader.sampler.indices)
        elif hasattr(train_loader.sampler, 'num_samples'):
            training_config['resampled_train_set_size'] = train_loader.sampler.num_samples
        else:
            training_config['resampled_train_set_size'] = len(train_dataset) # Fallback to original if sampler info is ambiguous
    else:
        training_config['resampled_train_set_size'] = len(train_loader.dataset) if hasattr(train_loader, 'dataset') else len(train_dataset)
except Exception as e:
    print(f"Warning: Could not determine resampled train set size: {e}")
    training_config['resampled_train_set_size'] = len(train_dataset)

print(f"DataLoaders created.")
print(f"Original train size: {training_config['original_train_set_size']}")
print(f"Resampled train size (loader): {training_config['resampled_train_set_size']}")
print(f"Validation size: {training_config['val_set_size']}")

## Model and Training Components Initialization

This section initializes the deep learning model (e.g., EfficientNet-B3), sets up the appropriate loss function, configures the optimizer, and defines the learning rate scheduler. It also enables NVIDIA's Automatic Mixed Precision (AMP) for faster training on compatible GPUs and cuDNN benchmarking for optimized performance.

In [None]:
print("\n--- Initializing Training Components ---")

# --- Configuration Variables for Training Loop ---
batch_size        = training_config['batch_size']
accum_steps       = training_config.get('accum_steps', 1) # Default to 1 if not specified
grad_clip_maxnorm = training_config.get('grad_clip_max_norm', None)
use_amp           = training_config['use_amp'] and device.type == 'cuda'
num_epochs_total  = training_config['epochs']
num_classes       = training_config['num_classes']

# Determine AMP data type (bfloat16 if supported, else float16)
amp_dtype = torch.bfloat16 if use_amp and device.type=='cuda' and torch.cuda.is_bf16_supported() else torch.float16

# --- Get Model (from model_utils) ---
print("Building model...")
try:
    # Call `get_model` with parameters extracted from `training_config`
    # IMPORTANT: Now unpacking two return values from model_utils.get_model
    model, fast_backbone_param_names = model_utils.get_model(
        model_name=training_config['model'],
        num_classes=training_config['num_classes'],
        pretrained=True, # Assuming pretrained models are always desired for fine-tuning
        dropout_rate=training_config['dropout'],
        drop_path_rate=training_config['drop_path_rate'],
        finetune_blocks=training_config['finetune_blocks']
    )
    model = model.to(device) # Move the model to the detected compute device
    print("Model built and moved to device.\n")

except KeyError as e:
     print(f"Error: Missing required key '{e}' in training_config for model creation. Check config.py.\n")
     raise # Re-raise the exception to stop execution
except Exception as e:
     print(f"An unexpected error occurred during model creation: {e}\n")
     traceback.print_exc()
     raise


# --- Get Criterion (Loss Function) from train_utils ---
print("Setting up loss function...")
try:
    # NOTE: The 'criterion_weights' calculation from previous versions of this cell is now
    # handled directly inside train_utils.get_criterion based on the 'use_class_weights'
    # and 'class_weights' parameters in 'training_config'.
    
    # Initialize the loss function using train_utils.get_criterion
    criterion = train_utils.get_criterion(
        config=training_config,
        device=device
    )
    print(f"Loss function '{training_config['loss_function_type']}' set up.\n")
except Exception as e:
     print(f"Error setting up loss function: {e}\n")
     traceback.print_exc()
     raise


# --- Get Optimizer and Scheduler from train_utils ---
print("Setting up optimizer and scheduler...\n")
try:
    if 'train_loader' not in locals() or train_loader is None:
         raise NameError("train_loader is not defined. Ensure data loading cell was run successfully.")

    # Calculate steps per epoch for OneCycleLR scheduler based on training data size and accumulation steps
    steps_per_epoch = math.ceil(len(train_loader) / accum_steps)

    # Initialize optimizer using train_utils.get_optimizer
    # IMPORTANT: Passing the dynamically generated fast_backbone_param_names
    optimizer = train_utils.get_optimizer(
        model=model,
        config=training_config,
        fast_backbone_param_names=fast_backbone_param_names # <-- PASS THIS LIST
    )

    # Initialize scheduler using train_utils.get_scheduler
    scheduler = train_utils.get_scheduler(
        optimizer=optimizer,
        config=training_config,
        steps_per_epoch=steps_per_epoch
    )

    # Store optimizer and scheduler types in training_config for summary
    training_config['optimizer_type'] = type(optimizer).__name__
    training_config['scheduler_type'] = type(scheduler).__name__ if scheduler else 'None'
    print(f"Optimizer '{training_config['optimizer_type']}' and Scheduler '{training_config['scheduler_type']}' set up.\n")
except NameError as ne:
     print(f"Error: {ne}\n")
     raise
except Exception as e:
     print(f"An unexpected error occurred setting up optimizer/scheduler: {e}\n")
     traceback.print_exc()
     raise


# --- Enable cuDNN benchmark (optional, for CUDA-enabled GPUs) ---
if device.type == 'cuda':
    print("\nEnabling cuDNN benchmark for optimized CUDA performance.\n")
    torch.backends.cudnn.benchmark = True

print("\n--- Initialization of Training Components Complete ---")

## Training and Validation Loop

This is the core training loop, iterating over a specified number of epochs. It includes:
-   Gradient accumulation for larger effective batch sizes.
-   Automatic Mixed Precision (AMP) for performance.
-   Gradient clipping to prevent exploding gradients.
-   Real-time monitoring of training and validation metrics (loss, accuracy, F1-score).
-   Early stopping mechanism based on validation F1-score to prevent overfitting.
-   Logging of learning rate and GPU memory usage.
-   Checkpointing the best performing model.

In [None]:
# --- Debugging Flag ---
# Set to True to enable anomaly detection (slower, for debugging NaNs/inf), False for normal runs.
debug_anomaly = False

# --- ANSI Color Codes for enhanced console output ---
COLOR_GREEN = '\033[92m'
COLOR_YELLOW = '\033[93m'
COLOR_RED = '\033[91m'
COLOR_RESET = '\033[0m'

# --- Verify and Extract Training Configuration Parameters ---
num_epochs_total = training_config.get('epochs', 90)
accum_steps = training_config.get('accum_steps', 1)
use_amp = device.type == 'cuda' and training_config.get('use_amp', True)
min_delta = training_config.get('min_delta', 0.001)
patience  = training_config.get('patience', 20)
grad_clip_maxnorm = training_config.get('grad_clip_max_norm', None)
num_classes = training_config.get('num_classes', 5)
batch_size = training_config.get('batch_size', 64)

# Create model configuration dictionary for checkpoint saving
model_config_for_checkpoint = {
    'model_name': training_config.get('model', 'unknown'),
    'dropout_rate': training_config.get('dropout', 0.0),
    'drop_path_rate': training_config.get('drop_path_rate', 0.0),
    'finetune_blocks': training_config.get('finetune_blocks', 0),
    'num_classes': training_config.get('num_classes', 5)
}

# Determine AMP data type (bfloat16 if supported, else float16)
amp_dtype = torch.bfloat16 if use_amp and device.type=='cuda' and torch.cuda.is_bf16_supported() else torch.float16

print("\n" + "="*60)
print(f"▶️ Training for {num_epochs_total} epochs -- Accumulation Steps: {accum_steps}")
print(f"   Optimizer:  {training_config.get('optimizer_type', 'AdamW')}")
print(f"   Scheduler:  {training_config.get('scheduler_type', 'OneCycleLR')}")
print(f"   Device:     {device}  |  AMP Enabled: {use_amp} (dtype: {amp_dtype})")
print(f"   Batch size: {batch_size}   |  Effective batch size: {batch_size * accum_steps}")
if debug_anomaly:
    print(f"   {COLOR_YELLOW}Anomaly Detection: ENABLED (Slower execution){COLOR_RESET}")
print("="*60)

# --- Initialize Training State Variables ---
# `best_metrics_summary` tracks the best validation performance achieved so far.
best_metrics_summary = {'f1': -1.0, 'acc': 0.0, 'loss': float('inf'), 'epoch': -1}

# Lists to store predictions, labels, and filenames from the best validation epoch.
best_epoch_val_preds_final = []
best_epoch_val_labels_final = []
best_epoch_val_filenames_final = []

# `history` dictionary to log metrics for each epoch.
history = {k: [] for k in ['train_loss','train_acc','val_loss','val_acc', 'precision','recall','f1','lr','gpu_mem_used']}

pat_ctr = 0 # Patience counter for early stopping
start_epoch = 0 # Start epoch for the training loop (can be adjusted for resuming training)

# Ensure required components are initialized from previous cells
if 'model' not in locals() or 'criterion' not in locals() or 'optimizer' not in locals():
     raise NameError("Model, criterion, or optimizer objects were not initialized. Please run previous cells.")
scheduler = locals().get('scheduler', None) # Scheduler might be None if not configured

# Initialize GradScaler for Automatic Mixed Precision (AMP)
scaler = torch.amp.GradScaler(enabled=use_amp)

# Print header for the training progress table
print("\n" + "="*101)
print(f"| {'Epoch':<7} | {'Train Loss':<10} | {'Train Acc':<9} | "
      f"{'Val Loss':<8} | {'Val Acc':<7} | {'Precision':<9} | "
      f"{'Recall':<6} | {'F1-Score':<8} | {'GPU Mem':<9} |")
print("|" + "-"*99 + "|")

try:
    for epoch in range(start_epoch, num_epochs_total):

        # ------------------ TRAINING PHASE ------------------
        model.train() # Set model to training mode
        run_loss_epoch = 0.0
        preds_ep, labs_ep = [], []
        skipped_batches_train = 0

        # Zero gradients at the beginning of an accumulation cycle if accumulation is enabled
        if accum_steps <= 1: optimizer.zero_grad(set_to_none=True)

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs_total} [Train]",
                          leave=False, unit="batch", dynamic_ncols=True, ascii=True, file=sys.stdout)

        for b, batch_data in enumerate(train_pbar):
            # Gradient Accumulation: Zero gradients if starting a new accumulation cycle
            if b % accum_steps == 0 and accum_steps > 1:
                 optimizer.zero_grad(set_to_none=True)

            # Load inputs and labels to device
            try: 
                inputs, labels, _ = batch_data
                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True, dtype=torch.long)
            except Exception as e:
                print(f"\n{COLOR_YELLOW}Warn: Skipping train batch {b} due to data loading error: {e}{COLOR_RESET}")
                skipped_batches_train+=1
                continue

            # Define a nested function for the forward and backward pass, optionally with anomaly detection
            def forward_backward_step():
                with torch.amp.autocast(device_type=device.type, enabled=use_amp, dtype=amp_dtype):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels) / accum_steps # Scale loss by accumulation steps
                if not torch.isfinite(loss):
                    print(f"\n{COLOR_YELLOW}Warn: Non-finite loss ({loss.item()*accum_steps:.4f}) in train batch {b} epoch {epoch+1}. Skipping this batch.{COLOR_RESET}")
                    return False, None, None
                scaler.scale(loss).backward() # Scale loss and perform backward pass
                return True, outputs, loss

            step_successful = False
            outer_exception = None
            step_outputs, step_loss = None, None

            # Execute forward/backward step with optional anomaly detection
            if debug_anomaly:
                try:
                    with torch.autograd.detect_anomaly(check_nan=True): 
                        step_successful, step_outputs, step_loss = forward_backward_step()
                except RuntimeError as e: 
                    outer_exception = e; step_successful=False
                except Exception as e: 
                    outer_exception = e; step_successful=False
            else:
                 try: 
                     step_successful, step_outputs, step_loss = forward_backward_step()
                 except RuntimeError as e: 
                     outer_exception = e; step_successful=False
                 except Exception as e: 
                     outer_exception = e; step_successful=False

            # Handle unsuccessful steps (e.g., non-finite loss, runtime errors)
            if outer_exception is not None or not step_successful:
                if outer_exception: print(f"\n{COLOR_RED}Error during train step (batch {b}, Epoch {epoch+1}): {outer_exception}{COLOR_RESET}")
                skipped_batches_train += 1
                # Clear gradients if an error occurred before the optimizer step for this accumulation cycle
                if (b + 1) % accum_steps != 0: optimizer.zero_grad(set_to_none=True)
                continue # Skip to the next batch

            # Accumulate loss and predictions from successful steps
            unscaled_loss_val = step_loss.item() * accum_steps
            run_loss_epoch += unscaled_loss_val * inputs.size(0) # Total accumulated loss for the epoch
            if step_outputs is not None:
                 with torch.no_grad(): # Ensure argmax operation doesn't build a computation graph
                     preds_ep.extend(step_outputs.argmax(1).cpu().numpy())
                     labs_ep.extend(labels.cpu().numpy())
            train_pbar.set_postfix(loss=f"{unscaled_loss_val:.4f}")

            # --- Gradient Accumulation: Perform optimizer step if accumulation cycle is complete ---
            if (b + 1) % accum_steps == 0:
                try:
                    scaler.unscale_(optimizer) # Unscale gradients before clipping
                    if grad_clip_maxnorm and grad_clip_maxnorm > 0: 
                        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_maxnorm)
                    scaler.step(optimizer) # Perform optimizer step
                    scaler.update() # Update the scale for the next iteration
                    optimizer.zero_grad(set_to_none=True) # Zero gradients after the optimizer step
                    # Step LR scheduler (typically per batch for OneCycleLR)
                    if scheduler and isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): 
                        scheduler.step()
                except Exception as step_e:
                     print(f"\n{COLOR_RED}Error during optimizer/scheduler step (batch {b}, Epoch {epoch+1}): {step_e}{COLOR_RESET}")
                     optimizer.zero_grad(set_to_none=True) # Ensure gradients are cleared even on error

        # Handle any remaining gradients if the last accumulation cycle is incomplete
        if accum_steps > 1 and (b + 1) % accum_steps != 0:
            try:
                print(f"  (Performing final optimizer step for epoch {epoch+1} with partial batch)")
                scaler.unscale_(optimizer)
                if grad_clip_maxnorm and grad_clip_maxnorm > 0: 
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_maxnorm)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            except Exception as final_step_e:
                 print(f"\n{COLOR_RED}Error during final optimizer step (Epoch {epoch+1}): {final_step_e}{COLOR_RESET}")

        # Calculate average training loss and accuracy for the epoch
        num_train_samples_processed = len(labs_ep)
        if num_train_samples_processed > 0:
            tr_loss = run_loss_epoch / num_train_samples_processed
            tr_acc = accuracy_score(labs_ep, preds_ep)
        else: 
            tr_loss = float('nan')
            tr_acc = 0.0
            print(f"{COLOR_YELLOW}Warning: No valid training samples processed in epoch {epoch+1}. Metrics for this epoch will be zero/NaN.{COLOR_RESET}")

        # Log training metrics and current learning rate
        history['train_loss'].append(tr_loss)
        history['train_acc'].append(tr_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        # ------------------ VALIDATION PHASE ------------------
        model.eval() # Set model to evaluation mode
        val_loss = 0.0
        v_preds, v_labs, v_files = [], [], []
        skipped_batches_val = 0

        with torch.no_grad(): # Disable gradient calculations for validation
            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs_total} [Val]  ", 
                            leave=False, unit="batch", dynamic_ncols=True, ascii=True, file=sys.stdout)
            for x, y, fn in val_pbar:
                 try:
                     x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True, dtype=torch.long)
                 except Exception as e:
                     print(f"\n{COLOR_YELLOW}Warn: Skipping validation batch due to data loading error: {e}{COLOR_RESET}")
                     skipped_batches_val+=1
                     continue
                 try:
                    with torch.amp.autocast(device_type=device.type, enabled=use_amp, dtype=amp_dtype):
                        outputs = model(x)
                        if not torch.all(torch.isfinite(outputs)):
                            print(f"\n{COLOR_YELLOW}Warn: Non-finite outputs in validation epoch {epoch+1}. Skipping this batch.{COLOR_RESET}"); skipped_batches_val+=1; continue
                        loss = criterion(outputs, y)
                    loss_value = loss.item()
                    if not math.isfinite(loss_value):
                        print(f"\n{COLOR_YELLOW}Warn: Non-finite validation loss in epoch {epoch+1}. Skipping this batch.{COLOR_RESET}"); skipped_batches_val+=1; continue
                    
                    val_loss += loss_value * x.size(0) # Accumulate validation loss
                    v_preds.extend(outputs.argmax(1).cpu().numpy())
                    v_labs.extend(y.cpu().numpy())
                    v_files.extend(fn if isinstance(fn,(list,tuple)) else [fn])
                 except Exception as val_batch_e:
                     print(f"Error processing validation batch: {val_batch_e}")
                     skipped_batches_val+=1
                     continue

        # Calculate validation metrics
        num_val_samples_processed = len(v_labs)
        if num_val_samples_processed > 0:
            val_loss /= num_val_samples_processed # Average loss per sample
            val_acc = accuracy_score(v_labs, v_preds)
            # Calculate precision, recall, F1-score (weighted average for multi-class)
            prec, rec, f1, _ = precision_recall_fscore_support(v_labs, v_preds, average='weighted', zero_division=0, labels=list(range(num_classes)))
        else:
            val_loss, val_acc, prec, rec, f1 = float('nan'), 0.0, 0.0, 0.0, 0.0
            print(f"{COLOR_YELLOW}Warning: No valid validation samples processed in epoch {epoch+1}. Metrics for this epoch will be zero/NaN.{COLOR_RESET}")

        # Log validation metrics
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc*100.0) # Store as percentage
        history['precision'].append(prec)
        history['recall'].append(rec)
        history['f1'].append(f1)

        # --- Get GPU Memory Usage ---
        mem_used_gb = 0.0
        if device.type == 'cuda':
            try:
                # Use memory_reserved for a better estimate of total GPU footprint including cached memory
                mem_bytes = torch.cuda.memory_reserved(device)
                mem_used_gb = mem_bytes / (1024**3) # Convert bytes to Gigabytes
            except Exception:
                mem_used_gb = 0.0 # Default to 0 on error, without printing a warning every epoch
        history['gpu_mem_used'].append(mem_used_gb)

        # Print epoch summary in a formatted table row
        print(f"| {epoch+1:<7} | {tr_loss:<10.6f} | {tr_acc*100:<9.2f}% | {val_loss:<8.6f} | {val_acc*100:<7.2f}% | {prec:<9.4f} | {rec:<6.4f} | {f1:<8.6f} | {mem_used_gb:<9.2f} GB |")

        # Step epoch-based schedulers (e.g., ReduceLROnPlateau) if OneCycleLR is not used
        if scheduler and not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
             if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 
                 scheduler.step(f1) # Step with validation F1-score
             else: 
                 scheduler.step() # Step without a specific metric

        # --------- Early Stopping Logic and Best Model Checkpointing ----------
        current_f1 = f1 if isinstance(f1, float) and math.isfinite(f1) else 0.0
        # Check for significant F1 improvement (current_f1 must be > best_f1 + min_delta)
        f1_improved = current_f1 > best_metrics_summary['f1'] + min_delta

        if f1_improved:
            improvement_msg = f"{COLOR_GREEN}✅ F1 improved from {best_metrics_summary['f1']:.4f} to {current_f1:.4f}. Saving model...{COLOR_RESET}"
            print(improvement_msg)
            # Update best metrics summary
            best_metrics_summary.update(f1=current_f1, acc=val_acc*100.0, loss=val_loss, epoch=epoch+1)
            pat_ctr = 0 # Reset patience counter on improvement
            try:
                # Save the model checkpoint using `train_utils.save_checkpoint`
                # Save the model checkpoint using `train_utils.save_checkpoint`
                train_utils.save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    epoch=epoch,
                    best_metric_val=current_f1,
                    path=best_model_save_path,
                    model_config=model_config_for_checkpoint  # Add this parameter
                )
                
                # Store validation results from this best epoch for the final report
                best_epoch_val_preds_final = v_preds[:] 
                best_epoch_val_labels_final = v_labs[:] 
                best_epoch_val_filenames_final = v_files[:]
                
            except Exception as save_e: 
                print(f"  {COLOR_RED}❌ Error saving best model checkpoint: {save_e}{COLOR_RESET}")
                traceback.print_exc()
        else:
            pat_ctr += 1 # Increment patience counter if no improvement
            warning_msg = f"{COLOR_YELLOW}⚠️ No F1 improvement for {pat_ctr}/{patience} epochs. (Best F1: {best_metrics_summary['f1']:.4f}){COLOR_RESET}"
            print(warning_msg)
            if not (isinstance(f1, float) and math.isfinite(f1)): 
                print(f"   {COLOR_YELLOW}(Current epoch F1 score was NaN or invalid, treated as no improvement){COLOR_RESET}")

        # Check if early stopping criteria are met
        if pat_ctr >= patience:
            print(f"\n{COLOR_RED}🛑 Early stopping triggered. Patience limit ({patience} epochs) reached.{COLOR_RESET}")
            training_config['patience_counter'] = pat_ctr # Store the final patience counter value
            # If no improvement was ever recorded, use the last epoch's validation results for reporting
            if not best_epoch_val_labels_final and v_labs:
                print(f"{COLOR_YELLOW}   (Using validation results from last epoch {epoch+1} for final report as no prior improvement was recorded).{COLOR_RESET}")
                best_epoch_val_preds_final = v_preds[:]
                best_epoch_val_labels_final = v_labs[:]
                best_epoch_val_filenames_final = v_files[:]
            break # Exit the training loop

except RuntimeError as e:
    print(f"\n{COLOR_RED}--- Training Loop Runtime Error ---{COLOR_RESET}\n{COLOR_RED}{e}{COLOR_RESET}")
    traceback.print_exc()
except KeyboardInterrupt:
    print(f"\n{COLOR_YELLOW}--- Training Interrupted by User ---{COLOR_RESET}")
finally:
    # Assign final results for the next cells using the tracked best epoch data
    epochs_completed = epoch + 1 # +1 because epoch is 0-indexed
    training_config['epochs_completed'] = epochs_completed
    
    best_epoch_val_labels = best_epoch_val_labels_final
    best_epoch_val_preds = best_epoch_val_preds_final
    best_epoch_val_filenames = best_epoch_val_filenames_final
    best_f1 = best_metrics_summary['f1']
    best_acc = best_metrics_summary['acc']
    best_val_loss = best_metrics_summary['loss']

    # Clean up GPU memory and Python objects
    try:
        del model, criterion, optimizer, scheduler, train_loader, val_loader, train_dataset, val_dataset
        del inputs, labels, outputs, loss # Attempt to delete tensors if they still exist
    except NameError: # Catch if some variables were not defined due to early exit
        pass 
    except UnboundLocalError: # Catch if some variables were not defined due to early exit
        pass
    gc.collect() # Force garbage collection
    if device.type == 'cuda': 
        torch.cuda.empty_cache() # Clear CUDA cache

    print("\n" + "-"*99 + "|")
    print("🏁 Training loop finished.")

## Post-Training Analysis and Saving Results

After the training loop concludes (either by reaching maximum epochs or early stopping), this section handles the saving of comprehensive training history, logs, and plots. This ensures that the results are fully reproducible and accessible for later analysis and reporting.

In [None]:
print("\n" + "="*90)
print("📊 Post-Training: Saving Training History and Plots")
print("="*90)

# `output_files_for_summary` dictionary compiles all paths for easy reference in the final summary.
best_f1_achieved = best_metrics_summary.get('f1', -1.0)
output_files_for_summary = {
    'best_model_path': best_model_save_path if best_f1_achieved > -1.0 and os.path.exists(best_model_save_path) else None,
    'history_json_path': history_json_path,
    'metrics_plot_path': metrics_plot_path,
    'lr_plot_path': lr_plot_path,
    'gpu_plot_path': gpu_plot_path,
    'summary_yaml_path': summary_yaml_path,
    'validation_report_csv': validation_report_csv_path,
}
print(f"  Output file paths prepared.")

# --- Save Training History to JSON ---
print(f"\nAttempting to save training history to: {history_json_path}")
if history and history_json_path: 
    try:
        # Convert numpy types within the history dictionary for JSON serialization compatibility
        serializable_history = {}
        for key, value in history.items():
            if isinstance(value, list) and len(value) > 0 and isinstance(value[0], np.generic):
                 serializable_history[key] = [item.item() for item in value]
            elif isinstance(value, np.ndarray):
                 serializable_history[key] = value.tolist()
            else:
                 serializable_history[key] = value
        with open(history_json_path, 'w') as f:
            json.dump(serializable_history, f, indent=2)
        print(f"  History saved successfully.")
    except Exception as e:
        print(f"  Error saving history to JSON: {e}")
else:
    print("  Skipping history saving (history empty or path not defined).")

# --- Generate and Save Plots using `visualization_utils` ---
print("\nGenerating and saving plots...")
try:
    # Use the correct function names with _and_save suffix
    visualization_utils.plot_metrics_and_save(history, metrics_plot_path, model_name=training_config.get('model', 'Model'))
    visualization_utils.plot_lr_schedule_and_save(history, lr_plot_path)
    visualization_utils.plot_gpu_memory_and_save(history, gpu_plot_path)
    print("  Plots generated and saved successfully.")
except Exception as e:
    print(f"  Error generating or saving plots: {e}")
    traceback.print_exc()

# --- Save Validation Report CSV ---
print(f"\nAttempting to save detailed validation report to: {validation_report_csv_path}")
if best_epoch_val_labels and best_epoch_val_preds and validation_report_csv_path:
    try:
        # Remove class_names parameter - it's not in the function signature
        train_utils.save_validation_report(
            labels=best_epoch_val_labels, 
            predictions=best_epoch_val_preds,
            filenames=best_epoch_val_filenames,
            num_classes=num_classes,  # This parameter exists
            output_path=validation_report_csv_path
        )
        print("  Detailed validation report saved successfully.")
    except Exception as e:
        print(f"  Error saving validation report: {e}")
        traceback.print_exc()
else:
    print("  Skipping detailed validation report saving (no validation data or path not defined).")

print("\n--- Results Saving Complete ---")

## Test Set Evaluation (Optional)

This section performs a final evaluation of the best-trained model on a separate, unseen test dataset (if configured). This provides an unbiased assessment of the model's generalization performance. Metrics calculated on the test set are saved to a JSON file.

In [None]:
print("\n" + "="*60)
print("➡️ Starting Final Test Set Inference (Optional)")
print("="*60)

# --- Configuration and Paths for Test Data ---
test_data_configured = 'test_root_dir' in DATA_PATHS and 'test_csv_file' in DATA_PATHS
test_root_dir = DATA_PATHS.get('test_root_dir')
test_csv_file = DATA_PATHS.get('test_csv_file')
best_model_path_to_load = best_model_save_path # Path to the best model saved during training



# --- Check Prerequisites for Test Inference ---
best_model_exists = best_model_path_to_load and os.path.exists(best_model_path_to_load)
test_csv_exists = test_csv_file and os.path.exists(test_csv_file)
test_dir_exists = test_root_dir and os.path.isdir(test_root_dir)

print(f"Test Data Configured in config.py: {test_data_configured}")
print(f"Test Root Directory Exists: {test_dir_exists} ({test_root_dir})")
print(f"Test CSV File Exists: {test_csv_exists} ({test_csv_file})")
print(f"Best Model Checkpoint Exists: {best_model_exists} ({best_model_path_to_load})")

if test_data_configured and best_model_exists and test_csv_exists and test_dir_exists:
    print("\nAll prerequisites met for test set inference. Proceeding...")
    try:
        # --- Prepare Test Transforms (using validation transforms logic) ---
        print(f"Preparing test data transforms...")
        test_transforms_dict = data_utils.create_data_transforms(
             DATA_TRANSFORM_CONFIG, NORMALIZATION, create_train=False, create_val=True # Use validation transforms for test
        )
        test_transform = test_transforms_dict['val']

        # --- Load Test Dataset ---
        print(f"Loading test dataset from: {test_csv_file}...")
        _, _, test_dataset_final, test_class_names = data_utils.load_datasets(
            data_paths=DATA_PATHS,
            transforms={'val': test_transform},
            load_split='test'
        )

        final_class_names = training_config.get('class_names', test_class_names)
        num_classes = training_config.get('num_classes', len(final_class_names) if final_class_names else 0)

        if test_dataset_final and num_classes > 0:
            print(f"Test dataset loaded with {len(test_dataset_final)} samples.")
            # --- Create Test DataLoader ---
            print("Creating test dataloader...")
            # Use a potentially larger batch size for inference as gradients are not computed
            test_loader_config = DATALOADER_CONFIG.copy()
            # Set test_batch_size, or use existing batch_size * 2 for validation, then adjust
            # to prevent exceeding GPU memory for large models or high-res inputs.
            # The original code attempts to load at 256*2=512, which can be too large.
            # Let's use a more conservative default from DATALOADER_CONFIG itself for test
            test_loader_config['batch_size'] = test_loader_config.get('test_batch_size', test_loader_config.get('batch_size', 64) * 2)

            test_loader_config['shuffle'] = False # No need to shuffle for inference

            _, _, test_loader_final = data_utils.create_dataloaders(
                train_dataset=None,
                val_dataset=None,
                test_dataset=test_dataset_final,
                dataloader_config=test_loader_config,
                use_oversampler=False
            )

            if test_loader_final:
                # --- Load Best Model ---
                print(f"\nLoading best model from: {best_model_path_to_load}")
                # Ensure training_config, device, model_utils available
                device = training_config.get('device')
                
                # Corrected line: unpack the tuple returned by model_utils.get_model
                # We only need the model object, so use '_' for the second returned value (fast_backbone_param_names)
                inference_model, _ = model_utils.get_model(
                    model_name=training_config['model'], num_classes=num_classes,
                    dropout_rate=training_config['dropout'], drop_path_rate=training_config['drop_path_rate'],
                    finetune_blocks=training_config['finetune_blocks'], pretrained=False
                )
                inference_model = inference_model.to(device) # Move the model to the device
                
                checkpoint = torch.load(best_model_path_to_load, map_location=device)
                state_dict = checkpoint.get('model_state_dict', checkpoint)
                if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                inference_model.load_state_dict(state_dict)
                inference_model.eval()
                print("Model weights loaded successfully and set to eval() mode.")

                # --- Run Inference ---
                # Ensure train_utils available and run_inference works
                test_preds, test_labels = train_utils.run_inference(
                    model=inference_model, dataloader=test_loader_final, device=device,
                    num_classes=num_classes, class_names=final_class_names,
                    description="Test Set Inference"
                )

                # Check if inference returned valid results
                if test_preds is not None and test_labels is not None:
                    # --- Calculate and Print Metrics ---
                    print("\nCalculating test metrics...")
                    # Ensure train_utils available and calculate_and_print_metrics works
                    test_metrics = train_utils.calculate_and_print_metrics(
                        labels=test_labels, preds=test_preds, num_classes=num_classes,
                        class_names=final_class_names, results_title="Final Test Set Metrics"
                    )
                    # --- Save Test Metrics ---
                    # Create a dedicated path for test metrics using the same pattern
                    test_metrics_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_test_metrics.json")
                    print(f"\nAttempting to save test metrics to: {test_metrics_path}")
                    if test_metrics:
                        try:
                            # Convert numpy types before saving
                            if 'confusion_matrix' in test_metrics and isinstance(test_metrics['confusion_matrix'], np.ndarray): 
                                test_metrics['confusion_matrix'] = test_metrics['confusion_matrix'].tolist()
                            for key in ['binary_tn', 'binary_fp', 'binary_fn', 'binary_tp']:
                                if key in test_metrics: test_metrics[key] = int(test_metrics[key])
                            with open(test_metrics_path, 'w') as f: 
                                json.dump(test_metrics, f, indent=2)
                            print(f"  Test metrics saved successfully.")
                        except Exception as e: 
                            print(f"  Warning: Could not save test metrics to JSON: {e}")
                    else: 
                        print("  Skipping test metrics saving (metrics dictionary not available).")
                else:
                    print("\nSkipping metrics calculation and saving (inference failed to return results).")
            else:
                print("\nSkipping test inference: Failed to create test DataLoader.")
        elif not test_dataset_final:
            print("\nSkipping test inference: Failed to load test dataset.")
        else:
            print("\nSkipping test inference: Number of classes is zero or invalid.")
    except FileNotFoundError as fnf_err:
        print(f"\nSkipping test inference due to FileNotFoundError: {fnf_err}")
    except KeyError as key_err:
        print(f"\nSkipping test inference due to KeyError: {key_err}")
    except AttributeError as attr_err:
        print(f"\nSkipping test inference due to AttributeError: {attr_err}. Check if utility functions exist.")
        traceback.print_exc() # Added to help debug the actual AttributeError
    except Exception as e:
        print(f"\nAn unexpected error occurred during test inference: {e}")
        traceback.print_exc()
# Handle prerequisite failures
elif not test_data_configured:
    print("\nSkipping test inference: Test data not configured in DATA_PATHS.")
elif not best_model_exists:
    print(f"\nSkipping test inference: Best model file not found at {best_model_path_to_load} (Likely no improvement or file path issue).\n")
    print(f"  Expected path: {best_model_path_to_load}")
elif not test_csv_exists:
    print(f"\nSkipping test inference: Test CSV file missing or not found ({test_csv_file}).")
elif not test_dir_exists:
    print(f"\nSkipping test inference: Test root directory missing or invalid ({test_root_dir}).")


# Clean up:
del test_loader_final
torch.cuda.empty_cache()

print("\n🏁 Test Set Evaluation Complete.")

## Export Comprehensive Training Summary

This final section compiles a complete summary of the training run, including configuration parameters, best performance metrics, history highlights, dataset information, and file paths. This comprehensive report is exported to a YAML file, providing a single, human-readable record of the entire experiment.

In [None]:
print("\n" + "="*60)
print("📝 Exporting Comprehensive Training Summary")
print("="*60)

# Start with a deep copy of the `training_config` which contains most initial parameters
full_summary = copy.deepcopy(training_config)

# Add best performance metrics from the training loop
full_summary['best_metrics_achieved'] = {
    'f1_score': best_metrics_summary['f1'],
    'accuracy': best_metrics_summary['acc'],
    'val_loss': best_metrics_summary['loss'],
    'best_epoch': best_metrics_summary['epoch']
}

# Add a summary of training history (last and best values for key metrics)
if history:
    history_summary = {}
    for key, values in history.items():
        if len(values) > 0:
            history_summary[f"{key}_final"] = values[-1] # Last recorded value
            if key in ['val_acc', 'precision', 'recall', 'f1']:
                history_summary[f"{key}_best"] = max(values) # Max for performance metrics
            elif key in ['train_loss', 'val_loss']:
                history_summary[f"{key}_best"] = min(values) # Min for loss metrics
    full_summary['history_summary'] = history_summary
else:
    full_summary['history_summary'] = "No training history available."

# Add detailed dataset information
full_summary['dataset_info'] = {
    'train_csv_file': DATA_PATHS.get('train_csv_file', 'not_configured'),
    'val_csv_file': DATA_PATHS.get('val_csv_file', 'not_configured'),
    'test_csv_file': DATA_PATHS.get('test_csv_file', 'not_configured'),
    'train_root_dir': DATA_PATHS.get('train_root_dir', 'not_configured'),
    'val_root_dir': DATA_PATHS.get('val_root_dir', 'not_configured'),
    'test_root_dir': DATA_PATHS.get('test_root_dir', 'not_configured'),
    'initial_train_size': training_config.get('original_train_set_size', 0),
    'resampled_train_size': training_config.get('resampled_train_set_size', 0),
    'val_size': training_config.get('val_set_size', 0),
    'test_size': len(test_dataset_final) if 'test_dataset_final' in globals() and test_dataset_final else 0,
    'num_classes': num_classes,
    'class_names': training_config.get('class_names', [])
}

# Add data transforms configuration (directly from the config dict)
full_summary['data_transforms_config'] = DATA_TRANSFORM_CONFIG

# Add dataloader configuration (directly from the config dict)
full_summary['dataloader_config'] = DATALOADER_CONFIG

# Add normalization information (directly from the config dict)
full_summary['normalization_values'] = {
    'mean': NORMALIZATION['mean'],
    'std': NORMALIZATION['std']
}

# Add model configuration details
full_summary['model_configuration'] = {
    'name': training_config.get('model', 'unknown'),
    'dropout': training_config.get('dropout', 0.0),
    'drop_path_rate': training_config.get('drop_path_rate', 0.0),
    'finetune_blocks': training_config.get('finetune_blocks', -1),
    'pretrained_on_imagenet': True, # Assuming models are pretrained unless specified otherwise
    'optimizer_type': training_config.get('optimizer_type', 'unknown'),
    'scheduler_type': training_config.get('scheduler_type', 'None'),
    'loss_function_type': training_config.get('loss_function_type', 'unknown')
}

# Add file paths for all generated outputs
full_summary['output_file_paths'] = {
    'base_checkpoint_dir': CHECKPOINT_DIR,
    'best_model_checkpoint': best_model_save_path,
    'training_history_json': history_json_path,
    'metrics_plot_png': metrics_plot_path,
    'learning_rate_plot_png': lr_plot_path,
    'gpu_memory_plot_png': gpu_plot_path,
    'run_summary_yaml': summary_yaml_path,
    'validation_report_csv': validation_report_csv_path,
    'test_metrics_json': os.path.join(CHECKPOINT_DIR, f"{file_prefix}_test_metrics.json") if 'test_metrics' in locals() else None # Only add if test was run and metrics generated
}

# Add runtime information
full_summary['runtime_info'] = {
    'date_completed': datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    'device': str(device),
    'pytorch_version': torch.__version__,
    'peak_gpu_memory_gb': max(history.get('gpu_mem_used', [0])) if 'gpu_mem_used' in history else 0.0,
    'epochs_completed': training_config.get('epochs_completed', 0),
    'early_stopping_triggered': training_config.get('patience_counter', 0) >= training_config.get('patience', 15)
}

# Define a recursive function to sanitize numpy types for YAML serialization
def sanitize_for_yaml(obj):
    if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, 
                        np.uint8, np.uint16, np.uint32, np.uint64)):
        return int(obj)
    elif isinstance(obj, (np.float16, np.float32, np.float64)):
        if np.isnan(obj): return '.nan'
        if np.isinf(obj): return '.inf' if obj > 0 else '-.inf'
        return float(obj)
    elif isinstance(obj, (np.bool_)):
        return bool(obj)
    elif isinstance(obj, (np.ndarray,)):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: sanitize_for_yaml(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [sanitize_for_yaml(item) for item in obj]
    elif isinstance(obj, tuple):
        return str(obj) # Represent tuples as strings
    else:
        return obj

# Export the comprehensive summary to a YAML file
comprehensive_yaml_path = os.path.join(CHECKPOINT_DIR, f"{file_prefix}_comprehensive_summary.yaml")
try:
    sanitized_summary = sanitize_for_yaml(full_summary)
    with open(comprehensive_yaml_path, 'w') as f:
        yaml.dump(sanitized_summary, f, default_flow_style=False, sort_keys=False)
    print(f"✅ Comprehensive model summary exported to: {comprehensive_yaml_path}")
except Exception as e:
    print(f"❌ Error exporting comprehensive model summary: {e}")
    traceback.print_exc()

print("="*60)
print("🏁 Notebook Execution Finished.")