# imports

In [None]:
%matplotlib inline
import eelbrain as eel
import numpy as np
import scipy, pathlib, importlib, mne, time, os, sys, statsmodels, statsmodels.stats.multitest
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
mne.set_log_level(verbose='error')
import matplotlib
import plotting
from pathnames import *

colors = [(0, 0.2, 0.7), (0, 0.5, 0.8), (0.4, 0.8, 1), (0.4, 0.9, 0.7)]

# get rN

In [None]:
badsubjs = ['TP0015', 'TP0020', 'TP0022']
rNs = []
for i in range(24):
    subject = f'TP{i+1:04d}'
    if subject in badsubjs:
        continue
    rNs.append(np.mean(eel.load.unpickle(speech_path / f'test_A_predgtp_pndtrue0.3_nbins4_Ntrials-1\\TP{subject[2:]}_rNs.pkl')))
print(np.mean(rNs), np.std(rNs), np.max(rNs), np.min(rNs))

# Visual abstract

In [None]:
importlib.reload(plotting)

twin=0.001
pkwin = (0.004, 0.01)
badsubjs = ['TP0015', 'TP0020', 'TP0022']
predks = ['clicks','rect', 'gt', 'oss', 'ossa', 'zil']
shifts = [0, 0, 0.001, -0.004, -0.005, 0.001]
shifts = [s-0.0009 for s in shifts]
resA = dict(trfs={}, pklats={}, pkamps={}, pk2pkamps={}, pklatnegs={}, pkampnegs={})
for k, shift in zip(predks, shifts):
    print(k, shift)
    if k=='clicks':
        clicks, click_pklats, click_pkamps, click_pk2pkamps, click_pklatnegs, click_pkampnegs, subjects1 = plotting.get_clicks(twin=twin, badsubjs=badsubjs, pkwin=pkwin, shift=shift)
        resA['trfs'][k] = clicks
        resA['pklats'][k] = np.asarray(click_pklats)
        resA['pkamps'][k] = np.asarray(click_pkamps)
        resA['pk2pkamps'][k] = np.asarray(click_pk2pkamps)
        resA['pklatnegs'][k] = np.asarray(click_pklatnegs)
    else:
        respath = pathlib.Path(speech_path / f'test_A_pred{k}p_pndtrue0.3_nbins4_Ntrials-1')
        res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
        snrs1, spows1, npows1 = plotting.get_waveVsnr(res)
        resA['trfs'][k] = res['trfs'].copy()
        resA['pklats'][k] = np.asarray(res['pklats']).copy()
        resA['pkamps'][k] = np.asarray(res['pkamps']).copy()
        resA['pk2pkamps'][k] = np.asarray(res['pk2pkamps']).copy()
        resA['pklatnegs'][k] = np.asarray(res['pklatnegs']).copy()

In [None]:
gn = {}
for k in predks:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]
    resA['pk2pkamps'][k] = [p*gn[k] for p in resA['pk2pkamps'][k]]

In [None]:
gtx = np.mean(np.asarray([resA['trfs']['gt'][i].sub(time=(0, 0.015)).x for i in range(21)]), axis=0)
# gtx = resA['trfs']['gt'][7].sub(time=(0, 0.015)).x
gtx = gtx[::-1,:]
tdim = resA['trfs']['gt'][0].sub(time=(0, 0.015)).time.times*1000
colors = [(0, 0.2, 0.7), (0, 0.5, 0.8), (0.4, 0.8, 1), (0.4, 0.9, 0.7)]
fig = plt.figure(figsize=(20,10))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
plt.subplot(1,2,2)
for i in range(4):
    plt.plot(tdim, gtx[i]+5e-7*i**1.4, color=colors[::-1][i], lw=5, zorder=1)
for spine in plt.gca().spines:
    plt.gca().spines[spine].set_visible(False)
plt.xticks([])
plt.yticks([])
plt.title('Estimated subcortical\nEEG response', fontsize=40)

