# Create DataLoader

## K-fold

In [None]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold


class KFoldDataModule(LightningDataModule):
    def __init__(
            self,
            data_dir: str = "./dataset/kuba",
            k: int = 1,  # fold number
            split_seed: int = 12345,  # split needs to be always the same for correct cross validation
            num_splits: int = 10,
            batch_size: int = 32,
            num_workers: int = 0,
            pin_memory: bool = False
        ):
        super().__init__()
        
        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

        # num_splits = 10 means our dataset will be split to 10 parts
        # so we train on 90% of the data and validate on 10%
        assert 1 <= self.k <= self.num_splits, "incorrect fold number"
        
        # data transformations
        pp = EEGDataProcessor()
        pp.DOWNSAMPLED_FREQ = 512
        self.transforms = Compose([
            pp.correct_offset,
            pp.filter,
            pp.downsample,
            pp.normalize,
        ])

        self.data_train: Dataset = None
        self.data_val: Dataset = None

    @property
    def num_classes() -> int:
        return 3

    def setup(self, stage=None):
        if not self.data_train and not self.data_val:

            dataset_full = EEGDataset(self.hparams.data_dir, self.transforms)

            # choose fold to train on
            kf = KFold(n_splits=self.hparams.num_splits, shuffle=True, random_state=self.hparams.split_seed)
            all_splits = [k for k in kf.split(dataset_full)]
            train_indexes, val_indexes = all_splits[self.hparams.k]
            train_indexes, val_indexes = train_indexes.tolist(), val_indexes.tolist()

            self.data_train, self.data_val = dataset_full[train_indexes], dataset_full[val_indexes]

    def train_dataloader(self):
        return DataLoader(dataset=self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory, shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers,
                          pin_memory=self.hparams.pin_memory)

In [None]:
results = []
nums_folds = 5

for k in range(nums_folds):
    datamodule = KFoldDataModule(k=k, num_folds=nums_folds)
    datamodule.prepare_data()
    datamodule.setup()

    results.append(1)

score = sum(results) / num_folds
score

## Normal data loader

In [5]:
from dataset import EEGDataset
from utils.preprocessing import EEGDataProcessor
from torchvision.transforms import Compose
from torch.utils.data import random_split
from torch.utils.data import DataLoader

pp = EEGDataProcessor()

pp.DOWNSAMPLED_FREQ = 512

transforms = Compose([
  pp.correct_offset,
  pp.amplitude_conversion,
  pp.filter,
  pp.downsample,
  pp.normalize
])

dataset = EEGDataset("../dataset/kapi_splited", transforms, n_classes=4)

l = len(dataset)

train_set, test_set, validation_set = random_split(dataset, [0.7, 0.2, 0.1])

train_loader = DataLoader(train_set, batch_size=32, num_workers=4, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, num_workers=4)
val_loader = DataLoader(validation_set, batch_size=32, num_workers=4)

In [6]:
len(dataset)

484

In [7]:
dataset.__getitem__(3)[0].shape

(16, 2048)

In [18]:
dataset.__getitem__(1)[1]

3

In [8]:
dataset.n_classes

4

# Model

In [9]:
import torchmetrics
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy, one_hot
import pytorch_lightning as pl

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg

## Inception

In [10]:
class CustomPad(nn.Module):
    def __init__(self, padding):
        super().__init__()
        self.padding = padding

    def forward(self, x):
        return F.pad(x, self.padding)

