## Generate supplemental figures for the sensitivity analysis manuscript
Import the needed modules, define parameters and light schedules. 

In [1]:
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev
import jax.numpy as jnp

# import the needed modules 
from hessian_normalized import HessianCircadian
from lightschedules import RegularLight
from lightschedules import ShiftWorkLight
from lightschedules import ShiftWorkerThreeTwelves
from lightschedules import SocialJetLag
from lightschedules import SlamShift

sens=HessianCircadian()
params = sens.get_parameters_array()

figure_save = 'off'

# define the light schedules 
ndays = 7
dt = 0.1
ts=np.arange(0, 24*ndays, dt)
intensity = 979
lights_rl = jnp.array([RegularLight(t, Intensity = intensity) for t in ts ]) # define the light schedules 
lights_sw = jnp.array([ShiftWorkLight(t,Intensity = intensity) for t in ts ])
lights_sw312 = jnp.array([ShiftWorkerThreeTwelves(t,Intensity = intensity) for t in ts ])
lights_sjl = jnp.array([SocialJetLag(t,Intensity = intensity) for t in ts ])
lights_ss = jnp.array([SlamShift(t,Intensity = intensity) for t in ts ])
lights_dark = jnp.zeros([len(ts),])

Analyze Frobenius norm of the hessian overall sensitivity metrics with individual schedule initial conditions (as in the rest of the manuscript) and regular light schedule initial conditions. 

In [None]:
# get stored initial conditions, generated in ics_generate.ipynb
%store -r ics_rl
%store -r ics_sw
%store -r ics_sw312
%store -r ics_sjl
%store -r ics_ss
%store -r ics_dark

# generate the sensitivity hessian results 
hessianVal_rl = sens.normalized_hessian(u0 = ics_rl,light = lights_rl)
hessianVal_sw = sens.normalized_hessian(u0 = ics_sw,light = lights_sw)
hessianVal_sw312 = sens.normalized_hessian(u0 = ics_sw312,light = lights_sw312)
hessianVal_sjl = sens.normalized_hessian(u0 = ics_sjl,light = lights_sjl)
hessianVal_ss = sens.normalized_hessian(u0 = ics_ss,light = lights_ss)
hessianVal_dark = sens.normalized_hessian(u0 = ics_dark,light = lights_dark)

# visualize 
plt.bar(np.arange(6),norm_3)
plt.xlabel('light schedule')
plt.ylabel('frobenius norm of hessian')
plt.xticks(np.arange(6), ('Reg', 'SW', 'SW312', 'SJL', 'SS','Dark'))
plt.title('Individual schedule ICs')
if figure_save == 'on':
    plt.savefig('figures_8_29/bar_plot_hessian_new_rl.svg')
plt.show()

# now with all RL ICs
hessianVal_rl = sens.normalized_hessian(u0 = ics_rl,light = lights_rl)
hessianVal_sw = sens.normalized_hessian(u0 = ics_rl,light = lights_sw)
hessianVal_sw312 = sens.normalized_hessian(u0 = ics_rl,light = lights_sw312)
hessianVal_sjl = sens.normalized_hessian(u0 = ics_rl,light = lights_sjl)
hessianVal_ss = sens.normalized_hessian(u0 = ics_rl,light = lights_ss)
hessianVal_dark = sens.normalized_hessian(u0 = ics_rl,light = lights_dark)
norm_3 = np.zeros(6)
norm_3[0] = np.linalg.norm(hessianVal_rl)
norm_3[1]  = np.linalg.norm(hessianVal_sw)
norm_3[2] = np.linalg.norm(hessianVal_sw312)
norm_3[3] = np.linalg.norm(hessianVal_sjl)
norm_3[4] = np.linalg.norm(hessianVal_ss)
norm_3[5] = np.linalg.norm(hessianVal_dark)

# visualize 
plt.bar(np.arange(6),norm_3)
plt.xlabel('light schedule')
plt.ylabel('frobenius norm of hessian')
plt.xticks(np.arange(6), ('Reg', 'SW', 'SW312', 'SJL', 'SS','Dark'))
plt.title('Regular light schedule ICs')
if figure_save == 'on':
    plt.savefig('figures_8_29/bar_plot_hessian_new_rl.svg')
plt.show()



Generate the initial condition heatmap result.

