In [1]:
import os
import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
import csv
import gc
sfreq=500 #specified in Matlab scripts

import pandas as pd
import numpy as np
import cv2

path = os.getcwd()

## Read and Preprocess Data

In [None]:
def tsv_to_fif(subj, run, path='C:/Users/HP pavilion X360/Desktop/BCI'):
    stim_code = pd.read_csv(f'{path}/BCI_HSE{subj[1:]}/stim_code/stim_code_{subj}R{run}.csv')
    data = np.zeros((stim_code.shape[0], 13))
    with open(f"{path}/BCI_HSE{subj[1:]}/tsv/BCI_HSE{subj}R{run}.ascii") as tsv:
        i = 0
        for line in csv.reader(tsv, dialect="excel-tab"):
            if i == 0:
                columns = line
                i += 1
            elif i <= stim_code.shape[0]:
                data[i-1] = line
                i += 1
    smart = pd.DataFrame(data, columns = columns)
    smart['STIM101'] = stim_code
    n_channels = 2
    sfreq = 500
    # ch_names = columns
    ch_names = ['Fp1', 'Fp2', 'STIM101']
    ch_types = ['eeg']*2 + ['stim'] #['eeg']*2 + ['misc']*5 + ['eog']*3 + ['misc']*3
    info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sfreq)
    raw = mne.io.RawArray(smart.values[:, [0, 1, 13]].T, info)

    path_out = f'{path}/BCI_HSE{subj[1:]}/preprocessed/'
    os.makedirs(path_out, exist_ok = True)
    
    raw_filt = raw.copy().filter(
        1, 40, l_trans_bandwidth='auto', picks = ['eeg'],
        h_trans_bandwidth='auto', filter_length='auto', phase='zero',
        fir_window='hamming', fir_design='firwin', n_jobs=4)
    raw_filt.save(f'{path_out}/{subj}R{run}_filt_raw.fif', overwrite = True)
    return raw, raw_filt

In [None]:
subject_list = []
for subj in subjecy list:
    runs = glob.glob(f'{path}/BCI_HSE{subj[1:]}/tsv/*')
    for run in runs:
        raw, raw_filt = tsv_to_fif(subj = subj, run = run)

## ICA:

In [None]:
def find_ica(subj, run):
    path_out = f'{path}/BCI_HSE{subj[1:]}/preprocessed/'
    from mne.preprocessing import ICA
    %matplotlib qt
    raw = mne.io.read_raw(f"BCI_HSE{subj[1:]}/preprocessed/{subj}R{run}_filt_raw.fif", preload = True)
    ica = ICA(n_components=2, max_iter="auto", random_state=97)
    ica.fit(raw)
    fig = ica.plot_sources(raw, show_scrollbars=False, show = True)
    fig.savefig(f'ICA/{subj}_{run}_ica_comp.png', dpi = 300, bbox_inches = 'tight')
    # raw.save(f'{path_out}/{subj}R{run}_filt_raw_ica.fif', overwrite = True)
    return raw, ica

In [None]:
subj = 
epo_list = []
for run in runs:
    raw, ica = find_ica(subj = subj, run = run)
    if run == '02':
        ica.exclude = [1]
    else:
        ica.exclude = [0]
    reconst_raw = raw.copy()
    ica.apply(reconst_raw)

    events = mne.find_events(raw, stim_channel='STIM101', consecutive = True)
    events[:, 1] = [0] * len(events)
    
    trigg = np.arange(2, np.unique(events[:, 2])[-1], 8)
    for i in range(len(events[:, 2])):
        if events[i, 2] in trigg:
            events[i, 2] = 1
        elif events[i, 2] == 0:
            events[i, 2] = -1
        else:
            events[i, 2] = 0
    
    reject_criteria = dict(
        eeg=10e-2,  # 100 µV
        # eog=200e-6,
    )  # 200 µV
    
    flat_criteria = dict(eeg=1e-13)
    
    os.makedirs(f'BCI_HSE{subj[1:]}/epochs/', exist_ok = True)
    
    event_id = {'stim' : 1, 'distr' : 0}
    epochs = mne.Epochs(
        reconst_raw, tmin=-1, tmax = 2.5, events=events, event_id=event_id, baseline = (-.5, 0),
        reject_tmin = 0, reject_tmax = 1,
        reject=reject_criteria,
        flat=flat_criteria, preload = True)
    epochs.save(f'BCI_HSE{subj[1:]}/epochs/{subj}R{run}_epo.fif')
    epo_list.append(epochs)

