## Generate Figure 5, for the sensitivity analysis manuscript
Here we generate the parameter recovery figure for the sensitivity analysis manuscript. 

This figure consists of three main components: a heatmap displaying the ability to recover each given parameter and light schedule pair, three example figures of recovery error for given parameter and light schedule pairs, and a bar plot with an overall summary of the recovery error for each parameter. We generate results and plots for each of these figure components in the sections below. 

In [1]:
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev
import jax.numpy as jnp

# import the needed modules 
from hessian_normalized import HessianCircadian
from hessian_normalized import ParameterRecovery

from lightschedules import RegularLight
from lightschedules import ShiftWorkLight
from lightschedules import ShiftWorkerThreeTwelves
from lightschedules import SocialJetLag
from lightschedules import SlamShift

sens=HessianCircadian()
recov = ParameterRecovery()
params = sens.get_parameters_array()


# define the light schedules 
ndays = 7
intensity = 979
dt = 0.1
ts=np.arange(0, 24*ndays, dt)
lights_rl = jnp.array([RegularLight(t, Intensity = intensity) for t in ts ]) # define the light schedules 
lights_sw = jnp.array([ShiftWorkLight(t,Intensity = intensity) for t in ts ])
lights_sw312 = jnp.array([ShiftWorkerThreeTwelves(t,Intensity = intensity) for t in ts ])
lights_sjl = jnp.array([SocialJetLag(t,Intensity = intensity) for t in ts ])
lights_ss = jnp.array([SlamShift(t,Intensity = intensity) for t in ts ])
lights_dark = jnp.zeros([len(ts),])

Read in the initial conditions and parameter distribution results from MCMC run of the model. 

In [2]:
# get stored initial conditions, generated in ics_generate.ipynb
%store -r ics_rl
%store -r ics_sw
%store -r ics_sw312
%store -r ics_sjl
%store -r ics_ss
%store -r ics_dark

# read in the mcmc results 
import pandas as pd

# read in the data 
mcmc_runs = pd.read_csv('mcmc_new_run_2022.csv', header = None, delimiter = '\t').to_numpy()
# list: tau, K, A1, A2, BetaL1, BetaL2, sigma, p, I0, alpha0, Beta1, ent_angle, cost, dd_period
mcmc_param_list = ["tau", "K", "A1", "A2", "BetaL1", "BetaL2", "sigma", "p", "I0", "alpha0", "Beta1", "ent_angle", "cost", "dd_period"]
model_param_list = ["tau", "K", "gamma", "Beta1", "A1", "A2", "BetaL1", "BetaL2", "sigma", "G", "alpha_0", "delta", "p", "I0"]

Generate the model states with default parameters and noisy light. 

In [3]:
# parameters to define noisy light 
mu = 0
sigma = 1

states_rl_noisy_light = recov.noisy_light(mu, sigma, lights_rl, ts)
states_sw_noisy_light = recov.noisy_light(mu, sigma, lights_sw, ts)
states_sw312_noisy_light = recov.noisy_light(mu, sigma, lights_sw312, ts)
states_ss_noisy_light = recov.noisy_light(mu, sigma, lights_ss, ts)
states_sjl_noisy_light = recov.noisy_light(mu, sigma, lights_sjl, ts)
states_dark_noisy_light = recov.noisy_light(mu, sigma, lights_dark, ts)

  final_state_diff = abs(statesfinal[0] - u0[0]) + abs(np.mod(statesfinal[1] - u0[1] + np.pi,2*np.pi) - np.pi)


Generate the model states with perturbed parameters and non-noisy light. 

In [None]:
num_iter = 20

# initialize the results 
state_diff_rl = np.zeros([num_iter, len(params)])
state_diff_sw = np.zeros([num_iter, len(params)])
state_diff_sw312 = np.zeros([num_iter, len(params)])
state_diff_ss = np.zeros([num_iter, len(params)])
state_diff_sjl = np.zeros([num_iter, len(params)])
state_diff_dark = np.zeros([num_iter, len(params)])

