In [1]:
from trainer import Trainer
from datasets import EEGDataset
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.nn import functional as F
import numpy as np
from sklearn.metrics import roc_auc_score
from early_stopping import EarlyStopping
from matplotlib import pyplot as plt
from sklearn.model_selection import KFold,GroupShuffleSplit,GroupKFold
from torch.utils.data import DataLoader, ConcatDataset
from models import EEGAutoencoder,EEG3DDeformAutoencoder, EEG3DAutoencoder,EEGDeformAutoencoder
from torch import nn
import pandas as pd
import random
import os

def warn(*args, **kwargs):
    pass


import warnings

warnings.warn = warn

In [2]:
SEED =0
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)

In [3]:
dataset = EEGDataset('../../../data/MDD_EEG/eeg/preprocessing/spectrums/full.csv', is_3D =False)

In [4]:
#dataset.data

In [5]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [6]:
def seed_worker(worker_id):
    # worker_seed = torch.initial_seed() % 2**32
    np.random.seed(SEED)
    random.seed(SEED)

In [7]:
def cross_validation(dataset,experiment_name,learning_rate =1e-4, k_folds=10, num_epochs=50, device='cuda:1'):
    random_seed = 0 # or any of your favorite number
    num_workers =8
    set_seed(random_seed)
    g = torch.Generator()
    g.manual_seed(random_seed)
    kfold = GroupKFold(n_splits=k_folds)
    best_scores, reconstr_losses =[], []
    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset, None, dataset.data['unique_id'])):
        print(f'FOLD {fold}')
        print('--------------------------------')
        print(test_ids)
        train_dataset =torch.utils.data.Subset(dataset,train_ids)
        val_dataset =torch.utils.data.Subset(dataset,test_ids)
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            num_workers=num_workers,
            worker_init_fn=seed_worker,
            generator=g,
            drop_last=True,
            batch_size=4)
        test_loader = torch.utils.data.DataLoader(
            val_dataset,
            num_workers=num_workers,
            worker_init_fn=seed_worker,
            generator=g,
            drop_last=True,
            batch_size=4)

        model = EEGDeformAutoencoder(hidden_layers=[32,64]).to(device)
        criterion = nn.MSELoss()
        cr_loss = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                     weight_decay=1e-1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
        trainer = Trainer(num_epochs, f'{experiment_name}-fold-{fold}',
                 model, criterion, cr_loss, optimizer, scheduler,device =device,seed=random_seed)
        best_score_acc, reconst_loss =trainer.train(train_loader,test_loader)
        best_scores.append(best_score_acc)
        reconstr_losses.append(reconst_loss)
    return  best_scores,reconstr_losses

In [8]:
best_scores,reconstr_losses =cross_validation(dataset,"runs/deform-2D-autoencoder-cross-validation")

FOLD 0
--------------------------------
[  0   3  18  31  34  41  50  51  53  70  73  77 102 132 133 143 148 154
 162 164 211 215 219 231 243 259 262 263 292 296]
EEGDeformAutoencoder(
  (encoder): Sequential(
    (0): DeformableConv2d(
      (offset_conv): Conv2d(15, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (modulator_conv): Conv2d(15, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (regular_conv): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): DeformableConv2d(
      (offset_conv): Conv2d(32, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (modulator_conv): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (regular_conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=

In [9]:
np.mean(best_scores)

0.6892857142857143

In [10]:
best_scores

[0.75,
 0.7142857142857143,
 0.7857142857142857,
 0.7142857142857143,
 0.6071428571428571,
 0.6785714285714286,
 0.6071428571428571,
 0.5714285714285714,
 0.75,
 0.7142857142857143]

In [11]:
np.std(best_scores)

0.06785714285714287

In [12]:
np.mean([reconst_loss.detach().cpu().numpy() for reconst_loss in reconstr_losses])

0.5447323

In [13]:
np.std([reconst_loss.detach().cpu().numpy() for reconst_loss in reconstr_losses])

0.901809

In [14]:
[reconst_loss.detach().cpu().numpy() for reconst_loss in reconstr_losses]

[array(0.23343945, dtype=float32),
 array(0.40892613, dtype=float32),
 array(0.16485944, dtype=float32),
 array(0.23378865, dtype=float32),
 array(0.14494827, dtype=float32),
 array(0.31304044, dtype=float32),
 array(0.21584284, dtype=float32),
 array(0.3161478, dtype=float32),
 array(3.2403498, dtype=float32),
 array(0.17598024, dtype=float32)]