# setup

In [None]:
%matplotlib inline
import eelbrain as eel
import numpy as np
import scipy, pathlib, importlib, mne, time, os, sys
from tqdm import tqdm
import matplotlib.pyplot as plt
import preprocessing as pre
import models as md
import plotting
from pathnames import *
importlib.reload(pre)
importlib.reload(md)
importlib.reload(plotting)

mne.set_log_level(verbose='error')

subjects = [f'part_{i:03d}' for i in range(1, 26)]
verbose = False

# load BDF data and preprocess

In [None]:
ch = 'Cz'
refs = ['EXG1', 'EXG2'] # mastoids

out_path_data.mkdir(exist_ok=True, parents=True)

for stimtype in ['s']: # change accordingly to 'in' or 's'
    speechfiles = [f'{stimtype}_speech_{i}a' for i in range(1,5)] + [f'{stimtype}_speech_{i}b' for i in range(1,5)]
    clickfile = f'{stimtype}_clicks_4c'

    print(speechfiles)

    # loop over subjects
    for subject in tqdm(subjects[12:]):
        subject_folder = data_path / subject

        # speech files
        eegs = [] # eeg
        ergs = [] # Erg
        eegfs = [] # eeg file names
        fp_ergs = [] # first peak time in Erg channel (eeg and Erg start after this)
        for eegf in speechfiles:
            if subject == 'part_018' and eegf == 'in_speech_4b': # part_018 has missing data error in trial 4b
                continue
            if subject == 'part_024' and eegf in ['in_speech_2a', 'in_speech_4b']: # part_024 has flat ERG channels for trials 2a and 4b
                continue
            eegfile = subject_folder / f'P{subject[-3:]}C{eegf}.bdf'
            if not eegfile.exists():
                print(f'FILE {eegfile.stem} NOT FOUND')
                continue
            # load EEG, reref and extract channels
            eeg_a, erg_a, erg_start = pre.load_eeg(eegfile, ch=ch, refs=refs)
            eegs.append(eeg_a)
            eegfs.append(eegfile.stem)
            ergs.append(erg_a)
            fp_ergs.append(erg_start)
        eel.save.pickle(dict(eegs=eegs, filenames=eegfs, fp_ergs=fp_ergs, ergs=ergs), out_path_data / f'{subject}_{stimtype}_speech_{ch}_reref_ergs.pkl')

        # click file
        eegfile = data_path / subject / f'P{subject[-3:]}C{stimtype}_clicks_4c.bdf'
        eeg_a, erg_a, erg_start = pre.load_eeg(eegfile, ch=ch, refs=refs) # preprocessing
        eel.save.pickle(dict(eeg=eeg_a, erg=erg_a, erg_start=erg_start), out_path_data / f'{subject}_{stimtype}_clicks_{ch}_reref_{"".join(refs)}.pkl')

        # also load io clicks
        if stimtype=='in': 
            eegfile = data_path / subject / f'P{subject[-3:]}Cio_clicks_4c.bdf'
            eeg_a, erg_a, erg_start = pre.load_eeg(eegfile)
            eel.save.pickle(dict(eeg=eeg_a, erg=erg_a, erg_start=erg_start), out_path_data / f'{subject}_io_clicks_{ch}_reref_{"".join(refs)}.pkl')

# Click ERP

In [None]:
importlib.reload(md)
importlib.reload(pre)

filttype = 'FIR'
lowc = 30
highc = 500
refC = 'mastoids'

if refC == 'earlobes':
    in_path = in_path_earlobes # PATH TO PREPROCESSED DATA
elif refC == 'mastoids':
    in_path = in_path_mastoids # PATH TO PREPROCESSED DATA
    
out_path_clicks.mkdir(parents=True, exist_ok=True)

verbose = True # debug print outs
twin = (-0.02, 0.04) # time window for erps/trfs
filtsos = scipy.signal.butter(1, [lowc, highc], btype='pass', analog=False, output='sos', fs=16384)