class Inception(pl.LightningModule):
    def __init__(self, input_time=1000, fs=128, ncha=8, filters_per_branch=8,
                 scales_time=(500, 250, 125), dropout_rate=0.25,
                 activation=nn.ELU(inplace=True), n_classes=2):
        super(Inception, self).__init__()
        
        self.fs = fs

        input_samples = int(input_time * fs / 1000)
        scales_samples = [int(s * fs / 1000) for s in scales_time]

        # ========================== BLOCK 1: INCEPTION ========================== #
        self.inception1 = nn.ModuleList([
            nn.Sequential(
                CustomPad((0, 0, scales_sample // 2 - 1, scales_sample // 2, )),
                nn.Conv2d(1, filters_per_branch, (scales_sample, 1)),
                nn.BatchNorm2d(filters_per_branch),
                activation,
                nn.Dropout(dropout_rate),
                nn.Conv2d(filters_per_branch, filters_per_branch*2,
                          (1, ncha), bias=False, groups=filters_per_branch),  # DepthwiseConv2D
                nn.BatchNorm2d(filters_per_branch * 2),
                activation,
                nn.Dropout(dropout_rate),
            ) for scales_sample in scales_samples
        ])

        self.avg_pool1 = nn.AvgPool2d((4, 1))

        # ========================== BLOCK 2: INCEPTION ========================== #
        self.inception2 = nn.ModuleList([
            nn.Sequential(
                CustomPad((0, 0, scales_sample // 8 -
                           1, scales_sample // 8, )),
                nn.Conv2d(
                    len(scales_samples) * 2 * filters_per_branch,
                    filters_per_branch, (scales_sample // 4, 1),
                    bias=False
                ),
                nn.BatchNorm2d(filters_per_branch),
                activation,
                nn.Dropout(dropout_rate),
            ) for scales_sample in scales_samples
        ])

        self.avg_pool2 = nn.AvgPool2d((2, 1))

        # ============================ BLOCK 3: OUTPUT =========================== #
        self.output = nn.Sequential(

            CustomPad((0, 0, 4, 3)),
            nn.Conv2d(
                24, filters_per_branch * len(scales_samples) // 2, (8, 1),
                bias=False

            ),
            nn.BatchNorm2d(filters_per_branch * len(scales_samples) // 2),
            activation,
            nn.AvgPool2d((2, 1)),
            nn.Dropout(dropout_rate),

            CustomPad((0, 0, 2, 1)),
            nn.Conv2d(
                12, filters_per_branch * len(scales_samples) // 4, (4, 1),
                bias=False

            ),
            nn.BatchNorm2d(filters_per_branch * len(scales_samples) // 4),
            activation,
            # nn.Dropout(dropout_rate),
            nn.AvgPool2d((2, 1)),
            nn.Dropout(dropout_rate),
        )

        self.dense = nn.Sequential(
            # nn.Linear(4 * 1 * 6, n_classes), # to zmieniłem bo sie rozjeżdżało
            nn.Linear(int(self.fs*0.75), n_classes),
            nn.Softmax(1)
        )

        self.accuracy_train = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_test = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_val = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)

        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)
        self.confmat = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)
        self.roc = torchmetrics.ROC(task='multilabel', num_labels=3)

    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        x = x.permute((0,1,3,2))
        x = torch.cat([net(x) for net in self.inception1], 1)  # concat
        x = self.avg_pool1(x)
        x = torch.cat([net(x) for net in self.inception2], 1)
        x = self.avg_pool2(x)
        x = self.output(x)
        x = torch.flatten(x, 1)
        print(x.shape)
        x = self.dense(x)
        return x

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer
  
    def training_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)

        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("train_loss", loss, on_step=True)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)

        self.confmat.update(output, label_n)

        return loss

    def test_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("test_loss", loss)
                
        # calculate accuracy
        self.accuracy_test.update(output, label_n)
        self.log('test_acc', self.accuracy_test)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)
    
    def validation_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("validation_loss", loss)
                
        # calculate accuracy
        self.accuracy_val.update(output, label_n)
        self.log('validation_acc', self.accuracy_val)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1, on_epoch=True)

    def on_train_epoch_end(self):
        cm = self.confmat.compute().detach().cpu().numpy()

        import seaborn as sn
        import pandas as pd
        import matplotlib
        matplotlib.use('agg')
        import matplotlib.pyplot as plt

        fig, ax1 = plt.subplots(1)
        df_cm = pd.DataFrame(cm, index = [i for i in "012"],
                          columns = [i for i in "012"])
        sn.heatmap(df_cm, annot=True, ax=ax1)

        # add the confusion matrix to TensorBoard
        self.logger.experiment.add_figure("Confusion Matrix", fig, self.current_epoch)


