# CNN Model

In [1]:
import os
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
#not needed?
import hydra
from omegaconf import DictConfig
#not needed?
import sys
#not needed?
import yaml
import torch
import shutil
#might not need this
import argparse
import torch.nn as nn
from tqdm import tqdm
from operator import getitem
from functools import reduce
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import Dataset
from PIL import Image
import torchvision
from torch.utils.data.sampler import Sampler
import torcheval.metrics as tm
from hydra import initialize, compose
import timm



 importing libraries and packages necessary for setting up a machine learning project using PyTorch, specifically for image processing tasks for OCT medical image analysis with a Convolutional Neural Network (CNN).

* **os, sys:** These are standard Python modules used to interact with the operating system and the Python runtime environment. They are useful for handling file and directory paths, and system-specific functions, respectively.

* **torch :** The main PyTorch library, that provides rich functionalities for tensor operations, automatic differentiation(techniques used for numerical computing for calculating the derivatives (or gradients)), and neural network layers.

* **torch.nn:** This submodule of PyTorch contains modules and utilities specifically designed for building neural networks, such as different types of layers such as convolutional layers, activation functions.

* **torchvision:** A companion package to PyTorch that provides tools and datasets for image processing, including pre-defined transforms, common datasets, and pre-trained models.

* **torch.utils.data:** This includes utilities to load data, iterate through datasets, and make dataset management easier. 'DataLoader' is a critical class that provides an iterable over the given dataset, and 'Dataset' is a base class for making datasets.

* **PIL (Python Imaging Library):** Used for opening, manipulating, and saving many different image file formats. This is essential in processing raw image files for model.

* **shutil:** Provides a number of high-level operations on files and collections of files. In this project, it is used for copying files or entire folder contents.

* **tqdm:** A library that provides a progress bar that can help visualize the progress of Python loops including during training loops or data loading.

* **torch.utils.tensorboard:** Integrates PyTorch with TensorBoard, a visualization toolkit for machine learning experimentation. SummaryWriter allows logging of experiments, such as visualizing the model graph, parameters, and metrics during training.

* **hydra, omegaconf:** These are configuration management tools. Hydra is used for elegantly configuring complex applications, and OmegaConf manages configurations as hierarchical structures, facilitating the handling of settings and parameters.

* **timm:** Short for PyTorch Image Models, this library provides pre-trained models, model components, and utilities for image classification.

* **torcheval.metrics (tm):** PyTorch library for evaluation metrics, helping in assessing the performance of model using various metrics like accuracy, precision, recall, etc.

* **Sampler:** An abstract class in torch.utils.data that allows to specify the strategy to draw samples from the dataset.


In [3]:
#Input 
dir= "C:/Users/baner/OneDrive/Desktop/New folder (2)/OCTDL"
dataset_folder = dir + "/OCTDL images/"
#lables for each Input
labels_path = dir + "/OCTDL_labels.csv"
#Output
output_folder = dir + "/output"
#model save path
save_path = dir + "/run"
#Config file
cg = dir + "/configs/OCTDL.yaml"
config_path = dir + '/configs'

regression_loss = ['mean_square_error', 'mean_absolute_error', 'smooth_L1']

Setting up paths and configuration details for deep learning model training for Optical Coherence Tomography (OCT) image analysis.


In [4]:
def set_random_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = deterministic

The function set_random_seed is designed to establish a consistent starting point for the random number generators used. Setting a random seed allows for reproducibility of results.

In [5]:
config_path =  dir + '/configs'
config_name = "OCTDL"
cwd = os.getcwd()
# Convert the absolute config path to a relative path
relative_cfg_path = os.path.relpath(config_path, cwd)
with initialize(config_path=relative_cfg_path):
    cfg = compose(config_name=config_name)

if cfg.base.random_seed != -1:
    seed = cfg.base.random_seed
    set_random_seed(seed, cfg.base.cudnn_deterministic)
log_path = os.path.join(cfg.base.save_path, 'log')
logger = SummaryWriter(log_path)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path=relative_cfg_path):


This code is designed for robust configuration management and reproducibility, utilizing Hydra for dynamic configuration composition. The script sets up paths for configuration files and initializes Hydra with a relative path to the configuration directory. Upon successful initialization, it composes a configuration object named "OCTDL" based on YAML files located in the specified directory. The code ensures that if a specific random seed is provided in the configuration (not set to -1), it is applied using the set_random_seed function, which also configures PyTorch's CUDA backend to operate deterministically if specified. This ensures reproducible results, essential for experimental consistency and validation. Additionally, the script sets up a logging mechanism using TensorBoard's SummaryWriter, directing output to a designated log directory, facilitating detailed tracking and visualization of the training process. This structured approach not only enhances reproducibility but also simplifies managing diverse experimental setups and tuning parameters across different runs.

In [6]:
def select_out_features(num_classes, criterion):
    out_features = num_classes
    if criterion in regression_loss:
        out_features = 1
    return out_features

def get_terminal_col():
    try:
        return os.get_terminal_size().columns
    except OSError:
        return 80

def print_msg(msg, appendixs=[], warning=False):
    color = '\033[93m'
    end = '\033[0m'
    print_fn = (lambda x: print(color + x + end)) if warning else print

    max_len = len(max([msg, *appendixs], key=len))
    max_len = min(max_len, get_terminal_col())
    print_fn('=' * max_len)
    print_fn(msg)
    for appendix in appendixs:
        print_fn(appendix)
    print_fn('=' * max_len)

def select_out_features(num_classes, criterion):
    out_features = num_classes
    if criterion in regression_loss:
        out_features = 1
    return out_features

def build_model(cfg):
    network = cfg.train.network
    out_features = select_out_features(
        cfg.data.num_classes,
        cfg.train.criterion
    )

    if 'vit' in network or 'swin' in network:
        model = timm.create_model(
            network,
            img_size=cfg.data.input_size,
            in_chans=cfg.data.in_channels,
            num_classes=out_features,
            pretrained=cfg.train.pretrained,
        )
    else:
        model = timm.create_model(
            network,
            in_chans=cfg.data.in_channels,
            num_classes=out_features,
            pretrained=cfg.train.pretrained,
        )

    return model

def generate_model(cfg):
    model = build_model(cfg)

    if cfg.train.checkpoint:
        weights = torch.load(cfg.train.checkpoint)
        model.load_state_dict(weights, strict=True)
        print_msg('Load weights form {}'.format(cfg.train.checkpoint))
    model = model.to(cfg.base.device)

    return model

Detailed explanation of each function, its parameters, and its role, particularly focusing on configuring, building, and initializing a machine learning model for image processing tasks:<br>

**1. select_out_features(num_classes, criterion)**:<br>
Purpose: Determines the number of output features for the model based on the training criterion. if classification then **num_classes** is number of output feature, and if regression, then 1.<br>

Parameters: <br>
* num_classes: The total number of classes in the classification task
* criterion: The loss function or criterion being used for the model.

Role: This function checks if the criterion is a regression. If it is, it sets the number of output features to 1. Otherwise, it is considered classification and sets the output features to the number of classes.<br>