# loop through the parameters and iterations 
for j in range(len(params)):
    
    print(j)
    
    for m in range(num_iter):
        
        param_name = model_param_list[j]
        rand_percent_vec = 0.96+(m-1)*0.005 #0.96+m*0.01
        params_perturb = params
        
        if param_name in mcmc_param_list: 
            
            index = mcmc_param_list.index(param_name)
            med_val = np.quantile(mcmc_runs[:,index],0.50) 
            def_val = params[j] 
            shift = med_val - def_val
            params_perturb = params_perturb.at[j].set(np.quantile(mcmc_runs[:,index],(m+1)/num_iter) - shift) # simplify
            
        else:
            
            params_perturb = params_perturb.at[j].set(params[j]*rand_percent_vec)
        
        states_rl_params = recov.perturbed_params(ics_rl, params_perturb, lights_rl, ts)
        states_sw_params = recov.perturbed_params(ics_sw, params_perturb, lights_sw, ts)
        states_sw312_params = recov.perturbed_params(ics_sw312, params_perturb, lights_sw312, ts)
        states_ss_params = recov.perturbed_params(ics_ss, params_perturb, lights_ss, ts)
        states_sjl_params = recov.perturbed_params(ics_sjl, params_perturb, lights_sjl, ts)
        states_dark_params = recov.perturbed_params(ics_dark, params_perturb, lights_dark, ts)
        state_diff_rl[m,j] = recov.loss_recovery(states_rl_noisy_light, states_rl_params,ts)
        state_diff_sw[m,j] = recov.loss_recovery(states_sw_noisy_light, states_sw_params,ts)
        state_diff_sw312[m,j] = recov.loss_recovery(states_sw312_noisy_light, states_sw312_params,ts)
        state_diff_ss[m,j] = recov.loss_recovery(states_ss_noisy_light, states_ss_params,ts)
        state_diff_sjl[m,j] = recov.loss_recovery(states_sjl_noisy_light, states_sjl_params,ts)
        state_diff_dark[m,j] = recov.loss_recovery(states_dark_noisy_light, states_dark_params,ts)


0


## Generate heatmap figure 
Next we generate a figure where we plot the index of minimum recovery error for each parameter and light schedule combination. Note that index 10 signifies perfect recovery. 

In [None]:
import seaborn as sns 
num_schedules = 6

# generate the heatmap
heatmap_v1 = np.zeros([num_schedules,len(params)])
heatmap_v1[0,:] = np.nanargmin(state_diff_rl, axis = 0) # check the axis
heatmap_v1[1,:] = np.nanargmin(state_diff_sw, axis = 0)
heatmap_v1[2,:] = np.nanargmin(state_diff_sw312, axis = 0)
heatmap_v1[3,:] = np.nanargmin(state_diff_sjl, axis = 0)
heatmap_v1[4,:] = np.nanargmin(state_diff_ss, axis = 0)
heatmap_v1[5,:] = np.nanargmin(state_diff_dark, axis = 0)
heatmap_v1 = abs(heatmap_v1 - 9)

sns.heatmap(heatmap_v1)
#plt.savefig('figures_8_29/param_recov_hm1_v4.svg')

## Generate example recovery error plots 
Here we plot three example plots of recovery errors for given light schedule and parameter combinations. 

Specifically, we plot the recovery error for $\tau$ under a real light schedule, for $K$ under the shift work light schedule, and for $\alpha_0$ under the shift work three twelve's light schedule. Change the state_diff_ value to change the light schedule, and change the indices to change the parameters. 

In [None]:
# choose the three parameters to examine 
index1 = 0 # tau
index2 = 1 # K
index3 = 10 # alpha_0

# generate the example figures 
x_vals = np.linspace(1/num_iter, 1, num_iter)
plt.plot(x_vals,state_diff_rl[:,index1])
#plt.savefig('figures_8_29/rl_tau_recov_err_v4.svg')
plt.show()

plt.plot(x_vals,state_diff_sw[:,index2])
#plt.savefig('figures_8_29/sw_k_recov_err_v4.svg')
plt.show()

plt.plot(x_vals,state_diff_sw312[:,index3])
#plt.savefig('figures_8_29/sw_alph_recov_err_v4.svg')
plt.show()

In [None]:
# generate the average recovery error for each parameter 
mean_param = np.zeros([num_schedules,len(params)])
mean_param[0,:] = np.nanmean(state_diff_rl, axis = 0)
mean_param[1,:] = np.nanmean(state_diff_sw, axis = 0)
mean_param[2,:] = np.nanmean(state_diff_sw312, axis = 0)
mean_param[3,:] = np.nanmean(state_diff_sjl, axis = 0)
mean_param[4,:] = np.nanmean(state_diff_ss, axis = 0)
mean_param[5,:] = np.nanmean(state_diff_dark, axis = 0)
plt.bar(range(len(params)),np.nanmean(mean_param, axis = 0))
#plt.savefig('figures_8_29/mean_recov_err_param_v4.svg')
plt.show()

In [None]:
num_iter = 20
for m in range(num_iter):
    #val1 = 0.96+(m-1)*0.005
    #val2 = (m+1)/num_iter
    #print(val1)
    print(val2)

In [None]:
plt.plot(state_diff_dark[:,0])


In [None]:
state_diff_sw312[:,12]

In [None]:
index = mcmc_param_list.index("p")
med_val = np.quantile(mcmc_runs[:,index],0.50) 
print(med_val)
def_val = params[12] 
shift = med_val - def_val
print(np.quantile(mcmc_runs[:,index],0.95) - shift)

In [None]:
params

In [None]:
med_val