In [1]:
%load_ext autoreload
%autoreload 2

In [71]:
from pathlib import Path

import torch
import pandas as pd
import matplotlib.pyplot as plt
import torchvision.transforms as t

from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange
from torch.optim import Adam
from IPython.display import clear_output


from utils import *

import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data_path = Path('eeg_data_csv/')

# parameters are acquired during EDA
common_electrodes = ['EEG FP1-R', 'EEG FP2-R', 'EEG F3-R', 'EEG F4-R', 'EEG C3-R', 'EEG C4-R', 
                     'EEG P3-R', 'EEG P4-R', 'EEG O1-R', 'EEG O2-R', 'EEG F7-R', 'EEG F8-R', 
                     'EEG T3-R', 'EEG T4-R', 'EEG T5-R', 'EEG T6-R', 'EEG FPZ-R', 'EEG FZ-R', 
                     'EEG CZ-R', 'EEG PZ-R', 'EEG OZ-R']


min_length = 21500
filter_params = {'l_freq': 1, 'h_freq': 50, 'method': 'iir', 'n_jobs': -1, 'verbose': False}
Fs = 500

In [4]:
def load_data(data_path):
    X, y = [], []

    for csv_file in (data_path/"Health").iterdir():
        table = pd.read_csv(csv_file)[common_electrodes]
        X.append(torch.tensor(table.values.T).float())
        y.append(0)

    for csv_file in (data_path/"MDD").iterdir():
        table = pd.read_csv(csv_file)[common_electrodes]
        X.append(torch.tensor(table.values.T).float())
        y.append(1)

    return X, torch.tensor(y)

In [69]:
X, y = load_data(data_path)
X_train, X_val, y_train, y_val = train_test_split(X, y, stratify=y)

train_set = EEGDataset(X_train, y_train, transform=t.RandomCrop((21, 534)))
eval_set = EEGDataset(X_val, y_val, transform=t.CenterCrop((21, 534)))

In [72]:
def split_eegs(X, y, size):
    n_splits = torch.tensor([(tensor.size(1) - 1) // size for tensor in X])
    splits = []
    for tensor in X:
        splits.extend(tensor.split(size, dim=-1)[:-1])

    return torch.stack(splits), y.repeat_interleave(n_splits)

X_chunks_train, y_chunks_train = split_eegs(X_train, y_train, 534)
X_chunks_val, y_chunks_val = split_eegs(X_val, y_val, 534)

train_set = TensorDataset(X_chunks_train, y_chunks_train)
eval_set = TensorDataset(X_chunks_val, y_chunks_val)

In [87]:
train_loader = DataLoader(train_set, batch_size=64, num_workers=2, shuffle=True)
eval_loader = DataLoader(eval_set, batch_size=64, num_workers=2, shuffle=False)

In [34]:
# model = ShallowConvNet(21, 2).to(device)
# model = DeepConvNet(21, 2).to(device)
# optimizer = Adam(model.parameters(), lr=1e-3)

# train_step = 0
# eval_step = 0

# writer = SummaryWriter(comment="")


# def train(model, dataloader):
#     global train_step
#     model.train()
#     for x_batch, y_batch in tqdm(dataloader, leave=False):
#         optimizer.zero_grad()
#         outputs = model(x_batch.to(device))
#         loss = F.cross_entropy(outputs, y_batch.to(device))
#         writer.add_scalar("train/loss", loss.item(), train_step)
#         train_step += 1
#         loss.backward()
#         optimizer.step()

# @torch.no_grad()
# def validate(model, dataloader):
#     global eval_step
#     n_correct = 0
#     model.eval()
#     for x_batch, y_batch in tqdm(dataloader, leave=False):
#         outputs = model(x_batch.to(device))
#         n_correct += (outputs.argmax(1) == y_batch.to(device)).int().sum().item()
#     writer.add_scalar("eval/accuracy", n_correct / len(dataloader.dataset), eval_step)
#     eval_step += 1

In [78]:
import pytorch_lightning as pl
import torchmetrics
from torch import nn
from torch.nn import functional as F

In [88]:
class EEGModel(pl.LightningModule):
    def __init__(self, kind="deep", num_channels=21) -> None:
        super().__init__()
        if kind == "deep":
            self.model = DeepConvNet(num_channels, 2)
        else:
            self.model = ShallowConvNet(num_channels, 2)

        self.metrics = nn.ModuleDict({
            "accuracy": torchmetrics.Accuracy(),
            "f1_score": torchmetrics.F1Score()
        })

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.model(inputs)
        loss = F.cross_entropy(outputs, labels)
        preds = outputs.argmax(1)
        self.log("train/loss", loss)
        for name, metric in self.metrics.items():
            metric(preds, labels)
            self.log(f"train/{name}", metric)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self.model(inputs)
        loss = F.cross_entropy(outputs, labels)
        preds = outputs.argmax(1)
        self.log("eval/loss", loss, on_step=True, on_epoch=False)
        for name in self.metrics:
            self.metrics[name].update(preds, labels)
            self.log(f"eval/{name}", self.metrics[name])


In [89]:
trainer = pl.Trainer(max_epochs=100, log_every_n_steps=5)
model = EEGModel()

trainer.fit(model, train_loader, eval_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name    | Type        | Params
----------------------------------------
0 | model   | DeepConvNet | 277 K 
1 | metrics | ModuleDict  | 0     
----------------------------------------
277 K     Trainable params
0         Non-trainable params
277 K     Total params
1.108     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

: 

: 