<br>**2. get_terminal_col()**:<br>
Purpose: Retrieves the width of the terminal.<br>

Role: The purpose of knowing the terminal width is to make sure that any output from your script looks neat and is formatted correctly to fit within the window without wrapping unexpectedly.<br>

<br>**3. print_msg(msg, appendixs=[], warning=False)**:<br>
Purpose: Prints a formatted message, with options to include additional lines (appendixes) and a warning mode that changes text color.<br>

Parameters:<br>
* msg: The main message to print.
* appendixs: A list of additional strings to print after the main message.
* warning: A boolean flag that, if True, prints the message in a warning-specific color (yellow).
Role: This function is helpful for displaying important runtime information or errors clearly and attractively.<br>

<br>**4. build_model(cfg)**:<br>
Purpose: Builds a machine learning model using configurations specified in cfg.<br>

Parameters:
* cfg: configuration object that has all the details and configs of the model.

Role: This function retrieves the model architecture from the configuration and determines the number of output features using the select_out_features function. It checks if the network is a vision transformer (like VGG16) and adjusts the model creation parameters accordingly using the timm library, which provides a collection of pre-trained and custom models. The function configures the model with the specified input size, number of channels, and pretrained weights(Determines whether to initialize the model with weights that have been previously trained) if available.<br>

<br>timm.create(): also called **PyTorch Image Models**, used to instantiate a neural network model based on the specifications provided in the 'cfg' object. <br>
* Selecting the Model Architecture 
* Number of Input Channels: Indicates the number of color channels in the input images. For instance, 3 for RGB images, 1 for grayscale.
* Number of Classes
* Pretrained Weights 
<br>timm provides access to a wide range of model architectures, including but not limited to traditional convolutional neural networks (CNNs) like ResNets, EfficientNets, and VGGs, as well as newer architectures such as Vision Transformers (ViT), MLP Mixers, and Swin Transformers. This variety allows to choose the model that best fits their specific needs or to experiment with different models for comparison purposes.<br>

**5. generate_model(cfg)**
<br>Purpose: Configures and initializes a model based on the provided configurations in **build_model(cfg)**, including loading pretrained weights if specified.<br>
<br>Parameters:
* cfg: A configuration object loaded possibly from a configuration that includes paths, model specifications, and other parameters.

Role: This function first builds the model using **build_model(cfg)**. It then checks if there are any pretrained weights specified in the configuration and loads them. This is crucial for transfer learning scenarios where starting with a pretrained model can significantly improve performance. After loading the weights, it logs the action with a formatted message and moves the model to the specified computing device (like GPU). This function is key for setting up a model ready for training or evaluation.



### VGG16

add info about vgg16

In [7]:
def simple_transform(input_size):
    return transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor()])


def mean_and_std(train_dataset, batch_size, num_workers):
    loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False
    )

    num_samples = 0.
    channel_mean = torch.Tensor([0., 0., 0.])
    channel_std = torch.Tensor([0., 0., 0.])
    for samples in tqdm(loader):
        X, _ = samples
        channel_mean += X.mean((2, 3)).sum(0)
        num_samples += X.size(0)
    channel_mean /= num_samples

    for samples in tqdm(loader):
        X, _ = samples
        batch_samples = X.size(0)
        X = X.permute(0, 2, 3, 1).reshape(-1, 3)
        channel_std += ((X - channel_mean) ** 2).mean(0) * batch_samples
    channel_std = torch.sqrt(channel_std / num_samples)

    mean, std = channel_mean.tolist(), channel_std.tolist()
    print('mean: {}'.format(mean))
    print('std: {}'.format(std))
    return mean, std

def auto_statistics(data_path, input_size, batch_size, num_workers):
    print('Calculating mean and std of training set for data normalization.')
    transform = simple_transform(input_size)
    train_path = os.path.join(data_path, 'train')
    train_dataset = datasets.ImageFolder(train_path, transform=transform)

    return mean_and_std(train_dataset, batch_size, num_workers)

def random_apply(op, p):
    return transforms.RandomApply([op], p=p)

def data_transforms(cfg):
    data_aug = cfg.data.data_augmentation
    aug_args = cfg.data_augmentation_args

    operations = {
        'random_crop': random_apply(
            transforms.RandomResizedCrop(
                size=(cfg.data.input_size, cfg.data.input_size),
                scale=aug_args.random_crop.scale,
                ratio=aug_args.random_crop.ratio
            ),
            p=aug_args.random_crop.prob
        ),
        'horizontal_flip': transforms.RandomHorizontalFlip(
            p=aug_args.horizontal_flip.prob
        ),
        'vertical_flip': transforms.RandomVerticalFlip(
            p=aug_args.vertical_flip.prob
        ),
        'color_distortion': random_apply(
            transforms.ColorJitter(
                brightness=aug_args.color_distortion.brightness,
                contrast=aug_args.color_distortion.contrast,
                saturation=aug_args.color_distortion.saturation,
                hue=aug_args.color_distortion.hue
            ),
            p=aug_args.color_distortion.prob
        ),
        'rotation': random_apply(
            transforms.RandomRotation(
                degrees=aug_args.rotation.degrees,
                fill=aug_args.value_fill
            ),
            p=aug_args.rotation.prob
        ),
        'translation': random_apply(
            transforms.RandomAffine(
                degrees=0,
                translate=aug_args.translation.range,
                fill=aug_args.value_fill
            ),
            p=aug_args.translation.prob
        ),
        'grayscale': transforms.RandomGrayscale(
            p=aug_args.grayscale.prob
        ),
        'gaussian_blur': random_apply(
            transforms.GaussianBlur(
                kernel_size=aug_args.gaussian_blur.kernel_size,
                sigma=aug_args.gaussian_blur.sigma
            ),
            p=aug_args.gaussian_blur.prob
        )
    }

    augmentations = []
    for op in data_aug:
        if op not in operations:
            raise NotImplementedError('Not implemented data augmentation operations: {}'.format(op))
        augmentations.append(operations[op])

    normalization = [
        transforms.Resize((cfg.data.input_size, cfg.data.input_size)),
        transforms.ToTensor(),
        transforms.Normalize(cfg.data.mean, cfg.data.std)
    ]

    train_preprocess = transforms.Compose([
        *augmentations,
        *normalization
    ])

    test_preprocess = transforms.Compose(normalization)

    return train_preprocess, test_preprocess


def img_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
class CustomizedImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, loader=img_loader):
        super(CustomizedImageFolder, self).__init__(root, transform, target_transform, loader=loader)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target
    
def generate_dataset_from_folder(data_path, train_transform, test_transform):
    train_path = os.path.join(data_path, 'train')
    test_path = os.path.join(data_path, 'test')
    val_path = os.path.join(data_path, 'val')

    train_dataset = CustomizedImageFolder(train_path, train_transform, loader=img_loader)
    test_dataset = CustomizedImageFolder(test_path, test_transform, loader=img_loader)
    val_dataset = CustomizedImageFolder(val_path, test_transform, loader=img_loader)

    return train_dataset, test_dataset, val_dataset

