In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from cereprocess.train.xloop import oneloop, get_datanpz, get_dataloaders
from cereprocess.train.misc import TrainElements, get_model_size, EarlyStopping, check_model, def_dev, def_hyp
from cereprocess.datasets.defaults import get_def_ds
from cereprocess.datasets.pytordataset import KFoldDataset
from cereprocess.datasets.pipeline import general_pipeline
from cereprocess.train.retrieve import load_picks, get_results                
from cereprocess.train.train import evaluate          
from cereprocess.train.callbacks import def_metrics       
from cereprocess.datasets.pytordataset import KFoldDataset
from models.neurogate import NeuroGATE
import pandas as pd
import numpy as np
import gc
import os

# Training

In [None]:
device = def_dev()
mins = 10
length = mins * 3000
input_size = (22, length)
tuh, nmt, nmt_4k = get_def_ds(mins)

In [None]:
# Change the dataset currently in use from over here
curr_data = nmt_4k

In [None]:
# Change pipeline from here
data_dir, data_description = get_datanpz(curr_data[0], curr_data[3], neurogate_pipeline(dataset='NMT', length_minutes=mins), input_size)

In [None]:
# Start training from here
# Set hyperparameters
hyps = def_hyp(batch_size=8, epochs=60, lr=0.0003, accum_iter=1)
train_loader, eval_loader = get_dataloaders(data_dir, hyps['batch_size'], target_length=length)
# Set early stopping
es = EarlyStopping(patience=60) 
# Configure training
train_elems = TrainElements(device
, earlystopping=es
)
model = NeuroGATE().to(device)
name = f'NeuroGate'
model_x = oneloop(device, model, train_loader, eval_loader, data_description, hyps, train_elems, curr_data[3], name)
torch.cuda.empty_cache()
gc.collect()neurograte_eeg

# Evaluating a trained model

In [None]:
metrics = def_metrics(device)
model = NeuroGATE().to(device)
model.load_state_dict(torch.load("results/nmt4k/models/model_XX.pt"))
train_elems = TrainElements(device)

In [None]:
length = 10 * 3000
hyps = def_hyp(batch_size=8, epochs=60, lr=0.0003, accum_iter=1)
train_loader, eval_loader = get_dataloaders(data_dir, hyps['batch_size'], target_length=length)

In [None]:
evaluate(model, eval_loader, train_elems.criterion, device, metrics, train_elems.history, plot_roc=False)

In [None]:
train_elems.history.history

In [None]:
{key: value[-1] if isinstance(value, list) else value for key, value in history.history['val'].items()}

# Running K-Fold Training

In [None]:
i = 1
for trainset, evalset in KFoldDataset(root_dir=data_dir, n_splits=10, shuffle=True):
    train_loader = DataLoader(trainset, batch_size=hyps["batch_size"], shuffle=True)
    eval_loader = DataLoader(evalset)
    model = NeuroGATE().to(device)
    train_elems = TrainElements(device)
    name = f'KFold-{i}'
    model_x = oneloop(device, model, train_loader, eval_loader, data_description, hyps, train_elems, curr_data[3], name)
    torch.cuda.empty_cache()
    gc.collect()
    i += 1

# Check Results

In [None]:
from train.retrieve import get_results, get_paths, load_picks
import pickle
from datasets.defaults import get_def_ds

In [None]:
# Defining the dataset (This stores information for tuh and nmt, including the class weight balancing)
tuh, nmt, nmt_4k = get_def_ds(mins=mins)
current_selected = nmt_4k

In [None]:
res = get_results(current_selected[3], clean=True)

In [None]:
res.head()