In [11]:
from torchsummary import summary

inception = Inception(
    input_time = 4000,
    fs = pp.DOWNSAMPLED_FREQ,
    ncha = 16,
    n_classes=dataset.n_classes,
)
# summary(inception.cuda(), (16, 2048))

In [12]:
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="eeg_inception")

trainer = pl.Trainer(max_epochs=100, logger=logger, log_every_n_steps=10)
trainer.fit(inception, train_loader, val_loader)
trainer.test(inception, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type                      | Params
--------------------------------------------------------------
0  | inception1     | ModuleList                | 4.5 K 
1  | avg_pool1      | AvgPool2d                 | 0     
2  | inception2     | ModuleList                | 43.1 K
3  | avg_pool2      | AvgPool2d                 | 0     
4  | output         | Sequential                | 2.6 K 
5  | dense          | Sequential                | 1.5 K 
6  | accuracy_train | MulticlassPrecision       | 0     
7  | accuracy_test  | MulticlassPrecision       | 0     
8  | accuracy_val   | MulticlassPrecision       | 0     
9  | f1             | MulticlassF1Score         | 0     
10 | confmat        | MulticlassConfusionMatrix | 0     
11 | roc            | MultilabelROC             | 0     
-

Sanity Checking: 0it [00:00, ?it/s]

torch.Size([32, 384])


../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [6,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [8,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [11,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [17,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [19,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## EEGNet

In [48]:
import torch
from torch import nn
from torch import optim
from torch.nn.functional import binary_cross_entropy, one_hot
from torch.optim import Adam
import pytorch_lightning as pl
import torchmetrics

class EEGNet(pl.LightningModule):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.T = 512
        
        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, 16), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        
        self.elu = nn.ELU()
        self.dropout = nn.Dropout(p=0.25)
        
        self.softmax = nn.Softmax(-1)
        
        # FC Layer
        # NOTE: This dimension will depend on the number of timestamps per sample in your data.
        # I have 120 timepoints. 
        self.fc1 = nn.Linear(1024, 3)
        
        self.accuracy_train = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_test = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_val = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)

        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)
        self.confmat = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)
        self.roc = torchmetrics.ROC(task='multilabel', num_labels=3)

    def forward(self, x):
        # Layer 1
        x = x.permute(0, 2, 1)
        x = torch.unsqueeze(x, 1)
        x = self.elu(self.conv1(x))
        x = self.batchnorm1(x)
        x = self.dropout(x)
        x = x.permute(0, 3, 1, 2)
        
        # Layer 2
        x = self.padding1(x)
        x = self.elu(self.conv2(x))
        x = self.batchnorm2(x)
        x = self.dropout(x)
        x = self.pooling2(x)
        
        # Layer 3
        x = self.padding2(x)
        x = self.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = self.dropout(x)
        x = self.pooling3(x)
        
        # FC Layer
        x = x.flatten(start_dim=1)
        x = self.fc1(x)
        x = self.softmax(x)
        
        return x

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)

        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("train_loss", loss, on_step=True)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)

        self.confmat.update(output, label_n)

        return loss

    def test_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("test_loss", loss)
                
        # calculate accuracy
        self.accuracy_test.update(output, label_n)
        self.log('test_acc', self.accuracy_test)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)
    
    def validation_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("validation_loss", loss)
                
        # calculate accuracy
        self.accuracy_val.update(output, label_n)
        self.log('validation_acc', self.accuracy_val)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1, on_epoch=True)

    def on_train_epoch_end(self):
        cm = self.confmat.compute().detach().cpu().numpy()

        import seaborn as sn
        import pandas as pd
        import matplotlib
        matplotlib.use('agg')
        import matplotlib.pyplot as plt

        fig, ax1 = plt.subplots(1)
        df_cm = pd.DataFrame(cm, index = [i for i in "012"],
                          columns = [i for i in "012"])
        sn.heatmap(df_cm, annot=True, ax=ax1)

        # add the confusion matrix to TensorBoard
        self.logger.experiment.add_figure("Confusion Matrix", fig, self.current_epoch)