def print_dataset_info(datasets):
    train_dataset, test_dataset, val_dataset = datasets
    print('=========================')
    print('Dataset Loaded.')
    print('Categories:\t{}'.format(len(train_dataset.classes)))
    print('Training:\t{}'.format(len(train_dataset)))
    print('Validation:\t{}'.format(len(val_dataset)))
    print('Test:\t\t{}'.format(len(test_dataset)))
    print('=========================')
    
def generate_dataset(cfg):
    if cfg.data.mean == 'auto' or cfg.data.std == 'auto':
        mean, std = auto_statistics(
            cfg.base.data_path,
            cfg.data.input_size,
            cfg.train.batch_size,
            cfg.train.num_workers
        )
        cfg.data.mean = mean
        cfg.data.std = std

    train_transform, test_transform = data_transforms(cfg)

    data_splits = generate_dataset_from_folder(
        cfg.base.data_path,
        train_transform,
        test_transform
    )

    print_dataset_info(data_splits)
    return data_splits

 Above code is tailored for image data preprocessing, transformation, and loading within a machine learning pipeline, particularly focusing on image classification tasks. Each function plays a specific role in handling and transforming the image data effectively. Here’s an explanation of the key functions and their roles:<br>

 **1. simple_transform(input_size):**<br>
 Purpose: Sets up a basic image transformation pipeline that resizes an image to a square of the specified size and converts it to a tensor.<br>

 Parameters:
 * input_size: The target size to which images will be resized.<br>

**2. mean_and_std(train_dataset, batch_size, num_workers):**<br>
Purpose: Computes the mean and standard deviation of all images in a dataset, which are critical for normalizing the dataset during training.<br>

Parameters:
* train_dataset: The dataset from which to compute the statistics.
* batch_size: Number of images to process in one batch.
* num_workers: Number of worker processes to use for loading the data.

<br>Role: Calculates the mean and standard deviation of all images across each color channel in a given dataset. These statistics are important for normalizing the dataset during the preprocessing step of training a model. Understanding the average color intensity and variability in a dataset is crucial for effectively training neural networks. By normalizing images (i.e., subtracting the mean and dividing by the standard deviation), the model sees a more standardized and simplified version of the input during training.<br>

**3. auto_statistics(data_path, input_size, batch_size, num_workers)**<br>
Purpose: Automates the calculation of mean and standard deviation for a dataset stored at a specified path.<br>

Parameters:
* data_path: Path to the dataset.
* input_size: Size to which images will be resized.
* batch_size: Number of images in each batch for processing.
* num_workers: Number of processes to use for data loading.

<br>Role: Uses datasets.ImageFolder, a utility class from the torchvision.datasets module in PyTorch. It is specifically designed to simplify the process of loading image data from a directory structure where images are organized into folders, with each folder representing a class. This class is very useful for training image classification models, as it automatically handles the mapping of images to their corresponding labels based on the folder structure.<br>

**4. data_transforms(cfg)**<br>
Purpose: Creates complex data augmentation pipelines based on configurations provided in cfg.<br>
Parameters:
* cfg: Configuration object containing settings for data augmentation and other preprocessing details.

<br>Role: Based on the configuration, this function constructs a series of transformations including flipping, color adjustments, contrast, brightness, etc, to augment the training data, helping improve the robustness of the model against overfitting and enhancing its ability to generalize.<br>

**5. generate_dataset_from_folder(data_path, train_transform, test_transform)**<br>
Purpose: Loads image data from specified directories and applies the appropriate transformations for training, testing, and validation datasets.<br>

Parameters:
* data_path: The base directory containing the dataset.
* train_transform: Transformations to apply to the training data.
* test_transform: Transformations to apply to the testing and validation data.

<br>Role: Utilizes the CustomizedImageFolder to apply the specified transformations and organize the data into training, testing, and validation splits. The CustomizedImageFolder class in the provided code is a subclass of torchvision.datasets.ImageFolder. This customization extends the base functionality of ImageFolder to adapt or enhance how images are loaded and processed, which can be crucial for specific datasets or preprocessing requirements. auto_statistics() is used initially to determine the normalization parameters, which are then typically hardcoded or saved for use in the actual data transformations applied during model training, handled by CustomizedImageFolder.<br>

**6. print_dataset_info(datasets)**<br>
Purpose: Prints information about the loaded datasets to help verify their correctness and completeness.<br>

Parameters:
* datasets: A tuple containing the training, testing, and validation datasets.

<br>Role: Displays the number of categories and the size of each dataset split, providing a quick summary that can help in understanding the dataset structure and ensuring that data loading is correct.<br>

**7. generate_dataset(cfg)**<br>
Purpose: Coordinates the overall process of dataset preparation mentioned in the above functions.<br>

Parameters:
* cfg: Configuration object with settings for data path, transformations, and other necessary parameters.

<br>Role: Manages the flow of setting up datasets from initial statistics calculation to applying transformations and organizing data splits. This function ensures that the datasets are ready for use in training, testing, and validation phases of machine learning models.<br>

In [8]:
metrics_fn = {
    'acc': tm.MulticlassAccuracy,
    'f1': tm.MulticlassF1Score,
    'auc': tm.MulticlassAUROC,
    'precision': tm.MulticlassPrecision,
    'recall': tm.MulticlassRecall
}
available_metrics = metrics_fn.keys()
logits_required_metrics = ['auc']
regression_based_metrics = ['mean_square_error', 'mean_absolute_error', 'smooth_L1']

class Estimator():
    def __init__(self, metrics, num_classes, criterion, average='macro', thresholds=None):
        self.criterion = criterion
        self.num_classes = num_classes
        self.thresholds = [-0.5 + i for i in range(num_classes)] if not thresholds else thresholds

        if criterion in regression_based_metrics and 'auc' in metrics:
            metrics.remove('auc')
            print_msg('AUC is not supported for regression based metrics {}.'.format(criterion), warning=True)

        self.metrics = metrics
        self.metrics_fn = {m: metrics_fn[m](num_classes=num_classes, average=average) for m in metrics}
        self.conf_mat_fn = tm.MulticlassConfusionMatrix(num_classes=num_classes)

    def update(self, predictions, targets):
        targets = targets.data.cpu().long()
        logits = predictions.data.cpu()
        predictions = self.to_prediction(logits)

        # update metrics
        self.conf_mat_fn.update(predictions, targets)
        for m in self.metrics_fn.keys():
            if m in logits_required_metrics:
                self.metrics_fn[m].update(logits, targets)
            else:
                self.metrics_fn[m].update(predictions, targets)

    def get_scores(self, digits=-1):
        scores = {m: self._compute(m, digits) for m in self.metrics}
        return scores

    def _compute(self, metric, digits=-1):
        score = self.metrics_fn[metric].compute().item()
        score = score if digits == -1 else round(score, digits)
        return score
    
    def get_conf_mat(self):
        return self.conf_mat_fn.compute().numpy().astype(int)

    def reset(self):
        for m in self.metrics_fn.keys():
            self.metrics_fn[m].reset()
        self.conf_mat_fn.reset()
    
    def to_prediction(self, predictions):
        if self.criterion in regression_based_metrics:
            predictions = torch.tensor([self.classify(p.item()) for p in predictions]).long()
        else:
            predictions = torch.argmax(predictions, dim=1).long()

        return predictions

    def classify(self, predict):
        thresholds = self.thresholds
        predict = max(predict, thresholds[0])
        for i in reversed(range(len(thresholds))):
            if predict >= thresholds[i]:
                return i

