In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.io import loadmat
from scipy.signal import welch

import pandas as pd

import mne
from autoreject import AutoReject, Ransac

plt.rcParams.update({'font.family': 'serif'})

In [None]:
# Useful parameters
ch_names = ['F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3', 'Cz', 'C4', 'T8']
ch_names = np.array(ch_names)

folder = 'NewbornEEGData/'
subjs = [5,6,9,11,12,13,14,16,17,18,19,20,21,22,25,27,28,29,30,31,33,34,35,37,39,40,41,45,46,47,49,52,53,56,57,58,60,62,63,64,65,66,67,68,69,70,71]

In [None]:
# Montage for plotting
mont1020 = mne.channels.make_standard_montage('standard_1020')
ind = [i for (i, channel) in enumerate(mont1020.ch_names) if channel in ch_names]
mont1020_new = mont1020.copy()
# Keep only the desired channels
mont1020_new.ch_names = [mont1020.ch_names[x] for x in ind]
kept_channel_info = [mont1020.dig[x+3] for x in ind]
# Keep the first three rows as they are the fiducial points information
mont1020_new.dig = mont1020.dig[0:3]+kept_channel_info

In [None]:
subj = 13

s1 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence1 (continuous).mat')
s2 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence2 (continuous).mat')

In [None]:
sfreq = s1['eegfs'][0][0]

info = mne.create_info(ch_names = list(ch_names),
                       ch_types = 'eeg',
                       sfreq = sfreq)

