In [1]:
import os
import random
import numpy as np
from functools import partial
import pickle
from glob import glob
from tqdm.notebook import tqdm
import IPython.display as ipd
import matplotlib.pyplot as plt
import yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, RMSprop, AdamW, lr_scheduler
from torch.utils.data import DataLoader
from torch.cuda.amp.grad_scaler import GradScaler

from ml_collections import ConfigDict
import wandb

from utils import *

import warnings
warnings.filterwarnings("ignore")


manual_seed(0)

device = torch.device('cuda:0')

torch.cuda.is_available()

True

In [2]:
model_name = 'eegadanet'
model_path = 'ckpts/'+model_name

if not os.path.isdir(model_path):
    os.mkdir(model_path)

In [3]:
with open(model_path+'/config.yaml') as f:
    config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))

config

data:
  chunk_size: 512
  delta_snr:
  - -5
  - 5
  noise_scale: 5
  noise_type: EOG
  sample_rate: 200
model:
  activation: gelu
  att_dim: 128
  att_drop: 0.1
  att_heads: 8
  emb_freqs: 4
  kernel_size: 3
  norm: BatchNorm
  num_blocks:
  - 1
  - 1
  - 1
  - 2
  - 2
  - 4
  num_channels: 16
  scale: 2
training:
  batch_size: 32
  cv_seed: 0
  cv_splits: 3
  lr: 0.0002

## Dataset

In [4]:
eegs = pickle.load(open('data/klados/EEG.pickle', 'rb'))

eogs = pickle.load(open('data/klados/EOG.pickle', 'rb'))
eogs = [eog*config.data.noise_scale for eog in eogs]

# order of electrodes in Klados dataset
electrodes_klados = [
    'Fp1', 'Fp2', 'F3', 'F4', 
    'C3', 'C4', 'P3', 'P4', 
    'O1', 'O2', 'F7', 'F8', 
    'T3', 'T4', 'T5', 'T6', 
    'Fz', 'Cz', 'Pz'
]

C = len(electrodes_klados)

eegs[0].shape, eogs[0].shape, C

(torch.Size([19, 5601]), torch.Size([19, 5601]), 19)

In [5]:
from eeg_positions import get_elec_coords

_df = get_elec_coords(system='1020', dim='3d')

coord_dict = {}
for i, (c,x,y,z) in enumerate(_df.values):
    if c=='T7':
        coord_dict['T3'] = (x,y,z)
    elif c=='T8':
        coord_dict['T4'] = (x,y,z)
    elif c=='P7':
        coord_dict['T5'] = (x,y,z)
    elif c=='P8':
        coord_dict['T6'] = (x,y,z)
    coord_dict[c] = (x,y,z)

coords_klados = torch.FloatTensor([coord_dict[e] for e in electrodes_klados])
coords_klados.shape

torch.Size([19, 3])

In [6]:

def chunk(x, chunk_size):
    i = np.random.randint((x.shape[-1]-chunk_size) + 1)  
    x = x[..., i:i+chunk_size]
    return x

def scale_noise(n, delta_snr):
    if type(delta_snr)==float:
        d = delta_snr
    else:
        d = np.random.uniform(*delta_snr)
    scale = 10**(-d/20)
    return n * scale


class Dataset(torch.utils.data.Dataset):
    def __init__(self, X):        
        self.X = X
               
    def __len__(self):
        return len(self.X)
            
    def __getitem__(self, index):
        trial_id = self.X[index]
        x, n = eegs[trial_id], eogs[trial_id]
        return x+n, x
        

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, X, N=1000):        
        self.X = X
        self.epoch_size = N
               
    def __len__(self):
        return self.epoch_size

    def __getitem__(self, index):
        
        trial_ids = random.choices(self.X, k=2)
        
        x = chunk(eegs[trial_ids[0]], chunk_size=config.data.chunk_size)
        n = chunk(eogs[trial_ids[1]], chunk_size=config.data.chunk_size)
        
        n = scale_noise(n, config.data.delta_snr)

        channel_indices = random.sample(range(C), np.random.randint(C)+1)
        channel_mask = torch.tensor([(c in channel_indices) for c in range(C)]).bool()
        
        return x+n, x, channel_mask

## Models