The Estimator class is a designed to manage and compute various metrics for evaluating machine learning models, particularly those used in classification tasks. This class is highly modular and adaptable to different types of machine learning problems, whether they involve simple classification, multiclass classification, or even regression with specific adjustments. Here’s a detailed breakdown of the class and its functionalities:<br>

**Constructor: __init__(self, metrics, num_classes, criterion, average='macro', thresholds=None)**
Purpose: Initializes the estimator with specific metrics, the number of classes, the type of criterion, and other optional settings.<br>
Parameters:
* metrics: A list of metric names that the estimator should compute, for our problem we are using AUC, F1, accuracy, precision and recall.
* num_classes: The total number of classes in the classification problem.
* criterion: The criterion used for training the model, for our problem we are using cross_entropy.
* average: Specifies the method for averaging in case of multiclass or multicategory metrics.
* thresholds: Custom thresholds for classifying numerical outputs into categories, particularly useful in regression or other continuous output models.<br>

**1. update(self, predictions, targets)**
Purpose: Updates the metric calculations based on the latest batch of predictions and actual targets.<br>

Parameters:
* predictions: These are typically the raw output from the model, which can be logits (pre-activation outputs) for a classification task.
* targets: These are the ground truth labels against which the model’s outputs are compared.

<br>Role: The method begins by ensuring that both predictions and targets are moved to the CPU and converted to the appropriate tensor type (long integers) if necessary. This is done using targets.data.cpu().long().Then the raw predictions are processed into class labels using the to_prediction method. Finally updates the metrics used like AUC, confusion metric, acc, etc.<br>

**2. get_scores(self, digits=-1)**
Purpose: Retrieves the computed scores for all metrics after rounding them to the specified number of decimal places, then returns a dictionary of metric names and their corresponding scores.<br>

**3. get_conf_mat(self)**
Purpose: Computes and returns the current state of the confusion matrix.<br>

**4. reset(self)**
Purpose: Resets all metrics and the confusion matrix to start fresh, typically used between training epochs or different phases of model evaluation.<br>

**5. to_prediction(self, predictions)**
Purpose: Converts raw model outputs (logits or regression outputs) into discrete class predictions. For classification tasks, it simply takes the argmax of logits to determine the class.


# cross entropy, AUC, F1, ACCURACY, Precision and Recall. Argmax

In [10]:
# train
model = generate_model(cfg)
train_dataset, test_dataset, val_dataset = generate_dataset(cfg)
estimator = Estimator(cfg.train.metrics, cfg.data.num_classes, cfg.train.criterion)

Calculating mean and std of training set for data normalization.


100%|██████████| 4/4 [00:08<00:00,  2.24s/it]
100%|██████████| 4/4 [00:08<00:00,  2.17s/it]

mean: [0.08801376819610596, 0.08801379799842834, 0.08801349997520447]
std: [0.16029858589172363, 0.16029863059520721, 0.16029848158359528]
Dataset Loaded.
Categories:	3
Training:	241
Validation:	59
Test:		118





**Configuration and Initialization:**
* **Configuration Setup**: Using Hydra, a configuration management library, the system initializes settings from configuration files stored in a directory. This setup involves converting absolute paths to relative ones for robustness and setting the working environment based on the specified configuration (cfg). This step is crucial for maintaining consistency across different runtime environments and for flexibility in experimentation.
* **Random Seed Setting**: The system sets a random seed to ensure reproducibility of results. This includes seeding the Python random module, NumPy's random number generator, and PyTorch's random functions, as well as configuring CUDA's deterministic mode if required.<br>

**Model Preparation:**
* **Model Generation:** Generate_model(cfg) constructs the neural network based on the specifications in cfg, which includes selecting the network architecture, loading any pretrained weights if specified, and ensuring the model is compatible with the designated computing device (e.g., GPU). This step is fundamental for leveraging transfer learning and custom architectures to address specific task needs.
* **Data Management:**
Dataset Handling: The generate_dataset(cfg) function orchestrates the loading and preprocessing of image data from specified directories. It handles:
* Calculating dataset-specific statistics (mean and standard deviation) for normalization purposes, if not already specified.
* Applying specified data transformations and augmentations to enhance model training and generalization capabilities. This includes resizing, cropping, flipping, and color adjustments.
* Organizing data into training, validation, and testing datasets using a customized loader that might include specific preprocessing or formatting aligned with the model's input requirements.<br>

**Performance Monitoring and Evaluation:**
* **Estimator Setup:** An Estimator instance is initialized with configurations for performance metrics such as accuracy, F1-score, and others relevant to classification tasks. This tool is designed to continuously evaluate the model’s performance during training, providing insights into how well the model is learning and predicting across various classes.
* **Metrics Management:** The Estimator is capable of handling both traditional classification metrics and specific adjustments for tasks that might output or interpret results in a regression-like manner (using thresholds to classify outputs). It ensures that all evaluations are appropriately updated and reported, aiding in model tuning and decision-making.<br>

**Training Loop Execution:**
* **Training Execution:** Using the prepared model and datasets, the training process would be executed, likely involving multiple epochs of passing the training data through the model, calculating loss, and updating model weights.

* **Validation and Adjustment:** Post each training epoch, the model would be evaluated against the validation set to monitor performance improvements and make adjustments (like learning rate changes or early stopping) based on predefined criteria or performance metrics.
* **Testing and Final Evaluation:** Once training is deemed complete, the model would be tested against the unseen test dataset to gauge its generalization capabilities and final performance metrics would be calculated.


In [11]:
class WarmupLRScheduler():
    def __init__(self, optimizer, warmup_epochs, initial_lr):
        self.epoch = 0
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.initial_lr = initial_lr

    def step(self):
        if self.epoch <= self.warmup_epochs:
            self.epoch += 1
            curr_lr = (self.epoch / self.warmup_epochs) * self.initial_lr
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = curr_lr

    def is_finish(self):
        return self.epoch >= self.warmup_epochs