wav = eel.load.wav(wav_path / 'Male1_001_Simon_part002.wav').sub(channel=1)
speechx = wav.sub(time=(8.45, 8.9)).x.astype(float)
speechx /= np.max(speechx)
speechtdim = wav.sub(time=(8.45, 8.9)).time.times

colors = [(0, 0.2, 0.7), (0, 0.5, 0.8), (0.4, 0.8, 1), (0.4, 0.9, 0.7)]
plt.subplot(1,2,1)
for i in range(4):
    plt.plot(speechtdim, ((i+1)**1.5)*10*speechx+i**1.4*50, color=colors[::-1][i])
for spine in plt.gca().spines:
    plt.gca().spines[spine].set_visible(False)
plt.xticks([])
plt.yticks([])
ax1.set_ylabel('Intensity level', fontsize=40, labelpad=30)
plt.title('Continuous speech stimuli', fontsize=40)

from matplotlib import patches
xyA = [8.95, 150]
xyB = [-2, 1.31e-6]
arrow = patches.ConnectionPatch(
    xyA,
    xyB,
    coordsA=ax1.transData,
    coordsB=ax2.transData,
    color="black",
    arrowstyle="-|>",  # "normal" arrow
    mutation_scale=30,  # controls arrow head size
    linewidth=3,
)
fig.patches.append(arrow)

xyA = [5.8, 2.95e-6]
xyB = [8, -0.2e-6]
arrow = patches.ConnectionPatch(
    xyA,
    xyB,
    coordsA=ax2.transData,
    coordsB=ax2.transData,
    color="red",
    arrowstyle="-|>",  # "normal" arrow
    mutation_scale=30,  # controls arrow head size
    linewidth=3,
    linestyle=(5, (5, 5))
)
fig.patches.append(arrow)


xyA = [8.41, 20]
xyB = [8.41, 280]
arrow = patches.ConnectionPatch(
    xyA,
    xyB,
    coordsA=ax1.transData,
    coordsB=ax1.transData,
    color="k",
    arrowstyle="-|>",  # "normal" arrow
    mutation_scale=30,  # controls arrow head size
    linewidth=3,
)
fig.patches.append(arrow)


rect = patches.Rectangle((8.5, -15), 0.1, -2, linewidth=1, edgecolor='none', facecolor='k')
ax1.add_patch(rect)
plt.figtext(0.175, 0.09, '100 ms', fontsize=20, color='k')

rect = patches.Rectangle((0, -0.2e-6), 5, -0.015e-6, linewidth=1, edgecolor='none', facecolor='k')
ax2.add_patch(rect)
plt.figtext(0.63, 0.11, '5 ms', fontsize=20, color='k')


plt.figtext(0.74, 0.08, 'Level dependency', fontsize=35, color='red')
plt.subplots_adjust(wspace=0.5)

post_analysis_path.mkdir(exist_ok=True)

plt.savefig(post_analysis_path / 'Visual_abstract.tiff', dpi=300, bbox_inches='tight')

# Fig 1

In [None]:
importlib.reload(plotting)

twin=0.004
pkwin = (0.004, 0.01)
badsubjs = ['TP0015', 'TP0020', 'TP0022']
predks = ['clicks','rect', 'gt', 'oss', 'ossa', 'zil']
shifts = [0, 0, 0.001, -0.004, -0.005, 0.001]
shifts = [s-0.0009 for s in shifts]
resA = dict(trfs={}, pklats={}, pkamps={})
for k, shift in zip(predks, shifts):
    print(k, shift)
    if k=='clicks':
        clicks, click_pklats, click_pkamps, click_pk2pkamps, click_pklatnegs, click_pkampnegs, subjects1 = plotting.get_clicks(twin=twin, badsubjs=badsubjs, pkwin=pkwin, shift=shift)
        resA['trfs'][k] = clicks
        resA['pklats'][k] = np.asarray(click_pklats)
        resA['pkamps'][k] = np.asarray(click_pkamps)
    else:
        respath = pathlib.Path(speech_path / f'test_A_pred{k}p_pndtrue0.3_nbins4_Ntrials-1')
        res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
        snrs1, spows1, npows1 = plotting.get_waveVsnr(res)
        resA['trfs'][k] = res['trfs'].copy()
        resA['pklats'][k] = np.asarray(res['pklats']).copy()
        resA['pkamps'][k] = np.asarray(res['pkamps']).copy()

