In [None]:
import torch.distributed as dist
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as pl


In [None]:
expe_name='1-22'  #'1-22' #'4-23' #'9-10' (30 dB atten M)  #'9-10_bis' for level below (38 dB atten M)  
#12-13 (Q336)

#CF=8000
CF=2200


mode_CAP='C+R' #'R' #'C+R'

E0_distributed=False #if True, E0 will be estimated from the main node of a distributed scheme (external process)
#load params from E0_params.json
Q10_distributed=False #if True, Q10 will be computed and estimated from the main node of a distributed scheme
#load params from RBF_params.json
I0_distributed=False #I0 for weibull cdf
plus_lambda=False #if I0_distributed is True and plus_lambda is true, 
#the output for the RBF network for I0 corresponds to I0 + lambda (scale)
load_wbcdf=False

#if I0_distributed or load_wbcdf True , loads wb cdf params from results_folder0
results_folder0=f'./results/fit{expe_name}-distrib/'  

backend=dist.Backend('GLOO')
n_workers=2
rank=1

filter_model='gammatone_4'  #'gaussian'

load_json_optim_params=True #if True load optim params from optim_params.json
load_json_init_params=True #if True, will load ./init_params/{expe_name}/{CF}_init_params.json if exists

write_results=False #write ur, I/O func, Q10, lat params in files
#to run (distributed): papermill -p E0_distributed True -p Q10_distributed True -p n_workers 5 -p rank 1 -p CF 4000 Fit\ data.ipynb fitdata4000.ipynb

sig_exc_plot=0.6 #gauss sigma for excitation patterns in time (in number of bins). for plots only  #0 if no filtering
save_figs=False

results_name=''  #if not blank, will save all the results in a folder with results_name (also loads param from this folder, like optim params)
results_folder=None
if results_name != '':
    results_folder=f'./results/fit{expe_name}-{results_name}/'


In [None]:
import config_mode

config_mode.init(mode_CAP)

if expe_name == '4-23':
    from fit_data_4_23_common import *
    from fit_data_4_23_list_maskers import *
elif expe_name=='1-22':
    from fit_data_1_22_common import *
    from fit_data_1_22_list_maskers import *
elif expe_name=='9-10':
    from fit_data_9_10_common import *
    from fit_data_9_10_list_maskers import *
elif expe_name=='9-10_bis':
    from fit_data_9_10_common_lower_level import *
    from fit_data_9_10_list_maskers_lower_level import *
elif expe_name=='12-13':
    from fit_data_12_13_common import *
    from fit_data_12_13_list_maskers import *

#pl.style.use('fivethirtyeight')

pl.style.use('seaborn-deep')


#mpl.rc('figure', figsize=(10,8))
#mpl.rcParams['axes.facecolor']='white'  
#mpl.rcParams['figure.facecolor'] = '1'

from scipy.ndimage import gaussian_filter1d
from optim import *

from rbf import RBFNet

import os
import json

import re
import datetime

if results_folder is None:
    if Q10_distributed:
        if I0_distributed:

            results_folder=f'./results/fit{expe_name}-distrib/I0_distrib/'
        else:
            results_folder=f'./results/fit{expe_name}-distrib/'

    else:
        results_folder=f'./results/fit{expe_name}/'

if write_results:
    print(f'writting results in {results_folder}')

In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

### Common to all CFs

In [None]:
#see fit_data_1_22_common.py

plot_main_CAPs()
plot_CAP_with_window()



if expe_name == '1-22':
    pass
    plot_CAP_w_wo_filter()

#NB: the plots below depend on the choice of the signal used for the estimation of ur
plot_raw_excitation_deconv()

**Narrowband analysis**

Below are: 
 1. plot of the CAP masking releases for the high-passed noise maskers
 2. plot of the CAP masking releases by bands ($\Delta CAP$ is computed as the CAP difference between two subsequent cut-off frequencies). First part (frequencies above 4 kHz)
 3. Same as 2, second part (frequencies below 4kHz)

In [None]:
plot_figures_narrowband_analysis()

The contributions by bands can be deconvolved by a rough approximate of the unitary function (in example below, ur is the response to a notched-noise masker at 4kHz). Deconvolution using Newton's optimization method, with 'ridge'-like penalties to ensure stability. The excitation patterns are constrained to be non-negative.

In [None]:

plot_figures_narrowband_analysis_deconv()
## in function above:
# for sig, label in list_sig:
#     E=deconv(sig, eps=1e-2)
#     E=deconv_newton(E, sig, alpha=0.005, nb_steps=50, eps_ridge=2e-1, t0=4.3e-3, t1=7e-3)
#     pl.plot(t2*1e3, E-0.25*i, label=label)
#     i+=1

The delays for the peaks are retrieved from the last plot. The best fit with a power-law is searched (dog leg method).  
$\label{eq:latencies} CF(\tau) = B (\tau-t_0)_+^\alpha$, parameters to fit: $t_0, B, \alpha$.

**Note: only a few points were used for this exp for estimation of latencies**

In [None]:

plot_estimated_latencies_deconv()
plot_latencies_fit()

Parameters after fitting:

In [None]:
lat

### CF specific

**First estimation of I/O masking curve**

In [None]:
if write_results and not(os.path.exists(results_folder)):
    os.makedirs(results_folder)

Plot of the masking releases for the notched noise maskers with varying atten. for the notch. The amount of masking is evaluated as the reduction of the CAP peak-to-peak amplitude.

In [None]:

cap=[]
rms=[]
masker_list=ntch_masker_lists[CF]  #, 'broadband_noise' 
masker_list=[st.replace('-', '_').replace('.json', '') for st in masker_list]


reg_exp=ntch_regexps[CF]

noise_rms=0