for stimtype in ['io','in','s']:#, 'io', 's']:
    for subject in tqdm(subjects):
        if verbose: print('loading', stimtype, subject)

        # load the appropriate files
        if refC == 'earlobes':
            datadict = eel.load.unpickle(in_path / f'{subject}_{stimtype}_clicks_Cz_reref_EXG3EXG4.pkl')
        elif refC == 'mastoids':
            datadict = eel.load.unpickle(in_path / f'{subject}_{stimtype}_clicks_Cz_reref_EXG1EXG2.pkl')

        eegk = 'eeg'

        if verbose: print('preprocessing')
        # ergp, ergn are positive and negative rectified Erg1 signals
        eeg, erg, ergp, ergn = pre.preprocess_eeg_clicks(datadict, verbose=verbose, eegk=eegk)

        res_dict = {} # store results in a dictionary

        if verbose: print('fitting ERP')
        # filter EEG before fitting
        if filttype == 'FIR':
            eegfilt = eel.filter_data(eeg, lowc, highc)
        elif filttype == 'IIRsos':
            eegfilt = eeg.copy()
            eegfilt.x = scipy.signal.sosfilt(filtsos, eegfilt.x)
        erp, triggers = md.fit_ERP(eegfilt.copy(), erg, twin[0], twin[1], verbose=verbose)
        erp = erp.sub(time=(-0.02, 0.04))
        res_dict['erp'] = erp

        if verbose: print('saving')
        eel.save.pickle(res_dict, out_path_clicks / f'{subject}_{stimtype}_click_erp.pkl')

# load predictors

## load rectified wav

In [None]:
datadict = eel.load.unpickle(in_path / f'part_001_in_speech_Cz_reref_ergs.pkl')
ergs = [e.sub(time=(1, 245)) for e in datadict['ergs']]
ergs = eel.combine([e-e.mean() for e in ergs])

rectps = []
rectns = []
wavs = []
for i in tqdm(range(8)):
    wav = ergs[i].copy()
    wav /= wav.abs().max()
    wav = eel.resample(eel.NDVar(wav.x, eel.UTS(1, wav.time.tstep, len(wav))), 16384)
    rectps.append(wav.clip(min=0))
    rectns.append(-(wav.clip(max=0)))
    wavs.append(wav)

preds_in = {}
preds_in['rectp'] = rectps
preds_in['rectn'] = rectns

## load Zilany

In [None]:
force_make = False
zilany_filename = pathlib.Path(predfolder / f'part_001_in_speech_erg_short_wav_zilany_hsr_posneg_all.pkl')

if not zilany_filename.exists() or force_make:
    zilanyps = []
    zilanyns = []
    for i in tqdm(range(8)):
        aa = eel.load.unpickle(predfolder / f'part_001_in_speech_erg_short_wav_{i}_zilany_hsr_approx_pos0.pkl')
        aa = eel.NDVar(aa.x, eel.UTS(1, aa.time.tstep, len(aa)))
        zilanyps.append(aa)
        aa = eel.load.unpickle(predfolder / f'part_001_in_speech_erg_short_wav_{i}_zilany_hsr_approx_pos1.pkl')
        aa = eel.NDVar(aa.x, eel.UTS(1, aa.time.tstep, len(aa)))
        zilanyns.append(aa)
    eel.save.pickle([zilanyps, zilanyns], zilany_filename)
else:
    zilanyps, zilanyns = eel.load.unpickle(zilany_filename)

preds_in['zilany_hsrp'] = zilanyps
preds_in['zilany_hsrn'] = zilanyns

In [None]:
for k in preds_in:
    preds_in[k] = eel.combine(preds_in[k]).sub(time=(1, 245))
    print(k, preds_in[k].time.tmax)

# Account for delays in predictors

In [None]:
# find predictor lags by cross correlating with rectified speech
import statistics


pred_corr_vals_mean = {}
pred_corr_vals_std = {}
pred_corr_lats_mean = {}
pred_corr_lats_median = {}
pred_corr_lats_mode = {}
pred_corr_lats_std = {}
pred_corr_vals_all = {}
pred_corr_lats_all = {}

ks = ['zilany_hsr']

