In [None]:
import torch

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as pl
import scipy.signal as sg

pl.style.use('fivethirtyeight')

#mpl.rc('figure', figsize=(10,8))
mpl.rcParams['axes.facecolor']='white'  
mpl.rcParams['figure.facecolor'] = '1'

from scipy.ndimage  import gaussian_filter1d

import json
import re
import os

#import copy

In [None]:
from masking import *
from latencies import *
from excitation import *
from deconv import *
from ur import *
from tuning import *
from test import *
from ur import *

from data import CAPData


In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [None]:
cumsum_default=False

### Import data

In [None]:
data_folder='./Data/SP-2021_01_12-Q394_fmaskedCAP_data'

fs=48828
I0 = 100 +10 - 10 #intensity ref for masker    #100 dB rms=1  +10 dB amp 5   (/sqrt(2))   #-10 dB masker atten

I0 = 10*np.log10( 10**(I0/10)/(fs/2) )

print(f'reference masker power spectral density (0 attn): {I0:.2f} dB')

listFiles = os.listdir(data_folder)

In [None]:
capData=CAPData(data_folder, listFiles, begin_ind=395, end_ind=2000, mode='C+R')

In [None]:
t=capData.t

broadband1=capData.get_signal_by_name('broadband_noise')
broadband2=capData.get_signal_by_name('2_broadband_noise')
broadband3=capData.get_signal_by_name('3_broadband_noise')

broadband_avg=(broadband1+broadband2+broadband3)/3



nomasker1=capData.get_signal_by_name('nomasker')
nomasker2=capData.get_signal_by_name('2_nomasker')
nomasker3=capData.get_signal_by_name('3_nomasker')

nomasker_avg=(nomasker1+nomasker2+nomasker3)/3

#pl.figure(figsize=(16,4))
pl.plot(t, broadband1)
pl.plot(t, broadband2)
pl.plot(t, broadband3)
pl.plot(t, broadband_avg, label='avg')


pl.plot(t, nomasker1)

pl.plot(t, nomasker2)

pl.plot(t, nomasker3)
pl.plot(t, nomasker_avg, label='avg')

#pl.xlim([0.004, 0.007])
pl.legend()



In [None]:
#window

t0=6e-3
t1=9.5e-3
ind0=int(t0*48828)

ind0=int(t0*48828)
ind1=int(t1*48828)

win0=sg.tukey(ind1-ind0, alpha=0.4)

win=np.zeros_like(broadband_avg)
win[ind0:ind1]=win0

pl.figure()
pl.plot(t*1e3, broadband_avg, label='avg')
pl.plot(t*1e3, win*np.amax(broadband_avg))

pl.plot(t*1e3, 0.5*broadband_avg*win, label='avg')
pl.show()


In [None]:
def process_signal(sig, cumsum=cumsum_default, return_t=False):
    sig2=sig*win
    
    t0=3e-3
    t1=13e-3

    ind0=int(t0*48828)
    ind1=int(t1*48828)
    
    dim = len(np.shape(sig2))
    if dim ==1:
        sig2=sig2[ind0:ind1]
        if cumsum:
            sig2=np.cumsum(sig2)
            sig2[0:-50]*=sg.tukey(len(sig2)-50, 0.3)
            sig2[-50:]=0
    else:
        sig2=sig2[:, ind0:ind1]
        if cumsum:
            sig2=np.cumsum(sig2)
            
            
    
    if return_t:
        t=np.linspace(t0, t1, ind1-ind0)
        return t, sig2
    else:
        return sig2
    
    

    

In [None]:
t2, broadband_proc=process_signal(broadband_avg, cumsum=cumsum_default, return_t=True)
nomasker_proc=process_signal(nomasker_avg)

pl.plot(t2*1e3, nomasker_proc-broadband_proc)

In [None]:
def process_signal2(sig, cumsum=cumsum_default, gauss_sigma=0):
    '''subtracts the broadband noise response
    gauss_sigma: if diff of 0, smooths the signal with gaussian filter'''
    #NB:broadband_proc considered to had same parameter for cumsum
    
    res = process_signal(sig, cumsum=cumsum)-broadband_proc
    if gauss_sigma !=0:
        res = gaussian_filter1d(res, gauss_sigma)
    return res
    