for i, masker in enumerate(masker_list):
    sig=capData.get_signal_by_name(masker)
    if not(re.match(reg_exp, masker)):
        continue
    sig=process_signal(sig)
    broadband_sig_trunc=process_signal(broadband2)
    
    #REF broadband
    cap_amp=np.max(sig-broadband_sig_trunc)-np.min(sig-broadband_sig_trunc)
    #HACK
    if '17dB' in masker:
        cap_amp*=-1
    cap.append(cap_amp)
    #rms.append(np.std(sig-broadband_sig_trunc))
    
    #cap.append(np.max(sig)-np.min(sig))
    #rms.append(np.std(sig-broadband_sig_trunc))
    
    
    diff_sig_proc=gaussian_filter1d( sig-broadband_sig_trunc, gauss_sigma) #noise amp computed on filtered version
    

    noise_rms+=np.mean(diff_sig_proc[:ind0]**2)
    pl.plot(t2*1e3, diff_sig_proc*1e3, label=masker)
     
noise_rms=np.sqrt(noise_rms/len(masker_list))
print(f'noise rms: {noise_rms*1e3:.3f} μV')

pl.xlabel('t (ms)')
pl.ylabel('Amplitude difference (μV)')

pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
pl.show()

pl.figure(figsize=(8, 6))

if expe_name =='1-22':
    attns=-np.array([55, 50, 45,40,35,32,29,26,23])   # 20])
#elif expe_name == '4-23' or expe_name=='9-10' or expe_name=='9-10_bis':
else:
    attns=-attns_arrays[CF]
pl.plot(attns+20, cap, '+', label='max-min')  #20  REF

#pl.plot(-np.array([55,50,45,40,37,34,31,28,25,22]), np.array(rms)*10, label='rms x10')
pl.legend()
pl.xlabel('Notch attenuation')
pl.ylabel('Amplitude difference')


In [None]:
sigm=SigmoidIOFunc(0, 0)
#maskamount=1-(  (cap-np.amin(cap)) /np.amax(cap-np.amin(cap)) )
maskamount=1-(cap/np.amax(cap)) 

I_pts=I0+attns
#sigm.fit_data(I_pts, maskamount, set_mmax=True)

#HACK enforce masking=100% at attn20
#sigm.mmax.data*=1/sigm(I0-20)

sigm.fit_data(I_pts, maskamount, constrained_at_Iref=True, Iref=I0-20)

wb_cdf=WeibullCDF_IOFunc()

wb_cdf.fit_data(I_pts, maskamount, constrained_at_Iref=True, Iref=I0-20)

if write_results:
    np.savez(f'{results_folder}/maskamountCAP_{CF}.npz', I_pts=I_pts, maskamount=maskamount)
    sigm.write_to_npz(f'{results_folder}/sigmIO_1st_estim_{CF}.npz')
    wb_cdf.write_to_npz(f'{results_folder}/wbcfdIO_1st_estim_{CF}.npz')

In [None]:
I=np.linspace(-30, 25)

fig=pl.figure()
ax=fig.add_axes([0,0,1,1])
pl.plot(I, sigm(torch.tensor(I))*100, label='fit sigmoid')
pl.plot(I, wb_cdf(torch.tensor(I))*100, label='fit Weibull CDF')

#plot after optim
#pl.plot(I, wb_cdf2(torch.tensor(I)).clone().detach().numpy()*100, label='   (after optim.)', color='C1', linestyle='--')

pl.xlabel('Power spectral density (dB)')

pl.plot(I_pts, maskamount*100, '+', markersize=10, markeredgewidth=3, label='based on ΔCAP \namplitude')

pl.plot(I0-20, 100, '+', markersize=10, markeredgewidth=3, color='purple')

pl.ylabel('Masking (%)')

#pl.xlim([-25, 17])
pl.ylim([0, 130])

for key, spine in ax.spines.items():
    spine.set_visible(True)
    
    spine.set_linewidth(1.3)
    spine.set_edgecolor('black')


ax.grid(which='minor')


pl.legend()
#pl.savefig('IO_func_fit.svg')

Setting model

In [None]:

#reg_exp=ntch_regexps[CF]  #previous method

fln_list=ntch_masker_lists[CF]
def get_regexp(fln_list):
    masker_list=[st.replace('-', '_').replace('.json', '') for st in fln_list]
    reg_exp=')|('.join(masker_list)
    reg_exp='('+reg_exp+')'
    return reg_exp
    
reg_exp=get_regexp(fln_list)

ntch_maskerNames, ntch_maskingConds, ntch_signals =capData.get_batch_re(reg_exp)
ntch_maskingConds.set_amp0_dB(I0)

In [None]:
#HACK pad maskers >12e3 to avoid issues with latencies (equivalent to taking the difference
#  excitations of maskers - excitation 'broadband noise')
ntch_maskingConds.pad_maskers(f_thr=11000, f_max=1e5)
ntch_maskingConds.pad_maskers2() #same thing for low freqs

In [None]:
#gauss_sigma=(1e-4)/(t2[1]-t2[0])  #gauss_sigma defined in common.py
ntch_signals_proc=process_signal2(ntch_signals, gauss_sigma=gauss_sigma)

In [None]:
# test

for maskerName, sig in zip(ntch_maskerNames, ntch_signals_proc):
    pl.plot(t2, sig, label=maskerName)

pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
# estimation ur

sig=capData.get_signal_by_name('5_notch8000_bw2300_40dB')


#sig2=process_signal(sig)
#pl.plot((t2-3e-3)*1e3, sig2-broadband_proc)
#pl.plot(t2*1e3, sig2bis-broadband_proc)



sig2=process_signal2(sig, gauss_sigma=gauss_sigma)
pl.plot((t2-3e-3)*1e3, sig2)



t_shift=5.7e-3-3e-3 #excitation coincides with CM

#ur0=sig2-broadband_proc
ur0=sig2
ur0=np.roll(ur0,  -int(t_shift*48828) )
pl.plot((t2-3e-3)*1e3, ur0)
#pl.xlim([4,6])

In [None]:
#latencies model

#lat_model=lat_above4k
#lat_model=lat

#lat_model=PowerLawLatencies.shift(lat_model, 6e-3-1e-3)   #t0: start CM-1ms

#HACK as latencies are very small (sampling issues), manual dilatation

#lat_model=lat_above4k
lat_model=lat

