In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary


from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset

from src.dataset.MI_dataset_single_subject import MI_Dataset as MI_Dataset_single_subject

from config.default import cfg


from models.conditioned_eegnet import ConditionedEEGNet

from utils.eval import accuracy
from utils.model import print_parameters

%load_ext autoreload
%autoreload 2


In [2]:
subjects = [1,2,3,4,5,6,7,8,9]
train_runs = {
                1:[0, 1, 2, 3, 4],
                2:[0, 1, 2, 3, 4],
                3:[0, 1, 2, 3, 4],
                4:[0, 1],
                5:[0, 1, 2, 3, 4],
                6:[0, 1, 2, 3, 4],
                7:[0, 1, 2, 3, 4],
                8:[0, 1, 2, 3, 4],
                9:[0, 1, 2, 3, 4]
        }
test_runs = {
                1:[5],
                2:[5],
                3:[5],
                4:[2],
                5:[5],
                6:[5],
                7:[5],
                8:[5],
                9:[5]
}

batch_size = 64

In [3]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
train_datasets = []

for subject in subjects:
    dataset = MI_Dataset_single_subject(subject, train_runs[subject], return_subject_id=True, device=device, verbose=False)
    train_datasets.append(dataset)
    channels = dataset.channels
    time_steps = dataset.time_steps
train_dataset = ConcatDataset(train_datasets)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f"Train dataset: {len(train_dataset)} samples")

Train dataset: 2016 samples


In [6]:
test_datasets = []
for subject in subjects:
    test_datasets.append(MI_Dataset_single_subject(subject, test_runs[subject],return_subject_id=True, device=device, verbose=False))
test_dataset = ConcatDataset(test_datasets)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"Test dataset: {len(test_dataset)} samples")

Test dataset: 432 samples


In [7]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for feature, label in train_dataloader:
    # print(feature[0].shape)
    # print(feature[1].shape)
    # print(label)
    print(feature[1])
    break

Train dataset: 2016 samples
Test dataset: 432 samples
tensor([0, 2, 0, 4, 2, 0, 3, 6, 5, 5, 8, 6, 7, 0, 4, 2, 4, 8, 0, 1, 4, 8, 5, 4,
        5, 4, 4, 5, 0, 5, 1, 4, 7, 8, 7, 3, 2, 2, 8, 1, 6, 0, 8, 2, 1, 0, 0, 3,
        3, 4, 5, 5, 5, 7, 8, 2, 7, 1, 2, 4, 0, 4, 1, 7])


In [8]:
model = ConditionedEEGNet(num_subjects = len(subjects), channels = channels, samples= time_steps, num_classes = 4)
model.to(device)
print_parameters(model)

eeg_processor.conv1.weight.... --> 1024
eeg_processor.bn1.weight...... --> 16
eeg_processor.bn1.bias........ --> 16
eeg_processor.dw_conv1.weight. --> 704
eeg_processor.bn2.weight...... --> 32
eeg_processor.bn2.bias........ --> 32
eeg_processor.sep_conv1.weight --> 512
eeg_processor.conv2.weight.... --> 1024
eeg_processor.bn3.weight...... --> 32
eeg_processor.bn3.bias........ --> 32
subject_processor.fn1.weight.. --> 144
subject_processor.fn1.bias.... --> 16
query.weight.................. --> 12288
key.weight.................... --> 512
value.weight.................. --> 12288
fn1.weight.................... --> 4096
fn1.bias...................... --> 128
fn2.weight.................... --> 512
fn2.bias...................... --> 4

Total Parameter Count:........ --> 33412


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_features[0], batch_features[1])
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(model, train_dataloader)
        test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

Epoch 10/200, Loss: 44.31196594238281, Train accuracy: 26.59%, Test accuracy: 27.55%