In [None]:
gn = {}
for k in predks:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]

In [None]:
importlib.reload(plotting)
tstrs = ['Click ERP', 'Speech RS TRF', 'Speech GT TRF', 'Speech OSS TRF', 'Speech OSSA TRF', 'Speech ZIL TRF']
ylim = [-5e-7, 9e-7]
post_analysis_path.mkdir(exist_ok=True, parents=True)
plotting.plot_fig1(resA, predks, tstrs, ylim=ylim, savefolder=post_analysis_path)

# Fig 2

In [None]:
importlib.reload(plotting)
tstrs = ['Click ERP', 'Speech RS TRF', 'Speech GT TRF', 'Speech OSS TRF', 'Speech OSSA TRF', 'Speech ZIL TRF']
plotting.plot_fig2(resA, predks, tstrs, savefolder=post_analysis_path, pkstr='pkamps', ylim=0.0015)

# Fig 3

In [None]:
importlib.reload(plotting)
fig3res = plotting.plot_fig3(resA, savefolder=post_analysis_path, k='gt', ampstr='pkamps', savestr='GT')

# Fig 4

In [None]:
twin=0.004
pkwin = (0.004, 0.01)
badsubjs = ['TP0015', 'TP0020', 'TP0022']
predks = ['gt']
shifts = [0.001]
shifts = [s-0.0009 for s in shifts]
for k, shift in zip(predks, shifts):
    respath = speech_path / 'test_switch1_predgtp_pndtrue0.3_nbins4_Ntrials-1'
    res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
    resA['trfs'][k+'switch1'] = res['trfs'].copy()
    resA['pklats'][k+'switch1'] = np.asarray(res['pklats']).copy()
    resA['pkamps'][k+'switch1'] = np.asarray(res['pkamps']).copy()

    respath = speech_path / 'test_fixed1_predgtp_pndtrue0.3_nbins4_Ntrials-1'
    res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
    resA['trfs'][k+'fixed1'] = res['trfs'].copy()
    resA['pklats'][k+'fixed1'] = np.asarray(res['pklats']).copy()
    resA['pkamps'][k+'fixed1'] = np.asarray(res['pkamps']).copy()


In [None]:
for k in ['gtswitch1', 'gtfixed1']:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]
print(gn)

In [None]:
importlib.reload(plotting)
tstrs = ['Long duration (1 min)', 'Short duration (5 s)']
ylim = [-5e-7, 9e-7]
predks = ['gtfixed1', 'gtswitch1']
plotting.plot_fig4(resA, predks, tstrs, ylim=ylim, savefolder=post_analysis_path, ampstr='pkamps', ampylim=0.0015)

# Fig 5

In [None]:
importlib.reload(plotting)
badsubjs = ['TP0015', 'TP0020', 'TP0022']
predk = 'gtp'
twin = 0.004
snrsA = []
spowsA = []
npowsA = []
trfsA = []
pklatsA = []
pkampsA = []
for Ntrials in range(4,41,4):
    speechfolder1 = speech_path / f'test_switch_pred{predk}_pndtrue0.3_nbins4_Ntrials{Ntrials}'
    pkwin = (0.004, 0.01)
    res = plotting.get_trfs(speechfolder1, predk, badsubjs, 0.001-0.0009, twin=twin, pkwin=pkwin)
    snrs, spows, npows = plotting.get_waveVsnr(res)
    snrsA.append(snrs)
    spowsA.append(spows)
    npowsA.append(npows)
    trfsA.append(res['trfs'])
    trfs, pkamps, pklats, pk2pkamps = res['trfs'], res['pkamps'], res['pklats'], res['pk2pkamps']
    pklatsA.append(pklats)
    pkampsA.append(pkamps)