**Estimation ur**  
Depends on what is the focus


In [None]:
sig=capData.get_signal_by_name('2_notch3000_bw1700_55dB')
#sig=capData.get_signal_by_name('4_notch6000_bw2000_45dB')
sig2=process_signal(sig)

#sigbis=capData.get_signal_by_name('3_notch3000_bw1700_50dB')
#_, sig2bis=process_signal(sigbis)



'''pl.plot(t*1e3, sig)

pl.plot(t*1e3, broadband_avg)
pl.plot(t*1e3, sig-broadband_avg)
'''
pl.plot(t2*1e3, sig2-broadband_proc)

#pl.plot(t2*1e3, sig2bis-broadband_proc)


ur0=sig2-broadband_proc
ur0=np.roll(ur0, -100)
pl.plot(t2*1e3, ur0)
#pl.xlim([4,6])

In [None]:
def deconv(released_sig, ur0=ur0, eps=1e-2):
    
    released_sig_fft=np.fft.rfft(released_sig)
    ur0_fft=np.fft.rfft(ur0)
    E_fft=released_sig_fft/(ur0_fft+eps)
    E=np.fft.irfft(E_fft)
    return E
masked_sig=nomasker_proc-broadband_proc
E0=deconv(masked_sig)

Estimated raw excitation pattern

In [None]:
pl.plot(t2*1e3, E0)

Trying with projection

In [None]:
def proj_E(E, t0=4e-3, t1=8e-3):
    '''
    constraints u between t0 and t1'''
    proj=t2>t0
    proj*=t2<t1
    return E*proj

def deconv_newton(E0, released_sig, ur0=ur0, alpha=0.02, nb_steps=20, eps_ridge=1e-1, verbose=False, t0=4e-3, t1=7e-3):
    E=proj_E(E0, t0=t0, t1=t1)

    released_sig_fft=np.fft.rfft(released_sig)
    ur0_fft=np.fft.rfft(ur0)

    E=np.expand_dims(E, axis=0)

    for i in range(nb_steps):
        E-=alpha*deconv_newton_step(E, ur0_fft, released_sig_fft, eps_ridge=eps_ridge)

        E=proj_E(E, t0=t0, t1=t1)
        E[E<0]=0
        if verbose and i%5==0:
            pl.plot(t2*1e3, E[0], label=f'step {i}')
            pl.xlabel('t (ms)')
    if verbose:
        pl.legend()
    return E[0]


E=deconv_newton(E0, masked_sig, verbose=True)


    



Narrow-band analysis

In [None]:

s1=capData.get_signal_by_name('1_hp_10000Hz')
s2=capData.get_signal_by_name('2_hp_9000Hz')
s3=capData.get_signal_by_name('3_hp_8000Hz')
s4=capData.get_signal_by_name('4_hp_7000Hz')
s5=capData.get_signal_by_name('5_hp_6000Hz')
s6=capData.get_signal_by_name('6_hp_5000Hz')
s7=capData.get_signal_by_name('7_hp_4000Hz')
s8=capData.get_signal_by_name('8_hp_3200Hz')
s9=capData.get_signal_by_name('9_hp_2400Hz')
s10=capData.get_signal_by_name('10_hp_1800Hz')
s11=capData.get_signal_by_name('11_hp_1500Hz')
s12=capData.get_signal_by_name('12_hp_1200Hz')


s1_proc=process_signal2(s1)
s2_proc=process_signal2(s2)
s3_proc=process_signal2(s3)
s4_proc=process_signal2(s4)
s5_proc=process_signal2(s5)
s6_proc=process_signal2(s6)
s7_proc=process_signal2(s7)
s8_proc=process_signal2(s8)
s9_proc=process_signal2(s9)
s10_proc=process_signal2(s10)
s11_proc=process_signal2(s11)
s12_proc=process_signal2(s12)

In [None]:
pl.figure()
pl.plot(s1_proc)
pl.plot(s2_proc)
pl.plot(s3_proc)
pl.plot(s4_proc)
pl.plot(s5_proc)
pl.plot(s6_proc)
pl.plot(s7_proc)
pl.plot(s8_proc)
pl.plot(s9_proc)
pl.plot(s10_proc)
pl.plot(s11_proc)
pl.plot(s12_proc)
pl.show()

