In [2]:
import scipy.io
import mne
import os
import numpy as np
from mne.time_frequency import tfr_multitaper
from mne.time_frequency import tfr_morlet

from matplotlib.colors import TwoSlopeNorm
mne.set_log_level('error')
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler

In [4]:
data_root = 'C:/Data/UHD_EEG/'
subjects = ['S1', 'S2', 'S3', 'S4', 'S5']
dominant_hand = ['left','right','right','right','right']
mapping = {0: "No instruction", 1: "Rest", 2: "thumb", 3: "index", 4: "middle", 5: "ring", 6: "little"}

not_ROI_channels = ['c255', 'c256', 'c254', 'c251', 'c239', 'c240', 'c238', 'c235', 'c224', 'c222', 'c223', 'c219', 'c220', 'c221', 'c215', 'c216', 'c217', 'c213', 'c212', 'c211', 'c210', 'c209', 'c112', 'c110', 'c107', 'c108', 'c103', 'c104', 'c105', 'c101', 'c100', 'c99', 'c98', 'c97', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', 'c10', 'c11', 'c12', 'c14', 'c15', 'c16', 'c23', 'c29', 'c26', 'c17', 'c18', 'c20', 'c19', 'c21', 'c24', 'c22', 'c25', 'c28', 'c33', 'c35', 'c38', 'c42', 'c81', 'c34', 'c37', 'c41', 'c45', 'c36', 'c40', 'c44', 'c39', 'c43', 'c145', 'c147', 'c150', 'c154', 'c157', 'c153', 'c149', 'c146', 'c93', 'c159', 'c156', 'c152', 'c148', 'c95', 'c160', 'c158', 'c155', 'c151', 'c96', 'c202', 'c198', 'c195', 'c193']

S01_bad_channels = ['c69', 'c122', 'c170', 'c173', 'c189']

In [5]:
def get_montage(hemishpere):
    mat = scipy.io.loadmat(os.path.join(data_root, 'montage', f'montage_256_{hemishpere}_hemisphere.mat'))
    return mat['pos_256']

left_handed_montage = get_montage('right')
right_handed_montage = get_montage('left')

In [6]:
def load_run(subject_id, run, describe=True):
    subject = subjects[subject_id]
    mat = scipy.io.loadmat(os.path.join(data_root, 'rawdata', subject, run))
    data = mat['y'][1:]  # remove timestamp
    ch_names = [f'c{i}' for i in range(1, 257)] + ['STIM']
    info = mne.create_info(ch_names=ch_names, sfreq=mat['SR\x00'][0][0])

    raw = mne.io.RawArray(data, info)
    ch_types = {ch: 'eeg' if ch != 'STIM' else 'stim' for ch in ch_names}
    raw.set_channel_types(ch_types)

    events = mne.find_events(raw, stim_channel='STIM')
    annot_from_events = mne.annotations_from_events(events, event_desc=mapping, sfreq=raw.info['sfreq'])
    raw.set_annotations(annot_from_events)
    raw.drop_channels(['STIM'])

    montage_positions = left_handed_montage if dominant_hand[subject_id] == 'left' else right_handed_montage
    montage = mne.channels.make_dig_montage(ch_pos=dict(zip(ch_names, montage_positions)), coord_frame='head')
    raw.set_montage(montage)

    if describe:
        raw.describe()
    return raw


In [7]:
def load_subject(subject_id, describe = True):
    runs = []
    run_files = os.listdir(os.path.join(data_root, 'rawdata', subjects[subject_id]))
    for file in run_files:
        runs.append(load_run(subject_id, file, describe))
    return runs

In [8]:
raw_runs = load_subject(subject_id = 0, describe = False)

In [9]:
preprocessed_runs = []
for run in raw_runs:
    run = run.drop_channels(S01_bad_channels)
    run = run.set_eeg_reference('average', projection=False)
    run = run.drop_channels(not_ROI_channels)

    run.resample(100)

    # baseline = run.get_data()[:, int(-25*run.info['sfreq']):]
    # baseline = np.mean(baseline, axis=1)
    # run._data = run.get_data() - baseline[:, None]
    
    preprocessed_runs.append(run)


In [10]:
epochs = []
for run in preprocessed_runs:
    events, event_ids = mne.events_from_annotations(run)
    asd = mne.Epochs(run, events, baseline = None, event_id= event_ids, tmin=-1.0, tmax=7, preload=True)
    epochs.append(asd)
epochs = mne.concatenate_epochs(epochs)
epochs

0,1
Number of events,260
Events,Rest: 10 index: 50 little: 50 middle: 50 ring: 50 thumb: 50
Time range,-1.000 – 7.000 sec
Baseline,off


In [21]:
freqs = np.arange(8, 25, 10)  # frequencies from 2-25Hz
print("freqs", freqs)
#tfr = tfr_morlet(epochs, freqs=freqs, n_cycles=freqs, use_fft=True)[0]
tfr = tfr_multitaper(epochs, freqs=freqs, n_cycles=freqs, use_fft=True,
                      return_itc=False, average=False, decim=2)
tfr.apply_baseline((-1.0, 0), mode='mean')
tfr.crop(0.5,1.5)

freqs [ 8 18]


<EpochsTFR | time : [0.500000, 1.500000], freq : [8.000000, 18.000000], epochs : 260, channels : 153, ~31.2 MB>

In [22]:
for event in event_ids:
    # select desired epochs for visualization
    print(event)
    tfr_ev = tfr[event]
    print(tfr_ev.data.shape)
    # Events, Channels, Frequencies, Time

Rest
(10, 153, 2, 51)
index
(50, 153, 2, 51)
little
(50, 153, 2, 51)
middle
(50, 153, 2, 51)
ring
(50, 153, 2, 51)
thumb
(50, 153, 2, 51)


In [23]:
middle = tfr['middle'].data
ring = tfr['ring'].data

X = np.concatenate((middle, ring))
X = X.reshape(X.shape[0], -1)

y = np.concatenate((np.zeros(middle.shape[0]), np.ones(ring.shape[0])))

print(X.shape)
print(y.shape)

(100, 15606)
(100,)


In [24]:
# Set the parameters by cross-validation
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-1, 1e-3],
                        'C': [0.001, 0.01,  1 ]}]
grid = GridSearchCV(SVC(), tuned_parameters, cv=StratifiedKFold(n_splits=10), scoring='accuracy')
grid.fit(X, y)
print(grid.best_params_)
print(grid.best_score_)


{'C': 0.001, 'gamma': 0.1, 'kernel': 'rbf'}
0.5
