In [1]:
import argparse
import logging
import os.path
import sys
from functools import partial

import numpy as np
import torch
from torch import nn
from braindecode import EEGClassifier
from braindecode.datasets.moabb import MOABBDataset
from braindecode.preprocessing.preprocess import Preprocessor, preprocess
from braindecode.preprocessing.preprocess import exponential_moving_standardize
from braindecode.preprocessing.preprocess import exponential_moving_demean
from braindecode.preprocessing.windowers import create_windows_from_events
from braindecode.models import Deep4Net, EEGResNet
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model, get_output_shape
from braindecode.training.losses import CroppedLoss
from braindecode.util import set_random_seeds
from braindecode.visualization.gradients import compute_amplitude_gradients
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from torch.utils.data import Subset

import torchvision.transforms as transforms
import torch.nn.functional as F
from audiomentations import Compose, AddGaussianSNR, AddGaussianNoise, TimeStretch, PitchShift, Shift, AddGaussianSNR, Gain, GainTransition
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torch

import numpy as np
import cv2
from Encoder.image_encoder import load_image_encoder_eeg2image

log = logging.getLogger(__name__)

2024-03-06 00:08:26.084615: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/exx/GithubClonedRepo/EEG-Research/.env/lib/python3.10/site-packages/moabb/pipelines/__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


In [2]:
import torch.backends.cudnn as cudnn
import torch
# from hyperoptim.parse import cartesian_dict_of_lists_product, \
#     product_of_list_of_lists_of_dicts
import logging
import time
import os

os.sys.path.insert(0, '/home/schirrmr/code/invertible-reimplement/')


logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s')


log = logging.getLogger(__name__)
log.setLevel('INFO')


def get_templates():
    return {}
def cartesian_dict_of_lists_product(param_dict):
  """
  This function takes a dictionary where keys are parameter names and values are lists of possible values,
  and returns a list of dictionaries, each representing a combination of parameter values.
  """
  if not param_dict:
    return [{}]
  first_param_name, first_param_values = list(param_dict.items())[0]
  remaining_params = {key: value for key, value in param_dict.items() if key != first_param_name}
  return [
      {**combination, first_param_name: value}
      for value in first_param_values
      for combination in cartesian_dict_of_lists_product(remaining_params)
  ]
def product_of_list_of_lists_of_dicts(list_of_param_dicts):
  """
  This function takes a list of lists of dictionaries, where each inner list represents parameter combinations
  for a group, and returns a list of dictionaries, each containing combinations from all groups.
  """
  if not list_of_param_dicts:
    return [{}]
  first_group_params = list_of_param_dicts[0]
  remaining_groups = list_of_param_dicts[1:]
  return [
      {**combination, **group_param}
      for combination in product_of_list_of_lists_of_dicts(remaining_groups)
      for group_param in first_group_params
  ]


def get_grid_param_list():
    dictlistprod = cartesian_dict_of_lists_product

    save_params = [
        {
            'save_folder': '/home/schirrmr/data/exps/braindecode/hgd-decoding/',
        },
    ]

    debug_params = [{
        'debug': False,
    }]

    data_params = dictlistprod({
        'subject_id': range(1, 15),
        'low_cut_hz': [0, 4],
        'high_cut_hz': [None],
        'exponential_moving_fn': ['standardize', 'demean'],#standardize'],#'demean',#demean',
        'only_C_sensors': [True],
        'do_common_average_reference': [True],
        'use_final_eval': [False],#False
    })

    train_params = dictlistprod({
        'n_epochs': [800],
    })

    random_params = dictlistprod({
        'seed': range(0,3),#range(0, 3),
    })

    model_params = dictlistprod({
        'model_name': ['deep', 'shallow'],#'shallow',
    })

    store_params = dictlistprod({
        'save_amp_grads': [False],
        'save_model': [False],
    })

    grid_params = product_of_list_of_lists_of_dicts([
        save_params,
        data_params,
        train_params,
        debug_params,
        random_params,
        model_params,
        store_params,
    ])

    return grid_params


def sample_config_params(rng, params):
    return params