In [None]:
epo_con = mne.concatenate_epochs(epo_list).crop(tmin = -.2, tmax = .8).resample(500)

In [None]:
epo_con['stim'].average().plot()
epo_con['distr'].average().plot()

## To Plot:::

In [5]:
def plot_stat_comparison_timecourse_1samp_new(comp1, comp2, time, y_low = None, y_high = None, title='demo_title',
                         comp1_label='comp1', comp2_label='comp2'):
    assert(comp1.shape[1] == comp2.shape[1] == len(time))
    fig = plt.figure()
    ax1 = fig.add_subplot()
    
    plt.rcParams['axes.facecolor'] = 'none'
    plt.xlim(time[0], time[-1])
    if y_low is not None:
        plt.ylim(y_low, y_high)
    plt.plot([0, 0.000], [-500, 500], color='k', linewidth=1, linestyle='--', zorder=1)
    plt.plot([-10000, 10000], [0, 0.00], color='k', linewidth=1, linestyle='--', zorder=1)
    plt.plot(time, comp1.mean(axis = 0), color='turquoise', linewidth=1.5, label=comp1_label)
    plt.plot(time, comp2.mean(axis = 0), color='salmon', linewidth=1.5, label=comp2_label)
    
    # ax1.set_ylabel(r'$\mu$V')
    ax1.set_xlabel('Time (ms)')
    
    plt.xticks(ticks=np.arange(time[0], time[-1], 100))
    plt.tick_params(labelsize = 12)
    ax1.legend()

    ci_1 = np.std(comp1, axis = 0)/np.sqrt(comp1.shape[0])
    ci_2 = np.std(comp2, axis = 0)/np.sqrt(comp2.shape[0])
    ax1.fill_between(time, (comp1.mean(axis = 0)-ci_1), (comp1.mean(axis = 0)+ci_1), color='turquoise', alpha=.2)
    ax1.fill_between(time, (comp2.mean(axis = 0)-ci_2), (comp2.mean(axis = 0)+ci_2), color='salmon', alpha=.2)

    plt.title(title, fontsize = 12)
    plt.show()
    return fig

In [None]:
y_high = .001
y_low = -.0008
ch_1 = plot_stat_comparison_timecourse_1samp_new(comp1=epo_con['stim'].get_data()[:, 0], comp1_label='target',
                                          comp2=epo_con['distr'].get_data()[:, 0], comp2_label='distractor',
                                          time = epo_con.times * 1000,
                                          y_high = y_high, y_low = y_low, 
                                          title = 'Channel 1'
                                                 )
ch_1.savefig(f'output/{subj}_1_{len(epo_con)}epo_ica.png', dpi = 300, bbox_inches = 'tight')

ch_2 = plot_stat_comparison_timecourse_1samp_new(comp1=epo_con['stim'].get_data()[:, 1], comp1_label='target',
                                          comp2=epo_con['distr'].get_data()[:, 1], comp2_label='distractor',
                                          time = epo_con.times * 1000,
                                          y_high = y_low, y_low = y_low, 
                                          title = 'Channel 2'
                                                 )
ch_2.savefig(f'output/{subj}_2_{len(epo_con)}epo_ica.png', dpi = 300, bbox_inches = 'tight')

## To Concat Graphs:::

In [None]:
img1 = cv2.imread(f'output/{subj}_1_{len(epo_con)}epo_ica.png')
img2 = cv2.imread(f'output/{subj}_2_{len(epo_con)}epo_ica.png')

im = cv2.hconcat([img1, img2])
cv2.imwrite(f'output/{subj}_ica.png', im)