In [None]:
import os
import random
import numpy as np
from functools import partial
import pickle
from glob import glob
from tqdm.notebook import tqdm
import IPython.display as ipd
import matplotlib.pyplot as plt
import yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, RMSprop, AdamW, lr_scheduler
from torch.utils.data import DataLoader
from torch.cuda.amp.grad_scaler import GradScaler

from ml_collections import ConfigDict
import wandb

from utils import *

import warnings
warnings.filterwarnings("ignore")


device = torch.device('cuda:0')

torch.cuda.is_available()

In [None]:
model_name = 'eegadanet'
model_path = 'ckpts/base/'+model_name

if not os.path.isdir(model_path):
    os.mkdir(model_path)

In [None]:
with open(model_path+'/config.yaml') as f:
    config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))

config

## Dataset

In [None]:
eegs = pickle.load(open('data/klados/EEG.pickle', 'rb'))

eogs = pickle.load(open('data/klados/EOG.pickle', 'rb'))
eogs = [eog*config.data.noise_scale for eog in eogs]

# order of electrodes in Klados dataset
electrodes_klados = [
    'Fp1', 'Fp2', 'F3', 'F4', 
    'C3', 'C4', 'P3', 'P4', 
    'O1', 'O2', 'F7', 'F8', 
    'T3', 'T4', 'T5', 'T6', 
    'Fz', 'Cz', 'Pz'
]

electrodes_sleep = ['F3','F4','C3','C4','O1','O2']

C = len(electrodes_klados)

sr = config.data.sample_rate

eegs[0].shape, eogs[0].shape, C

In [None]:
from eeg_positions import get_elec_coords

_df = get_elec_coords(system='1020', dim='3d')

coord_dict = {}
for i, (c,x,y,z) in enumerate(_df.values):
    if c=='T7':
        coord_dict['T3'] = (x,y,z)
    elif c=='T8':
        coord_dict['T4'] = (x,y,z)
    elif c=='P7':
        coord_dict['T5'] = (x,y,z)
    elif c=='P8':
        coord_dict['T6'] = (x,y,z)
    coord_dict[c] = (x,y,z)

coords_klados = torch.FloatTensor([coord_dict[e] for e in electrodes_klados])
coords_klados.shape

In [None]:
from itertools import combinations

all_ep = flatten([[c for c in combinations(range(C), _C) if len(c)>0] for _C in range(1,C+1)])
len(all_ep)

In [None]:

def chunk(x, chunk_size):
    i = np.random.randint((x.shape[-1]-chunk_size) + 1)  
    x = x[..., i:i+chunk_size]
    return x

def scale_noise(n, delta_snr):
    if type(delta_snr)==float:
        d = delta_snr
    else:
        d = np.random.uniform(*delta_snr)
    scale = 10**(-d/20)
    return n * scale


class Dataset(torch.utils.data.Dataset):
    def __init__(self, X):        
        self.X = X
               
    def __len__(self):
        return len(self.X)
            
    def __getitem__(self, index):
        trial_id = self.X[index]
        x, n = eegs[trial_id], eogs[trial_id]
        return x+n, x
        

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, X, N=1000):        
        self.X = X
        self.epoch_size = N
               
    def __len__(self):
        return self.epoch_size

    def __getitem__(self, index):
        
        trial_ids = random.choices(self.X, k=2)
        
        x = chunk(eegs[trial_ids[0]], chunk_size=config.data.chunk_size)
        n = chunk(eogs[trial_ids[1]], chunk_size=config.data.chunk_size)
        
        n = scale_noise(n, config.data.delta_snr)

        channel_indices = random.choice(all_ep)
        channel_mask = torch.tensor([(c in channel_indices) for c in range(C)]).bool()
        
        return x+n, x, channel_mask

### Training

In [None]:
from eegadanet import UNet
from scipy.stats.stats import pearsonr
from sklearn.model_selection import KFold


cv = KFold(n_splits=config.training.cv_splits, shuffle=True, random_state=config.training.cv_seed)


def init_cv(config, epoch_size=10000):

    manual_seed(0)

    folds = []
    for trainX, valX in cv.split(eegs):
        
        model = UNet(config.model).to(device)
    
        folds.append({ 
            'model': model,
            'optimizer': Adam(model.parameters(), lr=config.training.lr),
            'train_loader': DataLoader(
                TrainDataset(trainX, N=epoch_size), 
                batch_size=config.training.batch_size, 
                num_workers=4,
                shuffle=True, 
            ),
            'valset': Dataset(valX)
        })

    config.model.num_params = num_params(model)

    return folds, EarlyStop(config.training.early_stop)