In [49]:
eegnet = EEGNet()

In [50]:
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="eeg_net")

trainer = pl.Trainer(max_epochs=100, logger=logger, log_every_n_steps=10)
trainer.fit(eegnet, train_loader, val_loader)
trainer.test(eegnet, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type                      | Params
--------------------------------------------------------------
0  | conv1          | Conv2d                    | 272   
1  | batchnorm1     | BatchNorm2d               | 32    
2  | padding1       | ZeroPad2d                 | 0     
3  | conv2          | Conv2d                    | 260   
4  | batchnorm2     | BatchNorm2d               | 8     
5  | pooling2       | MaxPool2d                 | 0     
6  | padding2       | ZeroPad2d                 | 0     
7  | conv3          | Conv2d                    | 516   
8  | batchnorm3     | BatchNorm2d               | 8     
9  | pooling3       | MaxPool2d                 | 0     
10 | elu            | ELU                       | 0     
11 | dropout        | Dropout                   | 0     
1

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.3914906680583954,
  'test_acc': 0.8610930442810059,
  'train_f1': 0.8578431606292725}]

## Transformer

In [6]:
import os
import numpy as np
import math
import random
import time
import scipy.io

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary

import torch
import torch.nn.functional as F

from torch import nn
from torch import Tensor

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

import matplotlib.pyplot as plt
from torch.backends import cudnn


class PatchEmbedding(pl.LightningModule):
    def __init__(self, emb_size):
        # self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            nn.Conv2d(1, 2, (1, 51), (1, 1)),
            nn.BatchNorm2d(2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5)),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        # self.positions = nn.Parameter(torch.randn((100 + 1, emb_size)))
        # self.positions = nn.Parameter(torch.randn((2200 + 1, emb_size)))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)

        # position
        # x += self.positions
        return x


class MultiHeadAttention(pl.LightningModule):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(pl.LightningModule):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

class GELU(pl.LightningModule):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=5,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

    def forward(self, x):
        out = self.clshead(x)
        return out
        # return x, out



class ViT(nn.Sequential):
    def __init__(self, emb_size=10, depth=3, n_classes=3, **kwargs):
        super().__init__(
            ResidualAdd(
                nn.Sequential(
                    nn.LayerNorm(2048),
                    channel_attention(),
                    nn.Dropout(0.5),
                )
            ),

            PatchEmbedding(emb_size),
            TransformerEncoder(depth, emb_size),
            ClassificationHead(emb_size, n_classes)
        )