for k in ks:
    for sign in ['p', 'n']:
        x1 = preds_in[k+sign].copy()
        x2 = preds_in['rect'+sign].copy()
        corrvals1 = []
        corrlats1 = []
        fs1 = int(1/x1.time.tstep)
        fs2 = int(1/x2.time.tstep)
        print(fs1, fs2)
        N = len(x1[0].x)
        correlation_lags = scipy.signal.correlation_lags(N, N)
        for i in tqdm(range(8)):
            corrsig = scipy.signal.correlate(x1[i].x, x2[i].x)
            corrval = np.max(corrsig)
            corrlat = correlation_lags[np.argmax(corrsig)]
            corrvals1.append(corrval)
            corrlats1.append(corrlat*preds_in['rect'+sign][i].time.tstep*1000)
        pred_corr_vals_mean[k+sign] = np.mean(corrvals1)
        pred_corr_vals_std[k+sign] = np.std(corrvals1)
        pred_corr_lats_mean[k+sign] = np.mean(corrlats1)
        pred_corr_lats_mode[k+sign] = max(set(corrlats1), key=corrlats1.count)
        pred_corr_lats_median[k+sign] = statistics.median(corrlats1)
        pred_corr_lats_std[k+sign] = np.std(corrlats1)
        pred_corr_vals_all[k+sign] = corrvals1
        pred_corr_lats_all[k+sign] = corrlats1
        print(f'{k}{sign} corr val = {pred_corr_vals_mean[k+sign]:.4f} +- {pred_corr_vals_std[k+sign]:.4f}')
        print(f'corr lat mean = {pred_corr_lats_mean[k+sign]:.2f} +- {pred_corr_lats_std[k+sign]:.4f}, mode = {pred_corr_lats_mode[k+sign]:.2f}, median = {pred_corr_lats_median[k+sign]:.2f}')
    print(f'{k} corr lat avg = {0.5*(pred_corr_lats_median[k+"p"]+pred_corr_lats_median[k+"n"]):.2f}')


# Speech TRF - Leave one out

In [None]:
importlib.reload(pre)
importlib.reload(md)
importlib.reload(plotting)

fsds = 4096
fit_null_model = True
save_trfdata = False
verbose = True
trfmethod = 'freq'

out_path.mkdir(parents=True, exist_ok=True)

ordering = eel.load.tsv('BalanceReceipe.csv')
stimnames = ['1a', '1b', '2a', '2b','3a', '3b', '4a', '4b',]

shifts1 = [0, 0.0011]
ks = ['rect', 'zilany_hsr']

rNsA = dict(i=[], s=[])

