In [1]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
import copy 
import pandas as pd
import numpy as np
import cv2
import requests
from datetime import datetime
import shutil
import random
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.utils import save_image
from IPython.display import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.optim import lr_scheduler
import torch.optim as optim
from tqdm import tqdm
from torch import nn
from torchvision import models
from typing import List, Tuple, Optional, Dict, Any, Union
# Import the model weights - takes time
from torchvision.models import (
    ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights,
    EfficientNet_B0_Weights, EfficientNet_B1_Weights, EfficientNet_B2_Weights, EfficientNet_B3_Weights,
    EfficientNet_B4_Weights, EfficientNet_B5_Weights, EfficientNet_B6_Weights, EfficientNet_B7_Weights,
    EfficientNet_V2_S_Weights, EfficientNet_V2_M_Weights, EfficientNet_V2_L_Weights,
    ConvNeXt_Tiny_Weights, ConvNeXt_Small_Weights, ConvNeXt_Base_Weights, ConvNeXt_Large_Weights,
    DenseNet121_Weights, DenseNet161_Weights, DenseNet169_Weights, DenseNet201_Weights
)

# Seed function 
def seed_all(seed: int = 1234):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


## Resources used: 
[CNN Cheatsheet from Stanford CS 230](https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks)

[Image classification: ResNet vs EfficientNet vs EfficientNet_v2 vs Compact Convolutional Transformers](https://medium.com/@enrico.randellini/image-classification-resnet-vs-efficientnet-vs-efficientnet-v2-vs-compact-convolutional-c205838bbf49). The code is essentially a generalized adaptation of this with some extra bells and whistles. 

[PyTorch docs](https://docs.pytorch.org/tutorials/). Self-explanatory

Deepseek for plotting functions in Python and reformatting my poor python syntax. OOP is new to me so I had a lot of issues with initial attempts at this. 

## Goal
- The goal of this was to become familiar with OOP/python through an applied example of making a generalized transfer-learning workflow to apply existing CNN models to MRI image classification. I used two datasets for this, the first a collection of brain tumor MRIs to classify the malignancy, and the second a collection of Alzheimer's MRIs to classify disease severity ([found here](https://www.kaggle.com/datasets/alifatahi/multi-class-neurological-disorder-mcnd-dataset/data)). In exploring this, I quickly got lost in the weeds due to the complexity and diversity of approaches you can use to solve this classification problem and the current state of the field. In the end, I chose to ignore these, as the goal wasn't to make the most groundbreaking/novel approach to this problem (something I'm ill-equipped to do) but to learn how to work through it and structure code to do this using existing methods. 

# Transformations for the training and test sets 
- This will be more important for other datasets/real data testing)
- For the test dataset, use a simple transformation to resize and normalize the data before converting to a tensor
- For the training dataset, use a more complex transformation that introduces shifts and transformations without grossly distorting biology
- Note: Need to use ToTensorV2 at the end so they're the correct format

In [2]:
transform_test = A.Compose([
    A.Resize(height=224, width=224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
], seed = 1234)

transform_train = A.Compose([
    # Spatial transformations; flip, rotate, 
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.Affine(scale=(0.9, 1.1), translate_percent=0.05, rotate=(-15, 15), p=0.5),
    # Intensity, noise, and blur
    A.GaussNoise(std_range=(0, 0.01), p=0.3),
    A.Blur(blur_limit=3, p=0.2),
    # : Elastic transformations 
    A.ElasticTransform(alpha=1, sigma=50, p=0.3),
    # Normalize and convert to tensor
    A.Resize(height=224, width=224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
], seed = 1234)

## Dataset and Model structures for consistent inputs to the modeling
### Dataset loaders
- AlbumentationsDataset: Dataset class to wrap the Torch Imagefolder class and apply the transformations described above
- create_datasets: Function that takes in a base directory (which itself contains the training, validation, and test directories) to feed into the albumentation transformations before returning a tuple of these datasets for modeling
- create_dataloaders: Function to use the built-in Dataloader util to load the datasets from the base_dir and shuffle them, returning a tuple of the dataloader objects
### Model config/loaders 
- ModelConfig: Class that holds various points of model/dataset information that are important, including
    - Number of classes for classification
    - Model name for lookup/comparison with the dictionary
    - Number of nodes of the hidden layer
    - Dropout rate in the classifier head
    - Information about whether to freeze weights and how to unfreeze
    - Whether to load ImageNet weights or not
    - Whether to use the custom head or not
- ImageClassificationModel: Class that does the following
    - Holds a dictionary of usable models (excluding the ConvNeXt models for now)
    - Validate the input model is in the dictionary
    - Store information about the number of layers for unfreezing
    - Defines the classifier head, which allows these pretrained models to be used for specific purposes of classifying our images of interest
        - Individual classifier heads for the different built-in models
        - Generalized classifier heads for other models that output 1D or 2D features
    - Helper/debugging functions to troubleshoot model specific issues (mostly useful when using new models) 

In [3]:
# Make a dataset class to hold the results of the albumentations transformation
class AlbumentationsDataset(Dataset):
    # 3 components to all dataset classes: _init_ (constructor), _len_ (gets length of the set), and _getitem_ (pulls individual instances from dataset)
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform

        
        # Get samples (image paths) and targets (labels) from the ImageFolder
        self.samples = self.image_folder.samples
        self.targets = self.image_folder.targets

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Get the path and label for the given index
        path, label = self.samples[idx]
        
        # Load the image as a NumPy array (BGR) and convert to RGB
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply albumentations transforms 
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        return image, label

def create_datasets(base_dir: str) -> Tuple[Dataset, Dataset, Dataset]:
    # Define the paths for each dataset
    train_dir = os.path.join(base_dir, 'train')
    test_dir = os.path.join(base_dir, 'test')
    validation_dir = os.path.join(base_dir, 'val')

    # Create ImageFolder datasets
    base_train_dataset = ImageFolder(root=train_dir, transform=None)
    base_test_dataset = ImageFolder(root=test_dir, transform=None)
    base_val_dataset = ImageFolder(root=validation_dir, transform=None)

    # Wrap the ImageFolder datasets with the custom AlbumentationsDataset from above
    train_dataset = AlbumentationsDataset(
        image_folder=base_train_dataset,
        transform=transform_train
    )
    
    test_dataset = AlbumentationsDataset(
        image_folder=base_test_dataset,
        transform=transform_test
    )

    # Apply the test transformation to the validation dataset 
    validation_dataset = AlbumentationsDataset(
        image_folder=base_val_dataset,
        transform=transform_test
    )
    
    return train_dataset, test_dataset, validation_dataset

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def create_dataloaders(base_dir: str, batch_size: int = 32, 
                      num_workers: int = 4, seed: int = 1234) -> Tuple[DataLoader, DataLoader, Optional[DataLoader], List[str]]:
    # Set the seed 
    if seed is not None:
        g = torch.Generator()
        g.manual_seed(seed)
    else:
        g = None
    
    # Call the create_datasets function
    train_dataset, test_dataset, val_dataset = create_datasets(base_dir)
    
    # Get class names from the dataset
    # ImageFolder stores class names in .classes attribute
    class_names = train_dataset.image_folder.classes
    
    # Define the DataLoader arguments
    loader_args = {
        "batch_size": batch_size,
        "num_workers": num_workers,
        "pin_memory": True, 
        "generator": g,
        "worker_init_fn": seed_worker if seed is not None else None,
    }
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
    val_loader = DataLoader(val_dataset, shuffle=False, **loader_args) 
    
    test_loader = None
    if test_dataset:
        test_loader = DataLoader(test_dataset, shuffle=False, **loader_args)

    return train_loader, val_loader, test_loader, class_names

# Use the dataclass decorator so that the _init_ is generated for the ModelConfig class
@dataclass
class ModelConfig:
    num_classes: int
    model_name: str
    n_nodes: int = 512
    dropout: float = 0.2
    freeze_layers: bool = False
    freeze_strategy: str = "all"
    pretrained: bool = True
    image_size: int = 224
    custom_head: bool = True
    debug: bool = False

class ImageClassificationModel(nn.Module):
    # Create a dictionary of the model configurations, using the default weights (could be turned into .yaml or external dictionary in the future)
    MODEL_CONFIGS = {
        # ResNet Models
        'resnet18': {'base_model': models.resnet18, 'weights': ResNet18_Weights.DEFAULT, 'feature_dim': 512, 'classifier_attr': 'fc', 'pooling_attr': 'avgpool'},
        'resnet34': {'base_model': models.resnet34, 'weights': ResNet34_Weights.DEFAULT, 'feature_dim': 512, 'classifier_attr': 'fc', 'pooling_attr': 'avgpool'},
        'resnet50': {'base_model': models.resnet50, 'weights': ResNet50_Weights.DEFAULT, 'feature_dim': 2048, 'classifier_attr': 'fc', 'pooling_attr': 'avgpool'},
        'resnet101': {'base_model': models.resnet101, 'weights': ResNet101_Weights.DEFAULT, 'feature_dim': 2048, 'classifier_attr': 'fc', 'pooling_attr': 'avgpool'},
        'resnet152': {'base_model': models.resnet152, 'weights': ResNet152_Weights.DEFAULT, 'feature_dim': 2048, 'classifier_attr': 'fc', 'pooling_attr': 'avgpool'},
        
        # EfficientNet Models
        'efficientnet_b0': {'base_model': models.efficientnet_b0, 'weights': EfficientNet_B0_Weights.DEFAULT, 'feature_dim': 1280, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b1': {'base_model': models.efficientnet_b1, 'weights': EfficientNet_B1_Weights.DEFAULT, 'feature_dim': 1280, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b2': {'base_model': models.efficientnet_b2, 'weights': EfficientNet_B2_Weights.DEFAULT, 'feature_dim': 1408, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b3': {'base_model': models.efficientnet_b3, 'weights': EfficientNet_B3_Weights.DEFAULT, 'feature_dim': 1536, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b4': {'base_model': models.efficientnet_b4, 'weights': EfficientNet_B4_Weights.DEFAULT, 'feature_dim': 1792, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b5': {'base_model': models.efficientnet_b5, 'weights': EfficientNet_B5_Weights.DEFAULT, 'feature_dim': 2048, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b6': {'base_model': models.efficientnet_b6, 'weights': EfficientNet_B6_Weights.DEFAULT, 'feature_dim': 2304, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_b7': {'base_model': models.efficientnet_b7, 'weights': EfficientNet_B7_Weights.DEFAULT, 'feature_dim': 2560, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_v2_s': {'base_model': models.efficientnet_v2_s, 'weights': EfficientNet_V2_S_Weights.DEFAULT, 'feature_dim': 1280, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_v2_m': {'base_model': models.efficientnet_v2_m, 'weights': EfficientNet_V2_M_Weights.DEFAULT, 'feature_dim': 1280, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        'efficientnet_v2_l': {'base_model': models.efficientnet_v2_l, 'weights': EfficientNet_V2_L_Weights.DEFAULT, 'feature_dim': 1280, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool'},
        
        # # ConvNeXt Models
        # NOTE: As currently implemented with the custom classifier head, the ConvNeXt models don't work because of their architecture. Will work to fix but it isn't a priority 
        'convnext_tiny': {'base_model': models.convnext_tiny, 'weights': ConvNeXt_Tiny_Weights.DEFAULT, 'feature_dim': 768, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool', 'classifier_is_sequential': True},
        'convnext_small': {'base_model': models.convnext_small, 'weights': ConvNeXt_Small_Weights.DEFAULT, 'feature_dim': 768, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool', 'classifier_is_sequential': True},
        'convnext_base': {'base_model': models.convnext_base, 'weights': ConvNeXt_Base_Weights.DEFAULT, 'feature_dim': 1024, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool', 'classifier_is_sequential': True},
        'convnext_large': {'base_model': models.convnext_large, 'weights': ConvNeXt_Large_Weights.DEFAULT, 'feature_dim': 1536, 'classifier_attr': 'classifier', 'pooling_attr': 'avgpool', 'classifier_is_sequential': True},
        
        # DenseNet Models
        'densenet121': {'base_model': models.densenet121, 'weights': DenseNet121_Weights.DEFAULT, 'feature_dim': 1024, 'classifier_attr': 'classifier', 'pooling_attr': None, 'requires_pooling': False},
        'densenet161': {'base_model': models.densenet161, 'weights': DenseNet161_Weights.DEFAULT, 'feature_dim': 2208, 'classifier_attr': 'classifier', 'pooling_attr': None, 'requires_pooling': False},
        'densenet169': {'base_model': models.densenet169, 'weights': DenseNet169_Weights.DEFAULT, 'feature_dim': 1664, 'classifier_attr': 'classifier', 'pooling_attr': None, 'requires_pooling': False},
        'densenet201': {'base_model': models.densenet201, 'weights': DenseNet201_Weights.DEFAULT, 'feature_dim': 1920, 'classifier_attr': 'classifier', 'pooling_attr': None, 'requires_pooling': False}
    }

    def __init__(self, cfg: Dict[str, Any]):
        # Use super() to extend the behavior of the parent constructor
        super().__init__()
        # Create the 
        self.config = self._create_config(cfg)
        self.model = self.get_model(self.config.model_name)
        
        # Store layer information for unfreezing
        self._setup_layer_tracking()
        
        if self.config.debug:
            self._debug_info()

    # Validate the input configuration and create the ModelConfig object
    def _create_config(self, cfg_dict: dict) -> ModelConfig:
        # Check for the required inputs (name and number of classes) 
        required = ["num_classes", "model_name"]
        model_cfg = cfg_dict.get("model", {})
        
        for key in required:
            if key not in model_cfg:
                raise ValueError(f"Missing required config key: model.{key}")
        
        data_cfg = cfg_dict.get("data", {})

        # Create the ModelConfig w/ defaults 
        return ModelConfig(
            num_classes=model_cfg["num_classes"],
            model_name=model_cfg["model_name"],
            n_nodes=model_cfg.get("n_nodes", 512),
            dropout=model_cfg.get("dropout", 0.2),
            freeze_layers=model_cfg.get("freeze_layers", False),
            freeze_strategy=model_cfg.get("freeze_strategy", "all"),
            pretrained=model_cfg.get("pretrained", True),
            image_size=data_cfg.get("size", 224),
            custom_head=model_cfg.get("custom_head", True),
            debug=cfg_dict.get("debug", False)
        )

    # Store the layer names for unfreezing later 
    def _setup_layer_tracking(self):
        self.layer_names = []
        self.param_layers = []
        
        for name, param in self.model.named_parameters():
            self.layer_names.append(name)
            self.param_layers.append(param)
        
        # Store total layers count for unfreezing strategies
        self.total_layers = len(self.param_layers)

    def forward(self, x):
        return self.model(x)

    def freeze_layers_base_model(self, model):
        for param in model.parameters():
            # Remove the gradient requirement to freeze the layer 
            param.requires_grad = False

    def unfreeze_layers(self, layer_start_unfreeze: Optional[int] = None):
        if layer_start_unfreeze is not None:
            print(f"Unfreezing layers starting from index {layer_start_unfreeze}")
            for i, param in enumerate(self.param_layers):
                if i >= layer_start_unfreeze:
                    param.requires_grad = True
                    if self.config.debug:
                        print(f"Unfrozen: {self.layer_names[i]}")
        else:
            # Unfreeze all layers
            print("Unfreezing all layers")
            for param in self.model.parameters():
                param.requires_grad = True

    def get_trainable_parameters_count(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def get_model(self, name_pretrained_model: str) -> nn.Module:
        # Check if the model name is in the dictonary above 
        if name_pretrained_model not in self.MODEL_CONFIGS:
            raise ValueError(f"Unsupported model: {name_pretrained_model}")
        
        config = self.MODEL_CONFIGS[name_pretrained_model]
        
        # Load pretrained model
        if self.config.pretrained:
            base_model = config['base_model'](weights=config['weights'])
        else:
            base_model = config['base_model'](weights=None)
        
        # Freeze layers if configured
        if self.config.freeze_layers:
            print("Freezing layers of pretrained model")
            self.freeze_layers_base_model(base_model)
        
        # Replace the classifier head based on architecture
        classifier_attr = config['classifier_attr']
        
        # Match the input model with the heads defined below 
        if name_pretrained_model.startswith('resnet'):
            custom_head = self._create_resnet_head(config['feature_dim'])
        elif name_pretrained_model.startswith('efficientnet'):
            custom_head = self._create_efficientnet_head(config['feature_dim'])
        elif name_pretrained_model.startswith('convnext'):
            custom_head = self._create_convnext_head(config['feature_dim'])
        elif name_pretrained_model.startswith('densenet'):
            custom_head = self._create_densenet_head(config['feature_dim'])
        else:
            # Default to 1D head for other models (can change if needed) 
            custom_head = self._create_1d_classifier_head(config['feature_dim'])

        # Set the final layer of the models to be the custom head using setattr, corresponding to the layers stored in the dictionary 
        setattr(base_model, classifier_attr, custom_head)
        return base_model

    # ResNet, EfficientNet, and DenseNet all do global average pooling so the custom heads are just linear layers + ReLU
    # In theory these could all be the same, but I was getting errors when I had one head for these models and one head for ConvNeXt
    # I think it was because the old version didn't have the setattr above so the ResNet final layers weren't being replace properly with the custom head 
    def _create_resnet_head(self, feature_dim: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Linear(feature_dim, self.config.n_nodes),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.n_nodes, self.config.num_classes)
        )
    
    def _create_efficientnet_head(self, feature_dim: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Linear(feature_dim, self.config.n_nodes),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.n_nodes, self.config.num_classes)
        )
        
    def _create_densenet_head(self, feature_dim: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Linear(feature_dim, self.config.n_nodes),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.n_nodes, self.config.num_classes)
        )
    
    # ConvNeXt outputs a 4D tensor w/ batch, channels, height, and width, so it requires pooling and flattening before going to linear layers 
    def _create_convnext_head(self, feature_dim: int) -> nn.Sequential:
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.LayerNorm(feature_dim, eps=1e-6),
            nn.Linear(feature_dim, self.config.n_nodes),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.n_nodes, self.config.num_classes)
        )
    
    # General 1D head (for models that output 1D features)
    def _create_1d_classifier_head(self, feature_dim: int) -> nn.Sequential:
        return nn.Sequential(
            nn.Linear(feature_dim, self.config.n_nodes),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.n_nodes, self.config.num_classes)
        )
    
    # General 2D head (for models that need pooling)
    def _create_2d_classifier_head(self, feature_dim: int) -> nn.Sequential:
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(feature_dim, self.config.n_nodes),
            nn.ReLU(),
            nn.Dropout(self.config.dropout),
            nn.Linear(self.config.n_nodes, self.config.num_classes)
        )

    # General model info debugging 
    def _debug_info(self):
        print(f"\n=== Model Configuration ===")
        print(f"Architecture: {self.config.model_name}")
        print(f"Num classes: {self.config.num_classes}")
        print(f"Hidden layer size: {self.config.n_nodes}")
        print(f"Dropout: {self.config.dropout}")
        print(f"Frozen layers: {self.config.freeze_layers}")
        
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = self.get_trainable_parameters_count()
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Total layers: {self.total_layers}")
        print("=" * 30)

    # Model structure debugging (AI written when I was having trouble with the custom heads above d/t the naming convention of ResNet's final layer 
    def debug_model_structure(self, model_name: str):
        config = self.MODEL_CONFIGS[model_name]
        
        if self.config.pretrained:
            base_model = config['base_model'](weights=config['weights'])
        else:
            base_model = config['base_model'](weights=None)
        
        print(f"\n=== Debugging {model_name} ===")
        print(f"Classifier attribute: {config['classifier_attr']}")
        
        classifier = getattr(base_model, config['classifier_attr'])
        print(f"Classifier type: {type(classifier)}")
        print(f"Classifier structure: {classifier}")
        
        # Check if it's sequential
        if isinstance(classifier, nn.Sequential):
            print("\nSequential layers:")
            for i, layer in enumerate(classifier):
                print(f"  [{i}] {layer.__class__.__name__}: {layer}")
        
        print(f"\nFeature dimension: {config['feature_dim']}")
        print("=" * 40)

## Definitions and classes below are components of the transfer learning function, called by train_model 

### EarlyStopping
- The EarlyStopping class helps to prevent overfitting in the model by monitoring the validation loss for each epoch and ending the training if the validation loss increases
- This prevents the model from learning from noise within the training set and overfitting to the training set 
- This improves generalizability and performance on external data 

In [4]:
class EarlyStopping:
    
    # Set the defaults, most importantly the counter and the overall patience 
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_weights = None

    # Set the best loss to the validation loss if not present, and check of the validation loss for the current epoch is greater than the greatest loss so far 
    def __call__(self, val_loss, model=None):
        if self.best_loss is None:
            self.best_loss = val_loss
            if model and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                if model and self.restore_best_weights and self.best_weights:
                    model.load_state_dict(self.best_weights)
        else:
            self.best_loss = val_loss
            self.counter = 0
            if model and self.restore_best_weights:
                self.best_weights = model.state_dict().copy()

### TrainingMetrics 
- The TrainingMetrics class stores information about the accuracy, loss, learning rate, and precision/recall for a given epoch

In [5]:
class TrainingMetrics:
    def __init__(self):
        # Basic metrics
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []
        self.learning_rates = []
        self.best_accuracy = 0.0
        
        # Comprehensive metrics
        self.val_auroc_macro = []
        self.val_auroc_micro = []
        self.val_f1_macro = []
        self.val_f1_micro = []
        self.val_precision_macro = []
        self.val_recall_macro = []
        self.val_f1_per_class = []  # List of dictionaries
        
        # Best epoch storage
        self.best_epoch = -1
    
    def update(self, train_loss, train_acc, val_loss, val_acc, lr,
               auroc_macro=None, auroc_micro=None,
               f1_macro=None, f1_micro=None,
               precision_macro=None, recall_macro=None,
               f1_per_class=None):
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        self.learning_rates.append(lr)
        
        if auroc_macro is not None:
            self.val_auroc_macro.append(auroc_macro)
        if auroc_micro is not None:
            self.val_auroc_micro.append(auroc_micro)
        if f1_macro is not None:
            self.val_f1_macro.append(f1_macro)
        if f1_micro is not None:
            self.val_f1_micro.append(f1_micro)
        if precision_macro is not None:
            self.val_precision_macro.append(precision_macro)
        if recall_macro is not None:
            self.val_recall_macro.append(recall_macro)
        if f1_per_class is not None:  # Fixed parameter name
            self.val_f1_per_class.append(f1_per_class)
        
        # Update best accuracy and epoch
        if val_acc > self.best_accuracy:
            self.best_accuracy = val_acc
            self.best_epoch = len(self.val_accuracies) - 1  # Current epoch index
    
    def get_summary(self) -> Dict[str, float]:
        # Return the best/final values or 0 if not run 
        return {
            'best_val_accuracy': self.best_accuracy,
            'final_train_loss': self.train_losses[-1] if self.train_losses else 0,
            'final_val_loss': self.val_losses[-1] if self.val_losses else 0,
            'final_train_acc': self.train_accuracies[-1] if self.train_accuracies else 0,
            'final_val_acc': self.val_accuracies[-1] if self.val_accuracies else 0
        }

### unfreeze_model_layers
- By default the model layers are all unfrozen. When we run the training function we will freeze all of the model layers, which reduces the computational cost that would come from training each of the individual layers of all of the models
- When we freeze the layers, the weights won't be updated during the backpropagation step, meaning the representations already trained are being used
- Since we're tweaking image recognition models to a specific task, this makes sense because it allows us to fine tune the final layers (which are more task-specific) while preserving the initial layers which contain foundational image features like texture, edges, etc.
- In the future, this could be improved by being more specific with the conditional freezing using pytorch hooks on specific layers to freeze/unfreeze if the gradient magnitude changes by a certain threshold

In [6]:
def unfreeze_model_layers(model, epoch_start_unfreeze, layer_start_unfreeze, current_epoch):
    
    # Check if the model is frozen or not, and if the current epoch is when we would unfreeze if the model is frozen 
    if epoch_start_unfreeze is not None and current_epoch >= epoch_start_unfreeze:
        print("Unfreezing base model weights")

        # Check if we're unfreezing selective layers or the whole thing by changing how PyTorch is monitoring the gradient flow (param.requires_grad)
        if layer_start_unfreeze is not None:
            print(f"Unfreezing layers >= {layer_start_unfreeze}")
            # Unfreeze only layers ≥ layer_start_unfreeze
            for i, (name, param) in enumerate(model.named_parameters()):
                if i >= layer_start_unfreeze:
                    param.requires_grad = True
        else:
            # Unfreeze all layers
            print("Unfreezing all model layers")
            for param in model.parameters():
                param.requires_grad = True
                
        return False
        
    return True

### save_checkpoint
- Straightforward, saves model information/parameters through the epochs 

In [7]:
def save_checkpoint(model, optimizer, scheduler, epoch, metrics, is_best, checkpoint_dir):
    
    checkpoint = {
        # Save the epoch number 
        'epoch': epoch,
        # Save the model state (stored in the state_dict), which contains infor about both the model and optimizer 
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        # Save model metadata 
        'model_metadata': {
            'model_name': model.config.model_name,
            'num_classes': model.config.num_classes,
            'feature_dim': getattr(model, 'feature_dim', 'N/A'),
            'total_params': sum(p.numel() for p in model.parameters()),
            'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
        },
        # Save information about the model performance 
        'metrics': {
            'best_accuracy': metrics.best_accuracy,
            'current_val_accuracy': metrics.val_accuracies[-1] if metrics.val_accuracies else 0,
            'current_train_accuracy': metrics.train_accuracies[-1] if metrics.train_accuracies else 0,
        },
        # Save information about the model training (total epochs, early stopping)
        'training_info': {
            'timestamp': datetime.now().isoformat(),
            'total_epochs': len(metrics.train_losses),
            'early_stopped': len(metrics.train_losses) < getattr(model.config, 'num_epochs', 100)
        }
    }
    
    # Save latest model 
    latest_path = os.path.join(checkpoint_dir, 'latest.pth')
    torch.save(checkpoint, latest_path)
    
    # Save best if current epoch is the best 
    if is_best:
        best_path = os.path.join(checkpoint_dir, 'best.pth')
        torch.save(checkpoint, best_path)
        print(f"Best model updated and saved to: {best_path}")
    
    return latest_path

### train_epoch
- Single training epoch function, returns the loss and accuracy of that individual epoch
- For each batch in the dataloader, do the following 
  - Move batch to device
  - Zero gradients (optimizer.zero_grad())
  - Forward pass (model(inputs))
  - Compute loss (criterion(outputs, targets))
  - Backward pass (loss.backward())
  - Update weights (optimizer.step())
  - Accumulate batch metrics

In [8]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0

    # Make a progress bar with tqdm to make it pretty 
    progress_bar = tqdm(dataloader, desc="Training")
    for inputs, labels in progress_bar:
        # Load the data to the CPU/GPU (if I'm lucky enough to have one available)
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Clear the gradients so they dopn't accumulate in buffers
        optimizer.zero_grad()
        
        # Forward pass through the model 
        outputs = model(inputs)
        # Calculate the loss (CrossEntropyLoss)
        loss = criterion(outputs, labels)
        
        # Backward pass
        # Calculate the gradient
        loss.backward()
        # Update the model parameters if unfrozen
        optimizer.step()
        
        # Calculate accuracy by taking the correct outputs and summing 
        _, preds = torch.max(outputs, 1)
        correct = (preds == labels).sum().item()
        
        # Update running metrics by summing the total losses
        running_loss += loss.item() * inputs.size(0)
        running_corrects += correct
        total_samples += inputs.size(0)
        
        # Update progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{correct/inputs.size(0):.4f}'
        })

    # Calcualte the epoch-wide weighted averages
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects / total_samples
    
    return epoch_loss, epoch_acc


### validate_epoch_with_metrics
- Similar to the training step, but is just a forward pass through the model without updating neuron weights for the validation step
- Also stores information about the validation accuracy for given epochs

In [9]:
def validate_epoch_with_metrics(model, dataloader, criterion, device, class_names=None):
    model.eval()
    total_loss, total_correct = 0, 0
    total_samples = 0
    
    # Store all predictions for metrics computation
    all_true_labels = []
    all_pred_probs = []
    all_pred_labels = []
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item() * inputs.size(0)
            
            # Get probabilities and predictions
            probs = torch.nn.functional.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            total_correct += preds.eq(targets).sum().item()
            total_samples += inputs.size(0)
            
            # Store for metrics computation
            all_true_labels.append(targets.cpu())
            all_pred_probs.append(probs.cpu())
            all_pred_labels.append(preds.cpu())
    
    # Concatenate all batches
    all_true = torch.cat(all_true_labels, dim=0).numpy()
    all_probs = torch.cat(all_pred_probs, dim=0).numpy()
    all_preds = torch.cat(all_pred_labels, dim=0).numpy()
    
    # Compute basic metrics
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    
    # Compute class-specific (AUROC) if class_names provided using the helper function below
    comp_metrics = {}
    if class_names is not None:
        comp_metrics = compute_comprehensive_metrics(
            all_true, all_probs, all_preds, class_names
        )
    
    return avg_loss, accuracy, comp_metrics

### compute_comprehensive_metrics 
- Compute and store metrics of the model performance, such as AUROC, F1 scores, precision and recall, and per-class probability 

In [10]:
def compute_comprehensive_metrics(true_labels, pred_probs, pred_labels, class_names):
    # import sklearn when running
    from sklearn.metrics import (roc_auc_score, f1_score, precision_score, 
                                 recall_score, confusion_matrix)
    # Pull the total number of classes and instantiate a dictionary to hold the metrics 
    n_classes = len(class_names)
    metrics = {}

    # Use sklearn functions for calculations of AUROC and F1 scores 
    # AUROC calculation
    # Binary classes 
    if n_classes == 2:
        metrics['auroc_macro'] = roc_auc_score(true_labels, pred_probs[:, 1])
        metrics['auroc_micro'] = metrics['auroc_macro']
    # Multi-class AUROC
    else:
        from sklearn.preprocessing import label_binarize
        true_binary = label_binarize(true_labels, classes=range(n_classes))
        # Average of per-class AUROCs
        metrics['auroc_macro'] = roc_auc_score(true_binary, pred_probs, 
                                              average='macro', multi_class='ovr')
        # One-versus-all binarization AUROC
        metrics['auroc_micro'] = roc_auc_score(true_binary, pred_probs, 
                                              average='micro', multi_class='ovr')
    
    # F1 Scores
    metrics['f1_macro'] = f1_score(true_labels, pred_labels, average='macro')
    metrics['f1_micro'] = f1_score(true_labels, pred_labels, average='micro')
    
    # Precision and Recall
    metrics['precision_macro'] = precision_score(true_labels, pred_labels, average='macro')
    metrics['recall_macro'] = recall_score(true_labels, pred_labels, average='macro')
    
    # Per-class metrics
    metrics['per_class_f1'] = dict(zip(class_names, 
        f1_score(true_labels, pred_labels, average=None)))
    metrics['per_class_precision'] = dict(zip(class_names,
        precision_score(true_labels, pred_labels, average=None)))
    metrics['per_class_recall'] = dict(zip(class_names,
        recall_score(true_labels, pred_labels, average=None)))
    
    # Confusion matrix
    metrics['confusion_matrix'] = confusion_matrix(true_labels, pred_labels)
    
    return metrics

### plot_learning_curves
- Plot the learning curves for the epochs (done every 5 epochs)

In [11]:
def plot_learning_curves(epochs, metrics, save_path):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    ax1.plot(range(1, epochs + 1), metrics.train_losses, label='Training Loss')
    ax1.plot(range(1, epochs + 1), metrics.val_losses, label='Validation Loss')
    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy curves
    ax2.plot(range(1, epochs + 1), metrics.train_accuracies, label='Training Accuracy')
    ax2.plot(range(1, epochs + 1), metrics.val_accuracies, label='Validation Accuracy')
    ax2.set_title('Accuracy Curves')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    # Learning rate
    ax3.plot(range(1, epochs + 1), metrics.learning_rates, label='Learning Rate', color='red')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_xlabel('Epochs')
    ax3.set_ylabel('Learning Rate')
    ax3.legend()
    ax3.grid(True)
    ax3.set_yscale('log')
    
    # Combined plot
    ax4.plot(metrics.train_losses, metrics.train_accuracies, 'o-', label='Training', alpha=0.7)
    ax4.plot(metrics.val_losses, metrics.val_accuracies, 'o-', label='Validation', alpha=0.7)
    ax4.set_title('Loss vs Accuracy')
    ax4.set_xlabel('Loss')
    ax4.set_ylabel('Accuracy')
    ax4.legend()
    ax4.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Learning curves saved to: {save_path}")

### plot_metrics
- Plot the evaluation metrics of the model performance, display, and save to file 
  - Loss curves
  - Accuracy curves
  - AUROC curves
  - F1 score curves
  - Precision-Recall curves
  - Learning rate (mostly unnecessary)

In [12]:
def plot_metrics(metrics, class_names, save_path=None):
    n_epochs = len(metrics.val_losses)
    
    # Create figure with subplots
    fig, axes = plt.subplots(3, 2, figsize=(15, 12))
    axes = axes.flatten()
    
    # 1. Loss Curves
    axes[0].plot(range(1, n_epochs + 1), metrics.train_losses, label='Train')
    axes[0].plot(range(1, n_epochs + 1), metrics.val_losses, label='Val')
    axes[0].set_title('Loss Curves')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 2. Accuracy Curves
    axes[1].plot(range(1, n_epochs + 1), metrics.train_accuracies, label='Train')
    axes[1].plot(range(1, n_epochs + 1), metrics.val_accuracies, label='Val')
    axes[1].axhline(y=metrics.best_accuracy, color='r', linestyle='--', 
                   label=f'Best: {metrics.best_accuracy:.4f}')
    axes[1].set_title('Accuracy Curves')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # 3. AUROC Curves
    if metrics.val_auroc_macro and len(metrics.val_auroc_macro) == n_epochs:
        axes[2].plot(range(1, n_epochs + 1), metrics.val_auroc_macro, label='Macro')
        axes[2].plot(range(1, n_epochs + 1), metrics.val_auroc_micro, label='Micro')
        axes[2].set_title('Validation AUROC')
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('AUROC')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
    
    # 4. F1 Score Curves
    if metrics.val_f1_macro and len(metrics.val_f1_macro) == n_epochs:
        axes[3].plot(range(1, n_epochs + 1), metrics.val_f1_macro, label='Macro')
        axes[3].plot(range(1, n_epochs + 1), metrics.val_f1_micro, label='Micro')
        axes[3].set_title('Validation F1 Score')
        axes[3].set_xlabel('Epoch')
        axes[3].set_ylabel('F1 Score')
        axes[3].legend()
        axes[3].grid(True, alpha=0.3)
    
    # 5. Precision-Recall Curves
    if (metrics.val_precision_macro and metrics.val_recall_macro and
        len(metrics.val_precision_macro) == n_epochs):
        axes[4].plot(range(1, n_epochs + 1), metrics.val_precision_macro, label='Precision')
        axes[4].plot(range(1, n_epochs + 1), metrics.val_recall_macro, label='Recall')
        axes[4].set_title('Validation Precision & Recall (Macro)')
        axes[4].set_xlabel('Epoch')
        axes[4].set_ylabel('Score')
        axes[4].legend()
        axes[4].grid(True, alpha=0.3)
    
    # 6. Learning Rate
    axes[5].plot(range(1, n_epochs + 1), metrics.learning_rates, color='red')
    axes[5].set_title('Learning Rate Schedule')
    axes[5].set_xlabel('Epoch')
    axes[5].set_ylabel('Learning Rate')
    axes[5].set_yscale('log')
    axes[5].grid(True, alpha=0.3)
    
    # Hide unused subplots (if any)
    for i in range(6, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

### train_model
- The actual training function
- Initializes metrics for storing model and validation information (validation loss)
- Move the model to the GPU and set the freeze state
- Run the individual epoch steps
  - Print the epoch number
  - Check the epoch number and unfreeze if the number is higher than the unfreeze starting number
    - For unfrozen steps, recreate the optimizer and preserve the learning rate. Can adjust in the future to have different learning rates for the classifier and base layers of the model
  - Run train_epoch (described above)
  - Run validate_epoch (also described above)
  - Get the learning rate and store
  - Update the best accuracy parameter if the validation accuracy is higher for this epoch
  - Print information about the performance on the training and validation for this epoch
  - Check the validation loss and increase the early stopping value if there is no improvement 

In [13]:
def train_model(
    device, model, criterion, optimizer, lr_scheduler,
    train_dataloader, val_dataloader, num_epochs, checkpoint_dir,
    epoch_start_unfreeze=None, layer_start_unfreeze=None,
    early_stopping_patience=10, resume_from_checkpoint=None, class_names=None, seed = None):

    # Initialize the seed value
    if seed is not None:
        seed_all(seed)
        print(f"Seed val: {seed}")
    
    # Initialize metrics and early stopping
    metrics = TrainingMetrics()
    early_stopper = EarlyStopping(patience=early_stopping_patience)
    
    # Set the starting epoch to 0
    start_epoch = 0
    # If the resume_from_checkpoint flag is set, resume from a previous checkpoint
    if resume_from_checkpoint and os.path.exists(resume_from_checkpoint):
        print(f"Loading checkpoint: {resume_from_checkpoint}")
        # Load the saved .pth file 
        checkpoint = torch.load(resume_from_checkpoint, map_location=device)
        # Load all the model weights and the optimizer state
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Check the learning rate scheduler and set 
        if lr_scheduler and 'scheduler_state_dict' in checkpoint:
            lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        # Adjust the value of the starting epoch to the last value + 1
        start_epoch = checkpoint['epoch'] + 1
        # Load metrics if available
        if 'metrics' in checkpoint:
            metrics = checkpoint['metrics']
        print(f"Resumed training from epoch {start_epoch}")

    # Send the model to the CPU/GPU
    model = model.to(device)
    # Start in the frozen state 
    freezed = model.config.freeze_layers
    
    print("Starting training...")
    for epoch in range(start_epoch, num_epochs):
        # Decorative formatting fo rthe epoch output
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*60}")
        
        # Handle layer unfreezing using model's built-in method
        if freezed and epoch_start_unfreeze is not None and epoch >= epoch_start_unfreeze:
            model.unfreeze_layers(layer_start_unfreeze)
            freezed = False
            
            # Recreate optimizer to include newly unfrozen parameters
            if layer_start_unfreeze is not None:
                print("Recreating optimizer with newly unfrozen parameters")
                # Make a list of all the parameters that currently require gradients (i.e. unfrozen)
                trainable_params = [p for p in model.parameters() if p.requires_grad]
                # Create a new Adam optimizer for the parameters using the learning rate of the previous optimizer
                optimizer = optim.Adam(trainable_params, lr=optimizer.param_groups[0]['lr'])
        
        # Training phase: Call train_epoch on the 
        train_loss, train_acc = train_epoch(model, train_dataloader, criterion, 
                                          optimizer, device)
        
        # Validation phase with metrics storage 
        if class_names is not None:
            val_loss, val_acc, comp_metrics = validate_epoch_with_metrics(
                model, val_dataloader, criterion, device, class_names
            )
        else:
            # Fallback to basic validation
            val_loss, val_acc = validate_epoch(model, val_dataloader, criterion, device)
            comp_metrics = {}
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        # Update metrics
        metrics.update(
            train_loss, train_acc, val_loss, val_acc, current_lr,
            auroc_macro=comp_metrics.get('auroc_macro'),
            auroc_micro=comp_metrics.get('auroc_micro'),
            f1_macro=comp_metrics.get('f1_macro'),
            f1_micro=comp_metrics.get('f1_micro'),
            precision_macro=comp_metrics.get('precision_macro'),
            recall_macro=comp_metrics.get('recall_macro'),
            f1_per_class=comp_metrics.get('f1_per_class')
        )
        
        # Store per-class metrics if this is the best epoch
        if val_acc >= metrics.best_accuracy and class_names is not None:
            metrics.per_class_f1 = comp_metrics.get('per_class_f1')
            metrics.per_class_precision = comp_metrics.get('per_class_precision')
            metrics.per_class_recall = comp_metrics.get('per_class_recall')
            metrics.best_confusion_matrix = comp_metrics.get('confusion_matrix')

        # Print epoch summary
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"Learning Rate: {current_lr:.2e}")
        print(f"Best Val Accuracy: {metrics.best_accuracy:.4f}")
        print(f"Trainable Parameters: {model.get_trainable_parameters_count():,}")
        
        # Save checkpoint
        is_best = val_acc >= metrics.best_accuracy
        checkpoint_path = save_checkpoint(model, optimizer, lr_scheduler, epoch, 
                                        metrics, is_best, checkpoint_dir)
        
        # Plot learning curves every 5 epochs
        if epoch % 5 == 0:
            plot_path = os.path.join(checkpoint_dir, f'learning_curves_epoch_{epoch+1:03d}.png')
            plot_learning_curves(epoch + 1, metrics, plot_path)
        
        # Early stopping check
        early_stopper(val_loss, model)
        if early_stopper.early_stop:
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            break
        
        # Step learning rate scheduler
        if lr_scheduler:
            lr_scheduler.step()
        
        # Clear GPU cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print("\n Training completed!")
    print(f"Best validation accuracy: {metrics.best_accuracy:.4f}")
    
    return metrics

### train_and_evaluate_model
- Since we're testing individual models it will be useful to have a wrapper function to train a model that we can iterate
- Create a model-specific directory to store the latest and best performing iterations of the model
- Create the dataloaders using the create_dataloaders function (described above)
- Set up the loss function, optimizer, and learning rate scheduler
- Call the training function (described above)

In [14]:
def train_and_evaluate_model(model_name, base_config, device):

    # Pull the seed from the config
    seed = base_config.get('training', {}).get('seed', 1234)
    
    # Set seed before anything else
    seed_all(seed)

    # Decorative formatting for the model training
    print(f"\n{'='*60}")
    print(f"TRAINING MODEL: {model_name}")
    print(f"{'='*60}")
    
    config = copy.deepcopy(base_config)
    
    # Create model-specific checkpoint directory
    base_checkpoint_dir = base_config['training']['checkpoint_dir']
    model_checkpoint_dir = os.path.join(base_checkpoint_dir, model_name)
    config['training']['checkpoint_dir'] = model_checkpoint_dir
    config['model']['model_name'] = model_name
    
    # Create the directory
    os.makedirs(model_checkpoint_dir, exist_ok=True)
    print(f"Checkpoint directory: {model_checkpoint_dir}")
    
    # Create model
    model = ImageClassificationModel(config)
    print(f"Model created: {model_name}")
    print(f"Number of classes: {config['model']['num_classes']}")
    
    # Create dataloaders
    train_loader, val_loader, test_loader, class_names = create_dataloaders(
        base_dir=config['data']['base_dir'],
        batch_size=config['training']['batch_size'],
        num_workers=2, 
        seed=seed
    )
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    # Trainable parameters are those not frozen (≠ requires_grad)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    # Set the optimizer  and learning rate scheduler 
    optimizer = optim.Adam(trainable_params, lr=config['training']['learning_rate'])
    scheduler = lr_scheduler.StepLR(
        optimizer, 
        step_size=config['training']['scheduler_step_size'],
        gamma=config['training']['scheduler_gamma']
    )
    
    # Updated training config to include class names for AUROC calculation
    training_config = {
        'device': device,
        'model': model,
        'criterion': criterion,
        'optimizer': optimizer,
        'lr_scheduler': scheduler,
        'train_dataloader': train_loader, 
        'val_dataloader': val_loader,
        'num_epochs': config['training']['num_epochs'],
        'checkpoint_dir': model_checkpoint_dir,
        'epoch_start_unfreeze': config['training']['epoch_start_unfreeze'],
        'layer_start_unfreeze': config['training']['layer_start_unfreeze'],
        'early_stopping_patience': config['training']['early_stopping_patience'],
        'class_names': class_names, 
        'seed':seed
    }
    
    # Train model
    metrics = train_model(**training_config)

    # Added to run resulting model on the test dir alongside the training/validation - doesn't integrate with the plotting function for confusion matrices/etc. yet
    if test_loader is not None:
        print("\nEvaluating on test set...")
        _, test_acc, test_comp_metrics = validate_epoch_with_metrics(
            model, test_loader, criterion, device, class_names
        )
        print(f"Test Accuracy: {test_acc:.4f}")
        print(f"Test AUROC (Macro): {test_comp_metrics.get('auroc_macro', 0):.4f}")
        print(f"Test F1 (Macro): {test_comp_metrics.get('f1_macro', 0):.4f}")
        
        # Store test metrics
        test_results = {
            'accuracy': test_acc,
            'auroc_macro': test_comp_metrics.get('auroc_macro'),
            'f1_macro': test_comp_metrics.get('f1_macro'),
            'confusion_matrix': test_comp_metrics.get('confusion_matrix')
        }
    else:
        test_results = None
    
    # Plot comprehensive metrics at the end
    if class_names is not None:
        plot_metrics(metrics, class_names, os.path.join(model_checkpoint_dir, 'all_metrics.png'))
    
    return {
        'model_name': model_name,
        'metrics': metrics,
        'test_results': test_results,
        'config': config,
        'model': model,
        'checkpoint_dir': model_checkpoint_dir,
        'class_names': class_names, 
        'seed':seed
    }

### run_model_comparison
- Helper function to run the individual models 

In [15]:
def run_model_comparison(models_to_try, base_config, device, experiment_name=None):

    # Get and set the seed 
    seed = base_config.get('training', {}).get('seed', 1234)
    seed_all(seed)

    # Make a directory for the experiment name that will serve as a parent directory for the individual trained models 
    if experiment_name:
        # Add experiment name to base directory
        base_config['training']['checkpoint_dir'] = f"./checkpoints/{experiment_name}"
        print(f"Experiment directory: {base_config['training']['checkpoint_dir']}")
    
    os.makedirs(base_config['training']['checkpoint_dir'], exist_ok=True)
    
    # Instantiate a dict to hold all the results 
    all_results = {}

    # Iterate the training/evaluation function over the models in the models_to_try list 
    for i, model_name in enumerate(models_to_try, 1):
        print(f"\n[{i}/{len(models_to_try)}] Training {model_name}...")
        
        try:
            result = train_and_evaluate_model(model_name, base_config, device)
            all_results[model_name] = result
            
            # Display summary
            metrics = result['metrics']
            print(f"{model_name}: Best Val Acc = {metrics.best_accuracy:.4f}")
            print(f"Checkpoints saved to: {result['checkpoint_dir']}")
            
        except Exception as e:
            print(f"{model_name} failed: {str(e)}")
            import traceback
            traceback.print_exc()
            continue
    
    # Return the dict of the results 
    return all_results

## Iterate through the different models on the brain tumor dataset 

In [None]:
# Make a list of all of the models stored in our dictionary
all_models = [
    'convnext_tiny',
    'convnext_small',
    'convnext_base',
    'convnext_large',
    'resnet18',
    'resnet34',
    'resnet50',
    'resnet101',
    'resnet152',
    'efficientnet_b0',
    'efficientnet_b1',
    'efficientnet_b2',
    'efficientnet_b3',
    'efficientnet_b4',
    'efficientnet_b5',
    'efficientnet_b6',
    'efficientnet_b7',
    'efficientnet_v2_s',
    'efficientnet_v2_m',
    'efficientnet_v2_l',
    'densenet121',
    'densenet161',
    'densenet169',
    'densenet201'
]
# Set the device as GPU if we're on a GPU node, CPU if not
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Set the base_config
base_config = {
    "model": {
        "num_classes": 4,
        "model_name": "placeholder", 
        "n_nodes": 512,
        "dropout": 0.2,
        "freeze_layers": True,
        "freeze_strategy": "all",
        "pretrained": True,
        "custom_head": True
    },
    "data": {
        "size": 224,
        "base_dir": "./mri_modeling/brain_tumor_data/",
        "seed": 1234
    },
    "training": {
        "batch_size": 32,
        "num_epochs": 100,
        "learning_rate": 0.001,
        "scheduler_step_size": 15,
        "scheduler_gamma": 0.5,
        "checkpoint_dir": "./checkpoints",
        "epoch_start_unfreeze": 5,
        "layer_start_unfreeze": 140,
        "early_stopping_patience": 10, 
        "seed": 1234
    },
    "debug": False
}
results = run_model_comparison(all_models, base_config, device, experiment_name="brain_tumor_v1")

Experiment directory: ./checkpoints/brain_tumor_v1

[1/24] Training convnext_tiny...

TRAINING MODEL: convnext_tiny
Checkpoint directory: ./checkpoints/brain_tumor_v1/convnext_tiny
Freezing layers of pretrained model
Model created: convnext_tiny
Number of classes: 4
Seed val: 1234
Starting training...

Epoch 1/100


Training:  15%|█▍        | 20/136 [01:00<05:08,  2.66s/it, Loss=0.3975, Acc=0.7812]

### evaluate_model_on_test
- Function that will be deprecated when refactored b/c testing was incorporated into the training function
- Used to run the trained models on the test set and return metrics for plotting and visualization 

In [None]:
# Run the trained models on the testing directory
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import seaborn as sns

def evaluate_model_on_test(model, test_loader, device, model_name, class_names=None):

    # Create lists to hold the predictions, labels, and probabilities to save during evaluation 
    model.eval()
    all_predictions = []
    all_labels = []
    all_probabilities = []
    test_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()

    # Decorate formatting for the evaluation print output
    print(f"\n{'='*60}")
    print(f"Testing: {model_name}")
    print(f"{'='*60}")

    # No backpropogation
    with torch.no_grad():
        # Iterate through batches from the dataloader for the test set, wrapping it in tqdm to have a nice progress bar
        # All of this will be stored until the end and moved to CPU for numpy calculations 
        for batch_idx, (inputs, labels) in enumerate(tqdm(test_loader, desc="Testing")):
            # More the image tensors and labels to the GPU
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            # Compute the loss and add product of the average loss * batch size to the running total 
            loss = criterion(outputs, labels)
            test_loss += loss.item() * inputs.size(0)
            
            # Get predictions and convert to probabilities using softmax
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            _, predictions = torch.max(outputs, 1)
            
            # Store results and move from the GPU to the CPU
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    # Calculate metrics for loss and accuracy 
    test_loss = test_loss / len(test_loader.dataset)
    accuracy = accuracy_score(all_labels, all_predictions)
    
    # Classification report
    if class_names is None:
        class_names = [f"Class_{i}" for i in range(len(np.unique(all_labels)))]
    
    report = classification_report(
        all_labels, 
        all_predictions, 
        target_names=class_names,
        output_dict=True
    )
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    
    # Per-class accuracy
    per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
    
    # Create results dictionary
    results = {
        'model_name': model_name,
        'test_accuracy': accuracy,
        'test_loss': test_loss,
        'classification_report': report,
        'confusion_matrix': cm,
        'predictions': np.array(all_predictions),
        'true_labels': np.array(all_labels),
        'probabilities': np.array(all_probabilities),
        'per_class_accuracy': dict(zip(class_names, per_class_accuracy)),
        'total_samples': len(all_labels)
    }
    
    # Print summary
    print(f"\nResults - {model_name}")
    print(f"   Test Loss: {test_loss:.4f}")
    print(f"   Test Accuracy: {accuracy:.4f}")
    print(f"   Per-class Accuracy:")
    for class_name, acc in results['per_class_accuracy'].items():
        print(f"     {class_name}: {acc:.4f}")
    
    return results

### evaluate_all_models and save_model_results 
- Helper function to iterate through all of the models and run them on the test set and save the plots 

In [None]:
def evaluate_all_models(experiment_dir, base_config, device, test_loader, class_names=None):
    import glob

    # Find all model directories
    model_dirs = [d for d in os.listdir(experiment_dir) 
                 if os.path.isdir(os.path.join(experiment_dir, d))]

    # Print the number of models found 
    print(f"Found {len(model_dirs)} models in {experiment_dir}")

    # Instantiate a dictionary to hold all results and a list to hold the summary data
    all_results = {}
    summary_data = []

    # Iterate through the models, loading the best one (best.pth) and then evaulating its performance on the test set 
    for model_name in model_dirs:
        model_dir = os.path.join(experiment_dir, model_name)
        checkpoint_path = os.path.join(model_dir, 'best.pth')
        
        if not os.path.exists(checkpoint_path):
            print(f"No best checkpoint for {model_name}, trying latest.pth")
            checkpoint_path = os.path.join(model_dir, 'latest.pth')
        
        if os.path.exists(checkpoint_path):
            try:
                print(f"\nLoading {model_name}...")
                
                # Load checkpoint
                checkpoint = torch.load(checkpoint_path, map_location=device)
                
                # Update config with model name
                model_config = base_config.copy()
                model_config['model']['model_name'] = model_name
                
                # Create model
                model = ImageClassificationModel(model_config)
                model.load_state_dict(checkpoint['model_state_dict'])
                model = model.to(device)
                
                # Evaluate on test set
                results = evaluate_model_on_test(
                    model, test_loader, device, model_name, class_names
                )
                
                all_results[model_name] = results
                
                # Add to summary
                summary_data.append({
                    'Model': model_name,
                    'Test Accuracy': results['test_accuracy'],
                    'Test Loss': results['test_loss'],
                    'Training Accuracy': checkpoint.get('metrics', {}).get('best_accuracy', 'N/A'),
                    'Training Loss': checkpoint.get('training_info', {}).get('best_loss', 'N/A'),
                    'Checkpoint Path': checkpoint_path
                })
                
                # Save individual model results
                save_model_results(results, model_dir)
                
            except Exception as e:
                print(f"Failed to evaluate {model_name}: {e}")
                import traceback
                traceback.print_exc()
        else:
            print(f"No checkpoint found for {model_name}")
    
    # Create summary DataFrame
    summary_df = pd.DataFrame(summary_data)
    summary_df = summary_df.sort_values('Test Accuracy', ascending=False)
    
    return all_results, summary_df

def save_model_results(results, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    
    # Save metrics as JSON
    import json
    metrics_to_save = {
        'model_name': results['model_name'],
        'test_accuracy': results['test_accuracy'],
        'test_loss': results['test_loss'],
        'per_class_accuracy': results['per_class_accuracy'],
        'total_samples': results['total_samples']
    }
    
    with open(os.path.join(save_dir, 'test_results.json'), 'w') as f:
        json.dump(metrics_to_save, f, indent=2)
    
    # Save predictions as CSV
    predictions_df = pd.DataFrame({
        'true_label': results['true_labels'],
        'predicted_label': results['predictions'],
        'max_probability': np.max(results['probabilities'], axis=1)
    })
    
    # Add per-class probabilities
    for i in range(results['probabilities'].shape[1]):
        predictions_df[f'prob_class_{i}'] = results['probabilities'][:, i]
    
    predictions_df.to_csv(os.path.join(save_dir, 'test_predictions.csv'), index=False)
    
    print(f"Results saved to {save_dir}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
experiment_dir = "./checkpoints/brain_tumor_v1"

base_config = {
    "model": {
        "num_classes": 4,
        "model_name": "placeholder",
        "n_nodes": 512,
        "dropout": 0.2,
        "freeze_layers": False,
        "freeze_strategy": "all",
        "pretrained": True,
        "custom_head": True
    },
    "data": {
        "size": 224,
        "base_dir": "./mri_modeling/brain_tumor_data/",
    },
    "debug": False
}

train_loader, val_loader, test_loader, class_names = create_dataloaders(
    base_dir=base_config['data']['base_dir'],  # Use the base_dir from config
    batch_size=32,  # Specify batch size
    num_workers=2    # Specify workers
)

# Get class names from the training dataset
train_dataset = ImageFolder(os.path.join(base_config['data']['base_dir'], 'train'))
class_names = train_dataset.classes
print(f"Classes: {class_names}")
print(f"Test samples: {len(test_loader.dataset)}")

# Evaluate all models
print("\nEvaluating all trained models...")
all_results, summary_df = evaluate_all_models(
    experiment_dir=experiment_dir,
    base_config=base_config,
    device=device,
    test_loader=test_loader,
    class_names=class_names
)

### plot_confusion_matrices
- Function to plot confusion matrices for the performance of individual classes for each model/misclassification 
### plot_model_comparison
- Function to plot the test accuracy, test loss, and per-class accuracy 

In [None]:
def plot_confusion_matrices(all_results, class_names, save_dir=None):
    # Pull the number of models and then make a grid based on the number of models used
    n_models = len(all_results)
    n_cols = min(3, n_models)
    n_rows = (n_models + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    axes = axes.flatten() if n_models > 1 else [axes]
    
    for idx, (model_name, results) in enumerate(all_results.items()):
        ax = axes[idx]
        cm = results['confusion_matrix']
        
        # Normalize confusion matrix
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names,
                   ax=ax, cbar_kws={'label': 'Normalized Count'})
        
        ax.set_title(f'{model_name}\nAccuracy: {results["test_accuracy"]:.3f}')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
    
    # Hide unused subplots
    for idx in range(len(all_results), len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, 'confusion_matrices.png'), dpi=300, bbox_inches='tight')
    
    plt.show()

def plot_model_comparison(all_results, save_dir=None):
    # Prepare data
    model_names = list(all_results.keys())
    test_accuracies = [all_results[name]['test_accuracy'] for name in model_names]
    test_losses = [all_results[name]['test_loss'] for name in model_names]
    
    # Sort by accuracy
    sorted_indices = np.argsort(test_accuracies)[::-1]
    model_names = [model_names[i] for i in sorted_indices]
    test_accuracies = [test_accuracies[i] for i in sorted_indices]
    test_losses = [test_losses[i] for i in sorted_indices]
    
    # Create figure
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    # Define x-axis positions for the ticks to quiet the MatPlotLib warning 
    x_pos = np.arange(len(model_names))
    
    # Plot 1: Test Accuracy
    bars1 = ax1.bar(x_pos, test_accuracies, color='skyblue', alpha=0.7) 
    ax1.set_title('Test Accuracy Comparison')
    ax1.set_ylabel('Accuracy')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(model_names, rotation=45, ha='right')
    ax1.set_ylim([0, 1])
    
    # Add value labels
    for bar, acc in zip(bars1, test_accuracies):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.4f}', ha='center', va='bottom', fontsize=9)
    
    # Plot 2: Test Loss
    bars2 = ax2.bar(x_pos, test_losses, color='lightcoral', alpha=0.7) 
    ax2.set_title('Test Loss Comparison')
    ax2.set_ylabel('Loss')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels(model_names, rotation=45, ha='right')
    
    # Plot 3: Per-class accuracy heatmap
    per_class_data = []
    for model_name in model_names:
        per_class_acc = list(all_results[model_name]['per_class_accuracy'].values())
        per_class_data.append(per_class_acc)
    
    if per_class_data:
        per_class_data = np.array(per_class_data)
        im = ax3.imshow(per_class_data, cmap='YlOrRd', aspect='auto', vmin=0.8, vmax=1)
        ax3.set_title('Per-Class Accuracy Heatmap')
        ax3.set_xlabel('Class')
        ax3.set_ylabel('Model')
        ax3.set_xticks(range(len(class_names)))
        ax3.set_xticklabels(class_names, rotation=45, ha='right')
        ax3.set_yticks(range(len(model_names)))
        ax3.set_yticklabels(model_names)
        
        # Add colorbar
        plt.colorbar(im, ax=ax3, label='Accuracy')
        
        # Add text annotations
        for i in range(len(model_names)):
            for j in range(len(class_names)):
                ax3.text(j, i, f'{per_class_data[i, j]:.2f}',
                        ha='center', va='center', color='black', fontsize=8)
    
    plt.tight_layout()
    
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, 'model_comparison.png'), dpi=300, bbox_inches='tight')
    
    plt.show()