def run(
        ex,
        subject_id,
        low_cut_hz,
        high_cut_hz,
        exponential_moving_fn,
        n_epochs,
        model_name,
        seed,
        debug,
        only_C_sensors,
        do_common_average_reference,
        use_final_eval,
        save_amp_grads,
        save_model,
):
    kwargs = locals()
    kwargs.pop('ex')
    if not debug:
        log.setLevel('INFO')
    if debug:
        kwargs['n_epochs'] = 3

    file_obs = ex.observers[0]
    output_dir = file_obs.dir
    kwargs['output_dir'] = output_dir
    torch.backends.cudnn.benchmark = True
    import sys
    logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)

    start_time = time.time()
    ex.info['finished'] = False
    from braindecode.experiments.hgd.run import run_exp

    clf = run_exp(**kwargs)

    end_time = time.time()
    run_time = end_time - start_time
    ex.info['finished'] = True
    ignore_keys = [
        'batches', 'epoch', 'train_batch_count', 'valid_batch_count',
        'train_loss_best',
        'valid_loss_best', 'train_trial_accuracy_best',
        'valid_trial_accuracy_best']
    results = dict([(key, val) for key, val in clf.history[-1].items() if
                    key not in ignore_keys])

    for key, val in results.items():
        ex.info[key] = float(val)
    ex.info['runtime'] = run_time

In [3]:
def load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz, exponential_moving_fn,
                           only_C_sensors, do_common_average_reference, set_name):
    log.info("Load dataset...")
    if set_name == 'hgd':
        dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id])
    else:
        assert set_name == "bcic_iv_2a"
        dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])

    C_sensors = [
        'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz', 'C4', 'CP5',
        'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
        'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h',
        'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h',
        'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h',
        'CCP2h', 'CPP1h', 'CPP2h']
    EEG_sensors = ['Fp1', 'Fp2', 'Fpz', 'F7', 'F3', 'Fz', 'F4', 'F8',
            'FC5', 'FC1', 'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2',
            'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1',
            'Oz', 'O2', 'AF7', 'AF3', 'AF4', 'AF8', 'F5', 'F1', 'F2', 'F6', 'FC3',
            'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1',
            'P2', 'P6', 'PO5', 'PO3', 'PO4', 'PO6', 'FT7', 'FT8', 'TP7', 'TP8',
            'PO7', 'PO8', 'FT9', 'FT10', 'TPP9h', 'TPP10h', 'PO9', 'PO10', 'P9',
            'P10', 'AFF1', 'AFz', 'AFF2', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h',
            'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h',
            'CPP3h', 'CPP4h', 'CPP6h', 'PPO1', 'PPO2', 'I1', 'Iz', 'I2', 'AFp3h',
            'AFp4h', 'AFF5h', 'AFF6h', 'FFT7h', 'FFC1h', 'FFC2h', 'FFT8h', 'FTT9h',
            'FTT7h', 'FCC1h', 'FCC2h', 'FTT8h', 'FTT10h', 'TTP7h', 'CCP1h', 'CCP2h',
            'TTP8h', 'TPP7h', 'CPP1h', 'CPP2h', 'TPP8h', 'PPO9h', 'PPO5h', 'PPO6h',
            'PPO10h', 'POO9h', 'POO3h', 'POO4h', 'POO10h', 'OI1h', 'OI2h']
    if only_C_sensors:
        sensor_names = C_sensors
    else:
        sensor_names = EEG_sensors
    # Parameters for exponential moving standardization
    factor_new = 1e-3
    init_block_size = 1000

    log.info("Preprocess dataset...")

    moving_fn ={'standardize': exponential_moving_standardize,
                'demean': exponential_moving_demean}[exponential_moving_fn]
    preprocessors = [
        # keep only C sensors
        Preprocessor(fn='load_data'),
    ]
    if set_name == "hgd":
        preprocessors.append(Preprocessor(fn='pick_channels', ch_names=sensor_names, ordered=True))
    else:
        assert set_name == 'bcic_iv_2a'
        preprocessors.append(Preprocessor("pick_types", eeg=True, meg=False, stim=False))  # Keep EEG sensors

    preprocessors.append(Preprocessor(fn=lambda x: x * 1e6, apply_on_array=True))
    preprocessors.append(Preprocessor(fn=lambda x: np.clip(x, -800, 800), apply_on_array=True))

    if do_common_average_reference:
        preprocessors.append(Preprocessor(fn='set_eeg_reference', ref_channels='average'),)
    preprocessors.extend([
        Preprocessor(fn='resample', sfreq=250),
        # bandpass filter
        Preprocessor(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        # exponential moving standardization
        Preprocessor(fn=moving_fn, factor_new=factor_new,
                     init_block_size=init_block_size, apply_on_array=True),
    ])

    # Transform the data
    preprocess(dataset, preprocessors)
    return dataset

In [4]:
def cut_windows(dataset, input_window_samples, window_stride_samples):
    ######################################################################
    # Cut the data into windows
    # -------------------------
    ######################################################################
    # In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the
    # ``create_windows_from_events`` function.
    trial_start_offset_seconds = -0.5
    # Extract sampling frequency, check that they are same in all datasets
    sfreq = dataset.datasets[0].raw.info['sfreq']
    print(sfreq)
    assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])

    # Calculate the trial start offset in samples.
    trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

    # Create windows using braindecode function for this. It needs parameters to define how
    # trials should be used.
    windows_dataset = create_windows_from_events(
        dataset,
        trial_start_offset_samples=trial_start_offset_samples,
        trial_stop_offset_samples=0,
        window_size_samples=input_window_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=False,
        preload=True,
        mapping={'left_hand': 0, 'right_hand': 1, 'feet': 2, 'rest': 3},
    )
    return windows_dataset