snrsA = np.asarray(snrsA)
npowsA = np.asarray(npowsA)
spowsA = np.asarray(spowsA)

In [None]:
importlib.reload(plotting)
plotting.plot_fig5(pkampsA, pklatsA, snrsA, savefolder=post_analysis_path, savestr='pkamps')

# Fig 6

In [None]:
importlib.reload(plotting)
badsubjs = ['TP0015', 'TP0020', 'TP0022']
k = 'gt'
pkwin = (0.004, 0.01)
shift =  0.001-0.0009
trfkey = 'ldtrfs'
binskey = 'bins'
respath = speech_path / f'test_A_pred{k}p_pndpred0.3_nbins8_Ntrials-1'
res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=0.004, pkwin_adjust=np.zeros(8), trfkey=trfkey, binskey=binskey)
resA['trfs'][k+'8lev'] = res['trfs'].copy()
resA['pklats'][k+'8lev'] = np.asarray(res['pklats']).copy()
resA['pkamps'][k+'8lev'] = np.asarray(res['pkamps']).copy()
resA['pk2pkamps'][k+'8lev'] = np.asarray(res['pk2pkamps']).copy()

In [None]:
bindB = [a for a in np.mean(np.asarray(res['bins']), axis=0)]
bindB = [20*np.log10(b) for b in bindB]
bindB = [int(b-bindB[-1]+72) for b in bindB]
bindB

In [None]:
for k in ['gt8lev']:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]
    resA['pk2pkamps'][k] = [p*gn[k] for p in resA['pk2pkamps'][k]]
print(gn)

In [None]:
importlib.reload(plotting)
post_analysis_path.mkdir(exist_ok=True)
plotting.plot_fig6(resA, bindB, k='gt8lev', tstr='Speech GT TRF', savefolder=post_analysis_path, levels=8, ampstr='pkamps', ampylim=0.0015)

# supplementary

## A

In [None]:
importlib.reload(plotting)
tstrs = ['Click ERP', 'Speech RS TRF', 'Speech GT TRF', 'Speech OSS TRF', 'Speech OSSA TRF', 'Speech ZIL TRF']

twin=0.004
pkwin = (0.004, 0.01)
badsubjs = ['TP0015', 'TP0022']
predks = ['clicks','rect', 'gt', 'oss', 'ossa', 'zil']
shifts = [0, 0, 0.001, -0.004, -0.005, 0.001]
shifts = [s-0.0009 for s in shifts]
resA = dict(trfs={}, pklats={}, pkamps={}, pk2pkamps={}, pklatnegs={}, pkampnegs={})
for k, shift in zip(predks, shifts):
    print(k, shift)
    if k=='clicks':
        clicks, click_pklats, click_pkamps, click_pk2pkamps, click_pklatnegs, click_pkampnegs, subjects1 = plotting.get_clicks(twin=twin, badsubjs=badsubjs, pkwin=pkwin, shift=shift)
        resA['trfs'][k] = clicks
        resA['pklats'][k] = np.asarray(click_pklats)
        resA['pkamps'][k] = np.asarray(click_pkamps)
        resA['pk2pkamps'][k] = np.asarray(click_pk2pkamps)
        resA['pklatnegs'][k] = np.asarray(click_pklatnegs)
        resA['pkampnegs'][k] = np.asarray(click_pkampnegs)

    else:
        respath = speech_path / f'test_A_pred{k}p_pndtrue0.3_nbins4_Ntrials-1'
        res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
        resA['trfs'][k] = res['trfs'].copy()
        resA['pklats'][k] = np.asarray(res['pklats']).copy()
        resA['pkamps'][k] = np.asarray(res['pkamps']).copy()
        resA['pklatnegs'][k] = np.asarray(res['pklatnegs']).copy()
        resA['pk2pkamps'][k] = np.asarray(res['pk2pkamps']).copy()
        resA['pkampnegs'][k] = np.asarray(res['pkampnegs']).copy()
