In [None]:
import os, torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from data_preparation.dataset import SignalDataset
from architecture.net import ClassifierNet
from pipeline.train import TrainingPipeline
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, confusion_matrix

# EEG (Electroencephalography) Data Modeling

In [None]:
t_size = 100
train_size = 0.8
training_eeg_dataset = SignalDataset(
    "data/samples", 
    signal="EEG",
    task="nback", 
    sample_size=train_size, 
    t_size=t_size
)

eval_eeg_dataset = SignalDataset(
    "data/samples", 
    signal="EEG", 
    task="nback", 
    excluded=training_eeg_dataset.segment_files, 
    t_size=t_size
)

print(f"Number of train samples: {len(training_eeg_dataset)}")
print(f"Number of eval samples: {len(eval_eeg_dataset)}")

In [None]:
num_workers = 4
batch_size = 16
train_eeg_dataloader = DataLoader(training_eeg_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True)
eval_eeg_dataloader = DataLoader(eval_eeg_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True)

In [None]:
sample_signal, label = training_eeg_dataset[12]

print(f"sample shape: {sample_signal.shape}")

plt.figure(figsize=(20, 5))
for ch in sample_signal.squeeze():
    plt.plot(ch)
plt.title("EEG signals")
plt.xlabel("Timesteps")
plt.ylabel("Channel readings")
plt.show()

In [None]:
in_channels = 1
num_classes = len(training_eeg_dataset.get_label_names())
dropout = 0.0
network = "resnet18"
pretrained_weights = None #"DEFAULT"
track_grads = True
lr = 1e-3
weight_decay = 0.0
betas = (0.9, 0.999)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
eeg_classifier = ClassifierNet(in_channels, num_classes, dropout, network, pretrained_weights, track_grads)
optimizer = torch.optim.Adam(eeg_classifier.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
lossfunc = nn.CrossEntropyLoss()

# define pipeline
pipeline = TrainingPipeline(eeg_classifier, lossfunc, optimizer, device, weight_init=False)

In [None]:
epochs = 10

import warnings
warnings.filterwarnings("ignore")

for epoch in range(epochs):
    pipeline.train(train_eeg_dataloader, verbose=True)
    pipeline.evaluate(eval_eeg_dataloader, verbose=True)