class ScheduledWeightedSampler(Sampler):
    def __init__(self, dataset, decay_rate):
        self.dataset = dataset
        self.decay_rate = decay_rate

        self.num_samples = len(dataset)
        self.targets = [sample[1] for sample in dataset.imgs]
        self.class_weights = self.cal_class_weights()

        self.epoch = 0
        self.w0 = torch.as_tensor(self.class_weights, dtype=torch.double)
        self.wf = torch.as_tensor([1] * len(self.dataset.classes), dtype=torch.double)
        self.sample_weight = torch.zeros(self.num_samples, dtype=torch.double)
        for i, _class in enumerate(self.targets):
            self.sample_weight[i] = self.w0[_class]

    def step(self):
        if self.decay_rate < 1:
            self.epoch += 1
            factor = self.decay_rate**(self.epoch - 1)
            self.weights = factor * self.w0 + (1 - factor) * self.wf
            for i, _class in enumerate(self.targets):
                self.sample_weight[i] = self.weights[_class]

    def __iter__(self):
        return iter(torch.multinomial(self.sample_weight, self.num_samples, replacement=True).tolist())

    def __len__(self):
        return self.num_samples

    def cal_class_weights(self):
        num_classes = len(self.dataset.classes)
        classes_idx = list(range(num_classes))
        class_count = [self.targets.count(i) for i in classes_idx]
        weights = [self.num_samples / class_count[i] for i in classes_idx]
        min_weight = min(weights)
        class_weights = [weights[i] / min_weight for i in classes_idx]
        return class_weights


class LossWeightsScheduler():
    def __init__(self, dataset, decay_rate):
        self.dataset = dataset
        self.decay_rate = decay_rate


        self.num_samples = len(dataset)
        self.targets = [sample[1] for sample in dataset.imgs]
        self.class_weights = self.cal_class_weights()

        self.epoch = 0
        self.w0 = torch.as_tensor(self.class_weights, dtype=torch.float32)
        self.wf = torch.as_tensor([1] * len(self.dataset.classes), dtype=torch.float32)

    def step(self):
        weights = self.w0
        if self.decay_rate < 1:
            self.epoch += 1
            factor = self.decay_rate**(self.epoch - 1)
            weights = factor * self.w0 + (1 - factor) * self.wf
        return weights

    def __len__(self):
        return self.num_samples

    def cal_class_weights(self):
        num_classes = len(self.dataset.classes)
        classes_idx = list(range(num_classes))
        class_count = [self.targets.count(i) for i in classes_idx]
        weights = [self.num_samples / class_count[i] for i in classes_idx]
        min_weight = min(weights)
        class_weights = [weights[i] / min_weight for i in classes_idx]
        return class_weights


class ClippedCosineAnnealingLR():
    def __init__(self, optimizer, T_max, min_lr):
        self.optimizer = optimizer
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)
        self.min_lr = min_lr
        self.finish = False

    def step(self):
        if not self.finish:
            self.scheduler.step()
            curr_lr = self.optimizer.param_groups[0]['lr']
            if curr_lr < self.min_lr:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.min_lr
                self.finish = True

    def is_finish(self):
        return self.finish


above functions Particularly focusing on learning rates, sampling strategies, and loss weights. Let’s break down each class and function to understand their purpose, role, and parameters:

### Learning Rate Scheduler:<br>

In machine learning, particularly in training deep neural networks, a learning rate scheduler adjusts the learning rate during training according to a pre-defined strategy. It systematically modifies the learning rate based on certain criteria, which could include the number of epochs completed, the rate of error reduction, or performance metrics on a validation set.

### Warmup in Learning Rate Scheduling<br>
A warmup scheduler is a type of learning rate scheduler that initially starts with a smaller learning rate and gradually increases it to a pre-defined target rate over a certain number of epochs or steps.

**1. WarmupLRScheduler:**<br>
Purpose: Gradually increases the learning rate from zero to the initial specified learning rate over a given number of epochs. This approach helps stabilize the model's training early on, preventing large gradient updates that could destabilize the optimizer.

Parameters:

* optimizer: The optimizer associated with the model, whose learning rate will be adjusted.
* warmup_epochs: The number of epochs over which the learning rate will increase.
* initial_lr: The target learning rate after the warmup period.

Role:
* On each step(), the learning rate for each parameter group in the optimizer is adjusted based on the current epoch, increasing until it reaches the initial learning rate after the specified number of warmup epochs.

* is_finish(): Checks if the warmup period is complete.

**2. ScheduledWeightedSampler:**<br>
Purpose: To adjust the sampling of training data according to class weights that change over time, ensuring that classes are sampled in a balanced manner as training progresses. This helps in dealing with imbalanced datasets.

Parameters:
* dataset: The dataset from which samples are drawn.
* decay_rate: A rate at which the influence of initial class weights decreases over time.

Role:
* Initialization:It starts by calculating initial weights for each class based on their frequency in the dataset. Classes that appear less frequently get higher weights, making it more likely for samples from these classes to be chosen during training. This helps to compensate for their lower natural frequency.It stores these weights and prepares to adjust them over time.
* Weight Adjustment Over Time:The sampler uses a decay_rate to decrease the influence of these initial weights gradually. The idea is to start the training focusing heavily on the imbalanced nature by sampling rarer classes more frequently. As training progresses, the weights slowly shift towards uniform weights, where every class has the same chance of being sampled. This transition is controlled by the decay_rate—a value between 0 and 1 that dictates how fast the weights shift towards being equal.
* Each epoch, it recalculates the weights for each class by mixing the initial calculated weights with uniform weights based on how many epochs have passed.
* For each batch of training data, the sampler uses these adjusted weights to randomly select which samples to include. This random selection is weighted, meaning samples from classes with higher weights are more likely to be chosen.

**3. LossWeightsScheduler:**<br>
Purpose: Similar to ScheduledWeightedSampler, but instead of adjusting sample weights, it adjusts the weights for the loss function to focus on certain classes more than others across training epochs. Rare classes are assigned higher weights, and frequent classes are given lower weights. 

Parameters:
* dataset: The dataset being used, including class information.
* decay_rate: The rate at which initial class weights decay to uniform.

Role:
step() adjusts the loss weights for each epoch, whi
ch can then be used in a weighted loss function to help the model focus on underrepresented classes.

**4. ClippedCosineAnnealingLR:**<br>
Purpose: The ClippedCosineAnnealingLR class provides a strategy to manage the learning rate during the training of a machine learning model, specifically adapting the learning rate according to a cosine annealing schedule but ensuring it doesn’t fall below a certain minimum threshold.

Parameters:
* optimizer: The optimizer being used.
* T_max: The maximum number of iterations/epochs before the learning rate resets.
* min_lr: The minimum learning rate to which the learning rate can decay.
Role:
On each step(), the learning rate is adjusted using the internal scheduler. If it falls below min_lr, it is clipped to min_lr. This prevents the learning rate from becoming too low, which could stall the training process.<br>
A learning rate that's too high can cause the model to converge too rapidly to suboptimal solutions, while a rate that’s too low might result in slow convergence or getting stuck in local minima. The ClippedCosineAnnealingLR class is designed to dynamically adjust the learning rate in a cyclical manner while preventing it from dropping too low, which helps maintain a balance between exploration of the solution space (avoiding local minima) and exploitation (fine-tuning to the best solution).<br>

**Cosine Annealing:**<br>
Cosine annealing modifies the learning rate following a cosine curve over a predefined number of epochs or cycles. This curve starts at a higher learning rate and decreases to a lower rate towards the end of a cycle, then resets to the higher rate at the start of a new cycle.<br>