if expe_name == '12-13': #ref click
    lat_shifted=PowerLawLatencies.shift(lat_model, 4.8e-3-1e-3)   #t0: start click-1ms
else: #ref CM
    lat_shifted=PowerLawLatencies.shift(lat_model, 6e-3-1e-3)   #t0: start CM-1ms
lat_shifted.name='true latencies'

use_bincount=True
if use_bincount:
    lat_model=lat_shifted
else:
    #HACK as latencies are very small (sampling issues), manual dilatation
    lat_model=PowerLawLatencies.fromPts(0.0056, 10000, 0.007, 800, name= 'dilatated (hack)')
    #not required with bincount
    #lat_model=PowerLawLatencies.fromPts(0.0057, 9500, 0.0062, 6000)

In [None]:
lat_model

In [None]:
#test single lat model
singleLat=False
if singleLat or use_bincount:
    if CF>6500:
        f_min=4000
        f_max=12000
    elif CF < 5500:
        if CF<4500:
            f_min=600
            f_max=7500
        else:
            f_min=2200
            f_max=8000
    else:
        f_min=2500
        f_max=9000
        
    if E0_distributed:
        with open('E0_params.json') as f:
            params = json.load(f)
            f_min=float(params['f_min'])
            f_max=float(params['f_max'])
        
    if singleLat:
        lat_model = SingleLatency(6e-3, f_min=f_min, f_max=f_max)


In [None]:
if singleLat or use_bincount:
    
    if E0_distributed:
        with open('E0_params.json') as f:
            params = json.load(f)
            m=int(params['m'])
    else:
        m=400
    E0=1/2*np.ones((m,))
    
    pl.plot(np.linspace(f_min*1e-3, f_max*1e-3, m), E0)
    pl.xlabel('Frequency (kHz)')
    pl.ylabel('Init raw excitation')
    
    #ind for CF (can be useful later)
    ind_CF=int((CF-f_min)/(f_max-f_min)*m)
else:
    m=72
    E0_temp=sg.windows.tukey(m, alpha=0.5) 
    E0=np.zeros_like(t2)
    ind_begin=int((t_shift-1e-3)*48828)
    ind_end=int((t_shift-1e-3)*48828)+m
    E0[ind_begin:ind_end]=E0_temp

    pl.plot(t2*1e3, E0)
    pl.title('Init raw excitation')
    pl.xlabel('t (ms)')
    pl.ylabel('Amp')

In [None]:
plotLatencies(lat_model)

In [None]:
E=ExcitationPatterns(t2, E0, use_bincount=use_bincount, bincount_fmin=f_min, bincount_fmax=f_max)  #no non-maskable part
#E=ExcitationPatterns(t2, E_temp.E0_maskable)

#NB: first model for estimation of ur, cte bandwith, fixed

Q_10_0=2*(CF/1000)**0.5
BW10_0=CF/Q_10_0


#BW10_0Func=constant_BW10(BW10_0, requires_grad=False)   #constant BW

#Q10 defined by the power law above
BW10_0Func=Q10PowerLaw(2, 1000, 0.5, requires_grad=False)

print(f'BW10 for first guess: {BW10_0Func(CF):.1f} Hz')

#E.set_masking_model(lat_model, BW10_0Func, ntch_maskingConds, sigm, filter_model=filter_model)
E.set_masking_model(lat_model, BW10_0Func, ntch_maskingConds, wb_cdf, filter_model=filter_model)

**Estimation of unitary response**

#TODO clean

The unitary response is estimated by deconvolution of the CAP masking releases $[\Delta CAP(t)]_i$ for the notched-noise maskers with varying attenuation for the notch. A first guess for the masking release patterns is used (after optimization of the model parameters, a re-estimation of ur can be done with `load_wbcdf=True`, before a second optimisation). The UR is taken as the average of the deconvolved signals, weighted by the quadratic sum of $[\Delta CAP(t)]_i$ for each condition.

In [None]:
#signals from which ur is estimated
#all notched noise maskers around CF?

#fln_list=ntch_masker_lists[CF] #only varying atten maskers
fln_list=ntch_masker_lists[CF]+vbw_fln_lists[CF]

reg_exp=get_regexp(fln_list)

ur_estim_maskerNames, ur_estim_maskingConds, ur_estim_signals =capData.get_batch_re(reg_exp)
ur_estim_maskingConds.set_amp0_dB(I0)

gauss_sigma_deconv=2*gauss_sigma 

ur_estim_signals_proc=process_signal2(ur_estim_signals, gauss_sigma=gauss_sigma_deconv)
#HACK
ur_estim_maskingConds.pad_maskers(f_thr=11000, f_max=1e5)

In [None]:
#estimation ur


if load_wbcdf or I0_distributed:
    wb_cdf=WeibullCDF_IOFunc.load_from_npz(f'{results_folder0}/wbcfdIO_{CF}.npz')

E.set_masking_model(lat_model, BW10_0Func, ur_estim_maskingConds, wb_cdf, filter_model=filter_model)

maskAmounts, excs = E.get_tensors() 

nb_steps=20
alpha=np.linspace(0.5, 0.05, nb_steps)

EPs_fft=np.fft.rfft(excs, axis=1)
CAPs_fft=np.fft.rfft(ur_estim_signals_proc, axis=1)
#u1_mat=np.tile(ur0, (ur_estim_maskingConds.n_conditions, 1))
u1_mat=np.zeros_like(ur_estim_signals_proc)
#filter_mat  = (t2>7.5e-3)+(t2<3.2e-3)
filter_mat  = (t2>7.5e-3)
filter_mat=np.tile(filter_mat, (ur_estim_maskingConds.n_conditions, 1))
#filter_mat=np.zeros_like(ntch_signals_proc, dtype=bool)
#proj_fft=E.get_projector_fft()

