# Pytorch Lightning implementation
In this notebook I will implement CNN model using Pytorch Lightning.
This model will be more flexible, than model from `initial_experiments.ipynb`, to provide more hyperparameters for training sessions.


In [ ]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchmetrics as tm


class CNN(pl.LightningModule):
    """
    A Convolutional Neural Network (CNN) implemented using PyTorch Lightning.

    Parameters
    ----------
    conv_layers : int
        The number of convolutional layers.
    fc_layer_sizes : tuple of int
        The sizes of the fully connected layers.
    input_size : torch.Size
        The size of the input tensor.
    out_classes : int, optional
        The number of output classes, default is 2.
    initial_filters : int, optional
        The number of filters in the first convolutional layer, default is 32.
    hl_kernel_size : int, optional
        The kernel size for the hidden layers, default is 5.
    activation_func : nn.Module, optional
        The activation function to use, default is nn.ReLU.
    max_pool_kernel : int, optional
        The kernel size for max pooling, default is 2.
    dropout_conv : bool, optional
        Whether to apply dropout to the convolutional layers, default is False.
    dropout_fc : bool, optional
        Whether to apply dropout to the fully connected layers, default is False.
    dropout_rate : float, optional
        The dropout rate, default is 0.5.
    loss_func : nn.Module, optional
        The loss function to use, default is nn.CrossEntropyLoss.
    optimizer : str, optional
        The optimizer to use, default is "Adam". Can be either "Adam" or "SGD".
    initial_learning_rate : float, optional
        The initial learning rate, default is 0.01.

    Attributes
    ----------
    accuracy : torchmetrics.Accuracy
        Metric to calculate accuracy.
    precision : torchmetrics.Precision
        Metric to calculate precision.
    recall : torchmetrics.Recall
        Metric to calculate recall.
    f1 : torchmetrics.F1Score
        Metric to calculate F1 score.
    roc_auc : torchmetrics.AUROC
        Metric to calculate ROC AUC score.
    _device : torch.device
        The device to use for computations.
    hidden_layers : nn.Sequential
        The sequence of hidden layers.
    fc_layers : nn.Sequential
        The sequence of fully connected layers.
    """
    def __init__(
            self,
            *,
            conv_layers: int,
            fc_layer_sizes: tuple[int, ...],
            input_size: torch.Size,
            out_classes: int = 2,
            initial_filters: int = 32,
            hl_kernel_size: int = 5,
            activation_func: nn.Module = nn.ReLU,
            max_pool_kernel: int = 2,
            dropout_conv: bool = False,
            dropout_fc: bool = False,
            dropout_rate: float = 0.5,
            loss_func: nn.Module = nn.CrossEntropyLoss,
            optimizer: str = "Adam",
            initial_learning_rate: float = 0.01,
    ) -> None:
        
        super().__init__()

        # Initialize hyperparameters
        self.loss_func = loss_func
        self.initial_learning_rate = initial_learning_rate
        if optimizer == "Adam":
            self.optimizer = torch.optim.Adam
        elif optimizer == "SGD":
            self.optimizer = torch.optim.SGD
        else:
            raise ValueError("Invalid optimizer. Use 'Adam' or 'SGD'.")

        # Initialize metrics
        self.accuracy = tm.Accuracy(task="binary")
        self.precision = tm.Precision(task="binary")
        self.recall = tm.Recall(task="binary")
        self.f1 = tm.F1Score(task="binary")
        self.roc_auc = tm.AUROC(task="binary")
        
        self._device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
        
        
        # Initialize convolutional layers
        hidden_layers = []
        in_channels = input_size[0]

        for i in range(conv_layers):
            out_channels = initial_filters * 2 ** i
            hidden_layers.append(nn.Conv2d(in_channels, out_channels, hl_kernel_size))
            hidden_layers.append(activation_func())
            hidden_layers.append(nn.MaxPool2d(max_pool_kernel))
            in_channels = out_channels
            if dropout_conv:
                hidden_layers.append(nn.Dropout(dropout_rate))
        
        self.hidden_layers = nn.Sequential(*hidden_layers)

        # Initialize fully connected layers
        conv_out_shape = self._get_conv_out_shape(input_size)
        in_features = 8
        fc_layers = []  
        for out_features in fc_layer_sizes:
            fc_layers.append(nn.Linear(conv_out_shape, out_features))
            fc_layers.append(activation_func())
            if dropout_fc:
                fc_layers.append(nn.Dropout(dropout_rate))
            in_features = out_features
        
        fc_layers.append(nn.Linear(in_features, out_classes))
        self.fc_layers = nn.Sequential(*fc_layers)
        
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.
        
        Parameters
        ----------
        x : torch.Tensor
            The input tensor
        
        Returns
        -------
        torch.Tensor
            The output tensor
        """
        x = self.hidden_layers(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc_layers(x)
        return x
    
    
    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Training step of the model.
        
        Parameters
        ----------
        batch : torch.Tensor
            The input batch
        batch_idx : int
            The index of the batch
        
        Returns
        -------
        torch.Tensor
            The loss
        """
        x, y = batch
        y_pred = self(x)
        loss = self.loss_func(y_pred, y)
        self.log("train_step_loss", loss)
        return loss
    
    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Validation step of the model.
        
        Parameters
        ----------
        batch : torch.Tensor
            The input batch
        batch_idx : int
            The index of the batch
        
        Returns
        -------
        torch.Tensor
            The loss
        """
        x, y = batch
        y_pred = self(x)
        loss = self.loss_func(y_pred, y)
        
        self.accuracy(y_pred, y)
        self.precision(y_pred, y)
        self.recall(y_pred, y)
        self.f1(y_pred, y)
        self.roc_auc(y_pred, y)
        
        self.log("val_step_loss", loss)        
        return loss
    
    def on_validation_epoch_end(self) -> None:
        """
        Log the learning rate at the end of the validation epoch.
        """
        self.log("val_ep_accuracy", self.accuracy, on_epoch=True)
        self.log("val_ep_precision", self.precision, on_epoch=True)
        self.log("val_ep_recall", self.recall, on_epoch=True)
        self.log("val_ep_f1", self.f1, on_epoch=True)
        self.log("val_ep_roc_auc", self.roc_auc, on_epoch=True)
    
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """
        Test step of the model.
        
        Parameters
        ----------
        batch : torch.Tensor
            The input batch
        batch_idx : int
            The index of the batch
        
        Returns
        -------
        torch.Tensor
            The loss
        """
        x, y = batch
        y_pred = self(x)
        loss = self.loss_func(y_pred, y)
        
        self.accuracy(y_pred, y)
        self.precision(y_pred, y)
        self.recall(y_pred, y)
        self.f1(y_pred, y)
        self.roc_auc(y_pred, y)
        
        self.log("test_step_loss", loss)
        
        return loss

    def on_test_epoch_end(self) -> None:
        """
        Log the metrics at the end of the test epoch.
        """
        self.log("test_ep_accuracy", self.accuracy, on_epoch=True)
        self.log("test_ep_precision", self.precision, on_epoch=True)
        self.log("test_ep_recall", self.recall, on_epoch=True)
        self.log("test_ep_f1", self.f1, on_epoch=True)
        self.log("test_ep_roc_auc", self.roc_auc, on_epoch=True)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer for the model.
        
        Returns
        -------
        torch.optim.Optimizer
            The optimizer
        """
        return self.optimizer(self.parameters(), lr=self.initial_learning_rate)
    
    def _get_conv_out_shape(self, input_size: torch.Size) -> torch.Tensor:
        """
        Calculate shape of the output of the convolutional layers.
        
        Parameters
        ----------
        input_size : torch.Size
            The size of the input tensor
        
        Returns
        -------
        torch.Size
            The size of the output tensor
        """
        with torch.no_grad():
            zeros = torch.zeros(*input_size, device=self._device)
            z = self.hidden_layers(zeros)
            z = torch.prod(torch.tensor(z.shape))
        return z

    