In [103]:
augment = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1),
            Shift(p=0.5)
        ])

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((224, 224), antialias=None),
            normalize
        ])


In [5]:
def transform_to_grayscale_heatmap(dataset, augment=None):
    """
    Transforms a window dataset of EEG data into grayscale heatmaps.

    Args:
        window_dataset: The window dataset to be transformed.
        augment: A function for augmenting the EEG data. Defaults to None.
        downsample_rate: The desired sampling rate for the heatmaps. Defaults to 440.

    Returns:
        A new dataset containing the grayscale heatmaps.
    """
    eeg_dataset = [dataset[dataset_idx][0] for dataset_idx in len(dataset)]

    for eeg in eeg_dataset:
        eeg = eeg

        # # Downsample if needed
        # if downsample_rate != 440:  # Assuming original sampling rate is 440
        #     eeg = eeg[:, ::int(440 / downsample_rate)]  # Resample directly for efficiency

        # Apply augmentation (if provided)
        if augment is not None:
            eeg = np.array([augment(samples=eeg[i]) for i in range(eeg.shape[0])])

        eeg = torch.tensor(eeg, dtype=torch.float32)

        # Normalize and convert to grayscale image
        normalized_data = (eeg - eeg.min()) / (eeg.max() - eeg.min())
        grayscale_images = (normalized_data * 255).to(torch.uint8)

        # Resize to desired size
        eeg_heatmap = F.interpolate(grayscale_images.unsqueeze(0).unsqueeze(0), size=(4*128, 1000), mode='bilinear')

        # Apply edge detection and blending
        eeg_heatmap = eeg_heatmap.squeeze(0).squeeze(0).numpy()
        eeg_heatmap = cv2.GaussianBlur(eeg_heatmap, (3, 3), 0)
        edges = cv2.Canny(eeg_heatmap, 50, 120)
        alpha = 0.9
        eeg_heatmap = alpha * eeg_heatmap + (1 - alpha) * edges

        eeg_heatmap = torch.tensor(eeg_heatmap, dtype=torch.float32)

        eeg_heatmap = eeg_heatmap.squeeze(0).squeeze(0)
        eeg_heatmap = eeg_heatmap.unsqueeze(0).repeat(3,  1, 1)
        
        eeg_heatmap_resize = transform(eeg_heatmap)
        
    return eeg_heatmap

In [6]:
def split_into_train_valid(windows_dataset, use_final_eval):
    ######################################################################
    # Split the dataset
    # -----------------
    #
    # This code is the same as in trialwise decoding.
    #

    print("description", windows_dataset.description)
    if sum(windows_dataset.description.session == 'session_T') > 0:
        # BCIC IV 2a case
        splitted = windows_dataset.split("session")
        train_key = 'session_T'
        test_key = 'session_E'
    else:
        splitted = windows_dataset.split('run')
        train_key = '0train'
        test_key = '1test'
    print("splitted", splitted)
    if use_final_eval:
        train_set = splitted[train_key]
        valid_set = splitted[test_key]
    else:
        full_train_set = splitted[train_key]
        n_split = int(np.round(0.8 * len(full_train_set)))
        # ensure this is multiple of 2 (number of windows per trial)
        n_windows_per_trial = 2  # here set by hand
        n_split = n_split - (n_split % n_windows_per_trial)
        valid_set = Subset(full_train_set, range(n_split, len(full_train_set)))
        train_set = Subset(full_train_set, range(0, n_split))
    return train_set, valid_set