pl.figure(figsize=(10,8))
pl.plot(t2*1e3,s1_proc-s2_proc, label='9-10kHz')
pl.plot(t2*1e3,s2_proc-s3_proc, label='8-9kHz')
pl.plot(t2*1e3,s3_proc-s4_proc, label='7-8kHz')
pl.plot(t2*1e3,s4_proc-s5_proc, label='6-7kHz')
pl.plot(t2*1e3,s5_proc-s6_proc, label='5-6kHz')
pl.plot(t2*1e3,s6_proc-s7_proc, label='4-5kHz')
pl.xlim([3,12])
pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
pl.show()



pl.figure(figsize=(10,8))
pl.plot(t2*1e3,s6_proc-s7_proc, label='4-5kHz')
pl.plot(t2*1e3,s7_proc-s8_proc, label='3.2-4kHz')
pl.plot(t2*1e3,s8_proc-s9_proc, label='2.4-3.2kHz')
pl.plot(t2*1e3,s9_proc-s10_proc, label='1.8-2.4kHz')
pl.plot(t2*1e3,s10_proc-s11_proc, label='1.5-1.8kHz')
pl.plot(t2*1e3,s11_proc-s12_proc, label='1.2-1.5kHz')
pl.plot(t2*1e3,s12_proc, label='-1.2kHz')
pl.xlim([3,12])
pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
pl.show()

In [None]:
pl.figure()
i=0




pl.figure(figsize=(10,8))
for sig, label in [(s1_proc-s2_proc, '9-10kHz'),
(s2_proc-s3_proc, '8-9kHz'),
(s3_proc-s4_proc, '7-8kHz'),
(s4_proc-s5_proc, '6-7kHz'),
(s5_proc-s6_proc, '5-6kHz'),
(s6_proc-s7_proc, '4-5kHz'),(s7_proc-s8_proc, '3.2-4kHz'),
(s8_proc-s9_proc, '2.4-3.2kHz'),
(s9_proc-s10_proc, '1.8-2.4kHz'),
(s10_proc-s11_proc, '1.5-1.8kHz'),
(s11_proc-s12_proc, '1.2-1.5kHz'),
(s12_proc, '-1.2kHz')]:
    E=deconv(sig, eps=1e-2)
    E=deconv_newton(E, sig, alpha=0.005, nb_steps=50, eps_ridge=2e-1, t0=4e-3, t1=7e-3)
    pl.plot(t2*1e3, E-0.25*i, label=label)
    i+=1

pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
pl.xlim([4, 8])
pl.show()

In [None]:
#NB: some data not clear
t_max=np.array([43,40,44,48,52,55,41,45,54,61,73,94])
t_max_C=np.array([52,60,57,42,49,50,51,55,64,75,84,100])   #mod 0

t_max_R=np.array([48,53,46,49,48,51, 54,59,68,81,95,115]) 
#t_0lat=4e-3-5.8e-3+2e-3
t_0lat=4e-3+2.2e-3-5.8e-3
t_max=t_0lat+t_max*2*1e-5
t_max_C=t_0lat+t_max_C*2*1e-5

t_max_R=t_0lat+t_max_R*2*1e-5
freqs=np.array([9.5,8.5,7.5,6.5,5.5,4.5,3.6,2.8,2.1,1.65, 1.35, 1])
pl.plot(freqs, t_max*1e3, '+', markersize=12, label='C+R')
pl.plot(freqs, t_max_C*1e3, '+', markersize=12, label='C')

pl.plot(freqs, t_max_R*1e3, '+', markersize=12, label='R')
pl.ylabel('Estimated latencies (ms)')
pl.xlabel('freq (kHz)')

pl.legend()
#NB CM begins at 5.8 ms
#peak convol begins at 5.2-3 ms = 2.2ms

fit latencies 2 power laws

In [None]:
#above 4kHz


freqs_pts=np.array([8.5,7.5,6.5,5.5,4.5])*1e3
t_max_pts=t_max[1:len(freqs_pts)+1]