class channel_attention(pl.LightningModule):
    def __init__(self, sequence_num=512, inter=30):
        super(channel_attention, self).__init__()
        self.sequence_num = sequence_num
        self.inter = inter
        self.extract_sequence = int(self.sequence_num / self.inter)  # You could choose to do that for less computation

        self.query = nn.Sequential(
            nn.Linear(16, 16),
            nn.LayerNorm(16),  # also may introduce improvement to a certain extent
            nn.Dropout(0.3)
        )
        self.key = nn.Sequential(
            nn.Linear(16, 16),
            nn.LayerNorm(16),
            nn.Dropout(0.3)
        )

        self.projection = nn.Sequential(
            nn.Linear(16, 16),
            nn.LayerNorm(16),
            nn.Dropout(0.3),
        )

        self.drop_out = nn.Dropout(0)
        self.pooling = nn.AvgPool2d(kernel_size=(1, self.inter), stride=(1, self.inter))

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        temp = rearrange(x, 'b o c s->b o s c')
        temp_query = rearrange(self.query(temp), 'b o s c -> b o c s')
        temp_key = rearrange(self.key(temp), 'b o s c -> b o c s')

        channel_query = self.pooling(temp_query)
        channel_key = self.pooling(temp_key)

        scaling = self.extract_sequence ** (1 / 2)

        channel_atten = torch.einsum('b o c s, b o m s -> b o c m', channel_query, channel_key) / scaling

        channel_atten_score = F.softmax(channel_atten, dim=-1)
        channel_atten_score = self.drop_out(channel_atten_score)

        out = torch.einsum('b o c s, b o c m -> b o c s', x, channel_atten_score)
        '''
        projections after or before multiplying with attention score are almost the same.
        '''
        out = rearrange(out, 'b o c s -> b o s c')
        out = self.projection(out)
        out = rearrange(out, 'b o s c -> b o c s')
        return out


class Trans(pl.LightningModule):
    def __init__(self):
        super(Trans, self).__init__()
        self.model = ViT()

        self.accuracy_train = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_test = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_val = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)

        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)
        self.confmat = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)
        self.roc = torchmetrics.ROC(task='multilabel', num_labels=3)
        
    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        x = self.model(x)
        out = F.softmax(x, dim=-1)
        return out

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer
  
    def training_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)

        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("train_loss", loss, on_step=True)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)

        self.confmat.update(output, label_n)

        return loss

    def test_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("test_loss", loss)
                
        # calculate accuracy
        self.accuracy_test.update(output, label_n)
        self.log('test_acc', self.accuracy_test)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)
    
    def validation_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("validation_loss", loss)
                
        # calculate accuracy
        self.accuracy_val.update(output, label_n)
        self.log('validation_acc', self.accuracy_val)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1, on_epoch=True)

    def on_train_epoch_end(self):
        cm = self.confmat.compute().detach().cpu().numpy()

        import seaborn as sn
        import pandas as pd
        import matplotlib
        matplotlib.use('agg')
        import matplotlib.pyplot as plt

        fig, ax1 = plt.subplots(1)
        df_cm = pd.DataFrame(cm, index = [i for i in "012"],
                          columns = [i for i in "012"])
        sn.heatmap(df_cm, annot=True, ax=ax1)

        # add the confusion matrix to TensorBoard
        self.logger.experiment.add_figure("Confusion Matrix", fig, self.current_epoch)

NameError: name 'pl' is not defined

In [97]:
vit = Trans()

In [98]:
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="eeg_vit")

trainer = pl.Trainer(max_epochs=100, logger=logger, log_every_n_steps=10)
trainer.fit(vit, train_loader, val_loader)
trainer.test(vit, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                      | Params
-------------------------------------------------------------
0 | model          | ViT                       | 10.8 K
1 | accuracy_train | MulticlassPrecision       | 0     
2 | accuracy_test  | MulticlassPrecision       | 0     
3 | accuracy_val   | MulticlassPrecision       | 0     
4 | f1             | MulticlassF1Score         | 0     
5 | confmat        | MulticlassConfusionMatrix | 0     
6 | roc            | MultilabelROC             | 0     
-------------------------------------------------------------
10.8 K    Trainable params
0         Non-trainable params
10.8 K    Total params
0.043     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.6379259824752808,
  'test_acc': 0.11986863613128662,
  'train_f1': 0.3499999940395355}]

## DeepConvNet