gn = {}
for k in predks:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]
    resA['pk2pkamps'][k] = [p*gn[k] for p in resA['pk2pkamps'][k]]
    resA['pkampnegs'][k] = [p*gn[k] for p in resA['pkampnegs'][k]]

print(gn)

importlib.reload(plotting)
ylim = [-0.8e-6, 1.55e-6]
plotting.plot_indiv_trfs(resA, predks, tstrs, ylim=ylim, savefolder=post_analysis_path, badsubj_idx=18)

## switch, fixed

In [None]:
twin=0.004
pkwin = (0.004, 0.011)
badsubjs = ['TP0015', 'TP0022']
predks = ['gt']
shifts = [0.001]
shifts = [s-0.0009 for s in shifts]

for k, shift in zip(predks, shifts):
    respath = speech_path / f'test_switch1_predgtp_pndtrue0.3_nbins4_Ntrials-1'
    res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
    resA['trfs'][k+'switch1'] = res['trfs'].copy()
    resA['pklats'][k+'switch1'] = np.asarray(res['pklats']).copy()
    resA['pkamps'][k+'switch1'] = np.asarray(res['pkamps']).copy()

    respath = speech_path / f'test_fixed1_predgtp_pndtrue0.3_nbins4_Ntrials-1'
    res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=twin, pkwin=pkwin)
    resA['trfs'][k+'fixed1'] = res['trfs'].copy()
    resA['pklats'][k+'fixed1'] = np.asarray(res['pklats']).copy()
    resA['pkamps'][k+'fixed1'] = np.asarray(res['pkamps']).copy()


for k in ['gtswitch1', 'gtfixed1']:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]
print(gn)

importlib.reload(plotting)
ylim = [-0.8e-6, 1.2e-6]
predks = ['gtfixed1', 'gtswitch1']
plotting.plot_indiv_trfs(resA, predks, tstrs, ylim=ylim, savefolder=post_analysis_path, badsubj_idx=18)

## inherent

In [None]:
importlib.reload(plotting)
badsubjs = ['TP0015', 'TP0022']

k = 'gt'
pkwin = (0.004, 0.011)
shift =  0.001-0.0009
respath = speech_path / 'test_A_predgtp_pndpred0.3_nbins8_Ntrials-1'
res = plotting.get_trfs(respath, k+'p', badsubjs, shift, twin=0.004, pkwin_adjust=np.zeros(8))
resA['trfs'][k+'8lev'] = res['trfs'].copy()
resA['pklats'][k+'8lev'] = np.asarray(res['pklats']).copy()
resA['pkamps'][k+'8lev'] = np.asarray(res['pkamps']).copy()

for k in ['gt8lev']:
    if k=='clicks':
        continue
    gn[k] = np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs']['clicks']])/ np.mean([x.sub(time=(0,0.02)).std().mean() for x in resA['trfs'][k]])
    resA['trfs'][k] = [x*gn[k] for x in resA['trfs'][k]]
    resA['pkamps'][k] = [p*gn[k] for p in resA['pkamps'][k]]
print(gn)

ylim = [-0.8e-6, 1.2e-6]
predks = ['gt8lev']
cmap = matplotlib.cm.get_cmap('winter')
colors8lev = [list(cmap(i/8))[:3]+[0.75] for i in range(8)]
legendstr = [f'{int(b)} dBA' for b in bindB[::-1]]
tstrs = None
plotting.plot_indiv_trfs(resA, predks, tstrs, ylim=ylim, levels=8, colors1=colors8lev, savefolder=post_analysis_path, legendstr=legendstr, badsubj_idx=18)