lat=PowerLawLatencies(1e6, alpha=1, t0=4e-3, mode='left')
lat.fit_data(t_max_pts, freqs_pts, init_with_new_values=False, bounds=[0.5, 2])

freqs_lin=np.linspace(4, 10)*1e3
pl.plot(freqs_lin, lat(freqs_lin)*1e3)

pl.plot(freqs_pts, t_max_pts*1e3, '+', markeredgewidth=3, markersize=10)

lat_above4k=lat

#below 4kHz
freqs_pts=freqs[6:]*1e3
t_max_pts=t_max[6:]

lat=PowerLawLatencies()
lat.fit_data(t_max_pts, freqs_pts)

freqs_lin=np.linspace(0.5, 5)*1e3
pl.plot(freqs_lin, lat(freqs_lin)*1e3)

pl.plot(freqs_pts, t_max_pts*1e3, '+', markeredgewidth=3, markersize=10)




In [None]:
lat

### 2.2kHz

first estimation I/O curve

In [None]:

cap=[]
masker_list=['1_notch2200_bw1500_60dB',
'2_notch2200_bw1500_55dB',
'3_notch2200_bw1500_50dB',
'4_notch2200_bw1500_45dB',
'5_notch2200_bw1500_40dB',
'6_notch2200_bw1500_35dB',
'7_notch2200_bw1500_32dB',
'8_notch2200_bw1500_29dB',
'9_notch2200_bw1500_26dB',
'10_notch2200_bw1500_23dB']   #, 'broadband_noise'
for masker in masker_list:
    sig=capData.get_signal_by_name(masker)
    sig=process_signal(sig)
    broadband_sig_trunc=process_signal(broadband2)
    
    #REF broadband
    cap.append(np.max(sig-broadband_sig_trunc)-np.min(sig-broadband_sig_trunc))
    #rms.append(np.std(sig-broadband_sig_trunc))
    
    #cap.append(np.max(sig)-np.min(sig))
    #rms.append(np.std(sig-broadband_sig_trunc))
    pl.plot(t2, sig-broadband_sig_trunc, label=masker)


pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
pl.show()

pl.figure(figsize=(8, 6))

attns=-np.array([60,55, 50, 45,40,35,32,29,26,23])   # 20])
pl.plot(attns, cap, '+', label='max-min')

#pl.plot(-np.array([55,50,45,40,37,34,31,28,25,22]), np.array(rms)*10, label='rms x10')
pl.legend()
pl.xlabel('Notch attenuation')
pl.ylabel('Amplitude')


In [None]:
sigm=SigmoidIOFunc(0, 0)
#maskamount=1-(  (cap-np.amin(cap)) /np.amax(cap-np.amin(cap)) )
maskamount=1-(cap/np.amax(cap)) 

I_pts=I0+attns
#sigm.fit_data(I_pts, maskamount, set_mmax=True)

#HACK enforce masking=100% at attn20
#sigm.mmax.data*=1/sigm(I0-20)

sigm.fit_data(I_pts, maskamount, constrained_at_Iref=True, Iref=I0-20, method='dogbox')

I=np.linspace(-15, 50)
pl.plot(I, sigm(torch.tensor(I)), label='fit sigm')


pl.suptitle('Amount of masking in response to broadband noise')
pl.title(' (as estimated with the notch method)', fontsize=10)
pl.xlabel('Power spectral density (dB)')

pl.plot(I_pts, maskamount, '+', markersize=10, markeredgewidth=3)
pl.plot(I0-20, 1, '+', markersize=10, markeredgewidth=3, color='purple')

pl.ylabel('masking (ref max: no notch)')

#pl.xlim([0, 40])
#pl.savefig('amount_masking_5kHz.svg')

Setting model

In [None]:
ntch_maskerNames, ntch_maskingConds, ntch_signals =capData.get_batch_re('.*notch2200_bw1500')
ntch_maskingConds.set_amp0_dB(I0)

In [None]:
gauss_sigma=(1.5e-4)/(t2[1]-t2[0])
ntch_signals_proc=process_signal2(ntch_signals, gauss_sigma=gauss_sigma)

In [None]:
# test

for maskerName, sig in zip(ntch_maskerNames, ntch_signals_proc):
    pl.plot(t2, sig, label=maskerName)

pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
# estimation ur

sig=capData.get_signal_by_name('4_notch2200_bw1500_45dB')
sig2=process_signal(sig)

pl.plot((t2-3e-3)*1e3, sig2-broadband_proc)

#pl.plot(t2*1e3, sig2bis-broadband_proc)

t_shift=5.8e-3-3e-3 #excitation coincides with CM

ur0=sig2-broadband_proc
ur0=np.roll(ur0,  -int(t_shift*48828) )
pl.plot((t2-3e-3)*1e3, ur0)
#pl.xlim([4,6])

In [None]:
m=128
E0_temp=sg.windows.tukey(m, alpha=0.5)   #1ms excitation
E0=np.zeros_like(t2)
ind_begin=int((t_shift-5e-4)*48828)
ind_end=int((t_shift-5e-4)*48828)+m
E0[ind_begin:ind_end]=E0_temp

pl.plot(t2*1e3, E0)
pl.title('Init raw excitation')
pl.xlabel('t (ms)')
pl.ylabel('Amp')

In [None]:
E=ExcitationPatterns(t2, E0)  #no non-maskable part
lat_model=PowerLawLatencies.shift(lat, 5.8e-3-1e-3)   #t0: start CM-1ms 


BW10_6000=1000
BW10_6000Func=constant_BW10(BW10_6000, requires_grad=False)

E.set_masking_model(lat_model, BW10_6000Func, ntch_maskingConds, sigm)

In [None]:
pl.figure(figsize=(10,20))
plotExcitationPatterns(E, plot_raw_excitation=True) # ylim_top=1
pl.show()

In [None]:
#estimation ur

maskAmounts, excs = E.get_tensors() 

nb_steps=20
alpha=np.linspace(0.5, 0.05, nb_steps)

EPs_fft=np.fft.rfft(excs, axis=1)
CAPs_fft=np.fft.rfft(ntch_signals_proc, axis=1)
#u1_mat=np.tile(ur0, (ntch_maskingConds.n_conditions, 1))
u1_mat=np.zeros_like(ntch_signals_proc)
filter_mat  = (t2>7.5e-3)+(t2<3.2e-3)
filter_mat=np.tile(filter_mat, (ntch_maskingConds.n_conditions, 1))
#filter_mat=np.zeros_like(ntch_signals_proc, dtype=bool)
#proj_fft=E.get_projector_fft()

weights=np.sqrt(np.sum(excs.clone().detach().numpy()**2, axis=1))
for i in range(1, nb_steps+1):
    du=deconv_newton_step(u1_mat, EPs_fft, CAPs_fft, eps_ridge=0)   #TODO proj_fft
    #du=deconv_grad(u1_mat, EPs_fft, CAPs_fft)
    
    u1_mat-=alpha[i-1]*du
    #proj 1 
    u1_mat[filter_mat]=np.zeros_like(u1_mat[filter_mat])
    #proj 2

    #u1_mat_mean=np.mean(u1_mat, axis=0)[None, :]
    
    
    #HACK waiting for proj_fft
    u1_mat_mean=np.average(u1_mat, axis=0, weights=weights)[None, :]
    u1_mat=np.repeat(u1_mat_mean, ntch_maskingConds.n_conditions, axis=0)
     
    '''
    for i in range(5):
            pl.figure()
            name=ntch_maskerNames[i]
            pl.plot(u1_mat[i], label=name, color=f'C{i}')
            #pl.plot( np.abs(EPs_fft[i]))
            pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    '''
    
    if i==nb_steps:
        pl.figure()
        pl.title(f'Step {i} (deconv + proj)')
        #pl.plot(t, u0, label='u0 (truth)')
        pl.plot(t2, u1_mat[0], label='u0 (estimated)')
        pl.legend()
        pl.show()

Reestimation I/O curve

In [None]:
exc_max=[]
exc_rms=[]
attns_=[]