weights=np.sqrt(np.sum(excs.clone().detach().numpy()**2, axis=1))
for i in range(1, nb_steps+1):
    du=deconv_newton_step(u1_mat, EPs_fft, CAPs_fft, eps_ridge=0)   
    #du=deconv_grad(u1_mat, EPs_fft, CAPs_fft)
    
    u1_mat-=alpha[i-1]*du
    #proj 1 
    u1_mat[filter_mat]=np.zeros_like(u1_mat[filter_mat])
    #proj 2

    #u1_mat_mean=np.mean(u1_mat, axis=0)[None, :]
    
    
    #weighted average
    u1_mat_mean=np.average(u1_mat, axis=0, weights=weights)[None, :]
    u1_mat=np.repeat(u1_mat_mean, ur_estim_maskingConds.n_conditions, axis=0)
     
    '''
    for i in range(5):
            pl.figure()
            name=ntch_maskerNames[i]
            pl.plot(u1_mat[i], label=name, color=f'C{i}')
            #pl.plot( np.abs(EPs_fft[i]))
            pl.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    '''
    

    if i==nb_steps:
        pl.figure()
        pl.title(f'Step {i} (deconv + proj)')
        #pl.plot(t, u0, label='u0 (truth)')
        #pl.plot(t2, u_temp, label='u0 (last save)')
        pl.plot(t2, u1_mat[0], label='u0 (estimated)')
        pl.legend()
        #pl.savefig('ur_8kHz_Q395.svg')
        pl.show()
        u_temp=u1_mat[0]
#if write_results:
#    np.savez(f'{results_folder}/ur_{CF}.npz', t2=t2, ur=u1_mat[0])
#saved after normalization


Maskers and signals

In [None]:
#verious freqs. (for estimating E0)
fln_list=vfreq_fln_lists[CF]

reg_exp=get_regexp(fln_list)

vfreq_maskerNames, vfreq_maskingConds, vfreq_signals =capData.get_batch_re(reg_exp)
vfreq_signals_proc=process_signal2(vfreq_signals, gauss_sigma=gauss_sigma)
vfreq_maskingConds.set_amp0_dB(I0)
vfreq_maskingConds.pad_maskers(f_thr=11000, f_max=np.Inf)

#various bws (for estimating Q10)
fln_list=vbw_fln_lists[CF]

reg_exp=get_regexp(fln_list)

vbw_maskerNames, vbw_maskingConds, vbw_signals =capData.get_batch_re(reg_exp)
vbw_signals_proc=process_signal2(vbw_signals, gauss_sigma=gauss_sigma)
vbw_maskingConds.set_amp0_dB(I0)
#HACK
vbw_maskingConds.pad_maskers(f_thr=11000, f_max=1e5)


**Fine-tuning of model parameters using gradient descent (I/O curve, Q10, frequency weights)**

In [None]:
# try more accurate estimation of i/o curve

signals_proc=ntch_signals_proc
maskingConds=ntch_maskingConds


io_func = 'weibull' 
#io_func= 'sigm'

#sigm2=SigmoidIOFunc(sigm.mu.numpy(), sigm.a.numpy(), Iref=I0-20, constrained_at_Iref=True, requires_grad=True)


#sigm2=SigmoidIOFunc(sigm.mu.numpy(), sigm.a.numpy(), Iref=I0-20, constrained_at_Iref=True, requires_grad=True)

sigm2=SigmoidIOFunc(5., 0.25, Iref=I0-20, constrained_at_Iref=True, requires_grad=True)

init_params_json=f'./init_params/{expe_name}/{CF}_init_params.json'
if load_json_init_params and os.path.exists(init_params_json):
    
    with open(init_params_json) as f:
        dic_params=json.load(f)
        k_cdf=float(dic_params['k'])
        I0_cdf=float(dic_params['I0'])
        scale_cdf=float(dic_params['scale'])
else:
    if expe_name=='4-23':
        k_cdf=3.
        I0_cdf=-24.
    elif expe_name=='9-10':
        k_cdf=7.5
        I0_cdf=-12.
    elif expe_name=='9-10_bis':
        k_cdf=7.
        I0_cdf=-20.   
    elif expe_name=='12-13':
        #in-between standard and fit 5khz
        I0_cdf=-23.
        #scale=32.
        k_cdf=4.
    else:
        I0_cdf=-20.
        k_cdf=5.
    scale_cdf=30.
    
    
    
wb_cdf2=WeibullCDF_IOFunc(I0=I0_cdf,
    scale=scale_cdf,
    k=k_cdf,
    mmax=1.,
    requires_grad=True,
    constrained_at_Iref=True,
    Iref=I0-20)

#E2=ExcitationPatterns(t2, E0, requires_grad=True)  #no non-maskable part
E2=ExcitationPatterns.copyRaw(E, requires_grad=True)
if Q10_distributed or E0_distributed:   
    #init group
    if not(dist.is_initialized()):
        dist.init_process_group(backend, init_method='tcp://127.0.0.1:1234', world_size=n_workers, rank=rank, 
                                timeout=datetime.timedelta(0, 80))  
    
    if Q10_distributed:
        Q10rbf=Q10RBFNet.create_from_jsonfile('RBF_params.json')
        #update weights (have to be sent by main process)
        Q10rbf.update_weights()
        BW10_0TestFunc=Q10RBFNet_BW10(Q10rbf)
    #Not needed to load weights for E0 as should be initialized at 1 anyway
    
    
else:
    
    BW10_0TestFunc=constant_BW10(BW10_0, requires_grad=True)
    
    
if I0_distributed:
    I0_rbf=RBFNet.create_from_jsonfile('RBF_I0_params.json')
    wb_cdf2.set_I0_w_RBFNet(I0_rbf, plus_lambda=plus_lambda)
    #update weights (have to be sent by main process)
    I0_rbf.update_weights()

In [None]:
#optim params
#NOTE: params not imported from json file for first optim(?)

alpha=30
alpha_Q10=3e7

#for estimation of E0
n_dim=7 #projection of gradient on n_dim first harmomics (Fourier basis)


