# Pytorch Lightning implementation
In this notebook I will implement CNN model using Pytorch Lightning.
This model will be more flexible, than model from `initial_experiments.ipynb`, to provide more hyperparameters for training sessions.


In [ ]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

class CNN(pl.LightningDataModule):
    """
    A Convolutional Neural Network (CNN) implemented using PyTorch Lightning.

    Parameters
    ----------
    conv_layers : int
        The number of convolutional layers
    fc_layer_sizes : tuple of int
        The sizes of the fully connected layers
    input_size : torch.Size
        The size of the input tensor
    out_classes : int, optional
        The number of output classes, by default 2
    initial_filters : int, optional
        The number of filters in the first convolutional layer, by default 32
    hl_kernel_size : int, optional
        The kernel size for the hidden layers, by default 5
    activation_func : nn.Module, optional
        The activation function to use, by default nn.ReLU
    max_pool_kernel : int, optional
        The kernel size for max pooling, by default 2
    dropout_conv : bool, optional
        Whether to apply dropout to the convolutional layers, by default False
    dropout_fc : bool, optional
        Whether to apply dropout to the fully connected layers, by default False
    dropout_rate : float, optional
        The dropout rate, by default 0.5
    """
    def __init__(
            self,
            *,
            conv_layers: int,
            fc_layer_sizes: tuple[int, ...],
            input_size: torch.Size,
            out_classes: int = 2,
            initial_filters: int = 32,
            hl_kernel_size: int = 5,
            activation_func: nn.Module = nn.ReLU,
            max_pool_kernel: int = 2,
            dropout_conv: bool = False,
            dropout_fc: bool = False,
            dropout_rate: float = 0.5,
    ) -> None:
        
        super().__init__()
        hidden_layers = []
        fc_layers = []

        in_channels = input_size[0]

        for i in range(conv_layers):
            out_channels = initial_filters * 2 ** i
            hidden_layers.append(nn.Conv2d(in_channels, out_channels, hl_kernel_size))
            hidden_layers.append(activation_func())
            hidden_layers.append(nn.MaxPool2d(max_pool_kernel))
            in_channels = out_channels
            if dropout_conv:
                hidden_layers.append(nn.Dropout(dropout_rate))
        
        self.hidden_layers = nn.Sequential(*hidden_layers)

        conv_out_shape = self._get_conv_out_shape(input_size)
        in_features = nn.Flatten(conv_out_shape)
        
        for out_features in fc_layer_sizes:
            fc_layers.append(nn.Linear(in_features, out_features))
            fc_layers.append(activation_func())
            if dropout_fc:
                fc_layers.append(nn.Dropout(dropout_rate))
            in_features = out_features
        
        fc_layers.append(nn.Linear(in_features, out_classes))
        self.fc_layers = nn.Sequential(*fc_layers)
        
    
    def _get_conv_out_shape(self, input_size: torch.Size) -> torch.Tensor:
        """
        Calculate shape of the output of the convolutional layers.
        
        Parameters
        ----------
        input_size : torch.Size
            The size of the input tensor
        
        Returns
        -------
        torch.Size
            The size of the output tensor
        """
        with torch.no_grad():
            zeros = torch.zeros(*input_size, device=self.device)
            z = self.hidden_layers(zeros)
            z = torch.prod(torch.tensor(z.shape))
        return z
    