In [1]:
%matplotlib qt

import pyxdf
import mne
import os
import numpy as np
import pathlib
import sys
SCRIPT_DIR = pathlib.Path.cwd()
sys.path.append(os.path.dirname(SCRIPT_DIR))
from continuous_control_bci.util import channel_names
from continuous_control_bci.data.preprocessing import make_epochs, manual_clean_ica

import matplotlib.pyplot as plt

mne.set_log_level('warning')


In [23]:
from matplotlib.colors import TwoSlopeNorm

import mne
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.stats import permutation_cluster_1samp_test as pcluster_test
from mne.time_frequency import tfr_multitaper


def plot_tfr(epochs, baseline=(-2, -1), tmin=-2, tmax=3.75, event_ids=dict(left=-1, rest=0, right=1)):
    freqs = np.arange(10, 35)  # frequencies from 2-35Hz
    vmin, vmax = -1, 1  # set min and max ERDS values in plot
    cnorm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)  # min, center & max ERDS
    
    kwargs = dict(
        n_permutations=100, step_down_p=0.05, seed=1, buffer_size=None, out_type="mask"
    )  # for cluster test

    
    tfr = tfr_multitaper(
        epochs,
        freqs=freqs,
        n_cycles=freqs,
        use_fft=True,
        return_itc=False,
        average=False,
        decim=2,
    )
    tfr.crop(tmin, tmax)
    if baseline is not None:
        tfr.apply_baseline(baseline, mode="percent")
    
    for event in event_ids:
        # select desired epochs for visualization
        tfr_ev = tfr[event]
        fig, axes = plt.subplots(
            1, 3, figsize=(12, 4), gridspec_kw={"width_ratios": [10, 10, 1]}
        )
        for ch, ax in enumerate(axes[:-1]):  # for each channel
            if baseline is not None:
                # positive clusters
                _, c1, p1, _ = pcluster_test(tfr_ev.data[:, ch], tail=1, **kwargs)
                # negative clusters
                _, c2, p2, _ = pcluster_test(tfr_ev.data[:, ch], tail=-1, **kwargs)
        
                # note that we keep clusters with p <= 0.05 from the combined clusters
                # of two independent tests; in this example, we do not correct for
                # these two comparisons
                c = np.stack(c1 + c2, axis=2)  # combined clusters
                p = np.concatenate((p1, p2))  # combined p-values
                mask = c[..., p <= 0.01].any(axis=-1)
        
                # plot TFR (ERDS map with masking)
                tfr_ev.average().plot(
                    [ch],
                    cmap="RdBu_r",
                    cnorm=cnorm,
                    axes=ax,
                    colorbar=False,
                    show=False,
                    mask=mask,
                    mask_style="mask",
                )
            else:
                tfr_ev.average().plot(
                    [ch],
                    cmap="RdBu_r",
                    axes=ax,
                    colorbar=False,
                )
    
            ax.set_title(epochs.ch_names[ch], fontsize=10)
            ax.axvline(0, linewidth=1, color="black", linestyle=":")  # event
            if ch != 0:
                ax.set_ylabel("")
                ax.set_yticklabels("")
        fig.colorbar(axes[0].images[-1], cax=axes[-1]).ax.set_yscale("linear")
        fig.suptitle(f"ERDS ({event})")
        plt.show()


In [3]:
import pyxdf

fname = "../data/pilot_1/runs/full_run.xdf"
streams, header = pyxdf.load_xdf(fname)

for stream in streams: 
    print(stream['info']['name'])

prediction_stream = streams[0]["time_series"].T
eeg_streams = streams[1]['time_series'].T

raw= mne.io.RawArray(eeg_streams[1:41, :] / 10e5, info=mne.create_info(channel_names, sfreq=2048))

['PredictionStream']
['BioSemi']


In [5]:
eeg_mapping = {name: type for name, type in zip(channel_names[:-8], ["eeg"] * len(channel_names[:-8]))}
emg_mapping = {name: type for name, type in zip(channel_names[-8:-4], ["emg"] * 4)}
eog_mapping = {name: type for name, type in zip(channel_names[-4:], ["eog"] * 4)}



channel_type_mapping = {
    **eeg_mapping,
    **emg_mapping,
    **eog_mapping,
}



raw.set_channel_types(channel_type_mapping)
raw.set_montage("biosemi32", on_missing='raise')
raw.set_eeg_reference()
raw.filter(l_freq=1, h_freq=35)

  raw.set_channel_types(channel_type_mapping)


0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,35 points
Good channels,"32 EEG, 4 EMG, 4 EOG"
Bad channels,
EOG channels,"LHEOG, RHEOG, UVEOG, LVEOG"
ECG channels,Not available