In [None]:
# generate ics from around the unit circle 
num_iter = 10
norm_rl_adjust = np.zeros([num_iter,num_iter])
norm_sw_adjust = np.zeros([num_iter, num_iter])
norm_sw312_adjust = np.zeros([num_iter,num_iter])
norm_sjl_adjust = np.zeros([num_iter,num_iter])
norm_ss_adjust = np.zeros([num_iter,num_iter])
norm_dark_adjust = np.zeros([num_iter,num_iter])
for k in range(num_iter):
    print(k)
    for j in range(num_iter):
        phase = k*2*np.pi/num_iter#np.sin(k*2*np.pi/num_iter)
        amp = j/num_iter
        ics = jnp.array([amp,phase,0.5])
        hessianVal_rl_adjust = sens.normalized_hessian(u0 = ics,light = lights_rl)
        hessianVal_sw_adjust = sens.normalized_hessian(u0 = ics,light = lights_sw)
        hessianVal_sw312_adjust = sens.normalized_hessian(u0 = ics,light = lights_sw312)
        hessianVal_sjl_adjust = sens.normalized_hessian(u0 = ics,light = lights_sjl)
        hessianVal_ss_adjust = sens.normalized_hessian(u0 = ics,light = lights_ss)
        hessianVal_dark_adjust = sens.normalized_hessian(u0 = ics,light = lights_dark)
        norm_rl_adjust[k,j] = np.linalg.norm(hessianVal_rl_adjust)
        norm_sw_adjust[k,j] = np.linalg.norm(hessianVal_sw_adjust)
        norm_sw312_adjust[k,j] = np.linalg.norm(hessianVal_sw312_adjust)
        norm_sjl_adjust[k,j] = np.linalg.norm(hessianVal_sjl_adjust)
        norm_ss_adjust[k,j] = np.linalg.norm(hessianVal_ss_adjust)
        norm_dark_adjust[k,j] = np.linalg.norm(hessianVal_dark_adjust)


# In[ ]:


# generate ics from around the unit circle--look at days/weeks for entrainment 
def weeks_convergence(final_state_diff, convergence_val, ics, lights, params):
    u0 = ics
    count = 0
    while final_state_diff > convergence_val and count < 50:

        # simulate the model and extract the final time as the initial condition
        count = count + 1
        statesfinal = sens.step_n(u0 = u0, light = lights, params = params, dt = 0.10) # final state value
        final_state_diff = abs(statesfinal[0] - u0[0]) + abs(np.mod(statesfinal[1] - u0[1] + np.pi,2*np.pi) - np.pi)
        #print(final_state_diff)
        u0 = statesfinal
    return count

num_iter = 10
weeks_rl_adjust = np.zeros([num_iter,num_iter])
weeks_sw_adjust = np.zeros([num_iter, num_iter])
weeks_sw312_adjust = np.zeros([num_iter,num_iter])
weeks_sjl_adjust = np.zeros([num_iter,num_iter])
weeks_ss_adjust = np.zeros([num_iter,num_iter])
weeks_dark_adjust = np.zeros([num_iter,num_iter])
convergence_val = 10**(-3)
final_state_diff = 100
   
for k in range(num_iter):
    print(k)
    for j in range(num_iter):
        phase = k*2*np.pi/num_iter#np.sin(k*2*np.pi/num_iter)
        amp = j/num_iter
        ics = jnp.array([amp,phase,0.5])
        weeks_rl_adjust[k,j] = weeks_convergence(final_state_diff, convergence_val, ics, lights_rl, params)
        weeks_sw_adjust[k,j] = weeks_convergence(final_state_diff, convergence_val, ics, lights_sw, params)
        weeks_sw312_adjust[k,j] = weeks_convergence(final_state_diff, convergence_val, ics, lights_sw312, params)
        weeks_sjl_adjust[k,j] = weeks_convergence(final_state_diff, convergence_val, ics, lights_sjl, params)
        weeks_ss_adjust[k,j] = weeks_convergence(final_state_diff, convergence_val, ics, lights_ss, params)
        weeks_dark_adjust[k,j] = weeks_convergence(final_state_diff, convergence_val, ics, lights_dark, params)
     


Plot the initial condition heatmap (and week entrainment) results. 

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
 
def colormap(z):
    
    fig = plt.figure()
    ax = Axes3D(fig)

    n = 10
    m = 10
    rad = np.linspace(0, 1, m)
    a = np.linspace(0, 2 * np.pi, n)
    r, th = np.meshgrid(rad, a)
    #z[np.isnan(z)] = 0
    z[z > 10**6] = 0
    plt.subplot(projection="polar")

    plt.pcolormesh(th, r, z, cmap = 'Blues')

    plt.plot(a, r, ls='none', color = 'k') 
    plt.grid()
    plt.colorbar()
    return 

# generate the individual colormap figures 
colormap(rl_weeks)
plt.title('RL Entrainment Weeks')
if figure_save == 'on':
    plt.savefig('figures_8_29/rl_weeks.svg')
plt.show()

colormap(sw_weeks)
plt.title('SW Entrainment Weeks')
if figure_save == 'on':
    plt.savefig('figures_8_29/sw_weeks.svg')
plt.show()

colormap(sw312_weeks)
plt.title('SW312 Entrainment Weeks')
plt.savefig('figures_8_29/sw312_weeks.svg')
plt.show()

colormap(sjl_weeks)
plt.title('SJL Entrainment Weeks')
if figure_save == 'on':
    plt.savefig('figures_8_29/sjl_weeks.svg')
plt.show()

colormap(ss_weeks)
plt.title('SS Entrainment Weeks')
if figure_save == 'on':
    plt.savefig('figures_8_29/ss_weeks.svg')
plt.show()

colormap(dark_weeks)
plt.title('Dark Entrainment Weeks')
if figure_save == 'on':
    plt.savefig('figures_8_29/dark_weeks.svg')
