In [1]:
!pip install mne

In [2]:
!unzip /kaggle/input/grasp-and-lift-eeg-detection/train.zip
!unzip /kaggle/input/grasp-and-lift-eeg-detection/test.zip
!unzip /kaggle/input/grasp-and-lift-eeg-detection/sample_submission.csv.zip

In [3]:
import os
import mne
print(mne.__version__)

from mne.io import RawArray
from mne import pick_types
from mne.epochs import concatenate_epochs
from mne.channels import make_standard_montage
from mne import create_info, find_events, Epochs
from mne.viz.topomap import plot_topomap

import numpy as np
import pandas as pd

from scipy.signal import welch
from scipy.signal.windows import chebwin

In [4]:
local = False
if not local:
    PATH = "/kaggle/input/grasp-and-lift-eeg-detection/"
    train_PATH = "/kaggle/working/train/"
    test_PATH = "/kaggle/working/test/"
else:
    PATH = os.getcwd() + '/data/'
    train_PATH = PATH + 'train/'
    test_PATH = PATH + 'train/'

In [5]:
train_files = os.listdir(train_PATH)
test_files = os.listdir(test_PATH)

In [6]:
l_freq, h_freq = 1, 70

In [7]:
# event_dict = {"Fp1":0, "Fp2":1, "F7":2, "F3":3, 
#               "Fz":4, "F4":5, "F8":6, "FC5":7, "FC1":8,
#               "FC2":9, "FC6":10, "T7":11, "C3":12, 
#               "Cz":13, "C4":14, "T8":15, "TP9":16, 
#               "CP5":17, "CP1":18, "CP2":19, "CP6":20,
#               "TP10":21, "P7":22, "P3":23, "Pz":24, 
#               "P4":25, "P8":26, "PO9":27, "O1":28,
#               "Oz":29, "O2":30, "PO10":31}

# event_dict = {"HandStart":0, 
#               "FirstDigitTouch":1,
#               "BothStartLoadPhase":2,
#               "LiftOff":3,
#               "Replace":4, 
#               "BothReleased":5}

In [31]:
def load_data(subj_num, series_num):
    tmp_data = pd.read_csv(train_PATH + f'subj{subj_num}_series{series_num}_data.csv')
    tmp_event = pd.read_csv(train_PATH + f'subj{subj_num}_series{series_num}_events.csv')

    ch_names = list(tmp_data.columns[1:]) + list(tmp_event.columns[1:])
    ch_types = ['eeg']*32 + ['stim']*6
    tmp_data = pd.concat([tmp_data, tmp_event], axis=1)
    tmp_data.drop(columns=['id'], inplace=True)

    tmp_data = np.array(tmp_data)
    return tmp_data, ch_names, ch_types


def get_Raw(arr, ch_names, ch_types):
    arr = arr.T
    info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=500)
    return RawArray(arr, info)


def get_Epochs(raw, epochs_tot, y):
#     types = pick_types(raw.info, eeg=True)
#     print('types: ', types)
    raw.filter(l_freq, h_freq, method='iir')
    tmp_events = find_events(raw, stim_channel='Replace')
    epochs = Epochs(raw, tmp_events,
                    {'during':1}, 
                    -2, -0.5,
#                     picks=types, 
                    baseline=None,   # baseline 시간이 없나 봄
                    preload=True)

    epochs_tot.append(epochs)
    y.extend([1]*len(epochs))
    
    epochs_rest = Epochs(raw, tmp_events, 
                         {'after':1},
                        0.5, 2,
#                         picks=types,
                        baseline=None,
                        preload=True)
    
    epochs_rest._set_times(epochs.times)  # times가 맞아야 time mismatch로 인한 에러 안남

    epochs_tot.append(epochs_rest)
    y.extend([-1]*len(epochs_rest))
    return epochs_tot, y


def create_Plot(raw):
    montage = make_standard_montage('standard_1020')
    raw.set_montage(montage)
    picks = pick_types(raw.info, eeg=True)
#     data = raw.get_data()
#     plot_topomap(data[:1, :31].ravel(), pos=picks)
    raw.plot_psd()

In [32]:
subj_len = range(1,13)     # 1~12
series_len = range(1,9)    # 1~8
window = chebwin(M=500, at=90)

for i in subj_len:
    y = []
    epochs_tot = []
    for j in series_len:
        data, ch_names, ch_types = load_data(i, j)
        raw = get_Raw(data, ch_names, ch_types)
        epochs_tot, y = get_Epochs(raw, epochs_tot, y)
        create_Plot(raw)
        
    epochs = concatenate_epochs(epochs_tot, on_mismatch='warn')
    y = np.array(y)
    break