if io_func=='weibull':
    #alpha_dic={wb_cdf2.I0: 30*alpha, wb_cdf2.scale: 5*alpha, wb_cdf2.k: 30*alpha}
    alpha_dic={wb_cdf2.scale: alpha, wb_cdf2.k: 10*alpha}
    
    if I0_distributed:
        alpha_dic[wb_cdf2.rbfNet.l2.weight]=0.005*alpha
    else:
        alpha_dic[wb_cdf2.I0]=10*alpha
        
    #alpha_dic={wb_cdf2.I0: 0.1*alpha, wb_cdf2.scale: 0.05*alpha, wb_cdf2.k: 0.6*alpha}
else:
    alpha_dic={sigm2.mu: 0.01*alpha, sigm2.a: 0.005*alpha}

#alpha_dic[BW10_0TestFunc.BW_10]=alpha
#alpha_dic[E2.E0_maskable]=0.1*alpha  #/!| with sum_grad_E0 set to True #previous method to modify E0 amp
alpha_dic[E2.E0_maskable_amp]=0.1*alpha  


alpha_dic_Q10={}

if Q10_distributed:
    alpha_dic_Q10[BW10_0TestFunc.Q10RBFnet.l2.weight]=0.05*alpha
else:
    alpha_dic_Q10[BW10_0TestFunc.BW_10]=alpha_Q10 #cte bw
    


alpha_dic_E0={E2.E0_maskable: 0.2*alpha}


A first loop of 50 gradient descent steps is performed to fit the I/O function on the responses for the notched-noise maskers with a varying attenuation for the notch.

In [None]:
#first optim I/O func (if not I0 distributed or load_wbcdf)

nb_stepsIO=50

if io_func=='weibull':
    E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf2, filter_model=filter_model)
else:
    E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, sigm2, filter_model=filter_model)

if not(I0_distributed or load_wbcdf):
    axes, ind_plots, err_list=optim_steps(E2, u1_mat[0], signals_proc, alpha_dic, 
                nb_steps=nb_stepsIO, #sum_grad_E0=True, 
                                plot_masking_I0_graph=True,
               step_plots=5)
else:
    #import params (k, scale)
    
    wb_cdf_temp=WeibullCDF_IOFunc.load_from_npz(f'{results_folder0}/wbcfdIO_{CF}.npz')
    if not(I0_distributed):
        wb_cdf2.I0.data=wb_cdf_temp.I0.data
    wb_cdf2.k.data=wb_cdf_temp.k.data
    wb_cdf2.scale.data=wb_cdf_temp.scale.data



The unitary response is normalized (so that it is comparable across CFs, i.e. different optimization nodes)

In [None]:
#convention: normalize so that amp=0.0001 (/20 factor compared to others expes)
#norm_factor=0.0001/(np.amax(u1_mat[0])-np.amin(u1_mat[0]) ) #normalization peak-to-peak
norm_factor=0.00005/(-np.amin(u1_mat[0]) ) #normalization N1

u1_mat*=norm_factor

if write_results:
    np.savez(f'{results_folder}/ur_{CF}.npz', t2=t2, ur=u1_mat[0])

if not(E0_distributed):
    E2.E0_maskable.data=E2.E0_maskable/norm_factor
else:
    #send norm_factor to main node
    send_norm_factor_hand=dist.isend(torch.tensor(norm_factor, dtype=torch.float64),0, tag=99)
    send_norm_factor_hand.wait() 

**Params for main loop**

In [None]:
#note change params for I/O func
#TODO write in json?
#loads params from json file
if io_func=='weibull':
    if not(I0_distributed):
        alpha_dic[wb_cdf2.I0]=0.5*alpha
    alpha_dic[wb_cdf2.scale]= 0.05*alpha
    alpha_dic[wb_cdf2.k]= 0.5*alpha    

n_it=100  #100
nb_steps=3 #5


if load_json_optim_params:
    if os.path.exists(f'optim_params_{expe_name}.json'):
        optim_params_filename=f'optim_params_{expe_name}.json'
    else:
        optim_params_filename='optim_params.json'
    with open(optim_params_filename) as f:
        dic_params=json.load(f)
    
    n_it=dic_params['n_it']
    nb_steps=dic_params['nb_steps']
    n_dim=dic_params['n_dim']
    step_values=dic_params['alpha']
    if io_func=='weibull':
        if I0_distributed:
            alpha_dic[wb_cdf2.rbfNet.l2.weight]=float(step_values['I0RBFweights'])
        else:
            alpha_dic[wb_cdf2.I0]=float(step_values['I0'])
        alpha_dic[wb_cdf2.scale]= float(step_values['scale'])
        alpha_dic[wb_cdf2.k]= float(step_values['k'] )
    else:
        alpha_dic[sigm2.mu]= float(step_values['sigm_mu'])
        alpha_dic[sigm2.a]= float(step_values['sigm_a'])
    alpha_dic[E2.E0_maskable_amp]=float(step_values['E0_amp'])
        

    if Q10_distributed:
        alpha_dic_Q10[BW10_0TestFunc.Q10RBFnet.l2.weight]=float(step_values['Q10RBFweights'])
    else:
        alpha_dic_Q10[BW10_0TestFunc.BW_10]=float(step_values['Q10']) #cte bw
                                  
    
    alpha_dic_E0[E2.E0_maskable]=float(step_values['E0'])

    

**Optim (main loop)**

The model parameters are fine-tuned with an alternate gradient scheme.
 1. Update of the weights R0 (E0 in the code). Notched noise maskers with a notch present on a broad range of frequencies around CF are used for the computation of gradients. 3 steps
 2.  Update of the weights I/0 function (computation of gradients over: notched noise maskers with varying attenuation at the notch. notch centered at CF only). 3 steps
 3.  Update of Q_10 (notched noise maskers with varying notch width. Maskers with CF belonging to the notch only)

In [None]:

tot_steps=3*n_it*nb_steps

errs=[]
errs_total=[]

pl.figure(figsize=(6, 12))

