"ld. The thredhold of eye blink is 80% of the average maximum. Also
while running program the baseline will be newly calculated every
3 epochs (6 seconds) to adjust a new threshold within limited boundaries.
After that the system also calculated the energy of eye behavior in each
subject to detect eye blinks and eye movements. If the signal is in the
range of eye behavior, eye blinks and eye movements will perform. If the
signal exceeds a range of eye behavior, artifacts will perform. Thereby,
our systems can detect eye behavior out of artifacts and use eye behavior
to detect state of drowsiness."

In [4]:
pip install torchsummary

Note: you may need to restart the kernel to use updated packages.


In [5]:
from typing import Dict, Any, Tuple, Union, Optional, Callable, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torchsummary import summary
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score, AUROC
#from mind_ml.models.EEGNet import EEGNetLightning

from torchtyping import TensorType, patch_typeguard
#from typeguard import typechecked

#patch_typeguard()

In [None]:
https://www.eneuro.org/content/9/5/ENEURO.0160-22.2022
    https://github.com/yoshidan/pytorch-eyeblink-detection
        https://hal.science/hal-01917529/document
            https://arxiv.org/pdf/2101.10932.pdf
                https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10079879/
                    https://github.com/berdakh/eeg-pytorch

## CNN

In [None]:
class CNN(nn.Module): #pl.LightningModule
    """
    CNN model for EOG data.
    """
    def __init__(self,
                 channels: int,
                 dropout: float,
                 kernel_size: int,
                 sample_rate: int = 256,
                 chan_out: int = 256,
                 pool_out: int = 120,
                 n_blocks: int = 3,
                 **kwargs):
        """
        Parameters
        ----------
        channels : int, Number of EEG channels.
        dropout : float
        """

        super().__init__()

        self.channels = channels
        self.sample_rate = sample_rate
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.chan_out = chan_out
        self.pool_out = pool_out
        self.n_blocks = n_blocks

        """
        TO DO:
        """

        self.blocks = dict()
        # initialize empty dictionary 'blocks' as attribute of current object('self')
        for n in range(n_blocks):
            # Add new key-value pair to 'blocks' dictionary
            self.blocks[f"block_{n}"] = nn.Sequential( #container for sequence of layers
                #1D convolution operation on input 
                nn.Conv1d(in_channels=channels, #number of input channels
                        out_channels=chan_out, #number of output channels
                        kernel_size=kernel_size, #size of convolutional kernel
                        padding="same", #padding mode to apply
                        ),
                nn.LazyBatchNorm1d(), #Performs batch normalization on input, normalize input along channel dimension 
                # helps stabilizing and accelerating training of neural networks. Improve convergence, generalization, and reduce overfitting 
                # makes network less sensitive to the scale of input features and helps with gradient flow during backpropagation
                nn.GELU(), #Applies GELU activation function to the input (smooth approximation of rectified linear unit activation function)
                nn.AdaptiveMaxPool1d(output_size = pool_out), #Layer applies adaptive max pooling to input 
                # Adapts input to have specified output size by performing max pooling 
                nn.Dropout(p=dropout), #randomly sets elements of the input tensor to zero with probability p, p = dropout value 
            )
            channels = chan_out
            chan_out = chan_out//2 #rounds down to nearest integer
            pool_out = pool_out//2

        #This is just to make it run with EEGNetLightning
        self.conv_net = nn.Sequential(*[ #* unpacks list of blocks to be passed as individual arguments to nn.Sequential() 
            #creates new attribute 
            #sequence of blocks from self.blocks.values() 
            block for block in self.blocks.values() 
            #self.blocks contains blocks of convolutional network
        ])
        self.classifier_head = nn.Sequential(*[
            nn.Flatten(), #layer that flattens multi-dimensional input tensor into 1D tensor. 
            # converts output of convolutional layers into flat feature vector, which can be passed to fully connected layers
            # result: channels * pixels (row) * pixels (column)
            nn.LazyLinear(out_features = 2), #2 output features/classes, not initialized until accessed first time
            nn.GELU()#Gaussian Error Linear Unit activation function, introducing non-linearity into the model and helps capture complex relationships between features
            ])

    def forward(self, x: TensorType["num_batches", "num_channels", "num_samples"]) -> TensorType["num_batches", "kernal_size", "reduced_channels", 1]:
        # for n in range(self.n_blocks):
        #     x = self.blocks[f"block_{n}"](x)
        x = self.conv_net(x) #instance of nn.Sequential wiht blocks of convolutional network
        #Input tensor x goes through sequence of convolutional, normalization, activation, pooling, and dropout operatons defined in the blocks 
        x = self.classifier_head(x)
        # Applies classifier head part of the model to the output of the convolutional netowkr 
        #self.classifier_head is instance of nn.Sequential containing layers responsible for classification 
        
        return x