In [46]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [5]:
def create_cropped_model(model_name, n_chans, resnet_init_a):
    ######################################################################
    # Now we create the model. To enable it to be used in cropped decoding
    # efficiently, we manually set the length of the final convolution layer
    # to some length that makes the receptive field of the ConvNet smaller
    # than ``input_window_samples`` (see ``final_conv_length=30`` in the model
    # definition).
    #

    cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
    device = 'cuda' if cuda else 'cpu'
    if cuda:
        torch.backends.cudnn.benchmark = True
    seed = 20200220  # random seed to make results reproducible
    # Set random seed to be able to reproduce results
    set_random_seeds(seed=seed, cuda=cuda)

    n_classes = 4

    if model_name == 'shallow':
        model = ShallowFBCSPNet(
            n_chans,
            n_classes,
            input_window_samples=None, # no need to provide if final_conv_length given
            final_conv_length=30,
        )
    elif model_name == 'resnet':
        model = EEGResNet(
            n_chans,
            n_classes,
            input_window_samples=None, # no need to provide if final_conv_length given
            n_first_filters=48,
            final_pool_length=10,
            conv_weight_init_fn=partial(nn.init.kaiming_normal_, a=resnet_init_a))
    elif model_name == 'deep':
        model = Deep4Net(
            n_chans,
            n_classes,
            input_window_samples=None, # no need to provide if final_conv_length given
            final_conv_length=2,
        )
    elif model_name == "efficientnet_b0":
        """ Efficientnet B0
        """
        model = models.efficientnet_b0(pretrained=True)
        set_parameter_requires_grad(model, feature_extracting = False)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, n_classes)
        input_size = 224
    elif model_name == "efficientnet_v2_s":
        """ Efficientnet V2-S
        """
        model = models.efficientnet_v2_s(pretrained=True)
        set_parameter_requires_grad(model, feature_extracting = False)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, n_classes)
        input_size = 224
    # Send model to GPU
    if cuda:
        model.cuda()

    ######################################################################
    # And now we transform model with strides to a model that outputs dense
    # prediction, so we can use it to obtain predictions for all
    # crops.
    #
    if model_name in ["shallow", "deep"]:
        to_dense_prediction_model(model)
    return model



In [6]:
seed = 0
low_cut_hz = None  # low cut frequency for filtering
high_cut_hz = None  # high cut frequency for filtering
n_epochs = 100
model_name = 'deep'
output_dir = './results/HGD/'
exponential_moving_fn = "standardize"
only_C_sensors = False
do_common_average_reference = False
use_final_eval = True
save_amp_grads = False
save_model = False
resnet_lr = 1e-3
resnet_init_a = 1
resnet_weight_decay = 1e-5
debug = False
set_name = "hgd"
drop_channel_prob = 0.

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)