**Clipping the Learning Rate:**<br>
The 'clipping' feature of this scheduler ensures that the learning rate never falls below a specified minimum threshold (min_lr). Even as the cosine annealing formula decreases the learning rate, this class checks if the rate has dropped below the minimum and, if so, sets it back to this minimum value.This is crucial because a very low learning rate can effectively freeze the model’s training, preventing significant updates to the weights and thus halting improvement.


Cosine annealing. 

In [12]:
def initialize_optimizer(cfg, model):
    optimizer_strategy = cfg.solver.optimizer
    learning_rate = cfg.solver.learning_rate
    weight_decay = cfg.solver.weight_decay
    momentum = cfg.solver.momentum
    nesterov = cfg.solver.nesterov
    adamw_betas = cfg.solver.adamw_betas

    if optimizer_strategy == 'SGD':
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            nesterov=nesterov,
            weight_decay=weight_decay
        )
    elif optimizer_strategy == 'ADAM':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
    elif optimizer_strategy == 'ADAMW':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            betas=adamw_betas,
            weight_decay=weight_decay
        )
    else:
        raise NotImplementedError('Not implemented optimizer.')

    return optimizer

def initialize_sampler(cfg, train_dataset, val_dataset):
    sampling_strategy = cfg.data.sampling_strategy
    val_sampler = None
    if sampling_strategy == 'class_balanced':
        train_sampler = ScheduledWeightedSampler(train_dataset, 1)
    elif sampling_strategy == 'progressively_balanced':
        train_sampler = ScheduledWeightedSampler(train_dataset, cfg.data.sampling_weights_decay_rate)
    elif sampling_strategy == 'instance_balanced':
        train_sampler = None
    else:
        raise NotImplementedError('Not implemented resampling strategy.')

    return train_sampler, val_sampler

def initialize_lr_scheduler(cfg, optimizer):
    warmup_epochs = cfg.train.warmup_epochs
    learning_rate = cfg.solver.learning_rate
    scheduler_strategy = cfg.solver.lr_scheduler

    if not scheduler_strategy:
        lr_scheduler = None
    else:
        scheduler_args = cfg.scheduler_args[scheduler_strategy]
        if scheduler_strategy == 'cosine':
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_args)
        elif scheduler_strategy == 'multiple_steps':
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, **scheduler_args)
        elif scheduler_strategy == 'reduce_on_plateau':
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_args)
        elif scheduler_strategy == 'exponential':
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_args)
        elif scheduler_strategy == 'clipped_cosine':
            lr_scheduler = ClippedCosineAnnealingLR(optimizer, **scheduler_args)
        else:
            raise NotImplementedError('Not implemented learning rate scheduler.')

    if warmup_epochs > 0:
        warmup_scheduler = WarmupLRScheduler(optimizer, warmup_epochs, learning_rate)
    else:
        warmup_scheduler = None

    return lr_scheduler, warmup_scheduler

class WarpedLoss():
    def __init__(self, loss_function, criterion):
        self.loss_function = loss_function
        self.criterion = criterion

        self.squeeze = True if self.criterion in regression_loss else False

    def __call__(self, pred, target):
        if self.squeeze:
            pred = pred.squeeze()

        return self.loss_function(pred, target)
def initialize_loss(cfg, train_dataset):

    criterion = cfg.train.criterion
    criterion_args = cfg.criterion_args[criterion]

    weight = None
    loss_weight_scheduler = None
    loss_weight = cfg.train.loss_weight
    if criterion == 'cross_entropy':
        if loss_weight == 'balance':
            loss_weight_scheduler = LossWeightsScheduler(train_dataset, 1)
        elif loss_weight == 'dynamic':
            loss_weight_scheduler = LossWeightsScheduler(train_dataset, cfg.train.loss_weight_decay_rate)
        elif isinstance(loss_weight, list):
            assert len(loss_weight) == len(train_dataset.classes)
            weight = torch.as_tensor(loss_weight, dtype=torch.float32, device=cfg.base.device)
        loss = nn.CrossEntropyLoss(weight=weight, **criterion_args)
    elif criterion == 'mean_square_error':
        loss = nn.MSELoss(**criterion_args)
    elif criterion == 'mean_absolute_error':
        loss = nn.L1Loss(**criterion_args)
    elif criterion == 'smooth_L1':
        loss = nn.SmoothL1Loss(**criterion_args)
    elif criterion == 'kappa_loss':
        loss = KappaLoss(**criterion_args)
    elif criterion == 'focal_loss':
        loss = FocalLoss(**criterion_args)
    else:
        raise NotImplementedError('Not implemented loss function.')

    loss_function = WarpedLoss(loss, criterion)
    return loss_function, loss_weight_scheduler

def initialize_dataloader(cfg, train_dataset, val_dataset, train_sampler, val_sampler):
    batch_size = cfg.train.batch_size
    num_workers = cfg.train.num_workers
    pin_memory = cfg.train.pin_memory
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        num_workers=num_workers,
        drop_last=True,
        pin_memory=pin_memory
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=(val_sampler is None),
        sampler=val_sampler,
        num_workers=num_workers,
        drop_last=False,
        pin_memory=pin_memory
    )

    return train_loader, val_loader

def save_weights(model, save_path):
    if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel):
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    torch.save(state_dict, save_path)

def inverse_normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def select_target_type(y, criterion):
    if criterion in ['cross_entropy', 'kappa_loss']:
        y = y.long()
    elif criterion in ['mean_square_error', 'mean_absolute_error', 'smooth_L1']:
        y = y.float()
    elif criterion in ['focal_loss']:
        y = y.to(dtype=torch.int64)
    else:
        raise NotImplementedError('Not implemented criterion.')
    return y