def training(fold):

    model = fold['model'].train()
    optimizer = fold['optimizer']

    losses = []
    for batch in tqdm(fold['train_loader']):
    
        y, x, c_mask = [i.to(device) for i in batch]
        sigma = y.std(-1, keepdim=True)

        x_ = model(y, coords=coords_klados, channel_mask=c_mask) 

        per_channel_loss = nn.MSELoss(reduction='none')(x_/sigma, x/sigma).mean(-1)
        per_sample_loss = (per_channel_loss * c_mask).sum(1) / c_mask.sum(1)
        loss = per_sample_loss.mean()
                
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        losses.append(loss.item())
    
    return {
        'train_L2': np.mean(losses)
    }


def validation(fold):

    losses = []
    for y, x in tqdm(fold['valset']):

        with torch.no_grad():
            x_ = chunk_and_apply(
                partial(fold['model'].eval(), coords=coords_klados),
                y.to(device), 
                batch_size=config.training.batch_size, 
                chunk_size=config.data.chunk_size
            )

        sigma = y.std(-1, keepdim=True)
        loss = nn.MSELoss()(x_/sigma, x/sigma)
    
        losses.append(loss.item())

    return {
        'val_L2': np.mean(losses),
    }


def evaluation(folds, best_ckpt, result_path):

    for electrodes in [electrodes_klados, electrodes_sleep]:

        electrode_indices = [electrodes_klados.index(e) for e in electrodes]
        coords = coords_klados[electrode_indices]

        X_, rawX_ = [], []
        X, rawX = [], []   
        for fold, ckpt in zip(folds, best_ckpt):

            fold['model'].load_state_dict(ckpt)
            
            for y, x in tqdm(fold['valset']):
        
                y, x = [eeg[electrode_indices] for eeg in [y,x]]

                with torch.no_grad():
                    x_ = chunk_and_apply(
                        partial(fold['model'].eval(), coords=coords),
                        y.to(device), 
                        batch_size=config.training.batch_size, 
                        chunk_size=config.data.chunk_size
                    )
                    
                X_.append(x_), rawX_.append(tensor2raw(x_, sr, electrodes))
                X.append(x), rawX.append(tensor2raw(x, sr, electrodes))

        # metrics per channel
        rrmse, srrmse, cc = [], [], []
        for i, e in enumerate(electrodes):
            rrmse.append(np.mean([RRMSE(x_[i] ,x[i]) for x_, x in zip(X_, X)]))
            srrmse.append(np.mean([RRMSE(PSD(x_, e), PSD(x, e)) for x_, x in zip(rawX_, rawX)]))
            cc.append(np.mean([pearsonr(x_[i], x[i])[0] for x_, x in zip(X_, X)]))

        # average over channels
        rrmse, srrmse, cc = [round(np.mean(scores).item(), 3) for scores in [rrmse, srrmse, cc]]
        
        with open(result_path, 'a') as f:
            f.write(f'C_{len(electrodes)}: RRMSE_t={rrmse} | RRMSE_s={srrmse} | CC={cc} \n')
        f.close()

In [None]:

result_path = model_path+'/results.txt'

folds, early_stop = init_cv(config)


epoch = 0
while True:

    cv_results = {}

    for fold_id, fold in enumerate(folds): 
        print('-'*80, '\nFold', fold_id)   

        results = {
            **training(fold), 
            **validation(fold)
        }
        for k,v in results.items():
            if k in cv_results:
                cv_results[k].append(v)
            else:
                cv_results[k] = [v]    
        print(results)

    epoch += 1

    cv_results = {k: np.mean(v) for k,v in cv_results.items()}

    ipd.clear_output(wait=True)
    print('Epochs trained:', epoch)
    print(cv_results)

    early_stop.update(cv_results['val_L2'])
    
    if early_stop.best_epoch==epoch:
        early_stop.best_ckpt = [fold['model'].state_dict() for fold in folds]
        
    if early_stop.stop:
        break
        

evaluation(folds, early_stop.best_ckpt, result_path)

torch.save(early_stop.best_ckpt, model_path+f'/epoch{early_stop.best_epoch}.ckpt')