In [None]:
plot_confusion_matrices(all_results, class_names, save_dir=experiment_dir)
plot_model_comparison(all_results, save_dir=experiment_dir)

Now, let's try on a different condition - Alzheimer's disease. I don't anticipate that this will work as well because there is a class imbalance issue with the data being used to train (very few moderately demented images). When I've run it, it usually has errors for the first few epochs `(UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use zero_division parameter to control this behavior`.) due to the moderately demented images not being observed or classified, but this local minima is overcome through more epochs. 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_config = {
    "model": {
        "num_classes": 4,
        "model_name": "placeholder", 
        "n_nodes": 512,
        "dropout": 0.2,
        "freeze_layers": True,
        "freeze_strategy": "all",
        "pretrained": True,
        "custom_head": True
    },
    "data": {
        "size": 224,
        "base_dir": "./mri_modeling/alzheimers_data/",
    },
    "training": {
        "batch_size": 32,
        "num_epochs": 100,
        "learning_rate": 0.001,
        "scheduler_step_size": 15,
        "scheduler_gamma": 0.1,
        "checkpoint_dir": "./checkpoints",
        "epoch_start_unfreeze": 5,
        "layer_start_unfreeze": 140,
        "early_stopping_patience": 10
    },
    "debug": False
}

results_alz = run_model_comparison(all_models, base_config, device, experiment_name="alzheimers_v1")

