In [1]:
import scipy.io
import mne
import os
import numpy as np
from scipy import stats

from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import StratifiedKFold


from matplotlib import pyplot as plt
mne.set_log_level('error')


In [2]:
data_root = 'C:/Data/UHD_EEG/'
subjects = ['S1', 'S2', 'S3', 'S4', 'S5']
dominant_hand = ['left','right','right','right','right']

mapping = {0: "No instruction", 1: "Rest", 2: "thumb", 3: "index", 4: "middle", 5: "ring", 6: "little"}

not_ROI_channels = ['c255', 'c256', 'c254', 'c251', 'c239', 'c240', 'c238', 'c235', 'c224', 'c222', 'c223', 'c219', 'c220', 'c221', 'c215', 'c216', 'c217', 'c213', 'c212', 'c211', 'c210', 'c209', 'c112', 'c110', 'c107', 'c108', 'c103', 'c104', 'c105', 'c101', 'c100', 'c99', 'c98', 'c97', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', 'c10', 'c11', 'c12', 'c14', 'c15', 'c16', 'c23', 'c29', 'c26', 'c17', 'c18', 'c20', 'c19', 'c21', 'c24', 'c22', 'c25', 'c28', 'c33', 'c35', 'c38', 'c42', 'c81', 'c34', 'c37', 'c41', 'c45', 'c36', 'c40', 'c44', 'c39', 'c43', 'c145', 'c147', 'c150', 'c154', 'c157', 'c153', 'c149', 'c146', 'c93', 'c159', 'c156', 'c152', 'c148', 'c95', 'c160', 'c158', 'c155', 'c151', 'c96', 'c202', 'c198', 'c195', 'c193']

S01_bad_channels = ['c69', 'c122', 'c170', 'c173', 'c189']

In [3]:
def get_montage(hemishpere):
    mat = scipy.io.loadmat(os.path.join(data_root, 'montage', f'montage_256_{hemishpere}_hemisphere.mat'))
    return mat['pos_256']

left_handed_montage = get_montage('right')
right_handed_montage = get_montage('left')

In [4]:
def load_run(subject_id, run, describe=True):
    subject = subjects[subject_id]
    mat = scipy.io.loadmat(os.path.join(data_root, 'rawdata', subject, run))
    data = mat['y'][1:]  # remove timestamp
    ch_names = [f'c{i}' for i in range(1, 257)] + ['STIM']
    info = mne.create_info(ch_names=ch_names, sfreq=mat['SR\x00'][0][0])

    raw = mne.io.RawArray(data, info)
    ch_types = {ch: 'eeg' if ch != 'STIM' else 'stim' for ch in ch_names}
    raw.set_channel_types(ch_types)

    events = mne.find_events(raw, stim_channel='STIM')
    annot_from_events = mne.annotations_from_events(events, event_desc=mapping, sfreq=raw.info['sfreq'])
    raw.set_annotations(annot_from_events)
    raw.drop_channels(['STIM'])

    montage_positions = left_handed_montage if dominant_hand[subject_id] == 'left' else right_handed_montage
    montage = mne.channels.make_dig_montage(ch_pos=dict(zip(ch_names, montage_positions)), coord_frame='head')
    raw.set_montage(montage)

    if describe:
        raw.describe()
    return raw


In [5]:
def load_subject(subject_id, describe = False):
    runs = []
    run_files = os.listdir(os.path.join(data_root, 'rawdata', subjects[subject_id]))
    for file in run_files:
        runs.append(load_run(subject_id, file, describe))
    return runs

In [6]:
raw_runs = load_subject(subject_id = 0, describe = False)


In [7]:
preprocessed_runs = raw_runs.copy()
for run in preprocessed_runs:
    run = run.resample(200)


    run = run.drop_channels(S01_bad_channels)
    run = run.set_eeg_reference('average', projection=False)
    run = run.drop_channels(not_ROI_channels)



In [8]:
'''
The features extracted are band power features for the
mu and beta bands
'''
mus = []
betas = []
for i in range(len(preprocessed_runs)):
    mus.append(preprocessed_runs[i].copy().filter(l_freq=8, h_freq=12))
    betas.append(preprocessed_runs[i].copy().filter(l_freq=13, h_freq=25))

In [9]:
'''
The EEG data were band-pass filtered for the respective frequency
band, and the power was calculated by squaring each time
sample. 
Then the power was estimated in non-overlapping
0.25 s segments by averaging the power samples and applying
a centered moving average with a 0.75 s window length.
band power features were log-transformed
power shift compensation was applied, subtracting the mean band power over the last 25 s.
the band power features were epoched using 0.5 s pre-and 7 s post-cue.
'''


        
def my_feature(data, freq = 200):
    # Calculate power by squaring each time sample
    power_eeg_data = np.square(data)

    # Define segment parameters
    segment_length = int(0.25 * freq)  # 0.25 s segment length 
   
   
    power_eeg_data = power_eeg_data.reshape(power_eeg_data.shape[0], -1, segment_length)
    power_eeg_data = np.mean(power_eeg_data, axis=-1)



    def centered_moving_average(data, window_size):
        half_window = window_size // 2
        cumsum = np.cumsum(data, axis=-1)
        cumsum[..., window_size:] = cumsum[..., window_size:] - cumsum[..., :-window_size]
        return (cumsum[..., window_size - 1:-window_size + 1] / window_size)
    power_eeg_data = centered_moving_average(power_eeg_data, 3)


    power_eeg_data = np.log(power_eeg_data)
    return power_eeg_data


In [11]:
sfreq = raw_runs[0].info['sfreq']


tmin = -0.5
tmax = 7



def get_baseline(data, length = 25): # 25s
    end = data[..., -int(length*sfreq):]
    end = my_feature(end)
    return np.mean(end, axis=-1)



middles = [] # 4
rings = []   # 5 
for i in range(len(preprocessed_runs)):
    events, _ = mne.events_from_annotations(preprocessed_runs[i])

    mu_baseline = get_baseline(mus[i].get_data())
    beta_baseline = get_baseline(betas[i].get_data())

    for trigger in events:
        if trigger[-1] in [4,5]:
            # Epoching
            mu_data = mus[i].get_data()[...,trigger[0]-int(tmin*sfreq):trigger[0]+int(tmax*sfreq)] 
            beta_data = betas[i].get_data()[...,trigger[0]-int(tmin*sfreq):trigger[0]+int(tmax*sfreq)] 
          

            # feature extraction
            mu_data = my_feature(mu_data)
            beta_data = my_feature(beta_data)

            # power shift compensation
            mu_data-=mu_baseline[:,np.newaxis]
            beta_data-=beta_baseline[:,np.newaxis]


            # Format data
            mu_data = mu_data.flatten()
            beta_data = beta_data.flatten()
            data = np.concatenate((mu_data, beta_data), axis=0)
            

            if trigger[-1] == 4:
                middles.append(data)
            else:
                rings.append(data)
    

In [12]:
middles = np.array(middles)
rings = np.array(rings)

print(middles.shape)
print(rings.shape)

middle_vs_ring = np.concatenate((middles, rings), axis=0)
y_middle_vs_ring = np.concatenate((np.zeros(middles.shape[0]), np.ones(rings.shape[0])), axis=0)


(50, 6732)
(50, 6732)


In [16]:
scaler = StandardScaler()
scaler.fit(middle_vs_ring)
middle_vs_ring = scaler.transform(middle_vs_ring)

In [17]:
X = middle_vs_ring
y = y_middle_vs_ring
# Set the parameters by cross-validation
tuned_parameters = [{'kernel': ['rbf'], 'gamma': [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6],
                        'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}]
grid = GridSearchCV(SVC(), tuned_parameters, cv=StratifiedKFold(n_splits=10), scoring='accuracy')
grid.fit(X, y)
print(grid.best_params_)

print(grid.best_score_)


{'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
0.8200000000000001
