In [1]:
import scipy.io
import mne
from scipy.signal import butter, filtfilt
import os
import numpy as np
from scipy import stats

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from scipy.signal import stft

from matplotlib import pyplot as plt
mne.set_log_level('error')


In [2]:
data_root = 'C:/Data/UHD_EEG/'
subjects = ['S1', 'S2', 'S3', 'S4', 'S5']
dominant_hand = ['left','right','right','right','right']

mapping = {0: "No instruction", 1: "Rest", 2: "thumb", 3: "index", 4: "middle", 5: "ring", 6: "little"}

not_ROI_channels = ['c255', 'c256', 'c254', 'c251', 'c239', 'c240', 'c238', 'c235', 'c224', 'c222', 'c223', 'c219', 'c220', 'c221', 'c215', 'c216', 'c217', 'c213', 'c212', 'c211', 'c210', 'c209', 'c112', 'c110', 'c107', 'c108', 'c103', 'c104', 'c105', 'c101', 'c100', 'c99', 'c98', 'c97', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', 'c10', 'c11', 'c12', 'c14', 'c15', 'c16', 'c23', 'c29', 'c26', 'c17', 'c18', 'c20', 'c19', 'c21', 'c24', 'c22', 'c25', 'c28', 'c33', 'c35', 'c38', 'c42', 'c81', 'c34', 'c37', 'c41', 'c45', 'c36', 'c40', 'c44', 'c39', 'c43', 'c145', 'c147', 'c150', 'c154', 'c157', 'c153', 'c149', 'c146', 'c93', 'c159', 'c156', 'c152', 'c148', 'c95', 'c160', 'c158', 'c155', 'c151', 'c96', 'c202', 'c198', 'c195', 'c193']

# Subject 1 bad channels:
bad_channels =['c65', 'c66', 'c67', 'c68', 'c69', 'c70', 'c71', 'c72', 'c73', 'c74', 'c75', 'c76', 'c77', 'c78', 'c79', 'c80']

In [3]:
def get_montage(hemishpere):
    mat = scipy.io.loadmat(os.path.join(data_root, 'montage', f'montage_256_{hemishpere}_hemisphere.mat'))
    return mat['pos_256']

left_handed_montage = get_montage('right')
right_handed_montage = get_montage('left')

In [4]:
def load_run(subject_id, run, describe=True):
    subject = subjects[subject_id]
    mat = scipy.io.loadmat(os.path.join(data_root, 'rawdata', subject, run))
    data = mat['y'][1:]  # remove timestamp
    ch_names = [f'c{i}' for i in range(1, 257)] + ['STIM']
    info = mne.create_info(ch_names=ch_names, sfreq=mat['SR\x00'][0][0])

    raw = mne.io.RawArray(data, info)
    ch_types = {ch: 'eeg' if ch != 'STIM' else 'stim' for ch in ch_names}
    raw.set_channel_types(ch_types)

    events = mne.find_events(raw, stim_channel='STIM')
    annot_from_events = mne.annotations_from_events(events, event_desc=mapping, sfreq=raw.info['sfreq'])
    raw.set_annotations(annot_from_events)
    raw.drop_channels(['STIM'])

    montage_positions = left_handed_montage if dominant_hand[subject_id] == 'left' else right_handed_montage
    montage = mne.channels.make_dig_montage(ch_pos=dict(zip(ch_names, montage_positions)), coord_frame='head')
    raw.set_montage(montage)

    if describe:
        raw.describe()
    return raw


In [5]:
def load_subject(subject_id, describe = True):
    runs = []
    run_files = os.listdir(os.path.join(data_root, 'rawdata', subjects[subject_id]))
    for file in run_files:
        runs.append(load_run(subject_id, file, describe))
    return runs

In [6]:
raw_runs = load_subject(subject_id = 0, describe = False)

In [7]:
def preprocess(orig):
    run = orig.copy()

    run.drop_channels(bad_channels)
    run.drop_channels(not_ROI_channels)

    run.filter(8, 25, fir_design='firwin')

    run.set_eeg_reference('average', projection=False)

    baseline = run._data[:, :-int(25*run.info['sfreq'])]
    run._data = run._data - baseline.mean(axis=1, keepdims=True)

    run = run.resample(200)

    return run

In [8]:
preprocessed_runs = []
for run in raw_runs:
    preprocessed_runs.append(preprocess(run))


In [9]:
preprocessed_runs[0]._data.shape

(142, 45186)

In [10]:
# epochs = []
# for run in preprocessed_runs:
#     events, event_ids = mne.events_from_annotations(run)
#     asd = mne.Epochs(run, events, baseline = (-1, -0.5), event_id= event_ids, tmin=-1, tmax=7, preload=True)
#     epochs.append(asd)
# epochs = mne.concatenate_epochs(epochs)
# epochs = epochs.crop(tmin=-1.48, tmax=1.52)

In [34]:
# No baseline correction
epochs = []
for run in preprocessed_runs:
    events, event_ids = mne.events_from_annotations(run)
    asd = mne.Epochs(run, events, baseline = None, event_id= event_ids, tmin=-0.5, tmax=2, preload=True)
    epochs.append(asd)
epochs = mne.concatenate_epochs(epochs)

In [35]:
epochs

0,1
Number of events,260
Events,Rest: 10 index: 50 little: 50 middle: 50 ring: 50 thumb: 50
Time range,-0.500 – 2.000 sec
Baseline,off


In [36]:
epochs[0].get_data().shape

(1, 142, 501)

In [68]:
X = epochs.get_data()
y = epochs.events[:, -1]


In [69]:
def to_stft(time_domain):
    bsd = []
    total_powers = []
    time_domain = time_domain**2
    for ch in time_domain:
        # Perform STFT
        time = np.linspace(0, 10, len(ch))
        
        # Perform STFT
        f, t, Zxx = stft(ch, fs=1/(time[1]-time[0]), nperseg=100)
        #power = np.abs(Zxx) ** 2  # Compute power spectrum
        power = np.abs(Zxx)

        # Define frequency bands
        freq_band_1 = (8, 12)  # Frequency band 1 (8-12 Hz)
        freq_band_2 = (13, 25)  # Frequency band 2 (13-25 Hz)

        # Find indices of frequency bins corresponding to the frequency bands
        band_1_indices = np.where((f >= freq_band_1[0]) & (f <= freq_band_1[1]))[0]
        band_2_indices = np.where((f >= freq_band_2[0]) & (f <= freq_band_2[1]))[0]

        total_power_band_1 = np.sum(power[band_1_indices, :], axis=0)
        total_power_band_2 = np.sum(power[band_2_indices, :], axis=0)

        
        # print(total_power_band_1.shape)
       
        # print(total_power_band_2.shape)
     

        total_power = np.concatenate((total_power_band_1, total_power_band_2))
        #total_power = np.clip(total_power, 1e-10, np.inf)
        #total_power = np.log10(total_power)
        total_powers.append(total_power)
    return np.array(total_powers)

In [70]:
powers = []
for i in range(len(X)):
    powers.append(to_stft(X[i]))
powers = np.array(powers)

In [71]:
X = powers

In [72]:
powers.shape

(260, 142, 24)

In [73]:

X = X.reshape(X.shape[0], -1) # reshape to SVM

print(X.shape)
print(y.shape)


(260, 3408)
(260,)


In [74]:
scaler = StandardScaler()
scaler.fit(X)
X = scaler.transform(X)

In [75]:
# thumb, index, middle, ring, and little finger respectively.

thumb_indices = np.where(y == 2)
index_indices = np.where(y == 3)
middle_indices = np.where(y == 4)
ring_indices = np.where(y == 5)
little_indices = np.where(y == 6)

thumb = X[thumb_indices]
index = X[index_indices]
middle = X[middle_indices]
ring = X[ring_indices]
little = X[little_indices]

In [76]:
middle_vs_ring = np.concatenate((middle, ring), axis=0)
y_middle_vs_ring = np.concatenate((np.zeros(middle.shape[0]), np.ones(ring.shape[0])), axis=0)
X = middle_vs_ring
y = y_middle_vs_ring


In [79]:
# Grid search for best parameters


# Set the parameters by cross-validation
tuned_parameters = [{'kernel': ['rbf','linear'], 'gamma': [1e-1, 1e-2, 1e-3, 1e-4],
                        'C': [0.001, 0.01, 0.1, 1, 10]}]
grid = GridSearchCV(SVC(), tuned_parameters, cv=StratifiedKFold(n_splits=5), scoring='accuracy')
grid.fit(X, y)
print(grid.best_params_)
print(grid.best_score_)



{'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
0.62


In [80]:
# K fold cross validation
kf = KFold(n_splits=10, shuffle=True, random_state=42)
accuracy_scores = []
for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    svc = SVC()
    svc.fit(X_train, y_train)
    print(svc.score(X_test, y_test))
    accuracy_scores.append(svc.score(X_test, y_test))
avg_accuracy = np.mean(accuracy_scores)
print("Average accuracy: ", avg_accuracy)

0.6
0.8
0.6
0.6
0.5
0.5
0.3
0.4
0.7
0.4
Average accuracy:  0.54