In [None]:
experiment_dir = "./checkpoints/alzheimers_v1"

base_config = {
    "model": {
        "num_classes": 4,
        "model_name": "placeholder",
        "n_nodes": 512,
        "dropout": 0.2,
        "freeze_layers": False,
        "freeze_strategy": "all",
        "pretrained": True,
        "custom_head": True
    },
    "data": {
        "size": 224,
        "base_dir": "./mri_modeling/alzheimers_data/",
    },
    "debug": False
}

train_loader, val_loader, test_loader, class_names = create_dataloaders(
    base_dir=base_config['data']['base_dir'], 
    batch_size=32, 
    num_workers=2 
)

# Get class names from the training dataset
train_dataset = ImageFolder(os.path.join(base_config['data']['base_dir'], 'train'))
class_names = train_dataset.classes
print(f"Classes: {class_names}")
print(f"Test samples: {len(test_loader.dataset)}")

# Evaluate all models
print("\nEvaluating all trained models...")
all_results, summary_df = evaluate_all_models(
    experiment_dir=experiment_dir,
    base_config=base_config,
    device=device,
    test_loader=test_loader,
    class_names=class_names
)

In [None]:
plot_confusion_matrices(all_results, class_names, save_dir=experiment_dir)
plot_model_comparison(all_results, save_dir=experiment_dir)