# Pytorch Lightning implementation
In this notebook I will implement CNN model using Pytorch Lightning.
This model will be more flexible, than model from `initial_experiments.ipynb`, to provide more hyperparameters for training sessions.


In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchmetrics as tm

from torch.nn.modules.loss import _Loss


class RetinoCNN(pl.LightningModule):
    """
    A Convolutional Neural Network (CNN) implemented using PyTorch Lightning.
    Loss function is BCELoss, optimizer is Adam.
    
    Parameters
    ----------
    conv_layers : int
        The number of convolutional layers.
    fc_layer_sizes : tuple of int
        The sizes of the fully connected layers.
    input_size : torch.Size
        The size of the input tensor.
    out_classes : int, optional
        The number of output classes, default is 2.
    initial_filters : int, optional
        The number of filters in the first convolutional layer, default is 32.
    hl_kernel_size : int, optional
        The kernel size for the hidden layers, default is 5.
    activation_func : nn.Module, optional
        The activation function to use, default is nn.ReLU.
    max_pool_kernel : int, optional
        The kernel size for max pooling, default is 2.
    dropout_conv : bool, optional
        Whether to apply dropout to the convolutional layers, default is False.
    dropout_fc : bool, optional
        Whether to apply dropout to the fully connected layers, default is False.
    dropout_rate : float, optional
        The dropout rate, default is 0.5.
    initial_learning_rate : float, optional
        The initial learning rate, default is 0.01.
    metrics : dict[str, tm.Metric]|None, optional
        The metrics to use, default is None. If None, default metrics will be used.

    """
    def __init__(
            self,
            *,
            conv_layers: int,
            fc_layer_sizes: tuple[int, ...],
            input_size: torch.Size,
            out_classes: int = 2,
            initial_filters: int = 32,
            hl_kernel_size: int = 5,
            activation_func: nn.Module = nn.ReLU,
            max_pool_kernel: int = 2,
            dropout_conv: bool = False,
            dropout_fc: bool = False,
            dropout_rate: float = 0.5,
            initial_learning_rate: float = 0.01,
            metrics: dict[str, tm.Metric]|None = None
    ) -> None:
        
        # Validate inputs before calling super().__init__()
        self._validate_required_inputs(conv_layers, fc_layer_sizes, input_size)
        self._validate_default_inputs(
            out_classes,
            initial_filters,
            hl_kernel_size,
            activation_func,
            max_pool_kernel,
            dropout_conv,
            dropout_fc,
            dropout_rate,
            initial_learning_rate,
            metrics
        )
        super().__init__()

        # Initialize hyperparameters
        self._initial_learning_rate = initial_learning_rate

        # Initialize metrics
        if metrics is not None:
            self._metrics = metrics
        else:
            self._metrics = {
                "accuracy": tm.Accuracy(task="binary"),
                "precision": tm.Precision(task="binary"),
                "recall": tm.Recall(task="binary"),
                "f1": tm.F1Score(task="binary"),
                "roc_auc": tm.AUROC(task="binary"),
            }

        # Initialize convolutional layers
        hidden_layers = []
        in_channels = input_size[0]

        for i in range(conv_layers):
            out_channels = initial_filters * 2 ** i
            hidden_layers.append(nn.Conv2d(in_channels, out_channels, hl_kernel_size))
            hidden_layers.append(activation_func())
            hidden_layers.append(nn.MaxPool2d(max_pool_kernel))
            in_channels = out_channels
            if dropout_conv:
                hidden_layers.append(nn.Dropout(dropout_rate))
        
        self._hidden_layers = nn.Sequential(*hidden_layers)

        # Initialize fully connected layers
        in_features = self._get_conv_out_shape(input_size)
        fc_layers = []
        for out_features in fc_layer_sizes:
            fc_layers.append(nn.Linear(in_features, out_features))
            fc_layers.append(activation_func())
            if dropout_fc:
                fc_layers.append(nn.Dropout(dropout_rate))
            in_features = out_features
        
        fc_layers.append(nn.Linear(in_features, out_classes))
        self._fc_layers = nn.Sequential(*fc_layers)    
        
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.
        
        Parameters
        ----------
        x : torch.Tensor
            The input tensor
        
        Returns
        -------
        torch.Tensor
            The output tensor
        """
        x = self._hidden_layers(x)
        x = torch.flatten(x, start_dim=1)
        x = self._fc_layers(x)
        return x
    
    
    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Training step of the model.
        
        Parameters
        ----------
        batch : torch.Tensor
            The input batch
        batch_idx : int
            The index of the batch
        
        Returns
        -------
        torch.Tensor
            The loss
        """
        
        self.train()
        x, y = batch
        x = x.to(self._device)
        y = y.to(self._device)
        y_pred = self(x)
        loss = nn.BCELoss()(y_pred, y)
        self.log("train_step_loss", loss)
        return loss
    
    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Validation step of the model.
        
        Parameters
        ----------
        batch : torch.Tensor
            The input batch
        batch_idx : int
            The index of the batch
        
        Returns
        -------
        torch.Tensor
            The loss
        """
        self.eval()
        x, y = batch
        x = x.to(self._device)
        y = y.to(self._device)
        y_pred = self(x)
        loss = nn.BCELoss()(y_pred, y)
        self._calculate_metrics(y_pred, y)
        
        self.log("val_step_loss", loss)        
        return loss
    
    def on_validation_epoch_end(self) -> None:
        """
        Log the learning rate at the end of the validation epoch.
        """
        
        self._log_metrics("val_ep")
        self._reset_metrics()
    
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Test step of the model.
        
        Parameters
        ----------
        batch : torch.Tensor
            The input batch
        batch_idx : int
            The index of the batch
        
        Returns
        -------
        torch.Tensor
            The loss
        """
        x, y = batch
        x = x.to(self._device)
        y = y.to(self._device)

        y_pred = self(x)
        loss = nn.BCELoss()(y_pred, y)
        
        self._calculate_metrics(y_pred, y)        
        self.log("test_step_loss", loss)
        
        return loss

    def on_test_epoch_end(self) -> None:
        """
        Log the metrics at the end of the test epoch.
        """
        self._log_metrics("test_ep")
        self._reset_metrics()

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer for the model.
        
        Returns
        -------
        torch.optim.Optimizer
            The optimizer
        """
        return torch.optim.Adam(self.parameters(), lr=self._initial_learning_rate)
    
    def _get_conv_out_shape(self, input_size: torch.Size) -> torch.Tensor:
        """
        Calculate shape of the output of the convolutional layers.
        
        Parameters
        ----------
        input_size : torch.Size
            The size of the input tensor
        
        Returns
        -------
        torch.Size
            The size of the output tensor
        """
        with torch.no_grad():
            zeros = torch.zeros(*input_size, device=self.device)
            z = self.hidden_layers(zeros)
            z = torch.prod(torch.tensor(z.shape))
        return z

    def _validate_required_inputs(self, conv_layers, fc_layer_sizes, input_size) -> None:
        """Validate inputs with no default values."""

        if not isinstance(conv_layers, int) or conv_layers < 1:
            raise ValueError("conv_layers must be an integer greater than 0.")

        if not isinstance(fc_layer_sizes, tuple) or not all(isinstance(i, int) for i in fc_layer_sizes):
            raise ValueError("fc_layer_sizes must be a tuple of integers.")

        if not isinstance(input_size, torch.Size):
            raise ValueError("input_size must be a torch.Size object.")
    
    def _validate_default_inputs(self, out_classes, initial_filters, hl_kernel_size, activation_func, max_pool_kernel, dropout_conv, dropout_fc, dropout_rate, initial_learning_rate, metrics) -> None:
        """Validate inputs with default values."""

        if not isinstance(out_classes, int) or out_classes < 1:
            raise ValueError("out_classes must be an integer greater than 0.")

        if not isinstance(initial_filters, int) or initial_filters < 1:
            raise ValueError("initial_filters must be an integer greater than 0.")

        if not isinstance(hl_kernel_size, int) or hl_kernel_size < 1:
            raise ValueError("hl_kernel_size must be an integer greater than 0.")

        if not isinstance(activation_func, nn.Module):
            raise ValueError("activation_func must be an instance of torch.nn.Module.")

        if not isinstance(max_pool_kernel, int) or max_pool_kernel < 1:
            raise ValueError("max_pool_kernel must be an integer greater than 0.")

        if not isinstance(dropout_conv, bool):
            raise ValueError("dropout_conv must be a boolean.")

        if not isinstance(dropout_fc, bool):
            raise ValueError("dropout_fc must be a boolean.")

        if not isinstance(dropout_rate, float) or not 0 <= dropout_rate <= 1:
            raise ValueError("dropout_rate must be a float between 0 and 1.")

        if not isinstance(initial_learning_rate, float) or initial_learning_rate <= 0:
            raise ValueError("initial_learning_rate must be a float greater than 0.")

        if metrics is not None and not isinstance(metrics, dict):
            raise ValueError("metrics must be a dictionary of string keys and torchmetrics.Metric objects.")

    def _calculate_metrics(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
        """
        Calculate the metrics.
        
        Parameters
        ----------
        y_pred : torch.Tensor
            The predicted values
        y : torch.Tensor
            The true values
        """
        
        for name, metric in self._metrics.items():
            self._metrics[name] = metric(y_pred, y) 
    
    def _log_metrics(self, prefix: str) -> None:
        """
        Log the metrics.
        
        Parameters
        ----------
        prefix : str
            The prefix for the metric name
        """
        for name, value in self._metrics.items():
            self.log(f"{prefix}_{name}", value)

    def _reset_metrics(self) -> None:
        """
        Reset the metrics.
        """
        for name, value in self._metrics.items():
            self._metrics[name].reset()
