
# Baseline Experiment 1

Train model from scratch for each subject. 

Model: BSFShallowNet

Dataset: BCI Competitin IV 2a, BCNI2014001 via MOABB library


In [2]:
import matplotlib.pyplot as plt
from braindecode.datasets import MOABBDataset
from numpy import multiply
from braindecode.preprocessing import (Preprocessor,
                                       exponential_moving_standardize,
                                       preprocess)
from braindecode.preprocessing import create_windows_from_events
import torch
from braindecode.models import ShallowFBCSPNet
from braindecode.util import set_random_seeds
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from braindecode import EEGClassifier
# from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
import os
import pickle
from matplotlib.lines import Line2D
# from braindecode.visualization import plot_confusion_matrix

from braindecode.datasets import BaseConcatDataset
from braindecode.datasets.base import EEGWindowsDataset
from braindecode.preprocessing.windowers import _create_windows_from_events
import numpy as np
import mne
import random

C:\Users\mengz\anaconda3\envs\hyperBCI\Lib\site-packages\moabb\pipelines\__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


## Loading and preparing the data




### Loading the dataset




In [3]:
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=list(range(1, 10)))

### Preprocessing




In [4]:
low_cut_hz = 4.  # low cut frequency for filtering
high_cut_hz = 38.  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV
    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(exponential_moving_standardize,  # Exponential moving standardization
                 factor_new=factor_new, init_block_size=init_block_size)
]

# Transform the data
preprocess(dataset, preprocessors, n_jobs=-1)

  warn('Preprocessing choices with lambda functions cannot be saved.')


<braindecode.datasets.moabb.MOABBDataset at 0x227e1ea5f10>

### Extracting Compute Windows




In [5]:
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

In [6]:
def get_subset(window_set, target_trial_num):
    new_ds_lst = []
    
    for ds in window_set.datasets:
        cur_run_trial_num = len(ds.metadata)
        if target_trial_num > cur_run_trial_num:
            new_ds_lst.append(ds)
            target_trial_num -= cur_run_trial_num
        else:
            new_ds_lst.append(EEGWindowsDataset(ds.raw, ds.metadata[:target_trial_num], description=ds.description[:target_trial_num]))
            break

    return BaseConcatDataset(new_ds_lst)

## Fine tune for holdout subject set; Pre-train with data from all other subjects

In [None]:
results_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'results'))
exp_name = f'baseline_2_2_pretrain'

### ---------- Pre-training parameters ----------
lr = 0.07 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 30

splitted_by_subj = windows_dataset.split('subject')

data_amount_step = 20
results_columns = ['valid_accuracy',]
dict_results = {}

