In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
import numpy as np
from torchvision.transforms import Compose
from Utilities.transforms import minmax
from Utilities.dataset import EEGDataset
from torch.utils.data import random_split

transforms = Compose([
    minmax,
])

dataset = EEGDataset("./dataset", transforms)

train_set, test_set, validation_set = random_split(dataset, [0.7, 0.2, 0.1])

train_loader = DataLoader(train_set, batch_size=32, num_workers=4, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, num_workers=4)
val_loader = DataLoader(validation_set, batch_size=32, num_workers=4)

In [3]:
import torchmetrics
from torch.optim import Adam
from torch.nn.functional import binary_cross_entropy, one_hot
import lightning.pytorch as pl

class SignalCNN(pl.LightningModule):
    channels = 16
    
    def __init__(self):
        super(SignalCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv1d(
                in_channels = self.channels, 
                out_channels = 32, 
                kernel_size = 11,
                stride = 3,
                dilation = 3
            ),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.BatchNorm1d(32),
            nn.Dropout1d(p=0.6),
            
            nn.Conv1d(
                in_channels = 32, 
                out_channels = 64, 
                kernel_size = 5,
                stride = 2,
                dilation = 1
            ),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.BatchNorm1d(64),
            nn.Dropout1d(p=0.6),
            
            nn.Conv1d(
                in_channels = 64, 
                out_channels = 128, 
                kernel_size = 3,
                stride = 2,
                dilation = 1
            ),
            nn.ReLU(),
            nn.AvgPool1d(3),
            nn.BatchNorm1d(128),
            nn.Dropout1d(p=0.6),
            
            nn.Conv1d(
                in_channels = 128, 
                out_channels = 128, 
                kernel_size = 1,
                stride = 2,
                dilation = 1
            ),
            nn.ReLU(),
            nn.AvgPool1d(3),
            nn.BatchNorm1d(128),
            nn.Dropout1d(p=0.6),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            
            nn.Linear(1152, 512),
            nn.BatchNorm1d(512),
            nn.Dropout1d(p=0.6),
            
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.Dropout1d(p=0.6),
            
            nn.Linear(128, 3),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        data, label = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        output = torch.sigmoid(output)
        label = one_hot(label,num_classes=3).to(torch.float32)
    
        # calculate loss
        loss = binary_cross_entropy(output, label)
        self.log("training_loss", loss)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        data, label = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        output = torch.sigmoid(output)
        label = one_hot(label,num_classes=3).to(torch.float32)
    
        # calculate loss
        loss = binary_cross_entropy(output, label)
        self.log("test_loss", loss)
    
    def validation_step(self, batch, batch_idx):
        data, label = batch
        
        # get predictions
        output = self(data)
        
        # convert for loss calculation
        output = torch.sigmoid(output)
        label = one_hot(label,num_classes=3).to(torch.float32)
    
        # calculate loss
        loss = binary_cross_entropy(output, label)
        self.log("val_loss", loss)

In [4]:
model = SignalCNN()

In [5]:
# from torchsummary import summary

# summary(model, (16, 8192))

# Train

In [6]:
trainer = pl.Trainer(max_epochs=100,log_every_n_steps=4)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | features   | Sequential | 57.9 K
1 | classifier | Sequential | 657 K 
------------------------------------------
715 K     Trainable params
0         Non-trainable params
715 K     Total params
2.862     Total estimated model params size (MB)


Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 13.53it/s, v_num=5]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                                                                                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                                                                               | 0/1 [00:00<?, ?it/s][A
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  8.84it/s, v_num=5][A
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  8.43it/s, v_num=5]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 85.67it/s]


[{'test_loss': 0.6526675224304199}]