# Working with XDF data using MNE
The goal of this notebook is to import data saved using the LSL Lab Recorder, and run through some pre-processing steps using [MNE-Python](https://mne.tools/stable/index.html)!

In [None]:
import mne
import matplotlib.pyplot as plt
import pyxdf
import numpy as np

## Path to dataset

In [None]:
DATA_PATH = '/Users/shashankbansal/UCSD/Research/Telluride23/EEG-data/car_racing_error/sub-karan/ses-S001/eeg/'
!ls $DATA_PATH

In [None]:
import os
runs = 2
streams, headers = [], []
for run in range(1, runs+1):
    s, h = pyxdf.load_xdf(os.path.join(DATA_PATH, 'sub-karan_ses-S001_task-Default_run-00{}_eeg.xdf'.format(run)))
    streams.append(s)
    headers.append(h)
    


In [None]:
# Dict structure of a stream object
headers[0]

In [None]:
streams[0][0]

In [None]:
stream = streams[0]

In [None]:
stream = streams[0]
# First time stamps for each stream
FIRST_TIME_STAMP = float(stream[0]['footer']['info']['first_timestamp'][0])
print(FIRST_TIME_STAMP)

for i in range(1, len(stream)):
    START_TIME = float(stream[i]['footer']['info']['first_timestamp'][0])
    print(START_TIME)
    
    # update the earliest time stamp seen if needed
    if START_TIME < FIRST_TIME_STAMP:
        FIRST_TIME_STAMP = float(START_TIME)

print(FIRST_TIME_STAMP, '<== earliest')   

# Collect stream info and data

In [None]:
# Identify EEG data and impedance streams

for s in stream:
    print(s)
    
    s_name = s['info']['name'][0]
    s_type = s['info']['type'][0]
    print(f'Stream Name: {s_name}\tType: {s_type}')
    
    # Get the EEG data stream for CGX
    if ('CGX' in s_name) and (s_type == 'EEG'):
        eeg_data = s['time_series']
        eeg_t = s['time_stamps'] - FIRST_TIME_STAMP # offset first time stamp to t=0
        eeg_ch_names = [ch['label'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
        eeg_ch_units = [ch['unit'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
        eeg_sfreq = s['info']['effective_srate']
        print(f'Channels: {eeg_ch_names}')
        print(f'Unit: {eeg_ch_units}')
        print(f'Eff. Sampling Rate: {eeg_sfreq} Hz')
        
    # Get the impedance data stream for CGX
    elif ('CGX' in s_name) and (s_type == 'Impeadance'): # typo in the stream name?
        z_data = s['time_series']
        z_t = s['time_stamps'] - FIRST_TIME_STAMP
        z_ch_names = [ch['label'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
        z_ch_units = [ch['unit'][0] for ch in s['info']['desc'][0]['channels'][0]['channel']]
        z_sfreq = s['info']['effective_srate']
        print(f'Channels: {z_ch_names}')
        print(f'Unit: {z_ch_units}')
        print(f'Eff. Sampling Rate: {z_sfreq} Hz')
        
    # Keyboard events
    elif (s_type == 'Markers'):
        events = s['time_series']
        events_t = s['time_stamps'] - FIRST_TIME_STAMP

In [None]:
print(eeg_data.shape)
eeg_data

In [None]:
print(sum(eeg_data[:, -1])) # no events in trigger channel
eeg_data[:, -1] # trigger channel

In [None]:
events

In [None]:
# Plot EEG data
fig, ax = plt.subplots(2, 1)

ax[0].plot(eeg_t, eeg_data)
ax[0].set_ylabel('uV')
ax[0].set_xlabel('Time Stamps [s]')
ax[0].set_title('EEG data')

# Plot event markers
for i in range(len(events)):
    event_t = events_t[i]
    event_v = events[i][0]
    event_type = events[i][1]
    
    if event_type == 'left':
        ax[0].axvline(x = event_t, c='pink')
    elif event_type == 'right':
        ax[0].axvline(x = event_t, c='c')
    elif event_type == 'enter':
        ax[0].axvline(x = event_t, c='r')
    
ax[1].plot(z_t, z_data)
ax[1].set_ylabel('kOhms')
ax[1].set_title('Impedance')

plt.show()

# Prepare to import data to mne

## Rescale EEG data to V
(only run this once!)

In [None]:
print(eeg_data[:, 0])

# Rescale units of eeg_data to V
if 'microvolts' in eeg_ch_units:
    eeg_data /= 1e6
    
print(eeg_data[:, 0])

In [None]:
eeg_data.T.shape # mne expects (n_channels, n_times)

In [None]:
eeg_ch_names

(Manually) Identify each channel's type for importing data to mne:

In [None]:
ch_types = []
ch_types.extend(['eeg']*29) # first 29 channels are EEG
ch_types.extend(['misc']*(1+2+3+1)) # A2 (assuming unused, ExG x2, ACC x3, Packet Counter)
ch_types.extend(['stim']) # Trigger channel (probably unused)

Make an info object for importing data to mne:

https://mne.tools/stable/generated/mne.create_info.html

In [None]:
info = mne.create_info(
    ch_names = eeg_ch_names,
    sfreq = eeg_sfreq,
    ch_types = ch_types,
)

Import the data as a raw array:

https://mne.tools/stable/generated/mne.io.RawArray.html

In [None]:
raw = mne.io.RawArray(
    data = eeg_data.T, # mne expects (n_channels, n_times)
    info = info,
)

raw

## Apply the 10-20 channel montage

In [None]:
raw.info.set_montage('standard_1020', match_case=False)

# Plot the montage
mne.viz.plot_montage(raw.info.get_montage())

# Add annotations from marker stream

In [None]:
ev_id = [ev[1] for ev in events]
ev_id

https://mne.tools/stable/auto_tutorials/raw/30_annotate_raw.html

In [None]:
lsl_annot = mne.Annotations(
    onset = events_t,
    duration = np.zeros_like(events_t), # setting durations to 0 to just label event onsets
    description = ev_id,
)

raw.set_annotations(lsl_annot)

In [None]:
%matplotlib widget

# could go back to 
# %matplotlib inline 
# if the plot doubling becomes annoying

# using widget to make the plot interactive (for scrolling, bad channel selection, etc)

# scaling can help with visibility, especially if we have some bad channels / artifacts
scalings = {
    'eeg': 10e-3, 
    'eog': 500e-6,
    'misc': 100
} 

# In case you want to manually flag some channels as bad (could also click in the plot)
bad_chans = []
raw.info['bads'] = bad_chans

PLOT_WINDOW_START_s = 0
PLOT_WINDOW_DUR_s = 20

raw.plot(
    scalings = scalings,
    #butterfly = True,
    start = PLOT_WINDOW_START_s,
    duration = PLOT_WINDOW_DUR_s,
)

# Manually reject bad data segments

In [None]:
# bad_annot = mne.Annotations(
#     onset = [7, 11, ],
#     duration = [3, 19, ], # setting durations to 0 to just label event onsets
#     description = "bad",
# )

#raw.set_annotations(raw.annotations + bad_annot)
raw.set_annotations(lsl_annot)

In [None]:
# Remove the first 20 seconds of data
raw.crop(tmin = 20)

In [None]:
raw.annotations

In [None]:
# Remove any erroneous annotations (if needed)
# raw.annotations[5]
# raw.annotations.delete(5)

raw.annotations

In [None]:
# Plot again
raw.plot(
    scalings = scalings,
    butterfly = True,
    start = PLOT_WINDOW_START_s,
    duration = PLOT_WINDOW_DUR_s,
)

# Power Spectra
(before filtering)

In [None]:
raw.compute_psd(n_fft = 1024).plot(average = False)

Let's bandpass filter the data:

In [None]:
raw_filt = raw.copy().filter( # make a copy of the array to avoid modifying in place
    l_freq = 0.1, # lower cut-off, in Hz
    h_freq = 100, # upper cut-off, in Hz
)

# And look at the PSD again
raw_filt.compute_psd(n_fft = 1024).plot(average = False)

n_fft = 1024

# picks = [''] # a list of channels that you want to focus on, if any
fig = raw_filt.compute_psd(
    # picks= picks,
    fmin = 0, # lower limit to plot
    fmax = 100, # upper limit to plot
    # tmin = 0, # if only using a subset of the data for PSD computation
    # tmax = 10,
    n_fft = n_fft,
).plot()

# Some reformatting if needed
# fig.get_axes()[0].set_ylim(YLIM_MIN, YLIM_MAX)
fig.get_axes()[0].set_title('Zooming in on pass-band')
plt.show()

# Identify any bad channels

In [None]:
# ## Update scalings if needed
# scalings['eeg'] = 200e-6 
# # scalings['eog'] = 500e-6
# # scalings['misc'] = 100

# ## How did you identify bad channels?
# # raw_filt.info.bads = ['F7', 'FC5'] #'Oz', 'PO8', 'Fp2', 'F7', 'P3'] # ideally could label bad segments instead of dropping channels

# # Drop bad channels from visual examination before running ICA
# raw_filt_drop_bad = raw_filt.copy().drop_channels(raw_filt.info.bads)

# raw_filt_drop_bad.plot(
#     scalings = scalings,
#     butterfly = True,
#     start = PLOT_WINDOW_START_s,
#     duration = PLOT_WINDOW_DUR_s,
# )

## Re-reference to average
https://mne.tools/stable/generated/mne.set_eeg_reference.html

In [None]:
# raw_filt_drop_bad_rref, _ = mne.set_eeg_reference(
#     raw_filt_drop_bad, 
#     ref_channels='average',
# )

# raw_filt_drop_bad_rref.plot(
#     scalings=scalings,
#     butterfly=True,
#     start = PLOT_WINDOW_START_s,
#     duration = PLOT_WINDOW_DUR_s,
# )

# Run ICA to remove artifacts
https://mne.tools/stable/auto_tutorials/preprocessing/40_artifact_correction_ica.html#sphx-glr-auto-tutorials-preprocessing-40-artifact-correction-ica-py

In [None]:
from mne.preprocessing import ICA

# Make an ICA object
ica = ICA(
    #n_components = 29, # number of prinicipal components passed to ICA
    random_state = 23, # seed for numpy RNG
    #method = 'infomax', # can use 'fastica' or 'picard' too
)

# Run ICA
ica.fit(raw_filt)

## Plot ICA components over time

In [None]:
%matplotlib widget
ica.plot_sources(
    raw_filt,
    start = PLOT_WINDOW_START_s,
    stop = PLOT_WINDOW_START_s + PLOT_WINDOW_DUR_s,
)

## Topomaps for each ICA component

In [None]:
suspect_ica_components = [
    0, # eye blinks
    9 # 10, # 17, # very noisy
]

ica.plot_properties(
    raw_filt,
    picks = suspect_ica_components, # select ICA components to plot topomaps for
)

plt.show()

## Remove suspect ICA components

In [None]:
PLOT_WINDOW_START_s = 15
PLOT_WINDOW_DUR_s = 30

ica.plot_overlay(
    raw_filt,
    exclude = suspect_ica_components,
    start = int(raw_filt.info['sfreq']) * PLOT_WINDOW_START_s, # this is index, not seconds
    stop = int(raw_filt.info['sfreq']) * (PLOT_WINDOW_START_s + PLOT_WINDOW_DUR_s),
)


## Project back to channel space

In [None]:
raw_ica = ica.apply(
    raw_filt, 
    exclude=suspect_ica_components, # ICA components to drop
)

In [None]:
%matplotlib widget
raw_ica.plot(
    scalings = scalings,
    butterfly = True,
    duration = PLOT_WINDOW_DUR_s,
    start = PLOT_WINDOW_START_s,
)

# Lets find events from annotations

In [None]:
events, event_id = mne.events_from_annotations(raw_ica)

In [None]:
events, event_id

In [None]:
event_id

In [None]:
eps = mne.Epochs(
    #raw_filt,
    raw_ica,
    events = events, 
    event_id = event_id, 
    tmin=-0.5,
    tmax=1.5,
    baseline=None,
)

%matplotlib inline

for ev_id in event_id:
    eps[ev_id].plot_image(
        #picks = [],
        combine='mean'
    )

In [None]:
%matplotlib widget

for ev_id in event_id:
    eps[ev_id].average(
        # picks = []
    ).plot_joint()

# Time-frequency analysis: power and inter-trial coherence

In [None]:
from mne.time_frequency import tfr_morlet

eps = mne.Epochs(
    raw_filt,
    # raw_ica,
    events = events, 
    event_id = event_id, 
    tmin=-0.5,
    tmax=1.5,
    baseline=None,
)

freqs = np.logspace(*np.log10([6, 70]), num=16)
n_cycles = freqs / 2.0  # different number of cycle per frequency

power, itc = tfr_morlet(
    eps,
    freqs=freqs,
    n_cycles=n_cycles,
    use_fft=True,
    return_itc=True,
    decim=3,
    n_jobs=None,
)

In [None]:
print(freqs)

In [None]:
power.ch_names

In [None]:
# ### THIS IS Crashing the kernel for some reason
# import matplotlib.pyplot as plt

# # Define the indices of the channels to plot
# # channels = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# channels = list(range(1, 5))
# # Create a figure with 9 subplots
# # fig, axes = plt.subplots(2, 2, figsize=(20, 10))

# # Loop over the channels
# for i, channel in enumerate(channels):
#     # ax = axes.flatten()[i]
#     power.plot([channel], baseline=(-0.5, 0), mode='logratio', title=power.ch_names[channel])

# # plt.tight_layout()
# # plt.show()

In [None]:
channels = ['PO7', 'O1', 'Oz', 'O2', 'P4', 'T8']

for channel in channels:
    power.plot([power.ch_names.index(channel)], baseline=(-0.5, 0), mode='logratio', title=power.ch_names[power.ch_names.index(channel)])

In [None]:
power.plot_topo(baseline=(-0.5, 0), mode="logratio", title="Average power")

# fig, axes = plt.subplots(1, 2, figsize=(7, 4), constrained_layout=True)

power.plot_joint(
    baseline=(-0.5, 0), mode="mean", tmin=-0.5, tmax=2, timefreqs=[(0.25, 11), (0.5, 13), (0.75, 18), (1, 36)]
)

In [None]:
eps.compute_psd().plot_topomap(normalize=False, contours=0)