for i in range(n_it):
    
    if Q10_distributed or E0_distributed or I0_distributed: #informs the main node that optim is still in process
        optim_done_hand=dist.isend(torch.tensor(nb_steps, dtype=torch.int32),0, tag=16)
        optim_done_hand.wait()
    
    if E0_distributed:  #update E0
        hand = dist.irecv(E2.E0_maskable, src=0, tag=8)
        hand.wait()
    if Q10_distributed:
        Q10rbf.update_weights()
    
    if I0_distributed:
        I0_rbf.update_weights()
       
    #E0
    if i==0:
        #HACK
        #try to have E0_amp and E0 values consistent with each other
        E2.E0_maskable_amp.data=E0[ind_CF]/E2.E0_maskable[ind_CF]*1/norm_factor
        
         #update E0_amp if E0_distributed for a few steps before anything else
        if E0_distributed:
            if io_func=='weibull':
                E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf2, filter_model=filter_model)
            else:
                E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, sigm2, filter_model=filter_model)
    
            optim_steps(E2, u1_mat[0], signals_proc, {E2.E0_maskable_amp:alpha_dic[E2.E0_maskable_amp]}, 
                nb_steps=nb_steps*5)
        
    if io_func=='weibull':
        E2.set_masking_model(lat_model, BW10_0TestFunc, vfreq_maskingConds, wb_cdf2, filter_model=filter_model)
    else:
        E2.set_masking_model(lat_model, BW10_0TestFunc, vfreq_maskingConds, sigm2, filter_model=filter_model)

    if i==0:
        axes=None
        ind_plots=None
        

        
    axes, ind_plots, err_list=optim_steps(E2, u1_mat[0], vfreq_signals_proc, alpha_dic_E0, 
        nb_steps=nb_steps, 
        n_dim_E0=n_dim, 
        E0_distributed=E0_distributed,                        
         #E0_t_min=t_min_E0, E0_t_max=t_max_E0, k_mode_E0=k_mode_E0,
        plot_E0_graph=True, plot_E0_amp_graph=True, plot_masking_I0_graph=True,
        plot_Q10=True, fc_ref_Q10=CF, step_plots=5, axes=axes, ind_plots=ind_plots, 
        step0=(3*i)*nb_steps, tot_steps=tot_steps) 
        
    err0=err_list[-1]

    #I/O Func (+ amp E0)
    if io_func=='weibull':
        E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf2, filter_model=filter_model)
    else:
        E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, sigm2, filter_model=filter_model)
    
    axes, ind_plots, err_list=optim_steps(E2, u1_mat[0], signals_proc, alpha_dic, 
                nb_steps=nb_steps, #sum_grad_E0=True, 
                plot_E0_graph=True, plot_masking_I0_graph=True,
                plot_Q10=True, fc_ref_Q10=CF,
               step_plots=5, axes=axes, ind_plots=ind_plots, step0=(3*i+1)*nb_steps,
                 tot_steps=tot_steps, I0_distributed=I0_distributed)
    err1=err_list[-1]

#     #Q10
    
    if io_func=='weibull':
        E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, wb_cdf2, filter_model=filter_model)
    else:
        E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, sigm, filter_model=filter_model)

    axes, ind_plots, err_list=optim_steps(E2, u1_mat[0], vbw_signals_proc, alpha_dic_Q10, 
            nb_steps=nb_steps, sum_grad_E0=True, 
            plot_E0_graph=True, plot_masking_I0_graph=True,
            plot_Q10=True, fc_ref_Q10=CF,
           step_plots=5, axes=axes, ind_plots=ind_plots, step0=(3*i+2)*nb_steps,
               tot_steps=tot_steps, #verbose=i%5,
               Q10_distributed=Q10_distributed)
    err2=err_list[-1]
    err_sum=(err1+err2)
    errs.append(err_sum.detach().numpy())  #errors are summed only on notched noise maskers (update I/O curve and Q10)
    #nb: possible duplicates
    errs_total.append( (err0+err1+err2))

    
if Q10_distributed or E0_distributed or I0_distributed: #informs the main node that optim is done
    optim_done_hand=dist.isend(torch.tensor(0, dtype=torch.int32),0, tag=16)
    #optim_done_hand.wait()
        
pl.tight_layout()
if save_figs:
    pl.savefig(f'fitdata{CF}_optim_steps.svg')

In [None]:
rms_2=np.sum(vbw_signals_proc**2)+np.sum(ntch_signals_proc**2)  
rms_2b=rms_2+np.sum(vfreq_signals_proc**2)  

pl.figure()
pl.plot(np.arange(len(errs)), errs/rms_2*100, label='notched noise maskers')

#pl.plot(np.arange(len(errs)), errs_total/rms_2*100, label='all maskers (w/ duplicates)')

pl.xlabel('Iterations')
pl.ylabel('Error (% variance)')

pl.legend()

if save_figs:
    pl.savefig(f'fitdata{CF}_optim_steps_err.svg')
pl.show()



