In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda:0")

In [7]:
import numpy as np
from torchvision.transforms import Compose
from Utilities.transforms import minmax
from Utilities.dataloader import EEGDataLoader

transforms = Compose([
    minmax,
])

dataset = EEGDataLoader("./dataset", transforms)
train_loader = DataLoader(dataset, batch_size=4)

In [10]:
import torchmetrics
from torch.optim import AdamW
from torch.nn.functional import nll_loss
import lightning.pytorch as pl

class SignalCNN(pl.LightningModule):
    channels = 16
    
    def __init__(self):
        super(SignalCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv1d(
                in_channels = self.channels, 
                out_channels = 32, 
                kernel_size = 11,
                stride = 3,
                dilation = 3
            ),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.BatchNorm1d(32),
            nn.Dropout1d(p=0.6),
            
            nn.Conv1d(
                in_channels = 32, 
                out_channels = 64, 
                kernel_size = 5,
                stride = 2,
                dilation = 1
            ),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.BatchNorm1d(64),
            nn.Dropout1d(p=0.6),
            
            nn.Conv1d(
                in_channels = 64, 
                out_channels = 128, 
                kernel_size = 3,
                stride = 2,
                dilation = 1
            ),
            nn.ReLU(),
            nn.AvgPool1d(3),
            nn.BatchNorm1d(128),
            nn.Dropout1d(p=0.6),
            
            nn.Conv1d(
                in_channels = 128, 
                out_channels = 128, 
                kernel_size = 1,
                stride = 2,
                dilation = 1
            ),
            nn.ReLU(),
            nn.AvgPool1d(3),
            nn.BatchNorm1d(128),
            nn.Dropout1d(p=0.6),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            
            nn.Linear(1152, 512),
            nn.BatchNorm1d(512),
            nn.Dropout1d(p=0.6),
            
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.Dropout1d(p=0.6),
            
            nn.Linear(128, 3),
        )
        
        self.val_metrics = torchmetrics.MetricCollection(
            [
                torchmetrics.Accuracy("multiclass", num_classes=3),
                torchmetrics.Precision(task='multiclass', num_classes=3, average='macro'),
                torchmetrics.Recall(task='multiclass', num_classes=3, average='macro'),
                torchmetrics.F1Score(task='multiclass', num_classes=3, average='macro'),
            ]
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=7e-5, weight_decay=8e-2)
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        output = self(data)
        loss = nll_loss(output, label)
        self.log("Training loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.accuracy(output, label)
        self.log("Training accuracy", self.accuracy, on_step=True, on_epoch=False, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        data, label = batch
        output = self(data)
        loss = nll_loss(output, label)
        self.log("Val loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_metrics.update(output, label.type(torch.int))
        return output, label

In [11]:
model = SignalCNN()
_ = model.to(device)

In [12]:
from torchsummary import summary

summary(model, (16, 8192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 32, 2721]           5,664
              ReLU-2             [-1, 32, 2721]               0
         AvgPool1d-3             [-1, 32, 1360]               0
       BatchNorm1d-4             [-1, 32, 1360]              64
         Dropout1d-5             [-1, 32, 1360]               0
            Conv1d-6              [-1, 64, 678]          10,304
              ReLU-7              [-1, 64, 678]               0
         AvgPool1d-8              [-1, 64, 339]               0
       BatchNorm1d-9              [-1, 64, 339]             128
        Dropout1d-10              [-1, 64, 339]               0
           Conv1d-11             [-1, 128, 169]          24,704
             ReLU-12             [-1, 128, 169]               0
        AvgPool1d-13              [-1, 128, 56]               0
      BatchNorm1d-14              [-1, 

# Train

In [13]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=model, train_dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | features    | Sequential       | 57.9 K
1 | classifier  | Sequential       | 657 K 
2 | val_metrics | MetricCollection | 0     
-------------------------------------------------
715 K     Trainable params
0         Non-trainable params
715 K     Total params
2.862     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|                                                                                                                         | 0/60 [00:00<?, ?it/s]

../aten/src/ATen/native/cuda/Loss.cu:240: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