In [None]:
def plot_data(data, ch_names, title, fmin=0, fmax=100):
    '''
    Function for plotting the data

    Parameters
    ----------
    data : Data to plot.
    ch_names : List of channel names.
    title : Title of the plot.
    fmin : Minimum frequency to plot.
    fmax : Maximum frequency to plot.
    '''
    
    times = np.arange(data.shape[1]) / sfreq
    
    fig = plt.figure(figsize=(12,8))
    gs = gridspec.GridSpec(ncols=2, nrows=2, figure=fig)

    ax1 = fig.add_subplot(gs[1, :-1])
    for tmp in data:
        ax1.plot(times, tmp)
    ax1.plot(times, data.mean(axis=0), c='k', label='mean')
    ax1.set_xlabel('time [ms]')
    ax1.set_ylabel('mV')
    ax1.set_xlim(0,times[-1])
    plt.legend()
    
    ax2 = fig.add_subplot(gs[0, :])
    for i, tmp in enumerate(data):
        ax2.plot(times, tmp+np.max(data)*i)
    ax2.set_xlim(0,times[-1])
    ax2.set_xlabel('time [ms]')
    ax2.set_ylabel('mV')
    ax2.set_yticks([np.max(data)*i for i in range(len(ch_names))])
    ax2.set_yticklabels(ch_names)
    ax2.set_title(title)
    
    ax3 = fig.add_subplot(gs[1:, -1])
    info = mne.create_info(ch_names = list(ch_names), ch_types = 'eeg', sfreq = sfreq)
    psds, freqs = mne.time_frequency.psd_array_welch(data/1e6, sfreq=sfreq, fmin=fmin, fmax=fmax, n_fft=2048, n_overlap=1024//2, verbose=False)
    psds = np.log10(psds)
    ax3.plot(freqs, psds.T)
    ax3.plot(freqs, psds.mean(axis=0), c='black', label='mean')
    ax3.set_xlim(freqs[0],freqs[-1])
    plt.legend()
    ax3.set_ylabel('Power spectrum')
    
    plt.tight_layout()
    plt.show()

In [None]:
def find_bad_channels(data, MAX_PEAK=200, MAX_PEAK_SIGMA=15, MAX_DIST_PS=4.5e-3, OFFSET=30, fmin=1, fmax=100):
    '''
	Find bad channels based on the maximum peak and the distance from the mean of the PSD.
	Inputs:
		data: data to be analyzed
		MAX_PEAK: maximum peak allowed (in uV)
		MAX_PEAK_SIGMA: maximum peak allowed (in number of standard deviations)
		MAX_DIST_PS: maximum distance from the mean of the PSD allowed
		OFFSET: time (in seconds) to be ignored at the beginning and at the end of the signal
		fmin: minimum frequency for the PSD
		fmax: maximum frequency for the PSD
	'''
    
    idx_down, idx_up = 0, data.shape[1]
    ss = np.std(data)
    
    psds, freqs = mne.time_frequency.psd_array_welch(data/1e6, sfreq=sfreq, fmin=fmin, fmax=fmax, n_fft=2048, n_overlap=1024//2, verbose=False)
    psds = np.log10(psds)
    
    dist = [((tmp-psds.mean(axis=0))**2).sum() for tmp in psds]
    dist = np.array(dist) / (psds**2).sum(axis=1) * 1e2
    #print(dist)
    
    rej = []
    for i, ch in enumerate(ch_names):
        delta = np.abs(data[i]).max()
    
        str_rej = str()
        if delta>MAX_PEAK or dist[i]>MAX_DIST_PS:
            rej.append(ch)
            str_rej = '-> rejected'
            str_rej += ' ('
            if delta>MAX_PEAK:
                str_rej += ' peak'
                # Store peaks timepoints
                idx = np.where(np.abs(data[i])>MAX_PEAK)[0]
            '''
            if delta/ss>MAX_PEAK_SIGMA:
                str_rej += ' sigma'
            '''
            if dist[i]>MAX_DIST_PS:
                str_rej += ' ps'
            str_rej += ')'
            
        print(f'* {ch}: max peak = {np.round(delta, 2)} mV ({np.round(delta/ss, 2)} std); dist ps = {np.round(dist[i],3)} '+str_rej)
        
        if delta>MAX_PEAK:
            if np.max(idx)/sfreq<OFFSET:
                idx_down = np.max([np.max(idx),idx_down])
                print(f'[*] WARNING: Peak only in the initial part - time: {np.max(idx)/sfreq} s')
            if np.min(idx)/sfreq>180-OFFSET:
                idx_up = np.min([np.min(idx),idx_up])
                print(f'[*] WARNING: Peak only in the last part - time: {np.min(idx)/sfreq} s')
                
    print('\nBad channels:', rej, '\n')
    
    return rej, idx_down, idx_up

In [None]:
def preprocess(data, subj, silence, fmin=1, fmax=100):
    '''
    Preprocess the whole dataset.

    Inputs:
        data: data to be analyzed
        subj: subject number
        silence: boolean to plot or not the results
        fmin: minimum frequency for the PSD
        fmax: maximum frequency for the PSD
    '''

    print(f'[*] SUBJ {subj} - silence {silence}')
    bad_chs = []
    
    ### Plot original data
    plot_data(data, ch_names, f'Original data')
        
    ### Reject channels
    print('### Rejection bad channels - our method...')
    rej, idx_down, idx_up = find_bad_channels(data, fmin=fmin, fmax=fmax)
    
    ### Reject channels with Ransac
    print('### Rejection bad channels - RANSAC')
    epochs = mne.EpochsArray(np.expand_dims(data, axis=0)/1e6, info)
    epochs.set_montage(mont1020_new)
    ransac = Ransac()
    ransac.fit(epochs)
    print('Bad channels:', ransac.bad_chs_)
    
    if idx_down>0 or idx_up<data.shape[1]:
        if idx_down>0:
            idx_down +=1
        if idx_up<data.shape[1]:
            idx_up -=1
            
        print('\n### REMOVE PART OF THE DATA')
        print('tmin:', idx_down/sfreq, ', tmax:', idx_up/sfreq, '\n')
        data = data[:,idx_down:idx_up]
    
        ### Reject channels
        print('### Rejection bad channels - our method...')
        rej, idx_down, idx_up = find_bad_channels(data, fmin=fmin, fmax=fmax)
    
        ### Reject channels with Ransac
        print('### Rejection bad channels - RANSAC')
        epochs = mne.EpochsArray(np.expand_dims(data, axis=0)/1e6, info)
        epochs.set_montage(mont1020_new)
        ransac = Ransac()
        ransac.fit(epochs)
        print('Bad channels:', ransac.bad_chs_)
    
    ### Plot clean data
    idx_to_keep = [i for i, ch in enumerate(ch_names) if ch not in rej]
    if len(idx_to_keep)>0:
        plot_data(data[idx_to_keep], ch_names[idx_to_keep], f'Cleaned data')
    
    ### Store results
    bad_chs.append(rej)
    bad_chs.append(ransac.bad_chs_)
    
    return bad_chs

## A couple of examples...

In [None]:
preprocess(s1['eeg_rest'].astype('float'), subj, 1)

In [None]:
preprocess(s2['eeg_rest'].astype('float'), subj, 2)

## Preprocess the whole dataset

In [None]:
mine1, mine2 = [], []
ran1, ran2 = [], []

### Loop over subjects
for i, subj in enumerate(subjs):
    print(f'\n\n########## SUBJ {subj} ({i+1}/{len(subjs)})##########')
    
    ### Load data
    s1 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence1 (continuous).mat')
    s2 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence2 (continuous).mat')
    
    ### Silence 1
    tmp = preprocess(s1['eeg_rest'].astype('float'), subj, 1)
    mine1.append(tmp[0])
    ran1.append(tmp[1])
    
    ### Silence 2
    tmp = preprocess(s2['eeg_rest'].astype('float'), subj, 2)
    mine2.append(tmp[0])
    ran2.append(tmp[1])

In [None]:
rej_mine1 = [len(tmp) for tmp in mine1]
rej_mine2 = [len(tmp) for tmp in mine2]

rej_ran1 = [len(tmp) for tmp in ran1]
rej_ran2 = [len(tmp) for tmp in ran2]

df = pd.DataFrame(data=np.array([rej_mine1, rej_mine2, rej_ran1, rej_ran2]).T, index=subjs, columns=['S1-our', 'S2-our', 'S1-Ransac', 'S2-Ransac'])

In [None]:
df.style.set_caption("Number of removed channels")

In [None]:
rej_s1 = np.where(np.array(rej_mine1)>5)[0]
rej_s1 = [subjs[i] for i in rej_s1]

rej_s2 = np.where(np.array(rej_mine2)>5)[0]
rej_s2 = [subjs[i] for i in rej_s2]

rejected = list(set(rej_s1+rej_s2))
rejected.sort()
print('subject rejected:', rejected)
print('N subjects:', len(subjs))
print('N rejected:', len(rejected))
print('N remained:', len(subjs)-len(rejected))

In [None]:
rem_s1 = np.where(np.array(rej_mine1)<6)[0]
rem_s1 = [10-rej_mine1[i] for i in rem_s1]

rem_s2 = np.where(np.array(rej_mine2)<6)[0]
rem_s2 = [10-rej_mine2[i] for i in rem_s2]

print('Clean channels S1:', np.mean(rem_s1), '+-', np.std(rem_s1))
print('Clean channels S2:', np.mean(rem_s2), '+-', np.std(rem_s2))

In [None]:
import matplotlib.font_manager
matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')

# Plot for figure - supplementary material

In [None]:
subj = subjs[0]

s1 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence1 (continuous).mat')
s2 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence2 (continuous).mat')

In [None]:
def plot_eeg(data, ch_names, title, fmin=0, fmax=100, scale=0.9, save=True):
    times = np.arange(data.shape[1]) / sfreq
    
    fig = plt.figure(figsize=(12,8))
    ax = plt.subplot(1,1,1)
    
    for i, tmp in enumerate(data):
        ax.plot(times, tmp*scale+np.max(data)*i, c='k')
    ax.set_xlim(-times[1]*10,times[-1])
    ax.set_yticks([np.max(data)*i for i in range(len(ch_names))])
    ax.set_yticklabels(ch_names, size=26)
    ax.set_title(title)
    
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    
    ax.axes.get_xaxis().set_visible(False)
    
    if save:
        plt.savefig('prep.svg', bbox_inches='tight', transparent=True)
    
    plt.show()

In [None]:
tmin = 1000
T = 1500
tmax = tmin+T

plot_eeg(s2['eeg_rest'][:,tmin:tmax], ch_names, None, scale=0.8)

In [None]:
subj = 13

s1 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence1 (continuous).mat')
s2 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence2 (continuous).mat')

In [None]:
ch = 1

tmin = 74000
T = 11500
tmax = tmin+T

tpeak_on = tmin+3150
tpeak_off = tpeak_on+5300

plt.figure(figsize=(8,3))
ax = plt.subplot(1,1,1)

plt.plot(np.arange(tpeak_off,tmax)-tmin,s2['eeg_rest'][ch,tpeak_off:tmax], c='k', lw=0.8)
plt.plot(s2['eeg_rest'][ch,tmin:tpeak_on], c='k', lw=0.8)
plt.plot(np.arange(tpeak_on,tpeak_off)-tmin, s2['eeg_rest'][ch,tpeak_on:tpeak_off], color="firebrick", lw=0.8)

ax.axes.get_xaxis().set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)

plt.ylim(-200, 200)
plt.yticks([-200, -100, 0, 100, 200], size=18)

plt.text(-900,240,r'$\mu$V', size=18)

plt.savefig('peak.svg', bbox_inches='tight', transparent=True)
plt.show()

In [None]:
subj = 13

s1 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence1 (continuous).mat')
s2 = loadmat(f'NewbornEEGData/BB{subj}_Filtered (1-100)/Silence (500)/BB{subj} silence2 (continuous).mat')

In [None]:
def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])