In [None]:
from typing import Dict, Any, Tuple, Union, Optional, Callable, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torchsummary import summary
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score, AUROC

from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

## DepthwiseConv2D, SeparableConv2D

In [None]:
# Keras offers a DepthwiseConv2D layer as well as a SeparableConv2D layer.
# The DepthwiseConv2D layer performs a depthwise convolution that acts separately on channels,
# while the SeparableConv2D performs a depthwise convolution that acts separately on channels, followed by a pointwise convolution that mixes channels.
# The pytorch equivalent is as follows:
class DepthwiseConv2d(nn.Module):
    """
    Each input channel is convolved separately with its own set of filters
    
    From the documentation of torch.nn.Conv2d:
    If groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
    this operation is also known as a depthwise convolution.
    In other words, for an input of size (N, C_in, L_in), a depthwise convolution with a depthwise multiplier K,
    can be constructed by providing the arguments (C_in = C_in, C_out = C_in * K, ..., groups = C_in).
    
    depth_multiplier (K): determines number of output channels for each input channel 
    Total number of output channels = in_channels * depth_multiplier
    """

    def __init__(self, in_channels, depth_multiplier, **kwargs):
        super(DepthwiseConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels=in_channels,
                                   out_channels=in_channels * depth_multiplier,
                                   groups=in_channels, # each input channel will be convolved separately with its own set of filters
                                   **kwargs) #passed to nn.Conv2d constructor

    def forward(self, x): #implements forward pass of DepthwiseConv2d module
        out = self.depthwise(x) 
        return out


class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, depth_multiplier=1, **kwargs):
        super(SeparableConv2d, self).__init__()
        self.depthwise = DepthwiseConv2d(in_channels, depth_multiplier, **kwargs)
        self.pointwise = nn.Conv2d(in_channels=in_channels * depth_multiplier,
                                   out_channels=out_channels,
                                   kernel_size=(1, 1))

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

## EEGNet Backbone