In [None]:
#write data
if write_results:

    
    
    if io_func=='weibull':
        if I0_distributed:
            I0_=wb_cdf2.rbfNet(torch.tensor([CF]))
            if plus_lambda:
                I0_-=wb_cdf2.scale
            wb_cdf2.I0=I0_[0]
        wb_cdf2.write_to_npz(f'{results_folder}/wbcfdIO_{CF}.npz')
    else:
        sigm2.write_to_npz(f'{results_folder}/sigmIO_{CF}.npz')    
    
    if isinstance(lat_model, SingleLatency):
        np.savez(f'{results_folder}/E0_{CF}.npz', f=lat_model.get_f_linspace(len(t2)).detach().numpy(),
                E0=E2.E0_maskable.detach().numpy(), lat=lat_model.t0, E0_amp=E2.E0_maskable_amp.detach().numpy())
    else:
        if use_bincount: 
            np.savez(f'{results_folder}/E0_{CF}.npz', f=E2.bincount_f.detach().numpy(),
                E0=E2.E0_maskable.detach().numpy(), E0_amp=E2.E0_maskable_amp.detach().numpy())
            
       
        #save lat model
        lat_model.write_to_npz(f'{results_folder}/lat_{CF}.npz') #Note: normally lat does not depend on CF but it could
        
    if Q10_distributed:
        pass

    Q10optim= CF/E2.bw10Func(torch.tensor(CF, dtype=torch.float32))
    np.save(f'{results_folder}/Q10optim_{CF}.npy',
            Q10optim.detach().numpy() )
    
    
    #write params
    
    json_data={}
    json_data["n_it"]=n_it
    json_data["nb_steps"]=nb_steps
    json_data["tot_steps"]=tot_steps

    alpha=30
    alpha_Q10=3e7


    #for estimation of E0
    json_data["n_dim"]=n_dim 

    if io_func=='weibull':
        json_data_alpha={"scale": alpha_dic[wb_cdf2.scale],
                         "k": alpha_dic[wb_cdf2.k]}
        if I0_distributed:
            json_data_alpha["I0_rbf_weights"]=alpha_dic[wb_cdf2.rbfNet.l2.weight]
        else:
            json_data_alpha["I0"]=alpha_dic[wb_cdf2.I0]
    else:
        json_data_alpha={"mu": alpha_dic[sigm2.mu], "a": alpha_dic[sigm2.a]}

        
    #json_data_alpha["E0_amp"]=alpha_dic[E2.E0_maskable] #previous method
    json_data_alpha["E0_amp"]=alpha_dic[E2.E0_maskable_amp]

    
    json_data["Q10_distributed"]=Q10_distributed
    
    
    json_data["E0_distributed"]=E0_distributed
    
    
    json_data["I0_distributed"]=I0_distributed
    
    if Q10_distributed:
        json_data_alpha["Q10RBFweights"]=alpha_dic_Q10[BW10_0TestFunc.Q10RBFnet.l2.weight]
    else:
        json_data_alpha["Q10"]= alpha_dic_Q10[BW10_0TestFunc.BW_10] #cte bw



    json_data_alpha["E0"]=alpha_dic_E0[E2.E0_maskable]
    
    json_data["alpha"]=json_data_alpha
    
    
    with open(f'{results_folder}/optim_params_{CF}.json', 'w') as outfile:
        json.dump(json_data, outfile, indent=4)

    np.save(f'{results_folder}/err_list_{CF}.npy',
            np.array(errs)/rms_2 )
    
    sig_rms=rms_2*1/ (vbw_maskingConds.n_conditions+ntch_maskingConds.n_conditions)*1/np.shape(vbw_signals_proc**2)[1]
    sig_rms=np.sqrt(sig_rms)
    np.savez(f'{results_folder}/err_list_{CF}.npz', sum_sq_err=np.array(errs), sum_sq_sig=rms_2, 
            noise_rms=noise_rms, sig_rms=sig_rms, snr=sig_rms/noise_rms)  #+info noise level

Alternative for errors: computes errors on smaller interval (easier to compare to noise level)

In [None]:
t0_bis=t0+alpha_tukey/2*(t1-t0)  #interval inside 100% for tukey window
t1_bis=t1-alpha_tukey/2*(t1-t0)

#Tukey window is applied to truncated signal -> needs to correct times
t0_bis+=float(E2.t[0])
t1_bis+=float(E2.t[0])


print(f't0_bis: {t0_bis*1e3:.3f} ms, t1_bis: {t1_bis*1e3:.3f} ms')


#ntch_maskingConds
u1=u1_mat[0]    
if io_func=='weibull':
    E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf2, filter_model=filter_model)
else:
    E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, sigm2, filter_model=filter_model)


sq_err_ntch, sq_sig_ntch =get_sq_err_CAPs(E2, u1, ntch_signals_proc, t0_bis, t1_bis)


#vbw_maskingConds


if io_func=='weibull':
    E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, wb_cdf2, filter_model=filter_model)
else:
    E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, sigm2, filter_model=filter_model)
    
    
    
sq_err_vbw, sq_sig_vbw=get_sq_err_CAPs(E2, u1, vbw_signals_proc, t0_bis, t1_bis)


#vfreq_maskingConds
if io_func=='weibull':
    E2.set_masking_model(lat_model, BW10_0TestFunc, vfreq_maskingConds, wb_cdf2, filter_model=filter_model)
else:
    E2.set_masking_model(lat_model, BW10_0TestFunc, vfreq_maskingConds, sigm2, filter_model=filter_model)

sq_err_vfreq, sq_sig_vfreq =get_sq_err_CAPs(E2, u1, vfreq_signals_proc, t0_bis, t1_bis)
sig_rms_vfreq=np.sqrt(np.sum(sq_sig_vfreq)/ vfreq_maskingConds.n_conditions )
err_rms_vfreq=np.sqrt(np.sum(sq_err_vfreq)/ vfreq_maskingConds.n_conditions )


sig_rms2_list=[]
errs2_rms_list=[]
print('On 100% Tukey window: ')
for (eps1, eps2, supp_text) in [(1, 0, '(various notch widths)'), 
                                (0, 1, '(various notch atten)'),
                                (1,1, 'overall')]:
    sig_rms2=(eps1*np.sum(sq_sig_vbw)+eps2*np.sum(sq_sig_ntch))*1/ (eps1*vbw_maskingConds.n_conditions+eps2*ntch_maskingConds.n_conditions)
    sig_rms2=np.sqrt(sig_rms2)
    sig_rms2_list.append(sig_rms2)
    
    errs2_rms=(eps1*np.sum(sq_err_vbw)+eps2*np.sum(sq_err_ntch))*1/ (eps1*vbw_maskingConds.n_conditions+eps2*ntch_maskingConds.n_conditions)
    errs2_rms=np.sqrt(errs2_rms)
    errs2_rms_list.append(errs2_rms)

    print(f'  signal RMS {supp_text} : {sig_rms2*1e3:.3f}  μV, mean error (RMS): {errs2_rms*1e3:.3f}  μV (estimated noise level: {noise_rms*1e3:.3f}  μV)')

if write_results:
    np.savez(f'{results_folder}/err_list_{CF}_inside_window.npz',  
            noise_rms=noise_rms, sig_rms=sig_rms2_list[2], sig_rms_ntch=sig_rms2_list[1],
            sig_rms_vbw=sig_rms2_list[0],  err_rms=errs2_rms_list[2], err_rms_ntch=errs2_rms_list[1],
            err_rms_vbw=errs2_rms_list[0],
            sig_rms_vfreq=sig_rms_vfreq, err_rms_vfreq=err_rms_vfreq)
    