In [77]:
class DeepConvNet(pl.LightningModule):
    def __init__(self, n_output):
        super(DeepConvNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 25, kernel_size=(1,10)),
            nn.Conv2d(25, 25, kernel_size=(16,25)),
            nn.BatchNorm2d(25),
            nn.ELU(alpha=0.4),
            nn.MaxPool2d(kernel_size=(1,3), stride=(1,3)),
            nn.Dropout(p=0.4)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(25, 50, kernel_size=(1, 10), groups=25),
            nn.BatchNorm2d(50),
            nn.ELU(alpha=0.4),
            nn.MaxPool2d(kernel_size=(1,3), stride=(1,3)),
            nn.Dropout(p=0.4)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(50, 100, kernel_size=(1, 10), groups=50),
            nn.BatchNorm2d(100),
            nn.ELU(alpha=0.4),
            nn.MaxPool2d(kernel_size=(1,3), stride=(1,3)),
            nn.Dropout(p=0.4)
        )
        
        self.conv4 = nn.Sequential(
            nn.Conv2d(100, 200, kernel_size=(1, 10), groups=100),
            nn.BatchNorm2d(200),
            nn.ELU(alpha=0.4),
            nn.MaxPool2d(kernel_size=(1,3), stride=(1,3)),
            nn.Dropout(p=0.4)
        )
        
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(4000, n_output, bias=True)
        )

        self.accuracy_train = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_test = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)
        self.accuracy_val = torchmetrics.Precision(task="multiclass", average='macro', num_classes=3)

        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)
        self.confmat = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)
        self.roc = torchmetrics.ROC(task='multilabel', num_labels=3)

    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.fc(x)
        x = F.softmax(x, dim=-1)
        return x

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer
  
    def training_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)

        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("train_loss", loss, on_step=True)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)

        return loss

    def test_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("test_loss", loss)
                
        # calculate accuracy
        self.accuracy_test.update(output, label_n)
        self.log('test_acc', self.accuracy_test)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1)
        
        self.confmat.update(output, label_n)
        
        cm = self.confmat.compute().detach().cpu().numpy()

        import seaborn as sn
        import pandas as pd
        import matplotlib
        matplotlib.use('agg')
        import matplotlib.pyplot as plt

        fig, ax1 = plt.subplots(1)
        df_cm = pd.DataFrame(cm, index = [i for i in "012"],
                          columns = [i for i in "012"])
        sn.heatmap(df_cm, annot=True, ax=ax1)

        # add the confusion matrix to TensorBoard
        self.logger.experiment.add_figure("Confusion Matrix", fig)
    
    def validation_step(self, batch, batch_idx):
        data, label_n = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        label = one_hot(label_n, num_classes=3)
    
        # calculate loss
        loss = binary_cross_entropy(output, label.to(torch.float32))
        self.log("validation_loss", loss)
                
        # calculate accuracy
        self.accuracy_val.update(output, label_n)
        self.log('validation_acc', self.accuracy_val)

        # calculate f1 score
        self.f1.update(output, label_n)
        self.log('train_f1', self.f1, on_epoch=True)

In [78]:
dcn = DeepConvNet(n_output=3)

In [79]:
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="eeg_dcn")

trainer = pl.Trainer(max_epochs=100,logger=logger, log_every_n_steps=10)
trainer.fit(dcn, train_loader, val_loader)
trainer.test(dcn, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type                      | Params
--------------------------------------------------------------
0  | conv1          | Sequential                | 250 K 
1  | conv2          | Sequential                | 650   
2  | conv3          | Sequential                | 1.3 K 
3  | conv4          | Sequential                | 2.6 K 
4  | fc             | Sequential                | 12.0 K
5  | accuracy_train | MulticlassPrecision       | 0     
6  | accuracy_test  | MulticlassPrecision       | 0     
7  | accuracy_val   | MulticlassPrecision       | 0     
8  | f1             | MulticlassF1Score         | 0     
9  | confmat        | MulticlassConfusionMatrix | 0     
10 | roc            | MultilabelROC             | 0     
----------------------------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.23422902822494507,
  'test_acc': 0.8558967113494873,
  'train_f1': 0.8529411554336548}]