In [7]:
from eegadanet import UNet

m = UNet(config.model)
config.model.num_params = num_params(m)
config.model.num_params, m

(0.12,
 UNet(
   (electrode_embedding): FourierEmb()
   (first_conv): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
   (down_blocks): ModuleList(
     (0-2): 3 x Module(
       (tcb): TCB(
         (film): FiLM(
           (cond_proj): Linear(in_features=128, out_features=32, bias=True)
         )
         (layers): ModuleList(
           (0): Sequential(
             (0): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
             (1): GELU(approximate='none')
             (2): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
           )
         )
       )
       (att): SpatialAttention(
         (Q): Linear(in_features=16, out_features=128, bias=False)
         (K): Linear(in_features=16, out_features=128, bias=False)
         (C): Conv1d(32, 16, kernel_size=(1,), stride=(1,), bias=False)
         (drop): Dropout(p=0.1, inplace=False)
       )
       (downsample): Conv1d(16, 16, kernel_size=(2,), strid

### Training

In [8]:
from sklearn.model_selection import KFold

cv = KFold(n_splits=config.training.cv_splits, shuffle=True, random_state=config.training.cv_seed)

epoch_size = 10000

folds = []
for trainX, valX in cv.split(eegs):
    
    model = UNet(config.model).to(device)

    folds.append({ 
        'model': model,
        'optimizer': Adam(model.parameters(), lr=config.training.lr),
        'train_loader': DataLoader(
            TrainDataset(trainX, N=epoch_size), 
            batch_size=config.training.batch_size, 
            num_workers=4,
            shuffle=True, 
        ),
        'valset': Dataset(valX)
    })

In [9]:
def training(fold):

    model = fold['model'].train()
    optimizer = fold['optimizer']

    losses = []
    for batch in tqdm(fold['train_loader']):
    
        y, x, c_mask = [i.to(device) for i in batch]
        sigma = y.std(-1, keepdim=True)

        x_ = model(y, coords=coords_klados, channel_mask=c_mask) 

        per_channel_loss = nn.MSELoss(reduction='none')(x_/sigma, x/sigma).mean(-1)
        per_sample_loss = (per_channel_loss * c_mask).sum(1) / c_mask.sum(1)
        loss = per_sample_loss.mean()
                
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        losses.append(loss.item())
    
    return {
        'train_L2': np.mean(losses)
    }


@torch.no_grad()
def validation(fold):

    model = fold['model'].eval()

    losses = []
    for y, x in tqdm(fold['valset']):
        
        sigma = y.std(-1, keepdim=True)

        x_ = chunk_and_apply(
            partial(model, coords=coords_klados),
            y.to(device), 
            batch_size=config.training.batch_size, 
            chunk_size=config.data.chunk_size
        )

        loss = nn.MSELoss()(x_/sigma, x/sigma)

        losses.append(loss.item())

    return {
        'val_L2': np.mean(losses),
    }

In [10]:
import wandb

use_wandb = False
if use_wandb:
    wandb.init(project='eeg_denoising', name=model_name, config=config.to_dict())


ckpt_freq = 10  # epochs

epoch = 0
while True:
    ipd.clear_output(wait=True)
    
    print('Epochs trained:', epoch)

    if epoch > 0:
        print(cv_results)

    cv_results = {k:[] for k in ['train_L2','val_L2']}

    rrmse, cc = [], []
    for fold_id, fold in enumerate(folds):
        print('-'*50 + '\nFold', fold_id)        

        results = {**training(fold), **validation(fold)}

        for k,v in results.items():
            cv_results[k].append(v)

        print(results)

    cv_results = {k: np.mean(v) for k,v in cv_results.items()}

    epoch += 1

    if use_wandb:
        wandb.log(cv_results, step=epoch)

    if epoch%ckpt_freq==0:
        torch.save(
            [fold['model'].state_dict() for fold in folds],
            model_path+f'/epoch{epoch}.ckpt'
        )
                  

Epochs trained: 0
--------------------------------------------------
Fold 0


  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?it/s]

{'train_L2': 0.19608095590584576, 'val_L2': 0.054124309991796814}
--------------------------------------------------
Fold 1


  0%|          | 0/313 [00:00<?, ?it/s]

KeyboardInterrupt: 