0,1
Sampling frequency,2048.00 Hz
Highpass,1.00 Hz
Lowpass,35.00 Hz
Duration,00:12:27 (HH:MM:SS)


In [10]:
from mne.preprocessing import ICA

exclude = [0, 1, 2, 5, 12, 13, 14, 16, 17, 18, 19]
ica = ICA(n_components=20, random_state=42)
ica.fit(raw)
if exclude:
    ica.exclude = exclude
    print(f"Rejecting ICs: {ica.exclude}")
else:

    bad_eog, _ = ica.find_bads_eog(raw)
    print(f"Bad EOG predicted: {bad_eog}")
    ica.plot_components()
    plt.show()
    print(f"Rejecting ICs: {ica.exclude}")


Bad EOG predicted: [0, 2]
Rejecting ICs: []


In [11]:
print(f"Rejecting ICs: {ica.exclude}")

Rejecting ICs: [0, 1, 2, 5, 12, 13, 14, 16, 17, 18, 19]


In [12]:
ica.apply(raw)

0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,35 points
Good channels,"32 EEG, 4 EMG, 4 EOG"
Bad channels,
EOG channels,"LHEOG, RHEOG, UVEOG, LVEOG"
ECG channels,Not available

0,1
Sampling frequency,2048.00 Hz
Highpass,1.00 Hz
Lowpass,35.00 Hz
Duration,00:12:27 (HH:MM:SS)


In [18]:
import pickle
import scipy
import itertools

def make_precise_emg_events(raw, emg_model_path="../data/pilot_1/emg_model.pkl", interval=0.05, epoch_time=0.2):
    with open(emg_model_path, 'rb') as f:
        emg_model = pickle.load(f)

    raw_emg = raw.copy().pick(['emg'])

    filters = [
        mne.filter.create_filter(raw_emg.get_data(), l_freq=30, h_freq=500, method='iir',
                                 phase='forward', sfreq=raw.info['sfreq']),
        mne.filter.create_filter(raw_emg.get_data(), l_freq=51, h_freq=49, method='iir',
                                 phase='forward', sfreq=raw.info['sfreq']),
    ]
    # We do this strange to make is causal filters, in line with the model
    raw_data = scipy.signal.sosfilt(filters[0]['sos'],  raw_emg.get_data())
    raw_data = scipy.signal.sosfilt(filters[1]['sos'],  raw_data)
    raw_emg = mne.io.RawArray(raw_data, raw_emg.info)

    # Extract samples to classify
    emg_fine_epochs = mne.make_fixed_length_epochs(
        raw_emg,
        duration=epoch_time,
        overlap=epoch_time - interval,
        reject_by_annotation=False,
    )

    # Make predictions and remap
    emg_fine_preds = emg_model.predict(np.abs(emg_fine_epochs.get_data()).mean(axis=2))
    emg_fine_preds[emg_fine_preds==0] = -1.0
    emg_fine_preds[emg_fine_preds==2] = 0.0
    emg_fine_preds[emg_fine_preds==1] = 1.0
        
    timestamps = np.arange(0, len(raw_emg.times) - epoch_time*2048 , interval * 2048)
    timestamps = timestamps + (epoch_time - interval) * 2048

    all_pred_events = np.stack([timestamps, np.zeros(emg_fine_preds.shape), emg_fine_preds]).T.astype('int32')

    bits = emg_fine_preds
    index = 0
    starting_point_events = []
    for bit, group in itertools.groupby(bits):
        length = len(list(group))
        if length * interval >= 3.75:
            starting_point_events.append(all_pred_events[index, :])
            # print(f"{length  * interval} seconds of {all_pred_events[index, 2]}")
        index += length
    
    # So at the start 
    starting_point_events = np.array(starting_point_events).astype('int32')
    starting_point_events[:, 0] = starting_point_events[:, 0] 

    return starting_point_events

events = make_precise_emg_events(raw)


In [34]:
raw = mne.preprocessing.compute_current_source_density(raw)

event_ids = dict(left=-1, right=1)

tmin = -3
tmax = 3.75
buffer = 0.5
epochs = mne.Epochs(
    raw,
    events,
    event_ids,
    tmin=tmin-buffer,
    tmax=tmax + buffer,
    baseline=None,
    preload=True,
    picks=["C3", "C4"],
)
epochs

0,1
Number of events,61
Events,left: 34 right: 27
Time range,-3.500 – 4.250 s
Baseline,off


In [36]:
plot_tfr(epochs, baseline=(tmin, -1.25), tmin=tmin, tmax=tmax, event_ids=event_ids)
plot_tfr(epochs, baseline=None, tmin=tmin, tmax=tmax, event_ids=event_ids)