In [None]:
with torch.no_grad():
    #various notch atten
    if io_func=='weibull':
        E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf2, filter_model=filter_model)
    else:
        E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, sigm2, filter_model=filter_model)
    
    #various notch widths
    
#     if io_func=='weibull':
#         E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, wb_cdf2, filter_model=filter_model)
#     else:
#         E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, sigm2, filter_model=filter_model)

    pl.figure(figsize=(10,20))
    plotExcitationPatterns(E2, plot_raw_excitation=True) # ylim_top=1
    pl.show()

In [None]:
#model
with torch.no_grad():
    if io_func=='weibull':
        E2.set_masking_model(lat_model, BW10_0TestFunc, vfreq_maskingConds, wb_cdf2, filter_model=filter_model)
    else:
        E2.set_masking_model(lat_model, BW10_0TestFunc, vfreq_maskingConds, sigm2, filter_model=filter_model)
    
    u1=u1_mat[0]
    pl.figure(figsize=(12,20))
    ax_list=plotSimulatedCAPs(E2, u1, max_plots=10, sig_exc=sig_exc_plot)
    plotSimulatedCAPs(E2, CAParray=vfreq_signals_proc, axlist=ax_list, max_plots=10, plot_excitations=False, plotargs={"color":'C2'})
    
    
    if save_figs:
        pl.savefig(f'fitdata{CF}_vfreq_maskConds.svg')
    
    pl.plot()
    

In [None]:
#@interact_manual(I0=(-30, 0), scale=(10, 50), k=(0.5, 15), plot_only_learned=False)   #only works for weibull CDF
def plot_v_attn_notch(I0, scale, k, plot_only_learned):
    print('After learning: ')
    print(wb_cdf2)
    with torch.no_grad():
        
        
        if not(plot_only_learned):
            wb_cdf_temp=WeibullCDF_IOFunc(constrained_at_Iref=True, Iref=wb_cdf2._Iref, I0=I0, 
                                      scale=scale, k=k)
                
            I=torch.linspace(-30, 30, 50)
            pl.figure()

            pl.plot(I, wb_cdf2(I))

            pl.plot(I, wb_cdf_temp(I))
            pl.xlim([-20, 30])
            pl.title('Masking IO Function')
            pl.xlabel('Power spectral density (dB)')
            pl.show()
        
        if io_func=='weibull':
            E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf2, filter_model=filter_model)
        else:
            E2.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, sigm2, filter_model=filter_model)

        u1=u1_mat[0]
        pl.figure(figsize=(12,20))
        ax_list=plotSimulatedCAPs(E2, u1, ylim=[-15, 15], max_plots=10, sig_exc=sig_exc_plot)

        
        if io_func=='weibull' and not plot_only_learned:
        
            E_temp=ExcitationPatterns.copyRaw(E2)
            E_temp.set_masking_model(lat_model, BW10_0TestFunc, ntch_maskingConds, wb_cdf_temp, filter_model=filter_model)
        
            plotSimulatedCAPs(E_temp, u1, axlist=ax_list, max_plots=10, sig_exc=sig_exc_plot)

        plotSimulatedCAPs(E2, CAParray=ntch_signals_proc, axlist=ax_list, max_plots=10, 
                          plot_excitations=False, plotargs={"color":'C2'})
    if save_figs:
        pl.savefig(f'fitdata{CF}_ntch_maskConds.svg')
    pl.plot()
    
plot_v_attn_notch(0, 15, 5, True) #hack learned curve (random params)

In [None]:
#model
u1=u1_mat[0]
pl.figure(figsize=(10,14))
if io_func=='weibull':
    E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, wb_cdf2, filter_model=filter_model)
else:
    E2.set_masking_model(lat_model, BW10_0TestFunc, vbw_maskingConds, sigm2, filter_model=filter_model)
    
    
with torch.no_grad():
    ax_list=plotSimulatedCAPs(E2, u1, ylim=[-10, 10], sig_exc=sig_exc_plot)
    plotSimulatedCAPs(E2, CAParray=vbw_signals_proc, axlist=ax_list, plot_excitations=False, plotargs={"color":'C2'})

if save_figs:
    pl.savefig(f'fitdata{CF}_vbw_maskConds.svg')    
pl.plot()

In [None]:
bw_arr=np.linspace(500, 5000, num= ((4000-500)//50+1) )
sigs_ref=vbw_signals_proc
errs=[]
for bw in bw_arr:

    BW10_0TestFunc2=constant_BW10(bw, requires_grad=False)


    if io_func=='weibull':
        E2.set_masking_model(lat_model, BW10_0TestFunc2, vbw_maskingConds, wb_cdf2, filter_model=filter_model)
    else:
        E2.set_masking_model(lat_model, BW10_0TestFunc2, vbw_maskingConds, sigm2, filter_model=filter_model)
    excs = E2.get_tensor() 
    maskingConditions = E2.maskingConditions
    err=0
    for i, exc in zip(range(maskingConditions.n_conditions), excs):
        exc_np = exc.detach().numpy()
        CAP=np.convolve(exc_np, u1, mode='full')
        t=E.t.numpy()
        CAP=CAP[0:len(E2.t)]
        err+=np.mean( (CAP-sigs_ref[i])**2)
    errs.append(err/maskingConditions.n_conditions*1e6)
    
pl.plot(bw_arr, np.sqrt(errs))
pl.xlabel('BW10 model (Hz)')

pl.ylabel('Mean error (μV)')

if save_figs:
    pl.savefig(f'fitdata{CF}_BW10_errs.svg')


ind_min=np.argmin(errs)
print(f'estimated bw10: {bw_arr[ind_min]:.0f} Hz')

if write_results:
    np.savez(f'{results_folder}/Q10gridsearch_{CF}.npz', bw=bw_arr, errs=errs, bw10_est=bw_arr[ind_min])

In [None]:
wb_cdf2