In [1]:
%matplotlib qt

import mne
import os
import numpy as np
import pathlib
import sys
SCRIPT_DIR = pathlib.Path.cwd()
sys.path.append(os.path.dirname(SCRIPT_DIR))
from continuous_control_bci.util import channel_names
import matplotlib.pyplot as plt

mne.set_log_level('warning')


In [13]:
from matplotlib.colors import TwoSlopeNorm

import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.stats import permutation_cluster_1samp_test as pcluster_test
from mne.time_frequency import tfr_multitaper


def plot_tfr(epochs, baseline=(-2, -1), tmin=-2, tmax=3.75, event_ids=dict(left=-1, rest=0, right=1)):
    freqs = np.arange(10, 30)  # frequencies from 2-35Hz
    vmin, vmax = -1, 1  # set min and max ERDS values in plot
    cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)  # min, center & max ERDS
    
    kwargs = dict(
        n_permutations=100, step_down_p=0.05, seed=1, buffer_size=None, out_type="mask"
    )  # for cluster test

    
    tfr = tfr_multitaper(
        epochs,
        freqs=freqs,
        n_cycles=freqs,
        use_fft=True,
        return_itc=False,
        average=False,
        decim=2,
    )
    tfr.crop(tmin, tmax)
    if baseline is not None:
        tfr.apply_baseline(baseline, mode="percent")
    
    for event in event_ids:
        # select desired epochs for visualization
        tfr_ev = tfr[event]
        fig, axes = plt.subplots(
            1, 3, figsize=(12, 4), gridspec_kw={"width_ratios": [10, 10, 1]}
        )
        for ch, ax in enumerate(axes[:-1]):  # for each channel
            if baseline is not None:
                # positive clusters
                _, c1, p1, _ = pcluster_test(tfr_ev.data[:, ch], tail=1, **kwargs)
                # negative clusters
                _, c2, p2, _ = pcluster_test(tfr_ev.data[:, ch], tail=-1, **kwargs)
        
                # note that we keep clusters with p <= 0.05 from the combined clusters
                # of two independent tests; in this example, we do not correct for
                # these two comparisons
                c = np.stack(c1 + c2, axis=2)  # combined clusters
                p = np.concatenate((p1, p2))  # combined p-values
                mask = c[..., p <= 0.01].any(axis=-1)
        
                # plot TFR (ERDS map with masking)
                tfr_ev.average().plot(
                    [ch],
                    cmap="RdBu_r",
                    cnorm=cnorm,
                    axes=ax,
                    colorbar=False,
                    show=False,
                    mask=mask,
                    mask_style="mask",
                )
            else:
                tfr_ev.average().plot(
                    [ch],
                    cmap="RdBu_r",
                    axes=ax,
                    colorbar=False,
                )
    
            ax.set_title(epochs.ch_names[ch], fontsize=10)
            ax.axvline(0, linewidth=1, color="black", linestyle=":")  # event
            if ch != 0:
                ax.set_ylabel("")
                ax.set_yticklabels("")
        fig.colorbar(axes[0].images[-1], cax=axes[-1]).ax.set_yscale("linear")
        fig.suptitle(f"ERDS ({event})")
        plt.show()


In [3]:
raw_calibration = mne.io.read_raw_gdf("../data/pilot_1/calibration/horse_reighns_pilot_driving.gdf",
                             preload=True)
original_channel_names = [f"Channel {i + 1}" for i in range(32)] + [f"EX {i + 1}" for i in range(8)]
renaming = {original: new for original, new in zip(original_channel_names, channel_names)}

In [4]:
raw_calibration = raw_calibration.set_eeg_reference()
raw_calibration = raw_calibration.set_channel_types(dict.fromkeys(["EX 1", "EX 2", "EX 3", "EX 4"], "emg"))
raw_calibration = raw_calibration.set_channel_types(dict.fromkeys(["EX 5", "EX 6", "EX 7", "EX 8"], "eog"))
raw_calibration = raw_calibration.rename_channels(renaming)
raw_calibration = raw_calibration.set_montage("standard_1020", on_missing='raise')
raw_calibration = raw_calibration.set_eeg_reference()