def plot_ps(data, ch_names, title, fmin=0, fmax=100, ch_bad=None):
    times = np.arange(data.shape[1]) / sfreq
    
    fig = plt.figure(figsize=(4.5,3))
    ax3 = plt.subplot(1,1,1)
    
    info = mne.create_info(ch_names = list(ch_names), ch_types = 'eeg', sfreq = sfreq)
    psds, freqs = mne.time_frequency.psd_array_welch(data/1e6, sfreq=sfreq, fmin=fmin, fmax=fmax, n_fft=2048, n_overlap=1024//2, verbose=False)
    psds = np.log10(psds)
    ax3.plot(freqs, psds.T, c='k', alpha=0.3, lw=2)
    
    if ch_bad is not None:
        ax3.plot(freqs, psds[ch_bad], c="firebrick", lw=3)
        
    ax3.plot(freqs, psds.mean(axis=0), c='black', label='avg', lw=3)
    ax3.set_xlim(freqs[0],freqs[-1])
    plt.legend(fontsize=16)
    
    ax3.spines['right'].set_visible(False)
    ax3.spines['top'].set_visible(False)
    adjust_spines(ax3, ['left', 'bottom'])
    
    plt.ylim(-13, -9)
    plt.yticks([-13, -11, -9], size=18)
    plt.xticks([0, 20, 40], size=18)
    
    plt.text(-3,-8.7, 'dB', size=18)
    plt.text(42,-13.3, 'Hz', size=18)
    
    plt.savefig('ps.svg', bbox_inches='tight', transparent=True)
    plt.show()

In [None]:
plot_ps(s1['eeg_rest'], ch_names, None, fmin=0.47, fmax=40, ch_bad=3)