for stimtype in ['s']:
    if stimtype == 'in':
        special_subj = ['part_024', 'part_018']
        special_subj_idxs = [[0, 1, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]]
        shifts = [s-0.001 for s in shifts1]
    else:
        special_subj = []
        special_subj_idxs = []
        shifts = [s-0.0043 for s in shifts1]
    tmin = preds_in[ks[0]+'p'].time.tmin
    tmax = preds_in[ks[0]+'p'].time.tmin + 240
    print(tmin, tmax)
    for subject in tqdm(subjects, f'running {stimtype}_speech'):
        if verbose: print(subject)        
        if verbose: print('loading')
        datadict = eel.load.unpickle(in_path / f'{subject}_{stimtype}_speech_Cz_reref_ergs.pkl')
        preds = {}
        for ik, k in enumerate(ks):
            preds[k] = [preds_in[k+'p'].copy(), preds_in[k+'n'].copy()]

        order_subj = [stimnames.index(ordering[f'sord_ab_{i+1}'][int(subject[-2:])-1]) for i in range(8)]
        order_subj_name = [ordering[f'sord_ab_{i+1}'][int(subject[-2:])-1] for i in range(8)]

        print(order_subj, order_subj_name)  
        ssi1 = range(8)
        for ss, ssi in zip(special_subj, special_subj_idxs):
            if subject == ss:
                print(ss, subject, ssi)
                for k in preds.keys():
                    pred = preds[k]
                    if isinstance(pred, list):
                        preds11 = []
                        for p in pred:
                            preds11.append([p[i] for i in ssi])
                        preds[k] = preds11
                    else:
                        preds[k] = [pred[i] for i in ssi]
                ssi1 = ssi
                break
            else:
                ssi1 = range(8)

        order_ssi = np.argsort([order_subj.index(s) for s in ssi1])
        print(ssi1, order_ssi)
        if stimtype == 's' and subject == 'part_013':
            eegnewT = eel.UTS(datadict['eegs'][4].time.tmin, datadict['eegs'][4].time.tstep, int(250/datadict['eegs'][4].time.tstep))
            eegnew = eel.NDVar(np.zeros(len(eegnewT)), eegnewT)
            eegnew.x[:len(datadict['eegs'][4])] = datadict['eegs'][4].x 
            datadict['eegs'][4] = eegnew
        eegs, preds, rNs = pre.preprocess_eeg_speech_preds(datadict['eegs'], preds, verbose=verbose, tmin=2, tmax=242, filt_method='fir')
        rNsA[stimtype[0]].append(rNs)

        eegs = eel.combine([eegs[i] for i in order_ssi])
        for k in preds.keys():
            preds[k] = [eel.combine([preds[k][0][i] for i in order_ssi]), eel.combine([preds[k][1][i] for i in order_ssi])] 
        print(order_subj_name, order_subj, ssi1, order_ssi)
        if fsds:
            if verbose: print('downsampling')
            eegs = eel.resample(eegs, fsds)
            for k in preds.keys():
                preds[k] = [eel.resample(x, fsds).clip(min=0) for x in preds[k]]
        
        eegs /= eegs.std()
        for k in preds.keys():
            preds[k] = [x/x.std() for x in preds[k]]

        
        if verbose: print('fitting TRFs')

        trfsA = {}
        corrsA = {}
        print(eegs, eegs.time.tmin, eegs.time.tmax)
        permshift = 30

        print(ks)
        for t in tqdm(range(1, len(eegs))):
            for ik, k in enumerate(ks):
                print(k)
                trfcvs = []
                corrcvs = []
                trfcvsperm = []
                corrcvsperm = []
                cvpredp = preds[k][0][:t+1].copy()
                cvpredn = preds[k][1][:t+1].copy()
                cveegs = eegs[:t+1].copy()
                for cv in range(t+1):
                    train_predp = eel.combine([cvpredp[cvi] for cvi in range(t+1) if cvi!=cv])
                    train_predn = eel.combine([cvpredn[cvi] for cvi in range(t+1) if cvi!=cv])
                    train_eeg = eel.combine([cveegs[cvi] for cvi in range(t+1) if cvi!=cv])

                    test_predp = cvpredp[cv].copy()
                    test_predn = cvpredn[cv].copy()
                    test_eeg = eel.filter_data(cveegs[cv], 30, 500).copy()
    
                    trf1, trfp1, trfn1 = md.fit_trf_posneg(train_eeg, train_predp, train_predn, trfstr=f' {k} {t} {cv}')
                    trf1_a = eel.NDVar(trf1.x, eel.UTS(-5+shifts[ik], trf1.time.tstep, len(trf1))).sub(time=(-4, 4))
                    trfp1_a = eel.NDVar(trfp1.x, eel.UTS(-5+shifts[ik], trf1.time.tstep, len(trfp1))).sub(time=(-4, 4))
                    trfn1_a = eel.NDVar(trfn1.x, eel.UTS(-5+shifts[ik], trf1.time.tstep, len(trfn1))).sub(time=(-4, 4))

                    trfsA[f'trf {k} {t} {cv}'] = trf1_a.copy()
                    trfsA[f'trf {k} {t} {cv} pos'] = trfp1_a.copy()
                    trfsA[f'trf {k} {t} {cv} neg'] = trfn1_a.copy()
                    trfcvs.append(trf1_a)

                    trf1 = eel.filter_data(trf1, 30, 500).sub(time=(-0.01-shifts[ik], 0.03-shifts[ik]))
                    ypredap = eel.filter_data(eel.convolve(trf1, test_predp), 30, 500)
                    ypredan = eel.filter_data(eel.convolve(trf1, test_predn), 30, 500) 
                    corrsA[f'corr {k} {t} {cv}'] = np.corrcoef(test_eeg.x, ypredap.x + ypredan.x)[0,1]
                    corrcvs.append(corrsA[f'corr {k} {t} {cv}'])

                    if fit_null_model:
                        trfperms = []
                        corrperms = []
                        for ip in range(1,4):
                            train_predp_perm = train_predp.copy()
                            train_predn_perm = train_predn.copy()
                            fs1 = 1/train_predp.time.tstep
                            if len(train_predp.x.shape) == 1:
                                train_predp_perm.x[int(ip*permshift*fs1):] = train_predp.x[:-int(ip*permshift*fs1)]
                                train_predp_perm.x[:int(ip*permshift*fs1)] = train_predp.x[-int(ip*permshift*fs1):]
                                train_predn_perm.x[int(ip*permshift*fs1):] = train_predn.x[:-int(ip*permshift*fs1)]
                                train_predn_perm.x[:int(ip*permshift*fs1)] = train_predn.x[-int(ip*permshift*fs1):]
                            else:
                                train_predp_perm.x[:,int(ip*permshift*fs1):] = train_predp.x[:,:-int(ip*permshift*fs1)]
                                train_predp_perm.x[:,:int(ip*permshift*fs1)] = train_predp.x[:,-int(ip*permshift*fs1):]
                                train_predn_perm.x[:,int(ip*permshift*fs1):] = train_predn.x[:,:-int(ip*permshift*fs1)]
                                train_predn_perm.x[:,:int(ip*permshift*fs1)] = train_predn.x[:,-int(ip*permshift*fs1):]

                            trf1, trfp1, trfn1 = md.fit_trf_posneg(train_eeg, train_predp_perm, train_predn_perm, trfstr=f' {k} {t} {cv} perm {ip}')
                            trf1_a = eel.NDVar(trf1.x, eel.UTS(-5+shifts[ik], trf1.time.tstep, len(trf1))).sub(time=(-4, 4))
                            trfp1_a = eel.NDVar(trfp1.x, eel.UTS(-5+shifts[ik], trf1.time.tstep, len(trfp1))).sub(time=(-4, 4))
                            trfn1_a = eel.NDVar(trfn1.x, eel.UTS(-5+shifts[ik], trf1.time.tstep, len(trfn1))).sub(time=(-4, 4))
                            trfperms.append(trf1_a)

                            trf1 = eel.filter_data(trf1, 30, 500).sub(time=(-0.01-shifts[ik], 0.03-shifts[ik]))
                            ypredap = eel.filter_data(eel.convolve(trf1, test_predp), 30, 500)
                            ypredan = eel.filter_data(eel.convolve(trf1, test_predn), 30, 500)
                            corrperms.append(np.corrcoef(test_eeg.x, ypredap.x + ypredan.x)[0,1])

                        trfsA[f'trf {k} {t} {cv} null'] = eel.combine(trfperms).mean('case')
                        corrsA[f'trf {k} {t} {cv} null'] = np.mean(corrperms)
                        trfcvsperm.append(trfsA[f'trf {k} {t} {cv} null'])
                        corrcvsperm.append(corrsA[f'trf {k} {t} {cv} null'])

                trfsA[f'trf {k} {t}'] = eel.combine(trfcvs).mean('case')
                corrsA[f'corr {k} {t}'] = np.mean(corrcvs)

                if fit_null_model:
                    trfsA[f'trf {k} {t} null'] = eel.combine(trfcvsperm).mean('case')
                    corrsA[f'corr {k} {t} null'] = np.mean(corrcvsperm)

            for k in ks:
                if fit_null_model:
                    printstr = f"{k} {t} AVG, corr = {corrsA[f'corr {k} {t}']}, null = {corrsA[f'corr {k} {t} null']}, corr-null = {corrsA[f'corr {k} {t}'] - corrsA[f'corr {k} {t} null']}"
                else:
                    printstr = f"{k} {t}5 AVG, corr = {corrsA[f'corr {k} {t}']}"
                print(printstr)

        settings = dict(order_ssi=order_ssi, order_subj_name=order_subj_name, order_subj=order_subj, shifts=shifts, tmin=tmin, tmax=tmax, fsds=fsds)
        eel.save.pickle(dict(trfsA=trfsA, corrsA=corrsA, settings=settings), out_path / f'{subject}_{stimtype}_res.pkl')