In [5]:
raw_calibration_erds = raw_calibration.filter(l_freq=1, h_freq=30)

In [6]:
from mne.preprocessing import ICA
ica = ICA(random_state=42)
ica.fit(raw_calibration_erds)

0,1
Method,fastica
Fit parameters,algorithm=parallel fun=logcosh fun_args=None max_iter=1000
Fit,103 iterations on raw data (882656 samples)
ICA components,31
Available PCA components,32
Channel types,eeg
ICA components marked for exclusion,—


In [7]:
ica.find_bads_eog(raw_calibration_erds)

([0, 1, 5],
 [array([-2.77768677e-01,  8.32506353e-01,  7.09073498e-02, -1.22120009e-01,
          1.39722493e-01, -6.55119519e-02, -3.13742982e-02, -1.96705291e-02,
          7.05875921e-02,  7.95592112e-02,  7.70281786e-02,  6.43609716e-02,
          3.42519875e-03,  1.51242693e-02, -1.00778669e-01, -6.39984870e-02,
         -5.67863057e-02, -7.03212598e-02, -4.65570199e-04, -5.10025869e-02,
          1.51025091e-02, -1.90135370e-02,  1.74033394e-03,  1.08726265e-03,
         -3.10547238e-02,  2.92334561e-02, -4.96390076e-02,  5.19706855e-02,
         -5.26087978e-02,  6.02748299e-04, -5.07442773e-02]),
  array([ 0.12363429, -0.79791148,  0.1217339 , -0.07881652, -0.09470928,
         -0.28075331,  0.17330713, -0.08798991, -0.03821129, -0.02986716,
         -0.01210837,  0.01978525, -0.02933685, -0.01309335,  0.0114333 ,
         -0.13132274, -0.01218212, -0.1762796 ,  0.02790253,  0.04005649,
          0.00805683,  0.02066189,  0.02200921,  0.01540115, -0.04172575,
          0.03097

In [8]:
# ica.plot_components()
ica.exclude = [0, 1, 5, 2, 6, 4, 10, 11, 14, 19, 20, 22, 23]

In [9]:
ica.exclude

[0, 1, 5, 2, 6, 4, 10, 11, 14, 19, 20, 22, 23]

In [10]:
ica.apply(raw_calibration_erds)
# raw_calibration_erds = mne.preprocessing.compute_current_source_density(raw_calibration_erds)

0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,0x00000000

0,1
Digitized points,35 points
Good channels,"32 EEG, 4 EMG, 4 EOG"
Bad channels,
EOG channels,"LHEOG, RHEOG, UVEOG, LVEOG"
ECG channels,Not available

0,1
Sampling frequency,2048.00 Hz
Highpass,1.00 Hz
Lowpass,30.00 Hz
Filenames,horse_reighns_pilot_driving.gdf
Duration,00:07:11 (HH:MM:SS)


In [11]:
LEFT_HAND_EVENT = "769"
RIGHT_HAND_EVENT = "770"
END_OF_TRIAL_EVENT = "800"  # Used for rests

calibration_events, _ = mne.events_from_annotations(raw_calibration_erds, event_id={LEFT_HAND_EVENT: 0,
                                                           RIGHT_HAND_EVENT: 1})
# Now we'd expect the movement 1 + 1.25 = 2.25 seconds in. The exact timing may be learned
# It should stay until 6 seconds in. The first and last second should both be "empty"
# These would be the optimal "during movement" timings. For visualisation, you might also consider
# preparation or rebound effects.
tmin = -3
tmax = 1.25 + 3.75

event_ids = dict(left=0, right=1)

calibration_epochs = mne.Epochs(
    raw_calibration_erds,
    calibration_events,
    event_ids,
    tmin - 0.5,
    tmax + 0.5,
    baseline=None,
    picks=['C3', 'C4'],
    preload=True,
)

In [14]:
plot_tfr(calibration_epochs, baseline=(-3, 0), tmin=tmin, tmax=tmax, event_ids=event_ids)
plot_tfr(calibration_epochs, baseline=None, tmin=tmin, tmax=tmax, event_ids=event_ids)

  from .autonotebook import tqdm as notebook_tqdm
