# BCGNet Demo

The following python toolbox trains a neural network intended for BCG artifact removal in EEG-fMRI datasets. More detail about our method can be found in the paper McIntosh et al. IEEE Trans Biomed Engi at https://ieeexplore.ieee.org/document/9124646

In [None]:
# import commands

import os

from pathlib import Path
from config import get_config
from session import Session

## Path setup

The first step is to set up all the relevant path. Here for the purpose of the demo we will define all path here; however, for custom use it is recommended to set up all path in the yaml file found in the 'config' directory

In [None]:
# get the absolute path to the root directory of the package
d_root = Path(os.getcwd())

# get the absolute path to the directory containing all data
# all dataset should be in EEGLAB format
# here the structure of directory is presumed to be
# d_data / subXX / input_file_naming_format
# where input_file_naming_format is defined in the yaml file
d_data = d_root / 'example_data' / 'raw_data'

# get the absolute path to the directory to save all trained models
# structure of the directory will be
# d_model / model_type / subXX / {model_type}_{time_stamp} / {model_type}_{time_stamp}.index

# (note: depending on TF version, either save in the new TF checkpoint format or old h5 format)
d_model = d_root / 'trained_model' / 'non_cv_model'

# get the absolute path to the directory to save all cleaned dataset
# structure of the directory will be
# d_output / subXX / output_file_naming_format
d_output = d_root / 'cleaned_data' / 'non_cv_data'

# (Optional)
# if the users wish, a dataset used to compare the performance of
# BCGNet can be provided, here a OBS-cleaned dataset is used
# convention is same as the d_data and all dataset
# should be in EEGLAB format

# get the absolute path to the directory containing all data
# cleaned by the alternative method
# here the structure of the directory is also presumed to be
# d_eval / subXX / eval_file_naming_format
d_eval = d_root / 'example_data' / 'obs_cleaned_data'

# (Optional - relevant only if  d_eval is provided)
# define the name of the alternative method
str_eval = 'OBS'

# generate a config (cfg) object from the yaml file
# all hyperparameters are from the paper
cfg = get_config(filename=d_root / 'config' / 'default_config.yaml')

# change all the path (recommended to set these in the yaml file directory)
cfg.d_root = d_root
cfg.d_data = d_data
cfg.d_model = d_model
cfg.d_output = d_output
cfg.d_eval = d_eval
cfg.str_eval = str_eval

Once the user has successfully set up all these variable in the yaml file, it's only needed to execute the following command

In [None]:
# cfg = get_config(d_root / 'config' / 'default_config.yaml')

In [None]:
# get rid of this later (for a quick test only)
cfg.num_epochs = 5

## Initialize training session

All key hyperparamters relevant to preprocessing and training are set in the yaml file

In [None]:
# provide the name of the subject
str_sub = 'sub34'

# provide the index of the runs to be used for training
# if just a single run, then [1] or [2]
# if multiple runs then [1, 2]

# for a run from sub11 and run index 1
# filename is presumed to be
# subXX_r0X_
vec_idx_run = [1, 2]


# str_arch specifies the type of the model to be used
# if str_arch is not provided then the default model (same as paper)
# is used. If user wants to define their own model, example on how to do it
# can be found in models/gru_arch_000.py, the only caveat is that 
# the name of the file and class name has to be same as the type of the model
# e.g. gru_arch_000

# random_seed is set to ensure that the splitting of entire dataset into
# training, validation and test sets is always the same, useful for model
# selection

# verbose sets the verbosity of Keras during model training
# 0=silent, 1=progress bar, 2=one line per epoch

# overwrite specifies whether or not to overwrite existing cleaned data

# cv_mode specifies whether or not to use cross validation mode
# more on this later
s1 = Session(str_sub=str_sub, vec_idx_run=vec_idx_run, str_arch='default_rnn_model',
             random_seed=1997, verbose=2, overwrite=False, cv_mode=False, num_fold=5, cfg=cfg)

## Prepare for training

In [None]:
# loads all dataset
s1.load_all_dataset()

# preform preprocessing of all dataset and initialize model
s1.prepare_training()

## Model training and generating cleaned dataset

In [None]:
# train the model
s1.train()

# generate cleaned dataset
s1.clean()

## Evaluating the performance

In [None]:
# Evaluate the performance of the model in terms of RMS and
# ratio of band power of cleaned dataset in delta, theta 
# and alpha bands compared to the raw data

# mode specifies which set to evaluate the performance on
# mode='train' evaluates on training set
# mode='valid' evaluates on validation set
# mode='test' evaluates on test set
s1.evaluate(mode='test')

## Saving trained model and cleaned dataset

In [None]:
# save trained model
s1.save_model()

# save cleaned data in .mat files
# the saved .mat file has one field 'data' which contains the 
# n_channel by n_time_stamp matrix holding all cleaned data
s1.save_data()

# alternatively, save cleaned data in Neuromag .fif format 
# (note that EEEGLAB support for .fif format is limited)
# s1.save_dataset()

## Cross validation mode

Alternatively, if cross validation is deemed necessary, the users can set up a cross validation style session via the following command

In [None]:
# first change the output and model directory
d_model = d_root / 'trained_model' / 'cv_model'
d_output = d_root / 'cleaned_data' / 'cv_data'
cfg.d_model = d_model
cfg.d_output = d_output

# it is recommended for user to set the num_fold argument,
# which specifies the number of cross validation folds
# in which case, percentage of test set and validation set data
# will be set to 1/num_fold and remaining data will be the training set
# e.g.
s2 = Session(str_sub=str_sub, vec_idx_run=vec_idx_run, str_arch='default_rnn_model',
             random_seed=1997, verbose=2, overwrite=True, cv_mode=True, num_fold=5, cfg=cfg)

# otherwise the number of cross validation folds will be inferred from
# percentage of test set data set in the config yaml file via 1/per_test
# s2 = Session(str_sub=str_sub, vec_idx_run=vec_idx_run, str_arch='default_rnn_model',
#                     random_seed=1997, verbose=2, overwrite=True,
#                     cv_mode=True, cfg=cfg)

Remaining commands are the same

In [None]:
s2.load_all_dataset()
s2.prepare_training()

In [None]:
s2.train()
s2.clean()

In [None]:
s2.evaluate(mode='test')

In [None]:
s2.save_model()
s2.save_data()