plt.show()

# generate the individual colormap figures--overall sensitivity 
colormap(rl_norm)
plt.title('RL Overall Sensitivity')
if figure_save == 'on':
    plt.savefig('figures_8_29/rl_norm.svg')
plt.show()

colormap(sw_norm)
plt.title('SW Overall Sensitivity')
if figure_save == 'on':
    plt.savefig('figures_8_29/sw_norm.svg')
plt.show()

colormap(sw312_norm)
plt.title('SW312 Overall Sensitivity')
if figure_save == 'on':
    plt.savefig('figures_8_29/sw312_norm.svg')
plt.show()

colormap(sjl_norm)
plt.title('SJL Overall Sensitivity')
if figure_save == 'on':
    plt.savefig('figures_8_29/sjl_norm.svg')
plt.show()

colormap(ss_norm)
plt.title('SS Overall Sensitivity')
if figure_save == 'on':
    plt.savefig('figures_8_29/ss_norm.svg')
plt.show()

colormap(dark_norm)
plt.title('Dark Overall Sensitivity')
if figure_save == 'on':
    plt.savefig('figures_8_29/dark_norm.svg')
plt.show()

Generate the varying light intensity overall sensitivity results. 

In [None]:
# generate the sensitivity hessian results 
num_iter = 20
norm_2 = np.zeros([6, num_iter])

for j in range(num_iter):
    
    print(j)
    intensity = 50*(j)
    lights_rl = jnp.array([RegularLight(t, Intensity = intensity) for t in ts ]) # define the light schedules 
    lights_sw = jnp.array([ShiftWorkLight(t,Intensity = intensity) for t in ts ])
    lights_sw312 = jnp.array([ShiftWorkerThreeTwelves(t,Intensity = intensity) for t in ts ])
    lights_sjl = jnp.array([SocialJetLag(t,Intensity = intensity) for t in ts ])
    lights_ss = jnp.array([SlamShift(t,Intensity = intensity) for t in ts ])
    lights_dark = jnp.zeros([len(ts),])+intensity
    
    ics = jnp.array([0.70,0.0,0.0])
    ics_rl = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_rl, params) # look into mod 2pi phase 
    ics_sw = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_sw, params)
    ics_sw312 = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_sw312, params)
    ics_sjl = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_sjl, params)
    ics_ss = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_ss, params)
    ics_dark = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_dark, params)

    hessianVal_rl = sens.normalized_hessian(u0 = ics_rl,light = lights_rl)
    hessianVal_sw = sens.normalized_hessian(u0 = ics_sw,light = lights_sw)
    hessianVal_sw312 = sens.normalized_hessian(u0 = ics_sw312,light = lights_sw312)
    hessianVal_sjl = sens.normalized_hessian(u0 = ics_sjl,light = lights_sjl)
    hessianVal_ss = sens.normalized_hessian(u0 = ics_ss,light = lights_ss)
    hessianVal_dark = sens.normalized_hessian(u0 = ics_dark,light = lights_dark)

    norm_2[0,j] = np.linalg.norm(hessianVal_rl)
    norm_2[1,j]  = np.linalg.norm(hessianVal_sw)
    norm_2[2,j] = np.linalg.norm(hessianVal_sw312)
    norm_2[3,j] = np.linalg.norm(hessianVal_sjl)
    norm_2[4,j] = np.linalg.norm(hessianVal_ss)
    norm_2[5,j] = np.linalg.norm(hessianVal_dark)


Plot the resulting Frobenius norm of the hessian results, while varying the light intensity. 

In [None]:
# visualize the intensity results 
fig, axes = plt.subplots(2, 3,figsize=(17, 9.5))
xvals = np.linspace(0,950,20)#np.linspace(0,1000,10)
axes[0][0].scatter(xvals,norm_2[0,:])
axes[0][0].set_xlabel('intensity')
axes[0][0].set_ylabel('frobenius norm')
axes[0][0].set_title('RL')

axes[0][1].scatter(xvals,norm_2[1,:])
axes[0][1].set_xlabel('intensity')
axes[0][1].set_ylabel('frobenius norm')
axes[0][1].set_title('SW')

axes[0][2].scatter(xvals,norm_2[2,:])
axes[0][2].set_xlabel('intensity')
axes[0][2].set_ylabel('frobenius norm')
axes[0][2].set_title('SW312')

axes[1][0].scatter(xvals,norm_2[3,:])
axes[1][0].set_xlabel('intensity')
axes[1][0].set_ylabel('frobenius norm')
axes[1][0].set_title('SJL')

axes[1][1].scatter(xvals,norm_2[4,:])
axes[1][1].set_xlabel('intensity')
axes[1][1].set_ylabel('frobenius norm')
axes[1][1].set_title('SS')

axes[1][2].scatter(xvals,norm_2[5,:])
axes[1][2].set_xlabel('intensity')
axes[1][2].set_ylabel('frobenius norm')
axes[1][2].set_title('Dark')
if figure_save == 'on':
    plt.savefig('figures_8_29/intensity_all.svg')