def train(cfg, model, train_dataset, val_dataset, estimator, logger=None):

    device = cfg.base.device
    optimizer = initialize_optimizer(cfg, model)
    train_sampler, val_sampler = initialize_sampler(cfg, train_dataset, val_dataset)
    lr_scheduler, warmup_scheduler = initialize_lr_scheduler(cfg, optimizer)
    loss_function, loss_weight_scheduler = initialize_loss(cfg, train_dataset)
    train_loader, val_loader = initialize_dataloader(cfg, train_dataset, val_dataset, train_sampler, val_sampler)

    # start training
    model.train()
    avg_loss = 0
    max_indicator = 0
    for epoch in range(1, cfg.train.epochs + 1):
        # resampling weight update
        if train_sampler:
            train_sampler.step()

        # update loss weights
        if loss_weight_scheduler:
            weight = loss_weight_scheduler.step()
            loss_function.weight = weight.to(device)

        # warmup scheduler update
        if warmup_scheduler and not warmup_scheduler.is_finish():
            warmup_scheduler.step()

        epoch_loss = 0
        estimator.reset()
        progress = tqdm(enumerate(train_loader), total=len(train_loader)) if cfg.base.progress else enumerate(train_loader)
        for step, train_data in progress:
            X, y = train_data
            X = X.to(device)
            y = y.to(device)
            y = select_target_type(y, cfg.train.criterion)

            # forward
            y_pred = model(X)
            loss = loss_function(y_pred, y)

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # metrics
            epoch_loss += loss.item()
            avg_loss = epoch_loss / (step + 1)

            estimator.update(y_pred, y)
            message = 'epoch: [{} / {}], loss: {:.6f}'.format(epoch, cfg.train.epochs, avg_loss)
            if cfg.base.progress:
                progress.set_description(message)

        if not cfg.base.progress:
            print(message)

        train_scores = estimator.get_scores(4)
        scores_txt = ', '.join(['{}: {}'.format(metric, score) for metric, score in train_scores.items()])
        print('Training metrics:', scores_txt)

        curr_lr = optimizer.param_groups[0]['lr']
        if logger:
            for metric, score in train_scores.items():
                logger.add_scalar('training {}'.format(metric), score, epoch)
            logger.add_scalar('training loss', avg_loss, epoch)
            logger.add_scalar('learning rate', curr_lr, epoch)

        if cfg.train.sample_view:
            samples = torchvision.utils.make_grid(X)
            samples = inverse_normalize(samples, cfg.data.mean, cfg.data.std)
            logger.add_image('input samples', samples, epoch, dataformats='CHW')

        # validation performance
        if epoch % cfg.train.eval_interval == 0:
            eval(cfg, model, val_loader, cfg.train.criterion, estimator, device)
            val_scores = estimator.get_scores(6)
            scores_txt = ['{}: {}'.format(metric, score) for metric, score in val_scores.items()]
            print_msg('Validation metrics:', scores_txt)
            if logger:
                for metric, score in val_scores.items():
                    logger.add_scalar('validation {}'.format(metric), score, epoch)

            # save model
            indicator = val_scores[cfg.train.indicator]
            if indicator > max_indicator:
                save_weights(model, os.path.join(cfg.base.save_path, 'best_validation_weights.pt'))
                max_indicator = indicator
                print_msg('Best {} in validation set. Model save at {}'.format(cfg.train.indicator, cfg.base.save_path))

        if epoch % cfg.train.save_interval == 0:
            save_weights(model, os.path.join(cfg.base.save_path, 'epoch_{}.pt'.format(epoch)))

        # update learning rate
        if lr_scheduler and (not warmup_scheduler or warmup_scheduler.is_finish()):
            if cfg.solver.lr_scheduler == 'reduce_on_plateau':
                lr_scheduler.step(avg_loss)
            else:
                lr_scheduler.step()

    # save final model
    save_weights(model, os.path.join(cfg.base.save_path, 'final_weights.pt'))

    if logger:
        logger.close()

These components work together to provide a comprehensive, configurable, and adaptable framework for training a CNN on OCT images, with considerations for imbalanced data, dynamic learning rate adjustments, and robust loss management to ensure optimal training performance.

**1. initialize_optimizer:**<br>
Purpose: Configures and initializes the optimizer used for training the model.<br>
Parameters:
* cfg: Configuration object containing optimizer settings.
* model: The neural network model for which the optimizer is being set.

Role: This function selects the optimizer based on configuration settings (cfg). Depending on the choice, it initializes ADAM with specific parameters like learning rate, weight decay, etc.

**2. initialize_sampler:**<br>
Purpose: Sets up sampling strategies for the training dataset.<br>
Parameters:
* cfg: Configuration containing sampling strategy information.
* train_dataset: The dataset used for training.
* val_dataset: The dataset used for validation (though not directly used here).
Role: Depending on the configuration, this function can initialize different types of samplers, such as ScheduledWeightedSampler(), to handle class imbalance by adjusting sample weights across training epochs.

**3. initialize_lr_scheduler:**<br>
Purpose: Configures and creates a learning rate scheduler.<br>
Parameters:
* cfg: Configuration containing scheduler type and parameters.
* optimizer: The optimizer to which the scheduler will be applied.
Role: This function supports multiple types of learning rate schedulers, allowing dynamic adjustment of learning rates based on training progress. It can also integrate a warmup period through WarmupLRScheduler() to gradually ramp up the learning rate.

**4. WarpedLoss:**<br>
Purpose: A wrapper class for loss functions, possibly to adjust the behavior of the loss based on the criterion.<br>
Parameters:
* loss_function: The base loss function(cross entropy loss).
* criterion: Specifies the type of loss (cross entropy).
Role: This class adjusts the prediction or the format of the output and target (e.g., squeezing dimensions) before passing them to the actual loss function. This is useful in cases where the loss function expects inputs in a specific format.

**5. initialize_loss:**<br>
Purpose: Initializes the loss function and potentially a loss weight scheduler.
Parameters:
* cfg: Configuration that specifies the type of loss and any special considerations like dynamic weights.
* train_dataset: Dataset to potentially use for dynamic loss weighting.
Role: Sets up the loss function according to the specified criteria and initializes dynamic weight, LossWeightsScheduler() adjustments if specified in the configuration.

**6. initialize_dataloader**<br>
Purpose: Sets up data loaders for training and validation datasets.
Parameters:
* cfg: Training configuration including batch size and worker details.
* train_dataset: Training dataset.
* val_dataset: Validation dataset.
* train_sampler: Sampler for training dataset.
* val_sampler: Sampler for validation dataset (if any).
Role: Configures DataLoader() with appropriate batch size, samplers, and other parameters to efficiently load data for both training and validation phases.

**DataLoader()**<br>
In PyTorch, the DataLoader is a versatile utility that automates the process of loading, shuffling, and batching data for training or inference.

**7. save_weights:**<br>
Purpose: Saves the model's weights to a specified path.
Parameters:
* model: The model whose weights are to be saved.
* save_path: File path where the weights should be saved.

Role: Handles saving of model weights, considering whether the model is wrapped in any parallel data processing wrappers like DataParallel.

**8. inverse_normalize:**<br>
Purpose: Applies inverse normalization to images (typically used for visualizing or processing outputs).
Parameters:
* tensor: Image data tensor.
* mean: Mean used for normalization.
* std: Standard deviation used for normalization.

Role: Reverses the normalization process to bring images back to their original value range, useful for visualization or post-processing.

**9. select_target_type:**<br>
Purpose: Adjusts the data type of targets based on the loss criterion.
Parameters:
* y: Target labels.
* criterion: Specifies the loss function type.

Role: Ensures that target data types are compatible with the expectations of different loss functions, like converting to long integers for classification losses.

**10. train:**<br>
Purpose: Orchestrates the entire training process using model.train().
Parameters:
* cfg: All-encompassing configuration object.
* model: The CNN model to be trained.
* train_dataset, val_dataset: Datasets for training and validation.
* estimator: An object to estimate and report training metrics.
* logger: (Optional) Logger for recording training progress.

Role: Manages the training loop, including updating samplers, loss weights, learning rates, and logging training metrics. Also, it periodically evaluates the model on the validation set and saves model checkpoints.





ADAM

In [13]:
train(
    cfg=cfg,
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    estimator=estimator,
    logger=logger
)

  0%|          | 0/3 [00:05<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 24260, 15580, 11372, 14504, 8136, 19232, 13908, 12000) exited unexpectedly

