In [1]:
""" 
Preproccess EEG data from BrainVision
Saves in .mat format, segmented by triggers
"""

################
## Imports
################
import mne
import os
import glob
import numpy as np
import pandas as pd
from scipy.io import savemat
import matplotlib.pyplot as plt

In [2]:
################
## EDIT THIS PART
################

# file = '../EEG_data_test/trig_test_2/Untitled2.vhdr'
file = '../../Data/Cindy/Realtime/Multi/multi1.vhdr'
log_path = '../../Data/Cindy/Realtime/Multi/trial_arrow_log.csv'

filename = file.split('/')[-1].split('.')[0]
exp_type = file.split('/')[-1].split('.')[0].split('_')[-1]
# exp_type = 'mixed' 

output_dir = '../../Data/Cindy/Realtime/Multi/Preprocessed/preprocessed_mixed_01_15Hz'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# overwrite = False


In [3]:
################
## Parameters
################
plot = False
# FS_ORIG = 25000  # Hz

# Preprocessing
# Notch filtering
notch_applied = False
freq_notch = 50

# Bandpass filtering
bpf_applied = True
freq_low   = 0.1
freq_high  = 15
bandpass = str(freq_low) + '-' + str(freq_high)
ftype = 'butter'
order = 3

# Spherical interpolation
int_applied = False
interpolation = 'spline'

# Rereferencing using average of mastoids electrodes
reref_applied = True
reref_type = 'average'  #channels or average
reref_channels = None

# Downsampling
down_applied = True
downfreq = 128
if not down_applied:
    downfreq = 'N/A'


In [4]:
################
## Read and crop data
################
raw = mne.io.read_raw_brainvision(file, preload=True)

#get info from raw
FS_ORIG = raw.info['sfreq']
ch_names = raw.ch_names
events, event_id = mne.events_from_annotations(raw)

#crop to start
exp_start = events[0][0] / FS_ORIG
exp_end = raw.times[-1]
eeg = raw.copy().crop(tmin = exp_start,
                      tmax = exp_end)

Extracting parameters from ../../Data/Cindy/Realtime/Multi/multi1.vhdr...
Setting channel info structure...
Reading 0 ... 477129  =      0.000 ...   954.258 secs...
Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1', 'Stimulus/S  2', 'Stimulus/S  3']


  raw = mne.io.read_raw_brainvision(file, preload=True)
['Soundwave']
Consider setting the channel types to be of EEG/sEEG/ECoG/DBS/fNIRS using inst.set_channel_types before calling inst.set_montage, or omit these channels when creating your montage.
  raw = mne.io.read_raw_brainvision(file, preload=True)


In [5]:
################
## SEGMENT DATA
################

#segment to 30s chunks of speech or music attended
trial_starts = events[events[:, 2] == 1]

