# Working with XDF data using MNE
The goal of this notebook is to import data saved using the LSL Lab Recorder, and run through some pre-processing steps using [MNE-Python](https://mne.tools/stable/index.html)!

In [None]:
import mne
import matplotlib.pyplot as plt
import pyxdf
import numpy as np
from glob import glob

## Path to dataset

In [None]:
SUBJECT = 'ruomin'
DATA_PATH = f'/Users/shashankbansal/UCSD/Research/Telluride23/EEG-data/car_racing_error/sub-{SUBJECT}/'
XDF_FILES = glob(DATA_PATH + '**/*/*.xdf')
XDF_FILES = sorted(XDF_FILES) # to sort by run#
XDF_FILES

In [None]:
for s in streams:
    print(s['info']['name'], s['info']['type'])
    
    


In [None]:
ALL_EEG = []
ALL_Z = []
ALL_KEYS = []
ALL_GAME_EVENTS = []

# Load each XDF file for a given subject
for XDF in XDF_FILES:
    streams, header = pyxdf.load_xdf(XDF)
    
    # Get the first time stamp across all streams (read from time_stamps)
    first_timestamps = []

    for s in streams: # loop through remaining streams
        s_name = s['info']['name']
        t0 = s['time_stamps'][0]
        print(t0, '\t', s_name)

        first_timestamps.append(t0)

    first_timestamp = min(first_timestamps)
    print(first_timestamp, '\t', '<== earliest') 
    
    # Identify EEG data and impedance streams
    for s in streams:
        s_name = s['info']['name'][0]
        s_type = s['info']['type'][0]
        print(f'Stream Name: {s_name}\tType: {s_type}')

        # Get the EEG data stream for CGX
        if ('CGX' in s_name) and (s_type == 'EEG'):
            eeg_data = s['time_series']
            eeg_t = s['time_stamps'] - first_timestamp # offset first time stamp to t=0
            eeg_ch_names = [ch['label'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
            eeg_ch_units = [ch['unit'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
            eeg_sfreq = s['info']['effective_srate']
            #print(f'Channels: {eeg_ch_names}')
            #print(f'Unit: {eeg_ch_units}')
            print(f'Eff. Sampling Rate: {eeg_sfreq} Hz')
            print(eeg_data.shape)
            
            # Rescale to V for importing into MNE
            if 'microvolts' in eeg_ch_units:
                eeg_data /= 1e6
                   
            ALL_EEG.append(eeg_data)

        # Get the impedance data stream for CGX
        elif ('CGX' in s_name) and (s_type == 'Impeadance'): # typo in the stream name?
            z_data = s['time_series']
            z_t = s['time_stamps'] - first_timestamp
            z_ch_names = [ch['label'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
            z_ch_units = [ch['unit'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
            z_sfreq = s['info']['effective_srate']
            #print(f'Channels: {z_ch_names}')
            #print(f'Unit: {z_ch_units}')
            print(f'Eff. Sampling Rate: {z_sfreq} Hz')
            
            
            ALL_Z.append(z_data)

        # Keyboard events
        elif (s_type == 'Markers'):
            events = s['time_series']
            events_t = s['time_stamps'] - first_timestamp
    
            ALL_KEYS.append((events_t, events))
            
        elif (s_type == 'Gaming'):
            events = s['time_series']
            events_t = s['time_stamps'] - first_timestamp
    
            ALL_GAME_EVENTS.append((events_t, events))
    
    # Plot EEG data
    fig, ax = plt.subplots(2, 1)

    ax[0].plot(eeg_t, eeg_data)
    ax[0].set_ylabel('uV')
    ax[0].set_xlabel('Time Stamps [s]')
    ax[0].set_title('EEG data')

    # Plot event markers
    for i in range(len(events)):
        event_t = events_t[i]
        event_v = events[i][0]
        event_type = events[i][1]

        if event_type == 'left':
            ax[0].axvline(x = event_t, c='pink')
        elif event_type == 'right':
            ax[0].axvline(x = event_t, c='c')
        elif event_type == 'enter':
            ax[0].axvline(x = event_t, c='r')

    ax[1].plot(z_t, z_data)
    ax[1].set_ylabel('kOhms')
    ax[1].set_title('Impedance')

    plt.tight_layout()

plt.show()

In [None]:
ALL_KEYS[0]
ALL_GAME_EVENTS

In [None]:
print(len(ALL_EEG), len(ALL_Z), len(ALL_KEYS), len(ALL_GAME_EVENTS))

# Stitch all the datasets together
(skip for now)

In [None]:
if 0:
    eeg_stacked = np.vstack(ALL_EEG)
    z_stacked = np.vstack(ALL_EEG)
    lsl_stacked = np.vstack(ALL_EEG)

# Prepare to import data to mne

In [None]:
# Just use the last run's channel names, since the setup is the same across runs
ch_names = eeg_ch_names + z_ch_names
ch_names

(Manually) Identify each channel's type for importing data to mne:

In [None]:
ch_types = []
ch_types.extend(['eeg']*(29+1)) # first 29 channels are EEG, and A2
ch_types.extend(['misc']*(2+3+1)) # ExG x2, ACC x3, Packet Counter
ch_types.extend(['stim']) # Trigger channel (unused)
ch_types.extend(['misc']*(len(z_ch_names))) # all impedance channels

Make an info object for importing data to mne:

https://mne.tools/stable/generated/mne.create_info.html

In [None]:
info = mne.create_info(
    ch_names = ch_names,
    sfreq = eeg_sfreq,
    ch_types = ch_types,
)

## Preprocessing on the Gaming Events


In [None]:
event_mapping = {
    0: 'start',
    1: 'error_right',
    2: 'error_left',
    999: 'explosion'
}

ALL_GAME_EVENTS_REMAPPED = []

for i in range(len(ALL_GAME_EVENTS)):
    new_inner_list = []
    for triplet in ALL_GAME_EVENTS[i][1]:
        # Convert triplet to list as it might be an immutable tuple
        new_triplet = list(triplet)
        if new_triplet[1] == -1:
            new_triplet[1] = event_mapping[int(new_triplet[2])]
        else:
            new_triplet[1] = 'obstacle'
        new_inner_list.append(new_triplet)
    ALL_GAME_EVENTS_REMAPPED.append([ALL_GAME_EVENTS[i][0], new_inner_list])


In [None]:
print(len(ALL_GAME_EVENTS_REMAPPED))

Import the data as a raw array:

https://mne.tools/stable/generated/mne.io.RawArray.html

In [None]:
ALL_RAWS = []

for run_idx in range(len(ALL_EEG)):
    
    # Pull out eeg data for this run
    eeg_data = ALL_EEG[run_idx]
    z_data = ALL_Z[run_idx]
    
    # Why is there a mismatch on one run?
    print(eeg_data.shape)
    print(z_data.shape)
    
    n_eeg_samples = eeg_data.shape[0]
    n_z_samples = z_data.shape[0]
    n_eeg_z_samples = min(n_eeg_samples, n_z_samples)
    eeg_z_data = np.hstack(
        (
            eeg_data[:n_eeg_z_samples, :], 
            z_data[:n_eeg_z_samples, :]
        )
    )
    print(eeg_z_data.shape)
    
    # Make an MNE object
    raw = mne.io.RawArray(
        data = eeg_z_data.T, # mne expects (n_channels, n_times)
        info = info,
    )
    
    # Add montage
    raw.info.set_montage('standard_1020', match_case=False)

    # Get annotations from key presses
    k_events_t, k_events = ALL_KEYS[run_idx]
    k_ev_id = [ev[1] for ev in k_events]
    key_annot = mne.Annotations(
        onset = k_events_t,
        duration = np.zeros_like(k_events_t), # setting durations to 0 to just label event onsets
        description = k_ev_id,
    )
    
    # Get annotations from game
    g_events_t, g_events = ALL_GAME_EVENTS_REMAPPED[run_idx]
    g_ev_id = [ev[1] for ev in g_events]
    game_annot = mne.Annotations(
        onset = g_events_t,
        duration = np.zeros_like(g_events_t), # setting durations to 0 to just label event onsets
        description = g_ev_id,
    )

    # Add all these annotations to the raw object
    raw.set_annotations(key_annot + game_annot)

    ALL_RAWS.append(raw.copy())

In [None]:
g_events_t, g_events = ALL_GAME_EVENTS_REMAPPED[1]
g_ev_id = [ev[1] for ev in g_events]
print(g_ev_id)

# Remove / label bad segments

In [None]:
# TO DO
# crop out first 10s and last 15s from each run

for i in range(len(ALL_RAWS)):
    tmin = 10  # start time to keep
    tmax = ALL_RAWS[i].times[-1] - 15  # end time to keep
    ALL_RAWS[i] = ALL_RAWS[i].crop(tmin, tmax)

In [None]:
ALL_RAWS[0].plot(block=False)

## Merge the raw objects

In [None]:
# Remove bad segments
raw_merged = mne.concatenate_raws(ALL_RAWS)
print(np.unique(raw_merged.annotations.description))

## Visualize the 10-20 channel montage

In [None]:
# Plot the montage
mne.viz.plot_montage(raw_merged.info.get_montage())

# Time Series

In [None]:
# %matplotlib widget

# could go back to 
# %matplotlib inline 
# if the plot doubling becomes annoying

# using widget to make the plot interactive (for scrolling, bad channel selection, etc)

# scaling can help with visibility, especially if we have some bad channels / artifacts
scalings = {
    'eeg': 100e-6,
    'misc': 100
} 

# In case you want to manually flag some channels as bad (could also click in the plot)
bad_chans = []
raw.info['bads'] = bad_chans

PLOT_WINDOW_START_s = 0
PLOT_WINDOW_DUR_s = 20

raw_merged.plot(
    scalings = scalings,
    # butterfly = True,
    start = PLOT_WINDOW_START_s,
    duration = PLOT_WINDOW_DUR_s,
)

# Manually reject bad data segments

In [None]:
# bad_annot = mne.Annotations(
#     onset = [125, ],
#     duration = [20, ], # setting durations to 0 to just label event onsets
#     description = "bad",
# )

# #raw.set_annotations(raw.annotations + bad_annot)
# raw.set_annotations(lsl_annot + bad_annot)

In [None]:
# Remove any erroneous annotations (if needed)
# raw.annotations[5]
# raw.annotations.delete(5)

# Power Spectra
(before filtering)

In [None]:
raw_merged.compute_psd(n_fft = 1024).plot(average = False)

In [None]:
raw = raw_merged.copy()

Let's filter the data:

In [None]:
# Notch filter
raw_filt = raw.copy().notch_filter( # make a copy of the array to avoid modifying in place
    freqs = [60, 120, 180, 240]
)

# Highpass filter
raw_filt = raw_filt.filter( # make a copy of the array to avoid modifying in place
    l_freq = 1, # lower cut-off, in Hz
    h_freq = None, # upper cut-off, in Hz
)

n_fft = 1024

# picks = [''] # a list of channels that you want to focus on, if any
raw_filt.compute_psd(
    # picks= picks,
    # fmin = 0, # lower limit to plot
    # fmax = 60, # upper limit to plot
    # tmin = 0, # if only using a subset of the data for PSD computation
    # tmax = 10,
    n_fft = n_fft,
).plot()

# Identify any bad channels

In [None]:
# ## Update scalings if needed
# scalings['eeg'] = 200e-6 
# # scalings['eog'] = 500e-6
# # scalings['misc'] = 100

# raw_filt.info.bads = [] #'Oz', 'O1', , 'Fp2', 'F7', 'P3'] # ideally could label bad segments instead of dropping channels

# # Drop bad channels from visual examination before running ICA
# raw_filt_drop_bad = raw_filt.copy().drop_channels(raw_filt.info.bads)

# raw_filt_drop_bad.plot(
#     scalings = scalings,
#     butterfly = True,
#     start = PLOT_WINDOW_START_s,
#     duration = PLOT_WINDOW_DUR_s,
# )

## Re-reference to average (or just A2)
https://mne.tools/stable/generated/mne.set_eeg_reference.html

In [None]:
# raw_filt_drop_bad_rref, _ = mne.set_eeg_reference(
#     raw_filt_drop_bad, 
#     ref_channels=['A2'], #'average',
# )

# raw_filt_drop_bad_rref.plot(
#     scalings=scalings,
#     butterfly=True,
#     start = PLOT_WINDOW_START_s,
#     duration = PLOT_WINDOW_DUR_s,
# )

# Run ICA to remove artifacts
https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html#sphx-glr-auto-tutorials-preprocessing-40-artifact-correction-ica-py

In [None]:
from mne.preprocessing import ICA

# Make an ICA object
ica = ICA(
    #n_components = 29, # number of prinicipal components passed to ICA
    random_state = 23, # seed for numpy RNG
    #method = 'infomax', # can use 'fastica' or 'picard' too
)

# Run ICA
ica.fit(raw_filt)

## Plot ICA components over time

In [None]:
# %matplotlib widget
ica.plot_sources(
    raw_filt,
    start = PLOT_WINDOW_START_s,
    stop = PLOT_WINDOW_START_s + PLOT_WINDOW_DUR_s,
)

## Topomaps for each ICA component

In [None]:
suspect_ica_components = [
    0, # eye blinks
    4, # pulse
    8, # very noisy
    11,
]

ica.plot_properties(
    raw_filt,
    picks = suspect_ica_components, # select ICA components to plot topomaps for
)

plt.show()

## Remove suspect ICA components

In [None]:
PLOT_WINDOW_START_s = 35
PLOT_WINDOW_DUR_s = 30

suspect_ica_components = [
    0, # eye blinks
    4, # pulse
    8, # very noisy
]

ica.plot_overlay(
    raw_filt,
    exclude = suspect_ica_components,
    
    start = int(raw_filt.info['sfreq']) * PLOT_WINDOW_START_s, # this is index, not seconds
    stop = int(raw_filt.info['sfreq']) * (PLOT_WINDOW_START_s + PLOT_WINDOW_DUR_s),
)


## Project back to channel space

In [None]:
raw_ica = ica.apply(
    raw_filt, 
    exclude=suspect_ica_components, # ICA components to drop
)

In [None]:
raw_ica.plot(
    scalings = scalings,
    # butterfly = True,
    duration = PLOT_WINDOW_DUR_s,
    start = PLOT_WINDOW_START_s,
)

# Lets find events from annotations

In [None]:
events, event_id = mne.events_from_annotations(raw_ica)

In [None]:
events

In [None]:
event_id

In [None]:
from mne.time_frequency import tfr_morlet

eps = mne.Epochs(
    raw,
    # raw_ica,
    events = events, 
    event_id = event_id, 
    tmin=-0.5,
    tmax=1.5,
    baseline=None,
    event_repeated='merge',
)

freqs = np.logspace(*np.log10([6, 70]), num=16)
n_cycles = freqs / 2.0  # different number of cycle per frequency

power, itc = tfr_morlet(
    eps,
    freqs=freqs,
    n_cycles=n_cycles,
    use_fft=True,
    return_itc=True,
    decim=3,
    n_jobs=None,
)

In [None]:
power.ch_names

In [None]:
channels = ['PO7', 'O1', 'Oz', 'O2', 'P4', 'T8']

for channel in channels:
    power.plot([power.ch_names.index(channel)], baseline=(-0.5, 0), mode='logratio', title=power.ch_names[power.ch_names.index(channel)])

In [None]:
power.plot_topo(baseline=(-0.5, 0), mode="logratio", title="Average power")

# fig, axes = plt.subplots(1, 2, figsize=(7, 4), constrained_layout=True)

power.plot_joint(
    baseline=(-0.5, 0), mode="mean", tmin=-0.5, tmax=2, timefreqs=[(0.25, 11), (0.5, 13), (0.75, 18), (1, 36)]
)

In [None]:
eps.compute_psd().plot_topomap(normalize=False, contours=0)

In [None]:
ica.fit(eps)

In [None]:
%matplotlib widget
ica.plot_sources(
    eps,
    #start = PLOT_WINDOW_START_s,
    #stop = PLOT_WINDOW_START_s + PLOT_WINDOW_DUR_s,
)

In [None]:
ica.plot_properties(
    eps,
    picks = range(28), # select ICA components to plot topomaps for
)


In [None]:
reject_criteria = dict(
    eeg=200e-6,  # 100 µV
)  # 200 µV

In [None]:
epochs = mne.Epochs(
    raw_ica,
    events,
    tmin=-0.5,
    tmax=0.5,
    #reject_tmax=0,
    reject=reject_criteria,
    #flat=flat_criteria,
    #reject_by_annotation=False,
    preload=True,
)
epochs.plot_drop_log()

In [None]:
%matplotlib widget

for ev_id in [2, 3]:
    epochs[ev_id].average(
        # picks = []
    ).plot_joint()

# To Do
- Notch out power supply noise and look at EEG bands