The train function orchestrates the comprehensive training process of a convolutional neural network model on optical coherence tomography (OCT) image data. This function manages various aspects of the training cycle, including data loading, dynamic scheduling of learning rates and sampling, application of loss functions, logging, and validation.

**Parameters:**
* cfg: Configuration object containing all settings related to the model, training process, optimizer, scheduler, etc.
* model: The neural network model that will be trained.
* train_dataset: Dataset used for training the model.
* val_dataset: Dataset used for validating the model's performance.
* estimator: Tool for calculating and storing metrics during training.
* logger: Optional logging tool for tracking training progress and metrics.

**Process Overview:**
* Initialization: Initializes the optimizer and learning rate scheduler as specified in cfg. Configures data loaders for both training and validation datasets, using custom sampling strategies to address issues like class imbalance. Sets up the loss function, with dynamic weighting to handle class imbalances effectively.

* Training Loop: Runs for a specified number of epochs, during which, updates the sampling weights for the training data to maintain balance across classes, Adjusts the weights used in the loss function based on class importance or frequency, Applies a warm-up scheduler to gradually increase the learning rate if specified, Processes batches(forward propagation, Computes loss, backpropagation, Logs training metrics, and updates the learning rate), Performs validation at specified intervals to monitor model performance on unseen data, adjusting training strategy based on validation results.

* Post-Training: Saves the final model weights.


In [None]:
def eval(cfg, model, dataloader, criterion, estimator, device):
    model.eval()
    torch.set_grad_enabled(False)

    estimator.reset()
    for test_data in dataloader:
        X, y = test_data
        X = X.to(device)
        y = y.to(device)
        y = select_target_type(y, criterion)
        y_pred = model(X)
        estimator.update(y_pred, y)

    model.train()
    torch.set_grad_enabled(True)

def evaluate(cfg, model, test_dataset, estimator):
    test_sampler = None
    test_loader = DataLoader(
        test_dataset,
        shuffle=(test_sampler is None),
        sampler=test_sampler,
        batch_size=cfg.train.batch_size,
        num_workers=cfg.train.num_workers,
        pin_memory=cfg.train.pin_memory
    )

    print('Running on Test set...')
    eval(cfg, model, test_loader, cfg.train.criterion, estimator, cfg.base.device)

    print('================Finished================')
    test_scores = estimator.get_scores(6)
    for metric, score in test_scores.items():
        print('{}: {}'.format(metric, score))
    print('Confusion Matrix:')
    print(estimator.get_conf_mat())
    print('========================================')

The eval and evaluate functions are essential parts of the validation and testing process. These functions are designed to assess the performance of the model on a test or validation set, ensuring that it generalizes well beyond the training data. Below is a detailed explanation of each function, their parameters, purpose, and role in the model's evaluation process:<br>

**1. eval Function:**<br>
Purpose: Model Evaluation - Conducts a pass through the validation or test dataset to evaluate the model's performance. This function is critical for assessing the model under a non-training condition where no learning is taking place.<br>

Parameters:
* cfg: Configuration object containing all necessary settings.
* model: The CNN model being evaluated.
* dataloader: A DataLoader object that provides batches of data from the test or validation set.
* criterion: Specifies the loss function used during training, which influences how the data should be prepared or processed.
* estimator: An object used to calculate and record various performance metrics.
* device: Specifies the computing device (e.g., CPU or CUDA GPU) where the evaluation should run.

Role:
* Mode Setting: Switches the model to eval mode (model.eval()), which disables dropout and batch normalization effects specific to training.
* Gradient Handling: Disables gradient computation(torch.set_grad_enabled(False)), reducing memory usage and speeding up computation, as gradients are not needed for model evaluation.
* Metric Calculation: Iterates through the dataset using (dataloader), computes predictions, and uses the estimator to update performance metrics based on the predictions and actual labels.
* Reset State: After evaluation, resets the model to training mode and re-enables gradient computation to resume training if needed.

**2. evaluate Function:**<br>
Purpose: Test Evaluation - Specifically designed to handle the evaluation of the model on a test dataset and to print out the performance metrics and confusion matrix. This function sets up the testing environment, loads the test data, and calls the eval function.
Parameters:
* cfg: Configuration object containing settings like batch size and device.
* model: The CNN model to be tested.
* test_dataset: Dataset used for testing the model.
* estimator: Metric calculator for evaluating model performance.

Role: 
* DataLoader Setup: Initializes a DataLoader for the test dataset. The DataLoader handles batching, shuffling (if applicable), and allocation of data for processing, based on the configuration settings.
* Evaluation Process: Calls the eval function to perform the actual evaluation, passing all necessary configurations and objects.
* Performance Output: After evaluation, prints detailed performance metrics and a confusion matrix to provide insights into the model's classification accuracy across different classes.

In [None]:
# test
print('Performance of the best validation model:')
checkpoint = os.path.join(cfg.base.save_path, 'best_validation_weights.pt')
cfg.train.checkpoint = checkpoint
model = generate_model(cfg)
evaluate(cfg, model, test_dataset, estimator)


Performance of the best validation model:
Load weights form C:/Users/baner/OneDrive/Desktop/New folder (2)/OCTDL/run\best_validation_weights.pt
Running on Test set...


RuntimeError: DataLoader worker (pid(s) 20188, 17164, 17452, 8472, 20120, 6856, 2552, 16460) exited unexpectedly

Evaluating a trained convolutional neural network (CNN) model on a test dataset. This process is critical in machine learning workflows to assess the generalizability and performance of the model outside of the training environment. Here’s a detailed summary:<br>

**Purpose:**<br>
The function is designed to load a pre-trained model from a saved checkpoint and then evaluate its performance on a test dataset. This allows for a realistic assessment of how well the model will perform in clinical settings for OCT (Optical Coherence Tomography) image analysis.<br>

**Process Overview:**
* Model Loading: Checkpoint Retrieval: Constructs a file path to retrieve the best model weights saved during validation (best_validation_weights.pt). This checkpoint contains the state of the model that had the best performance on the validation set during training.
* Configuration Update: Updates the training configuration to use the checkpoint file.
* Model Initialization: Calls generate_model(cfg), which, Constructs the model architecture specified in the configuration, Loads the model weights from the checkpoint, Transfers the model to the appropriate computational device (e.g., CPU or GPU).
* Model Evaluation: Prepares a DataLoader for the test dataset, which automates the batching and optional shuffling of the test data, Calls the evaluate function with the model and test data loader.

In [None]:

print('Performance of the final model:')
checkpoint = os.path.join(cfg.base.save_path, 'final_weights.pt')
cfg.train.checkpoint = checkpoint
model = generate_model(cfg)
evaluate(cfg, model, test_dataset, estimator)

Performance of the final model:


FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/baner/OneDrive/Desktop/New folder (2)/OCTDL/run\\final_weights.pt'

**Final Verification:**<br>
This function acts as the ultimate test for the model's learning and generalization abilities. By evaluating the model with the final weights, it ensures that the training outcomes are reliable and the model is ready for deployment.