for holdout_subj_id in range(1, 10):
    
    print(f'Hold out data from subject {holdout_subj_id}')
    
    ### ---------- Split dataset into pre-train set and fine-tune (holdout) set ----------
    pre_train_set = BaseConcatDataset([splitted_by_subj.get(f'{i}') for i in range(1, 10) if i != holdout_subj_id])
    fine_tune_set = BaseConcatDataset([splitted_by_subj.get(f'{holdout_subj_id}'),])

    ### ---------- Split pre-train set into pre-train-train set and pre-train-test set ----------
    pre_train_train_set_lst = []
    pre_train_test_set_lst = []
    pre_train_test_set_size = 1 # runs
    for key, val in pre_train_set.split('subject').items():
        subj_splitted_lst_by_run = list(val.split('run').values())
        pre_train_train_set_lst.extend(subj_splitted_lst_by_run[:-pre_train_test_set_size])
        pre_train_test_set_lst.extend(subj_splitted_lst_by_run[-pre_train_test_set_size:])
    
    pre_train_train_set = BaseConcatDataset(pre_train_train_set_lst)
    pre_train_test_set = BaseConcatDataset(pre_train_test_set_lst)

    ### ---------- Pre-training ----------
    cuda = torch.cuda.is_available() 
    device = 'cuda' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    
    seed = 20200220
    set_random_seeds(seed=seed, cuda=cuda)
    
    n_classes = 4
    classes = list(range(n_classes))
    # Extract number of chans and time steps from dataset
    n_chans = windows_dataset[0][0].shape[0]
    input_window_samples = windows_dataset[0][0].shape[1]
    
    cur_model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_window_samples=input_window_samples,
        final_conv_length='auto',
    )
    
    cur_clf = EEGClassifier(
        cur_model,
        criterion=torch.nn.NLLLoss,
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(pre_train_test_set),  # using valid_set for validation
        optimizer__lr=lr,
        optimizer__weight_decay=weight_decay,
        batch_size=batch_size,
        callbacks=[
            "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
        classes=classes,
        warm_start=False
    )

    print(f'Pre-training model with data from all subjects but subject {holdout_subj_id}')
    _ = cur_clf.fit(pre_train_train_set, y=None, epochs=n_epochs)

    cur_clf.save_params(f_params=os.path.join(results_dir, f'{exp_name}_without_subj_{holdout_subj_id}_model.pkl'), 
                        f_optimizer=os.path.join(results_dir, f'{exp_name}_without_subj_{holdout_subj_id}_opt.pkl'), 
                        f_history=os.path.join(results_dir, f'{exp_name}_without_subj_{holdout_subj_id}_history.json'))

    ### ---------- Split fine tune set into fine tune-train set and fine tune-valid set ----------
    finetune_splitted_lst_by_run = list(fine_tune_set.split('run').values())
    finetune_subj_train_set = BaseConcatDataset(finetune_splitted_lst_by_run[:-1])
    finetune_subj_valid_set = BaseConcatDataset(finetune_splitted_lst_by_run[-1:])
    
    ### Baseline accuracy on the finetune_valid set
    finetune_valid_predicted = cur_clf.predict(finetune_subj_valid_set)
    finetune_valid_true = np.array(finetune_subj_valid_set.get_metadata().target)
    finetune_baseline_correct = np.equal(finetune_valid_predicted, finetune_valid_true)
    finetune_baseline_acc = np.sum(finetune_baseline_correct) / len(finetune_baseline_correct)
    print(f'Before finetuning for subject {holdout_subj_id}, the baseline accuracy is {finetune_baseline_acc}')

    ### ---------- Fine tuning ----------
    dict_subj_results = {0: finetune_baseline_acc}

    ### Finetune with different amount of new data
    finetune_trials_num = len(finetune_subj_train_set.get_metadata())
    for finetune_training_data_amount in np.arange(1, finetune_trials_num // data_amount_step) * data_amount_step:

        ## Get current finetune samples
        cur_finetune_subj_train_subset = get_subset(finetune_subj_train_set, finetune_training_data_amount)

        finetune_model = ShallowFBCSPNet(
            n_chans,
            n_classes,
            input_window_samples=input_window_samples,
            final_conv_length='auto',
        )

        ### ---------- Fine tune parameters ----------
        finetune_lr = 0.07 * 0.01
        finetune_weight_decay = 0
        finetune_batch_size = int(min(finetune_training_data_amount // 2, 64))
        finetune_n_epochs = 20
        
        new_clf = EEGClassifier(
            finetune_model,
            criterion=torch.nn.NLLLoss,
            optimizer=torch.optim.AdamW,
            train_split=predefined_split(finetune_subj_valid_set), 
            optimizer__lr=finetune_lr,
            optimizer__weight_decay=finetune_weight_decay,
            batch_size=finetune_batch_size,
            callbacks=[
                "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=finetune_n_epochs - 1)),
            ],
            device=device,
            classes=classes,
        )
        new_clf.initialize()
        
        ## Load pretrained model
        new_clf.load_params(f_params=os.path.join(results_dir, f'{exp_name}_without_subj_{holdout_subj_id}_model.pkl'), 
                            f_optimizer=os.path.join(results_dir, f'{exp_name}_without_subj_{holdout_subj_id}_opt.pkl'), 
                            f_history=os.path.join(results_dir, f'{exp_name}_without_subj_{holdout_subj_id}_history.json'))

        ## Continue training / finetuning
        print(f'Fine tuning model for subject {holdout_subj_id} with {finetune_training_data_amount} trials')
        _ = new_clf.partial_fit(cur_finetune_subj_train_subset, y=None, epochs=finetune_n_epochs)

        ## Get results after fine tuning
        df = pd.DataFrame(new_clf.history[:, results_columns], columns=results_columns,
                          # index=new_clf.history[:, 'epoch'],
                         )

        cur_final_acc = np.mean(df.tail(5).valid_accuracy)
        dict_subj_results.update({finetune_training_data_amount: cur_final_acc})

    dict_results.update({holdout_subj_id: dict_subj_results})

In [None]:
file_name = 'ShallowFBCSPNet_BNCI2014_001_finetuning_2'
file_path = os.path.join(results_dir, f'{file_name}.pkl')

with open(f'{results_dir}\\{file_name}.pkl', 'wb') as f:
    pickle.dump(dict_results, f)

In [None]:
# if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
#     with open(file_path, 'rb') as f:
#         baseline_2_1 = pickle.load(f)
#     print("Dictionary loaded successfully.")
# else:
#     print(f"Error: File '{file_path}' does not exist or is empty.")

## Plotting Results




In [None]:
df_results = pd.DataFrame(dict_results)
display(df_results)

subject_averaged_df = df_results.mean(axis=1)
# Calculate the standard error of the mean
std_err_df = df_results.sem(axis=1)
# Calculate the confidence interval (95% confidence level)
conf_interval_df = stats.t.interval(0.95, len(df_results.columns) - 1, loc=subject_averaged_df, scale=std_err_df)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # 1 row, 2 columns

for subj_id, subj_res in dict_results.items():
    ax1.plot(subj_res.keys(), subj_res.values(), label=f'Subject {subj_id}')

ax1.legend()
ax1.set_xlabel('Amount of fine tuning data (trials, 4 secs each)')
ax1.set_ylabel('Validation Accuracy')

ax2.plot(subject_averaged_df, label='Subject averaged')
ax2.fill_between(subject_averaged_df.index, conf_interval_df[0], conf_interval_df[1], color='b', alpha=0.3, label='95% CI')
ax2.legend()
ax2.set_xlabel('Amount of fine tuning data (trials, 4 secs each)')

plt.suptitle('ShallowFBCSPNet on BNCI2014_001 Dataset \n Fine-tuning (using each subject as holdout)')

plt.savefig(os.path.join(results_dir, f'{file_name}.png'))