trial_log = pd.read_csv(log_path)
for i, trial in enumerate(trial_starts):
    print(f'Processing trial {i+1}/{len(trial_starts)}: {trial}')
    start_time = trial[0] / FS_ORIG
    end_time = trial_starts[i + 1][0] / FS_ORIG if i + 1 < len(trial_starts) else None
    eeg = raw.copy().crop(tmin = start_time, tmax = end_time)

    ################
    ## Preprocess
    ################
    df_pre = pd.DataFrame()

    ## -------------
    ## Select channels
    ## -------------
    eeg_channels = ch_names[0:31]
    eeg = eeg.pick_channels(eeg_channels)
    if plot:
        eeg.plot(start=100, duration=10, n_channels=len(raw.ch_names))

    ## -------------
    ## Notch filtering
    ## -------------
    df_pre['notch_applied'] = [notch_applied]
    if notch_applied:
        eeg = eeg.notch_filter(freqs=freq_notch)
        df_pre['notch'] = [freq_notch]
        if plot:
            eeg.plot()

    ## -------------
    ## BPFiltering
    ## -------------
    df_pre['bpf_applied'] = [bpf_applied]
    if bpf_applied:
        iir_params = dict(order=order, ftype=ftype)
        filter_params = mne.filter.create_filter(eeg.get_data(), eeg.info['sfreq'], 
                                                l_freq=freq_low, h_freq=freq_high, 
                                                method='iir', iir_params=iir_params)

        if plot:
            flim = (1., eeg.info['sfreq'] / 2.)  # frequencies
            dlim = (-0.001, 0.001)  # delays
            kwargs = dict(flim=flim, dlim=dlim)
            mne.viz.plot_filter(filter_params, eeg.info['sfreq'], compensate=True, **kwargs)
            # plt.savefig(os.path.join(output_dir, 'bpf_ffilt_shape.png'))

        eeg = eeg.filter(l_freq=freq_low, h_freq=freq_high, method='iir', iir_params=iir_params)
        df_pre['bandpass'] = [iir_params]
        df_pre['HPF'] = [freq_low]
        df_pre['LPF'] = [freq_high]
        if plot:
            eeg.plot()

    ## -------------
    ## Intrpolation
    ## -------------
    df_pre['int_applied'] = [int_applied]
    if int_applied: 
        eeg = eeg.interpolate_bads(reset_bads=False)  #, method=interpolation

        # Get the indices and names of the interpolated channels
        interp_inds = eeg.info['bads']
        interp_names = [eeg.info['ch_names'][i] for i in interp_inds]

        # Print the number and names of the interpolated channels
        print(f'{len(interp_inds)} channels interpolated: {interp_names}')

        df_pre['interpolation'] = [interpolation]
        df_pre['interp_inds'] = [interp_inds]
        df_pre['interp_names'] = [interp_names]

        if plot:
            eeg.plot()

    ## -------------
    ## Rereferencing
    ## -------------
    df_pre['reref_applied'] = [reref_applied]
    if reref_applied:
        if reref_type == 'average':
            # reref to average
            eeg = eeg.set_eeg_reference(ref_channels='average')
            df_pre['reref_type'] = [reref_type]
            df_pre['reref_channels'] = ['average']
            if plot:
                eeg.plot()

        elif reref_type == 'channels':
            # reref to a channel
            eeg = eeg.set_eeg_reference(ref_channels=reref_channels)
            df_pre['reref_type'] = [reref_type]
            df_pre['reref_channels'] = [reref_channels]
            if plot:
                eeg.plot()

    ## -------------
    ## Resampling
    ## -------------
    df_pre['down_applied'] = [down_applied]
    df_pre['downfreq'] = [downfreq]
    if down_applied:
        eeg = eeg.resample(sfreq=downfreq)
        print(eeg.info)
        if plot:
            eeg.plot()

    ## -------------
    ## Save preprocessing stages
    ## -------------
    df_pre.to_csv(os.path.join(output_dir, filename+'_pp_record.csv'), index=False)


    ################
    ## Crop into different trials
    ################
    eeg = eeg.get_data()
    print(eeg.shape)

    crop_params = [
        (0, 30*downfreq, 'FirstHalfAttend'),
        (30*downfreq, 60*downfreq, 'SecondHalfAttend')
    ]

    for tmin, tmax, stim_key in crop_params:
        stim = trial_log.iloc[i][stim_key]

        music_file = trial_log.iloc[i]['MusicFile'].split('/')[-1].split('.')[0]
        speech_file = trial_log.iloc[i]['SpeechFile'].split('/')[-1].split('.')[0]

        eeg_cropped = eeg[:, int(tmin):int(tmax)]
        filename_out = f"{filename}_{stim}_{i}.mat"
        filepath = os.path.join(output_dir, filename_out)
        savemat(filepath, {'eeg_data': eeg_cropped,
                           'stimuli_music': music_file,
                           'stimuli_speech': speech_file,
                           'stim_attended': stim,
                           'stim_attended_pos': stim_key})
    

Processing trial 1/5: [7280    0    1]
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Setting up band-pass filter from 0.1 - 15 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 12 (effective, after forward-backward)
- Cutoffs at 0.10, 15.00 Hz: -6.02, -6.02 dB

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 15 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 12 (effective, after forward-backward)
- Cutoffs at 0.10, 15.00 Hz: -6.02, -6.02 dB

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
<Info | 9 non-empty values
 bads: []
 ch_names: Fp1, Fz, F3, F7, FT9, FC5, FC1, C3, T7, TP9, CP5, CP1, Pz, P3, ...
 chs: 31 EEG
 custom_ref_applied: True
 dig: 34 items (3 Cardinal, 3