In [None]:
class EEGNetBackbone(nn.Module):
    """
    Pytorch implementation of the EEGNet's backbone (convnet) from Lawhern et al. 2018.

    Reference Implementation of EEGNet Version 3 from original authors:
    https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    def __init__(self,
                 channels= 1,
                 sample_length= 256,
                 kernel_size: int = 64,
                 dropout: float = 0.0,
                 f1: int = 8,
                 d: int = 2,
                 f2: int = 16,
                 sample_rate: int = 256,
                 adjust_for_sample_length: bool = False,
                 **kwargs):
        """
        Parameters
        ----------
        channels : int
            Number of channels in the input data.
        sample_rate : int
            Sample rate of the input data, the architecture is designed for 256 Hz, but should work for higher sample rates.
        dropout : float
            Dropout rate.
        kernel_size : int
            Length of the temporal convolution kernal in the first layer.
        f1, f2 : int
            Number of temporal filters (F1) and number of pointwise filters (F2) to learn.
            Default: F1 = 8, F2 = F1 * D
        d : int
            Number of spatial filters to learn within each temporal convolution.
            Default: d = 2
        """

        super().__init__()

        self.channels = channels
        self.sample_rate = sample_rate
        self.sample_length = sample_length
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.f1 = f1
        self.d = d
        self.f2 = f2

        # The authors describe their updated model as follows:
        # There are two CNN blocks followed by a fully connected layer.
        # Block 1:
        #  - Vanilla 2D Convolution with same padding and kernal size (1, kernel_size)
        #  - Batch normalization
        #  - Depthwise Convolution with kernal size (channels, 1) and depth multiplier d
        #  - Batch normalization
        #  - ELU activation
        #  - Average pooling with kernal size (1, 4)
        #  - Dropout or Spatial Dropout
        # Block 2:
        #  - Depthwise Separable Convolution with output channels f2, kernal size (1, 16) and same padding
        #  - Batch normalization
        #  - ELU activation
        #  - Average pooling with kernal size (1, 8)
        #  - Dropout or Spatial Dropout
        # Flatten

        """
        The following is the orignal implementation using keras


        input1       = Input(shape = (Chans, Samples, 1))
        block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                    input_shape = (Chans, Samples, 1),
                                    use_bias = False)(input1)
        block1       = BatchNormalization()(block1)
        block1       = DepthwiseConv2D((Chans, 1), use_bias = False,
                                    depth_multiplier = D,
                                    depthwise_constraint = max_norm(1.))(block1)
        block1       = BatchNormalization()(block1)
        block1       = Activation('elu')(block1)
        block1       = AveragePooling2D((1, 4))(block1)
        block1       = dropoutType(dropoutRate)(block1)

        block2       = SeparableConv2D(F2, (1, 16),
                                    use_bias = False, padding = 'same')(block1)
        block2       = BatchNormalization()(block2)
        block2       = Activation('elu')(block2)
        block2       = AveragePooling2D((1, 8))(block2)
        block2       = dropoutType(dropoutRate)(block2)

        flatten      = Flatten(name = 'flatten')(block2)
        """

        # We want the pytorch equivalent of the keras implementation above
        # Note that the input shape for pytorch is (Batch, Channels, Height, Width)
        # for our case height will be the spatial dimension and width will be the temporal dimension
        # spatial dimension is the number of recording channels (not to be confused with "channels" from the pytorch perspective)
        # temporal dimension is the number of samples
        self.pool1 = 4
        # Block 1
        self.block1 = nn.Sequential(
            # shape (batch, 1, channels, samples)
            # first a temporal convolution with kernel size (1, kernel_size) ignoring the spatial dimension
            nn.Conv2d(in_channels=1,
                      out_channels=f1,
                      kernel_size=(1, kernel_size),
                      padding="same",
                      bias=False),
            # shape (batch, f1, channels, samples)
            nn.BatchNorm2d(num_features=f1),
            # shape (batch, f1, channels, samples)
            # next a depthwise convolution over the spatial dimension to learn frequency specific spatial filters
            DepthwiseConv2d(in_channels=f1,
                            depth_multiplier=d,
                            kernel_size=(channels, 1),
                            padding="valid",
                            bias=False),
            # shape (batch, f1 * d, 1, samples)
            nn.BatchNorm2d(num_features=f1 * d),
            # shape (batch, f1 * d, 1, samples)
            nn.ELU(),
            # shape (batch, f1 * d, 1, samples)
            # the pooling is done over the temporal dimension
            nn.AvgPool2d(kernel_size=(1, self.pool1)),
            # shape (batch, f1 * d, 1, samples // 4)
            nn.Dropout(p=dropout),
            # shape (batch, f1 * d, 1, samples // 4)
        )

        # in previous versions the final pooling was hard cored to 8
        # assuming a sample rate of 256 Hz
        self.pool2 = 8

        # but actually if we increase the sample length (more than one second of data)
        # we would want to pool more aggressively to reduce the output size to the same dimension
        # otherwise the fully connected layer will need to be very large
        # we can adjust the pooling size based on ratio of sample length and sample rate
        if adjust_for_sample_length:
            ratio = sample_length / sample_rate
            self.pool2 = int(self.pool2 * ratio)

        # Block 2
        self.block2 = nn.Sequential(
            # shape (batch, f1 * d, 1, samples // 4)
            # again a temporal convolution but this time as a depthwise separable convolution
            SeparableConv2d(in_channels=f1 * d,
                            out_channels=f2,
                            depth_multiplier=1,
                            kernel_size=(1, 16),
                            padding="same",
                            bias=False),
            # shape (batch, f2, samples / 4, 1)
            nn.BatchNorm2d(num_features=f2),
            # shape (batch, f2, samples / 4, 1)
            nn.ELU(),
            # shape (batch, f2, samples / 4, 1)
            # the pooling is done over the temporal dimension
            nn.AvgPool2d(kernel_size=(1, self.pool2)),
            # shape (batch, f2, samples / 32, 1)
            nn.Dropout(p=dropout),
            # shape (batch, f2, samples / 32, 1)
        )

        self.output_shape = (f2, sample_length // (self.pool1*self.pool2), 1)

    def forward(self, x: TensorType["num_batches", 1, "num_channels", "num_samples"]) -> TensorType["num_batches", "kernal_size", "reduced_channels", 1]:
        x = self.block1(x)
        x = self.block2(x)
        return x


## EEGNet

In [None]:
class EEGNet(nn.Module):
    """
    Pytorch implementation of the EEGNet model from Lawhern et al. 2018.

    Reference Implementation of EEGNet Version 3 from original authors:
    https://github.com/vlawhern/arl-eegmodels/blob/master/EEGModels.py
    """

    def __init__(self,
                 num_classes= 2, #number of classes to predict 
                 classifier_hidden_units: Optional[int] = None, #number of hidden units(neurons) in each hidden layer of classifer
                 classifier_num_layers: int = 1, #hidden layers in the classifier
                 **backbone_kwargs): 
        """
        Parameters
        ----------
        num_classes : int
            Number of classes to predict
        """

        super().__init__()
        
        # Calls constructor of parent class nn.Module to initialize EEGNet module
        self.classifier_hidden_units = classifier_hidden_units
        self.classifier_num_layers = classifier_num_layers
        self.classifier_num_hidden_layers = classifier_num_layers - 1
        self.num_classes = num_classes
        
        # Ensures that if multiple hidden layers are specified, the number of hidden units is also provided
        assert self.classifier_num_hidden_layers == 0 or self.classifier_hidden_units is not None, "If classifier_num_layers > 1, classifier_hidden_units must be specified"
        """
        conv_net = ...
        flatten      = Flatten(name = 'flatten')(conv_net)

        dense        = Dense(nb_classes, name = 'dense',
                            kernel_constraint = max_norm(norm_rate))(flatten)
        softmax      = Activation('softmax', name = 'softmax')(dense)
        return Model(inputs=input1, outputs=softmax)
        """
        # for backwards compatibility with previous versions of this EEGNet implementation we need to use nn.Sequential
        # otherwise we couldn't use the pretrained weights / re-evaulate older models
        # self.conv_net = EEGNetBackbone(**backbone_kwargs)
        backbone = EEGNetBackbone(**backbone_kwargs)
        self.dropout = backbone.dropout

        self.conv_net = nn.Sequential(*[
            backbone.block1,
            backbone.block2,
        ])
        
        # Calculate the input units, output units, and total number of input units for classifier head
        # output shape of conv_net is (batch, f2, samples / 32, 1)
        classifier_input_units = np.array(backbone.output_shape).prod()
        input_units = [classifier_input_units] + [classifier_hidden_units] * self.classifier_num_hidden_layers
        output_units = [classifier_hidden_units] * self.classifier_num_hidden_layers + [num_classes]

        # Layers stored in list called 'classifier_layers'
        classifier_layers = [nn.Flatten()] # Flatten input tensor
        for in_features, out_features in zip(input_units, output_units): 
            # Each nn.Linear layer has in_features input units and out_features output units with nn.ELU activation function
            classifier_layers.append(nn.Linear(in_features=in_features, out_features=out_features))
            classifier_layers.append(nn.ELU())

        # remove the last activation layer
        classifier_layers = classifier_layers[:-1]

        # creates sequential module
        self.classifier_head = nn.Sequential(*classifier_layers)
        # output shape of classifier_head is (batch, num_classes)

    def forward(self, x: TensorType["num_batches", 1, "num_channels", "num_samples"]) -> TensorType["num_batches", "num_classes"]:
        x = self.conv_net(x)
        x = self.classifier_head(x)
        return x

    # Used for feature embedding. Takes input tensor x and performs convolutional operations using self.conv_net
    def embed(self, x: TensorType["num_batches", 1, "num_channels", "num_samples"]) -> TensorType["num_batches", "classifier_input_units"]:
        x = self.conv_net(x)
        x = x.view(x.shape[0], -1) #shape(num_batches, classifier_input_units)
        return x


## Conditional EEGNet

In [None]:
class ConditionalEEGNet(nn.Module):
    """
    A model based on EEGNet but with the ability to condition it's prediction for any given sample on samples from a calibration sequence

    The architecture changes as follows:
    - the input is now a tuple (x_calib, x_sample) where the shapes are
        - x_calib: (num_batches, num_calib_samples, 1, num_channels, num_samples)
        - x_sample: (num_batches, 1, num_channels, num_samples)
        Note: for technical reasons the input to the forward method is a single tensor with shape (num_batches, num_sequences, 1, num_channels, num_samples)
        where num_sequences = num_calib_samples + 1
    - both x_calib and x_sample are passed through the EEGNet backbone
    - the outputs for x_calib are aggregated (various aggregation methods are possible)
    - the output for x_sample is then aggregated with the x_calib_aggregated (again various aggregation methods are possible)
    - the aggregated output is passed through a classifier head
    """

    def __init__(self,
                 num_classes: int,
                 calibration_aggregation_method: str, #aggregate outputs of EEGNet backbone for calibration sequence
                 pre_classifier_aggregation_method: str, #aggregate outputs of EEGNet backbone for sample and aggregated calibration sequence
                 classifier_hidden_units: Optional[int] = None,
                 classifier_num_layers: int = 1,
                 **backbone_kwargs):
        """
        Parameters
        ----------
        num_classes : int
            Number of classes to predict
        calibration_aggregation_method : str
            The method used to aggregate the outputs of the EEGNet backbone for the calibration sequence
        pre_classifier_aggregation_method : str
            The method used to aggregate the outputs of the EEGNet backbone for the sample and the aggregated calibration sequence
        """

        super().__init__()

        self.num_classes = num_classes
        self.calibration_aggregation_method = calibration_aggregation_method
        self.pre_classifier_aggregation_method = pre_classifier_aggregation_method

        self.classifier_hidden_units = classifier_hidden_units
        self.classifier_num_layers = classifier_num_layers
        self.classifier_num_hidden_layers = classifier_num_layers - 1
        self.num_classes = num_classes
        # ensure that if multiple hidden layers are specified, number of hidden units is also provided
        assert self.classifier_num_hidden_layers == 0 or self.classifier_hidden_units is not None, "If classifier_num_layers > 1, classifier_hidden_units must be specified"

        self.conv_net = EEGNetBackbone(**backbone_kwargs)
        # output shape is (batch, f2, samples / 32, 1)
        conv_net_output_shape = self.conv_net.output_shape

        # two times the backbone output shape
        # because we concatenate the aggregated calibration sequence with the sample's output
        # classifier input shape might differ depending on pre_classifier_aggregation_method
        if pre_classifier_aggregation_method == "concat":
            classifier_input_shape = (2, *conv_net_output_shape)
        elif pre_classifier_aggregation_method in ["mean", "max", "min"]:
            classifier_input_shape = (1, *conv_net_output_shape)
        elif "difference" in pre_classifier_aggregation_method:
            classifier_input_shape = (1, *conv_net_output_shape)
        else:
            raise ValueError(f"Unknown pre_classifier_aggregation_method: {pre_classifier_aggregation_method}")

        classifier_input_units = np.array(classifier_input_shape).prod() 
        #compute product of elements in classifier_input_shape to obtain total number of units in input tensor for classifier head
        input_units = [classifier_input_units] + [classifier_hidden_units] * self.classifier_num_hidden_layers
        output_units = [classifier_hidden_units] * self.classifier_num_hidden_layers + [num_classes]

        classifier_layers = [nn.Flatten(start_dim=-4)]
        for in_features, out_features in zip(input_units, output_units): #number of input and output units, for each layer of classifier
            classifier_layers.append(nn.Linear(in_features=in_features, out_features=out_features))
            classifier_layers.append(nn.ELU()) #Exponent Linear Unit applies element-wise non-linearity to the output of previous linear layer

        # remove the last activation layer
        # Ensures that final layer doesn't have additional activation function, as model's forward pass is expected to reutrn logits before applying the final activation (softmax) for classification 
        classifier_layers = classifier_layers[:-1]

        self.classifier_head = nn.Sequential(*classifier_layers)
        # output shape of classifier_head is (batch, num_classes)

    def forward(self, x: TensorType["num_batches", "num_sequences", 1, "num_channels", "num_samples"]) -> TensorType["num_batches", "num_classes"]:
        # first pass all samples through the conv_net backbone
        # the conv_net backbone operates on 4D (batched) tensors so we need to flatten the first two dimensions and then reshape the output back to the original shape
        num_batches, num_sequences = x.shape[:2] #extract the sizes of the batch and sequence dimensions from the shape of x
        #x.view to reshape tensor
        x = x.view(num_batches * num_sequences, *x.shape[2:])
        x = self.conv_net(x)
        x = x.view(num_batches, num_sequences, *x.shape[1:])

        # split the input tensor into calibration and sample tensors
        x_calib, x_sample = x[:, :-1], x[:, -1:]

        # aggregate the outputs of the conv_net backbone for the calibration sequence
        # Condition prediction for given sample on samples from calibration sequence (a reference or context for making predictions)
        # Calibration sequence: set of samples or data points used to calibrate model or estimator to provide refernce of making prediction
        # Model aggregates outputs of EEGNet backbone for calibration sequence and combines with output for sample to make final prediction (prior knowledge for informed prediction)
        if self.calibration_aggregation_method == "mean":
            x_calib = x_calib.mean(dim=1).unsqueeze(1)
        elif self.calibration_aggregation_method == "max":
            x_calib = x_calib.max(dim=1)[0].unsqueeze(1)
        elif self.calibration_aggregation_method == "min":
            x_calib = x_calib.min(dim=1)[0].unsqueeze(1)
        elif self.calibration_aggregation_method == "none":
            # x_calib stays the same
            # but for this case we need to expand the dimension of x_sample instead so that the concatenation below works
            x_sample = x_sample.expand(x_calib.shape).unsqueeze(1)
            x_calib = x_calib.unsqueeze(1)
        else:
            raise ValueError(f"Unknown calibration_aggregation_method: {self.calibration_aggregation_method}")

        # Aggregate the outputs of the conv_net backbone for the sample and the aggregated calibration sequence
        x_pre_classifier = torch.cat([x_calib, x_sample], dim=1)
        assert x_pre_classifier.shape[
            1] == 2, f"Expected x_pre_classifier to have shape (batch, 2, ...), but got {x_pre_classifier.shape}"

        # Since we operate over axis=1 where we will always have 2 elements, we can also use other aggregation methods here
        # Perform aggregation on x_pre_classifier
        if self.pre_classifier_aggregation_method == "mean":
            x_pre_classifier = x_pre_classifier.mean(dim=1).unsqueeze(1)
        elif self.pre_classifier_aggregation_method == "max":
            x_pre_classifier = x_pre_classifier.max(dim=1)[0].unsqueeze(1)
        elif self.pre_classifier_aggregation_method == "min":
            x_pre_classifier = x_pre_classifier.min(dim=1)[0].unsqueeze(1)
        elif self.pre_classifier_aggregation_method == "difference":
            x_pre_classifier = (x_pre_classifier[:, 0] - x_pre_classifier[:, 1]).unsqueeze(1)
        elif self.pre_classifier_aggregation_method == "abs_difference":
            x_pre_classifier = torch.abs(x_pre_classifier[:, 0] - x_pre_classifier[:, 1]).unsqueeze(1)
        elif self.pre_classifier_aggregation_method == "square_difference":
            x_pre_classifier = torch.square(x_pre_classifier[:, 0] - x_pre_classifier[:, 1]).unsqueeze(1)
        elif self.pre_classifier_aggregation_method == "concat":
            x_pre_classifier = x_pre_classifier
        # even various distance metrics are possible here
        elif self.pre_classifier_aggregation_method == "cosine_similarity":
            x_pre_classifier = F.cosine_similarity(x_pre_classifier[:, 0], x_pre_classifier[:, 1])
        elif self.pre_classifier_aggregation_method == "euclidean_distance":
            x_pre_classifier = F.pairwise_distance(x_pre_classifier[:, 0], x_pre_classifier[:, 1])
        else:
            raise ValueError(f"Unknown pre_classifier_aggregation_method: {self.pre_classifier_aggregation_method}")

        # Reshape x_pre_classifier if calibration method is none
        # TODO: This might be 6 dimensional in the case of "none" for the calibration_aggregation_method
        assert len(x_pre_classifier.shape) >= 5, \
            f"Expected x_pre_classifier to be at least 5-dimensional, but got {x_pre_classifier.shape}"

        if len(x_pre_classifier.shape) == 6:
            # switch axis 1 and 2, since the last 4 dimensions are operated on by the classifier head
            x_pre_classifier = x_pre_classifier.permute(0, 2, 1, 3, 4, 5)
        # pass the aggregated output through the classifier head
        x = self.classifier_head(x_pre_classifier)

        if len(x.shape) == 3:
            # average over the second dimension
            x = x.mean(dim=1)

        return x

## EEGNetLightning

In [None]:
class ConditionalEEGNetLightning(EEGNetLightning):
    def __init__(self, **hparams):
        super().__init__(**hparams)
        self.eegnet = ConditionalEEGNet(**hparams)

    def forward(self, x: TensorType["num_batches", "sequence_length", "num_samples", "num_channels"]) -> TensorType["num_batches", "num_classes"]:
        # x is of shape (..., samples, channels) but we need (..., 1, channels, samples)
        x = torch.swapaxes(x, -2, -1)
        x = x.unsqueeze(-3)
        y_hat = self.eegnet(x)
        return y_hat

## CNN Lightning

In [None]:
class CNNLightning(EEGNetLightning):
    def __init__(self, **hparams):
        super().__init__(**hparams)
        self.eegnet = CNN(**hparams)

    def forward(self, x: TensorType["num_batches", 1, "num_channels", "num_samples"]) -> TensorType["num_batches", "num_classes"]:
        # x is of shape (..., samples, channels) but we need (..., 1, channels, samples)
        #x = torch.swapaxes(x, -2, -1)
        #x = x.unsqueeze(-3)
        x = x.squeeze(axis=1)
        x = torch.swapaxes(x, -2, -1)
        y_hat = self.eegnet(x)
        return y_hat

    def configure_optimizers(self):
        assert hasattr(torch.optim, self.hparams.optimizer_class),\
            f"{self.hparams.optimizer_class} is not a valid optimizer from torch.optim"
        optimizer_class = getattr(torch.optim, self.hparams.optimizer_class)
        optimizer = optimizer_class(
            [
                {  # conv layer parameters
                    "params": filter(lambda p: p.requires_grad, self.eegnet.conv_net.parameters()),
                    "weight_decay": getattr(self.hparams, "conv_weight_decay", 0.0),
                },
                {  # fc layer parameters
                    "params": filter(lambda p: p.requires_grad, self.eegnet.classifier_head.parameters()),
                    "weight_decay": getattr(self.hparams, "fc_weight_decay", 0.0),
                },
            ],
            lr=self.hparams.learning_rate)

        # lr = torch.optim.lr_scheduler.CyclicLR(
        #     optimizer, base_lr = self.hparams.learning_rate,
        #     max_lr = 4*self.hparams.learning_rate,
        #     step_size_up = 4*int(self.stepsize),
        #     mode = "triangular",
        #     cycle_momentum = False
        #     )

        lr = torch.optim.lr_scheduler.OneCycleLR(
            optimizer = optimizer,
            max_lr = self.hparams.learning_rate,
            epochs = self.trainer.max_epochs,
            steps_per_epoch = self.trainer.estimated_stepping_batches // self.trainer.max_epochs,
            cycle_momentum = True
            )

        scheduler = {
            "scheduler": lr,
            "interval": "step",
            "name": "Learning Rate Scheduling"
        }
        return [optimizer], [scheduler]
# pytorch lightning module of EEGNet

## Training

In [None]:
trainer = pl.Trainer(max_epochs = 5, accelerator="auto") #progress_bar_refresh_rate=20, update every 20 batch to reduce colab crasjh  #gpus=1
trainer.fit(model) #1000 epochs default