amps=ntch_maskingConds.amp_list[1]
for i in range(len(ntch_maskerNames)):
    name=ntch_maskerNames[i]
    sig=ntch_signals_proc[i]
    
    #pl.plot(t2*1e3, excs[i], label=name, color=f'C{i}')

    E_i=np.zeros_like(sig)
    E_i=deconv_newton(E_i, sig, ur0=u1_mat[0], alpha=1, nb_steps=20, t0=5.7e-3, t1=6.4e-3, eps_ridge=0.1)  #double peak?
    
    
    attn_=20*np.log10(amps[i])
    
    if True:
        exc_max.append(np.amax(E_i))

        exc_rms.append(np.sqrt(np.sum(E_i**2) ))
        attns_.append(attn_)
    pl.plot(t2*1e3, E_i, label=name, color=f'C{i}')
    
pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')


pl.figure()
pl.plot(attns_, exc_max, '+')
                   
pl.plot(attns_, np.array(exc_rms)*0.2, '+')
pl.xlabel('Spectral power density (dB)')

In [None]:
#reestimation I-0 curve: 

I_pts=np.array(attns_)  #I0 already in masking cond

sigm=SigmoidIOFunc(0, 0)
exc_rms=np.array(exc_rms)
maskamount=1- (exc_rms)/np.amax(exc_rms) 


#add broadband condition
#attns_.append(80)
#ma_list=list(maskamount)
#ma_list.append(1.)
#maskamount=np.array(ma_list)

#sigm.fit_data(np.array(attns_), maskamount, set_mmax=True)

#HACK enforce masking=100% at attn20
#sigm.mmax.data*=1/sigm(I0-20)

sigm.fit_data(I_pts, maskamount, constrained_at_Iref=True, Iref=I0-20, method='dogbox')


I=np.linspace(-15, 50)
pl.plot(I, sigm(torch.tensor(I)), label='fit sigm')


pl.suptitle('Amount of masking in response to broadband noise')
pl.title(' (as estimated with the notch method)', fontsize=10)
pl.xlabel('Power spectral density (dB)')

pl.plot(I_pts, maskamount, '+', markersize=10, markeredgewidth=3)
pl.plot(I0-20, 1, '+', markersize=10, markeredgewidth=3, color='purple')

pl.ylabel('masking (ref max: no notch)')


In [None]:
masker_name='4_notch4800_bw1300_20dB'  # '7_notch5300_bw800_20dB''
sig=capData.get_signal_by_name(masker_name)
sig=process_signal2(sig)
#pl.plot(t2*1e3, excs[i], label=name, color=f'C{i}')

E_i=np.zeros_like(sig)
E_i=deconv_newton(E_i, sig, ur0=u1_mat[0], alpha=0.1, nb_steps=20, t0=5e-3, t1=6.2e-3, eps_ridge=0.1)  #double peak?

pl.plot(t2*1e3, E_i, label=masker_name, color=f'C{i}')

pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')


test model

In [None]:

def plotSimulatedCAPs2(E, u=None, CAParray=None, axlist=None, shift=0, max_plots=8, ylim=None, reg_ex=None, title='Simulated CAPs', **kwargs):
    '''
    Args:
        E:ExcitationPatterns object
        u: unitary response (numpy array)
        CAParray: array of CAP signals (if the convolution is done outside the function), must be of size (nb_conditions, len(E.t)) . either CAParray or u must be given
        axlist:list of axes for the plots. If none creates a list of axes
        shift:time shift for the convolution
        ylim: interval to pass to matplotlib (opt.)
        reg_ex: regular expression to filter masker names (opt.)
    '''
    assert not(u is None) or not(CAParray is None), 'either CAParray or u must be given'
    axlist2=[]
    if E.masked:
        excs = E.get_tensor() 
        maskingConditions = E.maskingConditions
        pl.suptitle(title)
        nb_plots=min(maskingConditions.n_conditions, max_plots)
        ind=0
        for i, exc in zip(range(maskingConditions.n_conditions), excs):
            if ind==nb_plots:
                break
            if not reg_ex is None:
                if not(re.match(reg_ex, maskingConditions.names[i])):
                    continue 
            ax= pl.subplot((nb_plots+1)//2, 2, ind+1) if axlist is None else axlist[i]
            ax.set_title(maskingConditions.names[i], fontsize=10)
            
            if not CAParray is None:
                CAP=CAParray[i]
                ax.plot(E.t*1e3, CAP*1e3, **kwargs) 
                ax.grid(False)
            else:
                exc_np = exc.detach().numpy()
                CAP=np.convolve(exc_np, u, mode='full')
                t=E.t.numpy()
                ind_time=np.sum(t<(t[0]+shift))
                ind_time=min(ind_time, len(CAP)-len(E.t))
                CAP=CAP[ind_time:ind_time+len(E.t)]
                ax.plot(E.t*1e3, CAP*1e3, **kwargs) 
            #ax.grid(False)
            if not ylim is None:
                ax.set_ylim(ylim)
            ax.set_xlim([5.8,9.5])
            ax.set_xlabel('Time (ms)')
            
            ax.set_ylabel('Amplitude difference (µV)')
            axlist2.append(ax)      
            ind+=1
        pl.tight_layout()

    else:
        ax = pl.gca() if axlist is None else axlist[0]
        ax.plot(E.t*1e3, E.E0_nonmaskable, label='non maskable part', linestyle='--')
        p=ax.plot(E.t*1e3, E.E0_maskable, label='maskable part', linestyle='--', linewidth=1.5)
        E0=E.E0_nonmaskable+E.E0_maskable
        ax2=ax.twinx()  if axlist is None else axlist[1]
        exc_np = E0.detach().numpy()            
        CAP=np.convolve(exc_np, u, mode='full')
        t=E.t.numpy()
        ind_time=np.sum(t<(t[0]+shift))
        ind_time=min(ind_time, len(CAP)-len(E.t))
        CAP=CAP[ind_time:ind_time+len(E.t)]
        ax2.plot(E.t*1e3, CAP, color=p[0].get_color())
        ax2.grid(False)
        ax.set_xlabel('Time (ms)')
        ax.legend()
        axlist2.append(ax)
        axlist2.append(ax2)
    return axlist2

In [None]:
#model
u1=u1_mat[0]
pl.figure(figsize=(12,20))
ax_list=plotSimulatedCAPs2(E, u1, ylim=[-15, 15], max_plots=10, color='C1')
plotSimulatedCAPs2(E, CAParray=ntch_signals_proc, axlist=ax_list, max_plots=10, color='C0')
pl.plot()
#pl.savefig('CAPs_notch_5khz.svg')

In [None]:
vbw_maskerNames, vbw_maskingConds, vbw_signals =capData.get_batch_re('(1_notch2200_500_20dB|2_notch2300_700_20dB)')
vbw_signals_proc=process_signal2(vbw_signals, gauss_sigma=gauss_sigma)
vbw_maskingConds.set_amp0_dB(I0)

In [None]:
E2=ExcitationPatterns(t2, E0)  #no non-maskable part
BW10_6000TestFunc=constant_BW10(1500, requires_grad=False)

E2.set_masking_model(lat_model, BW10_6000TestFunc, vbw_maskingConds, sigm)

In [None]:
#model
u1=u1_mat[0]
pl.figure(figsize=(10,4))
ax_list=plotSimulatedCAPs2(E2, u1, ylim=[-15, 15])
plotSimulatedCAPs2(E2, CAParray=vbw_signals_proc, axlist=ax_list)
pl.plot()

In [None]:
bw_arr=np.linspace(500, 2500)
sigs_ref=vbw_signals_proc
errs=[]
for bw in bw_arr:

    BW10_6000TestFunc=constant_BW10(bw, requires_grad=False)
    E2.set_masking_model(lat_model, BW10_6000TestFunc, vbw_maskingConds, sigm)
    excs = E2.get_tensor() 
    maskingConditions = E2.maskingConditions
    err=0
    for i, exc in zip(range(maskingConditions.n_conditions), excs):
        exc_np = exc.detach().numpy()
        CAP=np.convolve(exc_np, u1, mode='full')
        t=E.t.numpy()
        CAP=CAP[0:len(E2.t)]
        err+=np.mean( (CAP-sigs_ref[i])**2)
    err/=maskingConditions.n_conditions
    errs.append(np.sqrt(err))

pl.figure(figsize=(6, 3.5))
pl.plot(bw_arr, np.array(errs)*1e3)
pl.xlabel('BW 10dB model (Hz)')
pl.ylabel('RMS error (µV)')
#pl.savefig('RMS_err_5kHz.svg')