In [7]:
set_random_seeds(seed, True)
subject_id_1 = 1 
subject_id_2 = 2 
subject_id_3 = 3 
subject_id_4 = 4 
subject_id_5 = 5
subject_id_6 = 6 
subject_id_7 = 7 
subject_id_8 = 8 
subject_id_9 = 9 
subject_id_10 = 10
subject_id_11 = 11 
subject_id_12 = 12 
subject_id_13 = 13
subject_id_14 = 14
# log.info(f"Load and preprocess data for subject {subject_id}...")
dataset_1 = load_preprocessed_data(subject_id_1, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_2 = load_preprocessed_data(subject_id_2, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_3 = load_preprocessed_data(subject_id_3, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_4 = load_preprocessed_data(subject_id_4, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_5 = load_preprocessed_data(subject_id_5, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_6 = load_preprocessed_data(subject_id_6, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_7 = load_preprocessed_data(subject_id_7, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_8 = load_preprocessed_data(subject_id_8, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_9 = load_preprocessed_data(subject_id_9, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_10 = load_preprocessed_data(subject_id_10, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_11 = load_preprocessed_data(subject_id_11, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_12 = load_preprocessed_data(subject_id_12, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_13 = load_preprocessed_data(subject_id_13, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )
dataset_14 = load_preprocessed_data(subject_id_14, low_cut_hz, high_cut_hz,
                                 exponential_moving_fn=exponential_moving_fn,
                                 only_C_sensors=only_C_sensors,
                                 do_common_average_reference=do_common_average_reference,
                                 set_name=set_name,
                                 )

2024-03-06 00:09:30,804 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/1.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:09:33,762 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:09:47,239 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/2.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:09:52,633 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:10:19,492 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/3.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:10:25,111 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:10:52,359 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/4.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:10:58,144 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:11:25,804 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/5.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:11:30,619 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:11:57,459 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/6.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:12:03,157 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:12:31,383 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/7.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:12:36,997 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:13:04,381 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/8.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:13:08,790 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:13:34,083 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/9.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Chan

2024-03-06 00:13:39,681 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:14:07,078 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/10.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Cha

2024-03-06 00:14:12,641 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:14:40,100 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/11.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Cha

2024-03-06 00:14:45,756 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:15:13,169 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/12.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Cha

2024-03-06 00:15:18,781 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:15:46,270 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/13.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Cha

2024-03-06 00:15:51,418 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



2024-03-06 00:16:18,371 | INFO : Load dataset...


Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/14.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').
Channel 'EEG M1' recognized as type EEG (renamed to 'M1').
Channel 'EEG T7' recognized as type EEG (renamed to 'T7').
Cha

2024-03-06 00:16:23,948 | INFO : Preprocess dataset...


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)



In [26]:
n_chans = dataset[0][0].shape[0]

In [27]:
print(n_chans)

128


In [8]:
print(len(dataset_1))

921500


In [9]:
# Extract number of chans from dataset
n_chans = dataset_1[0][0].shape[0]

log.info("Create cropped model...")
model = create_cropped_model(model_name, n_chans, resnet_init_a)

# Cut windows from the preprocessed data, using number of predictions
# per compute window to cut non-overlapping fully covering windows
# (except for overlap of last window to stay within trial bounds)
log.info("Cut windows from dataset ...")
input_window_samples = 1000

# To know the models’ receptive field, we calculate the shape of model
# output for a dummy input.
output_shape = get_output_shape(model, n_chans, input_window_samples)
n_preds_per_input = output_shape[2]
windows_dataset_1 = cut_windows(
    dataset_1, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_2 = cut_windows(
    dataset_2, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_3 = cut_windows(
    dataset_3, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_4 = cut_windows(
    dataset_4, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_5 = cut_windows(
    dataset_5, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_6 = cut_windows(
    dataset_6, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_7 = cut_windows(
    dataset_7, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_8 = cut_windows(
    dataset_8, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_9 = cut_windows(
    dataset_9, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_10 = cut_windows(
    dataset_10, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_11 = cut_windows(
    dataset_11, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_12 = cut_windows(
    dataset_12, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_13 = cut_windows(
    dataset_13, input_window_samples, window_stride_samples=n_preds_per_input)
windows_dataset_14 = cut_windows(
    dataset_14, input_window_samples, window_stride_samples=n_preds_per_input)

2024-03-06 00:18:41,344 | INFO : Create cropped model...
  warn(
2024-03-06 00:18:42,486 | INFO : Cut windows from dataset ...


250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
250.0
Used Annotations descriptions: ['feet', 'left_hand', 'rest

In [10]:
print(len(windows_dataset_1))
print(len(windows_dataset_2))
print(len(windows_dataset_3))
print(len(windows_dataset_4))
print(len(windows_dataset_5))
print(len(windows_dataset_6))
print(len(windows_dataset_7))
print(len(windows_dataset_8))
print(len(windows_dataset_9))
print(len(windows_dataset_10))
print(len(windows_dataset_11))
print(len(windows_dataset_12))
print(len(windows_dataset_13))
print(len(windows_dataset_14))

960
1946
2080
2114
1760
2080
2080
1628
2080
2080
2080
2080
1920
2080


In [11]:
from braindecode.datasets import BaseConcatDataset

In [12]:
ls_window = [windows_dataset_1, windows_dataset_2, windows_dataset_3, windows_dataset_4, windows_dataset_5, windows_dataset_6,
             windows_dataset_7, windows_dataset_8, windows_dataset_9, windows_dataset_10, windows_dataset_11, windows_dataset_12,
             windows_dataset_13, windows_dataset_14]

In [13]:
meta_dataset = BaseConcatDataset(ls_window)

In [14]:
print(len(meta_dataset))

26968


In [15]:
print(meta_dataset[4])

(array([[ 0.07717729,  0.37215105,  0.51180005, ..., -1.7510271 ,
        -1.7418437 , -1.4952111 ],
       [ 0.24559423,  0.24739473,  0.11590777, ..., -1.8094273 ,
        -1.6247057 , -1.305433  ],
       [ 0.5059485 ,  0.5193529 ,  0.3159461 , ..., -1.2553885 ,
        -1.3210361 , -0.92770123],
       ...,
       [-1.4092758 , -0.9729529 , -1.3785982 , ...,  0.11191165,
        -0.5466334 , -1.1093411 ],
       [-0.92409897, -0.33082068, -0.6130255 , ...,  0.10419627,
        -0.4491704 , -1.3548839 ],
       [-1.2649872 , -0.6935524 , -1.0574268 , ...,  0.06098427,
        -0.5013318 , -1.3377113 ]], dtype=float32), 0, [0, 5993, 6993])


In [16]:
meta_eeg_data = []  # Create an empty list to store all dictionaries

for i in range(len(meta_dataset)):
  eeg_data = meta_dataset[i][0]
  label = meta_dataset[i][1]
  sample_dict = {"eeg": eeg_data, "label": label}  # Create a dictionary for each sample
  meta_eeg_data.append(sample_dict)  # Append the dictionary to the list


In [17]:
print(len(meta_eeg_data))

26968


In [18]:
print(meta_eeg_data[1])

{'eeg': array([[-0.05097474, -0.03853722, -0.44336405, ...,  0.11838483,
         0.02127389,  0.37641716],
       [-0.04080858, -0.01586035, -0.39235893, ...,  0.23451884,
         0.16657886,  0.27804092],
       [ 0.11850397,  0.18304576, -0.06896071, ...,  0.14510281,
         0.05850426,  0.14260519],
       ...,
       [ 0.33444002,  0.30649388, -0.9044521 , ..., -0.14242806,
         0.12561147,  0.02040822],
       [ 0.00350811,  0.3602116 , -0.14849193, ..., -0.59028906,
        -0.40563998, -0.16179495],
       [ 0.28973633,  0.33587858, -0.77572596, ..., -0.18831319,
         0.00323256,  0.03320364]], dtype=float32), 'label': 3}


In [19]:
torch.save(meta_eeg_data, "/media/mountHDD1/LanxHuyen/high_gamma_dataset_deep.pth")

In [74]:
print(eeg_data.shape)

(128, 1000)


In [32]:
print(type(windows_dataset))

<class 'braindecode.datasets.base.BaseConcatDataset'>


In [88]:
# Split into train and valid, ignoring final evaluation for now
log.info("Split into train and valid...")
train_set, valid_set = split_into_train_valid(windows_dataset, use_final_eval=use_final_eval)

2024-03-04 17:22:28,321 | INFO : Split into train and valid...


description    subject session     run
0        1       0  0train
1        1       0   1test
splitted {'0train': <braindecode.datasets.base.BaseConcatDataset object at 0x7f54b52471c0>, '1test': <braindecode.datasets.base.BaseConcatDataset object at 0x7f54af986ec0>}


In [89]:
print(train_set)

print(len(dataset))

print(len(windows_dataset))
# input shape with batch size =1:
print(f"Input shape: (1, {n_chans}, {input_window_samples})")
#output shape
print(f"Output shape: {output_shape}")

<braindecode.datasets.base.BaseConcatDataset object at 0x7f54b52471c0>
921500
960
Input shape: (1, 128, 1000)
Output shape: torch.Size([1, 4, 479])


In [75]:
from torchinfo import summary

# summary(model, input_size=(1, n_chans, input_window_samples))
summary(model, input_size=(1,3,224,224))

Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet                                            [1, 4]                    --
├─Sequential: 1-1                                       [1, 1280, 7, 7]           --
│    └─Conv2dNormActivation: 2-1                        [1, 32, 112, 112]         --
│    │    └─Conv2d: 3-1                                 [1, 32, 112, 112]         864
│    │    └─BatchNorm2d: 3-2                            [1, 32, 112, 112]         64
│    │    └─SiLU: 3-3                                   [1, 32, 112, 112]         --
│    └─Sequential: 2-2                                  [1, 16, 112, 112]         --
│    │    └─MBConv: 3-4                                 [1, 16, 112, 112]         1,448
│    └─Sequential: 2-3                                  [1, 24, 56, 56]           --
│    │    └─MBConv: 3-5                                 [1, 24, 56, 56]           6,004
│    │    └─MBConv: 3-6                              

In [12]:
####
# 1. Import model from EEGEncoder. changes model such that the input shape is the same
# 2. Duplicate the previous dataloader and change the get() function to extract each sample from the windows_dataset
####

In [30]:
def run_training(model, model_name, train_set, valid_set, device, n_epochs, resnet_lr,
                 resnet_weight_decay, drop_channel_prob):
    assert model_name in ['deep', 'shallow', 'resnet']
    if model_name == 'shallow':
        # These values we found good for shallow network:
        lr = 0.0625 * 0.01
        weight_decay = 0
    elif model_name == 'resnet':
        # Guessing here
        # For deep4 they should be:
        lr = resnet_lr
        weight_decay = resnet_weight_decay

    else:
        assert model_name == 'deep'
        # For deep4 they should be:
        lr = 1 * 0.01
        weight_decay = 0.5 * 0.001

    batch_size = 64
    from braindecode.augmentation import AugmentedDataLoader, ChannelsDropout
    
    transforms = [ChannelsDropout(1, drop_channel_prob)]

    clf = EEGClassifier(
        model,
        cropped=True,
        criterion=CroppedLoss,
        criterion__loss_function=torch.nn.functional.nll_loss,
        iterator_train=AugmentedDataLoader,
        iterator_train__transforms=transforms,  # This sets the augmentations to use
        optimizer=torch.optim.AdamW,
        train_split=predefined_split(valid_set),
        optimizer__lr=lr,
        optimizer__weight_decay=weight_decay,
        iterator_train__shuffle=True,
        batch_size=batch_size,
        callbacks=[
            "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
        ],
        device=device,
        classes=["right", "left", "rest", "feet"],
    )
    # Model training for a specified number of epochs. `y` is None as it is already supplied
    # in the dataset.
    clf.fit(train_set, y=None, epochs=n_epochs)
    return clf

In [31]:
def compute_and_store_amp_grads(model, train_set, filename):
    amp_grads_per_filter = compute_amplitude_gradients(model, train_set, batch_size=64)
    # average across compute windows
    avg_amp_grads_per_filter = np.mean(amp_grads_per_filter, axis=1)
    np.save(filename, avg_amp_grads_per_filter)

In [32]:
def run_exp(
        seed,
        subject_id,
        low_cut_hz,
        high_cut_hz,
        exponential_moving_fn,
        n_epochs,
        model_name,
        output_dir,
        only_C_sensors,
        do_common_average_reference,
        use_final_eval,
        save_amp_grads,
        save_model,
        resnet_lr,
        resnet_weight_decay,
        resnet_init_a,
        debug,
        set_name,
        drop_channel_prob):
    assert model_name in ['deep', 'shallow', 'resnet']
    set_random_seeds(seed, True)
    log.info(f"Load and preprocess data for subject {subject_id}...")
    dataset = load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz,
                                     exponential_moving_fn=exponential_moving_fn,
                                     only_C_sensors=only_C_sensors,
                                     do_common_average_reference=do_common_average_reference,
                                     set_name=set_name,
                                     )

    # Extract number of chans from dataset to create model
    n_chans = dataset[0][0].shape[0]
    log.info("Create cropped model...")
    model = create_cropped_model(model_name, n_chans, resnet_init_a)

    # Cut windows from the preprocessed data, using number of predictions
    # per compute window to cut non-overlapping fully covering windows
    # (except for overlap of last window to stay within trial bounds)
    log.info("Cut windows from dataset ...")
    input_window_samples = 1000
    # To know the models’ receptive field, we calculate the shape of model
    # output for a dummy input.
    n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]
    windows_dataset = cut_windows(
        dataset, input_window_samples, window_stride_samples=n_preds_per_input)

    # Split into train and valid, ignoring final evaluation for now
    log.info("Split into train and valid...")
    train_set, valid_set = split_into_train_valid(windows_dataset, use_final_eval=use_final_eval)

    # Run actual training
    log.info("Run training...")
    clf = run_training(model, model_name, train_set, valid_set, 'cuda', n_epochs, resnet_lr,
                       resnet_weight_decay, drop_channel_prob)

    if save_amp_grads:
        log.info("Compute and store amplitude gradients ...")
        amp_grads_filename = os.path.join(output_dir, f"{subject_id}_avg_amp_grads.npy")
        compute_and_store_amp_grads(model, train_set, filename=amp_grads_filename)

    if (not debug) and (save_model):
        log.info("Save model ...")
        # save model
        torch.save(model, os.path.join(output_dir, f"model.pth"))
    log.info("... Done.")

    return clf


In [38]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="""Launch an experiment from a YAML experiment file.
        Example: ./train_experiments.py configs/config.py """
    )
    parser.add_argument('subject_id', type=int,
                        help='''Run for subject id....''')
    # args = parser.parse_args()
    seed = 0
    for i in range(1,15):
        subject_id = i
    # subject_id = range(1, 15) #args.subject_id
    low_cut_hz = None  # low cut frequency for filtering
    high_cut_hz = None  # high cut frequency for filtering
    n_epochs = 800
    model_name = 'deep'
    output_dir = './results/'
    exponential_moving_fn = "standardize"
    only_C_sensors = False
    do_common_average_reference = False
    use_final_eval = True
    save_amp_grads = False
    save_model = False
    resnet_lr = 1e-3
    resnet_init_a = 1
    resnet_weight_decay = 1e-5
    debug = False
    set_name = "hgd"
    drop_channel_prob = 0.

    logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                        level=logging.DEBUG, stream=sys.stdout)
    run_exp(
        seed,
        subject_id,
        low_cut_hz,
        high_cut_hz,
        exponential_moving_fn,
        n_epochs,
        model_name,
        output_dir,
        only_C_sensors,
        do_common_average_reference,
        use_final_eval,
        save_amp_grads,
        save_model,
        resnet_lr,
        resnet_weight_decay,
        resnet_init_a,
        debug,
        set_name,
        drop_channel_prob)

2024-03-02 14:20:04,980 INFO : Load and preprocess data for subject 14...
2024-03-02 14:20:04,981 INFO : Load dataset...
Extracting EDF parameters from /home/exx/mne_data/MNE-schirrmeister2017-data/robintibor/high-gamma-dataset/raw/master/data/train/14.edf...
EDF file detected
Channel 'EEG Fp1' recognized as type EEG (renamed to 'Fp1').
Channel 'EEG Fp2' recognized as type EEG (renamed to 'Fp2').
Channel 'EEG Fpz' recognized as type EEG (renamed to 'Fpz').
Channel 'EEG F7' recognized as type EEG (renamed to 'F7').
Channel 'EEG F3' recognized as type EEG (renamed to 'F3').
Channel 'EEG Fz' recognized as type EEG (renamed to 'Fz').
Channel 'EEG F4' recognized as type EEG (renamed to 'F4').
Channel 'EEG F8' recognized as type EEG (renamed to 'F8').
Channel 'EEG FC5' recognized as type EEG (renamed to 'FC5').
Channel 'EEG FC1' recognized as type EEG (renamed to 'FC1').
Channel 'EEG FC2' recognized as type EEG (renamed to 'FC2').
Channel 'EEG FC6' recognized as type EEG (renamed to 'FC6').


  warn('Preprocessing choices with lambda functions cannot be saved.')


Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal allpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Filter length: 1 samples (0.004 s)

2024-03-02 14:20:37,589 INFO : Create cropped model...
2024-03-02 14:20:37,612 INFO : Cut windows from dataset ...
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
Used Annotations descriptions: ['feet', 'left_hand', 'rest', 'right_hand']
2024-03-02 14:20:37,726 INFO : Split 

  warn(


  epoch    train_accuracy    train_loss    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  ----------------  ------------  ------  ------
      1            [36m0.4943[0m        [32m1.5358[0m            [35m0.4500[0m        [31m1.6080[0m  0.0100  3.8655
      2            [36m0.6705[0m        [32m0.9226[0m            [35m0.4562[0m        [31m1.2959[0m  0.0100  0.8545
      3            0.6466        [32m0.8003[0m            0.4125        1.3973  0.0100  0.8391
      4            [36m0.7898[0m        [32m0.7132[0m            [35m0.5625[0m        1.4277  0.0100  0.8694
      5            [36m0.8034[0m        [32m0.5437[0m            [35m0.5687[0m        1.5075  0.0100  0.8534
      6            [36m0.9011[0m        [32m0.4425[0m            [35m0.5875[0m        1.5190  0.0100  0.8430
      7            [36m0.9830[0m        [32m0.2999[0m            [35m0.6875[0m        [31m1.1276[0m  0.0100  0.8199
      8   

In [12]:
eeg_path = '/media/mountHDD1/LanxHuyen/high_gamma_dataset.pth'

In [15]:
loaded_eeg = torch.load(eeg_path)

In [14]:
dataset, classes = [loaded_eeg[k] for k in ['dataset', 'labels']]

NameError: name 'loaded_eeg' is not defined