In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary


from torch.utils.data import DataLoader

from src.dataset.MI_dataset_all_subjects import MI_Dataset as MI_Dataset_all_subjects
from src.dataset.MI_dataset_single_subject import MI_Dataset as MI_Dataset_single_subject

from config.default import cfg


from models.eegnet import EEGNet

from utils.eval import accuracy

%load_ext autoreload
%autoreload 2


In [2]:
device = 'cpu'# torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

'cpu'

In [3]:
subject = 1
train_runs = [0,1,2,3,4]
test_runs = [5]


train_dataset = MI_Dataset_single_subject(subject, train_runs, device = device)
test_dataset = MI_Dataset_single_subject(subject, test_runs, device = device)

train_dataloader = DataLoader(train_dataset,  batch_size=cfg['train']['batch_size'], shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset,  batch_size=cfg['train']['batch_size'], shuffle=False, drop_last=True)

In [4]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for features, label in train_dataloader:
    print(features.shape)
    print(label)
    break
    


Train dataset: 240 samples
Test dataset: 48 samples
torch.Size([48, 22, 401])
tensor([2, 1, 1, 0, 2, 2, 1, 1, 2, 1, 0, 0, 3, 1, 3, 0, 1, 2, 2, 1, 0, 3, 2, 2,
        3, 3, 1, 3, 3, 2, 0, 0, 1, 1, 0, 3, 3, 0, 3, 0, 3, 3, 0, 1, 2, 2, 2, 2])


In [5]:
#model = EEGNetv4(3,cfg["model"]["n_classes"],input_window_samples=641)
#model = EEGNet()
#model = MyModel([64, 25, 401])

Chans = next(iter(train_dataloader))[0][0].shape[0]
Samples = next(iter(train_dataloader))[0][0].shape[1]
model = EEGNet(Chans = Chans, Samples = Samples, nb_classes = 4)
model.to(device)
summary(model, input_size=(5, 10, *next(iter(train_dataloader))[0][0].shape));



Layer (type:depth-idx)                   Param #
├─Conv2d: 1-1                            1,024
├─BatchNorm2d: 1-2                       32
├─Conv2d: 1-3                            704
├─BatchNorm2d: 1-4                       64
├─ELU: 1-5                               --
├─AvgPool2d: 1-6                         --
├─Dropout: 1-7                           --
├─Conv2d: 1-8                            512
├─Conv2d: 1-9                            1,024
├─BatchNorm2d: 1-10                      64
├─AvgPool2d: 1-11                        --
├─Dropout: 1-12                          --
├─Flatten: 1-13                          --
├─Linear: 1-14                           1,540
Total params: 4,964
Trainable params: 4,964
Non-trainable params: 0


In [6]:
# Test forward pass
model(next(iter(train_dataloader))[0]);

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_features)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(model, train_dataloader)
        test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

Epoch 10/200, Loss: 6.906140208244324, Train accuracy: 27.08%, Test accuracy: 14.58%
Epoch 20/200, Loss: 6.171535611152649, Train accuracy: 47.92%, Test accuracy: 35.42%
Epoch 30/200, Loss: 4.667204797267914, Train accuracy: 58.75%, Test accuracy: 43.75%
Epoch 40/200, Loss: 4.2737919092178345, Train accuracy: 61.67%, Test accuracy: 45.83%
Epoch 50/200, Loss: 4.071137011051178, Train accuracy: 64.17%, Test accuracy: 56.25%
Epoch 60/200, Loss: 3.8616331219673157, Train accuracy: 68.33%, Test accuracy: 66.67%
Epoch 70/200, Loss: 3.6464752554893494, Train accuracy: 70.42%, Test accuracy: 66.67%
Epoch 80/200, Loss: 3.4538474082946777, Train accuracy: 69.58%, Test accuracy: 72.92%
Epoch 90/200, Loss: 3.3149185180664062, Train accuracy: 69.58%, Test accuracy: 72.92%
Epoch 100/200, Loss: 3.2068055868148804, Train accuracy: 70.00%, Test accuracy: 75.00%
Epoch 110/200, Loss: 3.1474274396896362, Train accuracy: 70.83%, Test accuracy: 75.00%
Epoch 120/200, Loss: 3.0369244813919067, Train accuracy: