## In this notebook, we run all of the analyses to generate the main text sensitivity analysis figures 
We go through the code used to generate the five main text figures in the cells below.

We first generate figure 1, a schematic or method figure below. 

In [None]:
# here we run the neccesary helper functions for the actogram
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev
import jax.numpy as jnp

from hessian_normalized import HessianCircadian
from hessian_normalized import Actogram
from hessian_normalized import ParameterRecovery
from lightschedules import RegularLight
from lightschedules import ShiftWorkLight
from lightschedules import ShiftWorkerThreeTwelves
from lightschedules import SocialJetLag
from lightschedules import SlamShift

sens=HessianCircadian()
params = sens.get_parameters_array()

# define the light schedules 
ndays = 7
intensity = 979
dt = 0.1
ts=np.arange(0, 24*ndays, dt)
lights_rl = jnp.array([RegularLight(t, Intensity = intensity) for t in ts ]) # define the light schedules 
lights_sw = jnp.array([ShiftWorkLight(t,Intensity = intensity) for t in ts ])
lights_sw312 = jnp.array([ShiftWorkerThreeTwelves(t,Intensity = intensity) for t in ts ])
lights_sjl = jnp.array([SocialJetLag(t,Intensity = intensity) for t in ts ])
lights_ss = jnp.array([SlamShift(t,Intensity = intensity) for t in ts ])
lights_dark = jnp.zeros([len(ts),])

# here we generate the actogram plot: change light_schedule to the other light schedules to see all the plots
light_schedule = lights_rl
Actogram(ts, light_schedule)

# plot of default parameters
plt.bar(range(len(params)),np.log(abs(params)))#np.log(params))
plt.axis([-1, 14, -4, 8])
plt.xlabel('parameter index')
plt.ylabel('log(abs(parameters))')
plt.title('Log of Parameter Magnitude, Default')
plt.show()

# generate example perturbed parameters
rand_percent_vec = 0.5
k = 12
params_perturb = params
log_params_perturbed = np.log(abs(params_perturb))
log_params_perturbed[k] = rand_percent_vec*log_params_perturbed[k]
plt.bar(range(len(params)),log_params_perturbed)
plt.axis([-1, 14, -4, 8])
plt.xlabel('parameter index')
plt.ylabel('log(abs(parameters))')
plt.title('Log of Parameter Magnitude, Perturbed')
plt.show()

# get stored initial conditions, generated in ics_generate.ipynb
%store -r ics_rl
%store -r ics_sw
%store -r ics_sw312
%store -r ics_sjl
%store -r ics_ss
%store -r ics_dark

num_iter = 1

for j in range(len(params)):
    
    for m in range(num_iter):
        
        # perturb the parameters
        rand_percent_vec = 0.8+m*0.05
        params_perturb = params
        params_perturb = params_perturb.at[j].set(params[j]*rand_percent_vec)
        
        # generate perturbed model states 
        model_states_all_new_rl = sens.step_all_n(ics_rl, lights_rl, params_perturb, 0.10)
        plt.plot(ts/24,model_states_all_new_rl[:,0]*np.cos(model_states_all_new_rl[:,1]),'k',label = 'RL', alpha = 0.2)
        
        
       
        
model_states_all_rl = sens.step_all_n(ics_rl, lights_rl, params, 0.10)
plt.plot(ts/24,model_states_all_rl[:,0]*np.cos(model_states_all_new_rl[:,1]),'b',label = 'RL')
plt.xlabel('day')
plt.ylabel('Rcos(psi)')
plt.title('Model states, default and perturbed parameters')
plt.show()

We next generate figure 2, a hessian-based sensitivity analysis, below. 

In [None]:
# get stored initial conditions, generated in ics_generate.ipynb
%store -r ics_rl
%store -r ics_sw
%store -r ics_sw312
%store -r ics_sjl
%store -r ics_ss
%store -r ics_dark

# generate the new sensitivity hessian results 
normalized_hessian_rl = sens.normalized_hessian(u0 = ics_rl,light = lights_rl)
normalized_hessian_sw = sens.normalized_hessian(u0 = ics_sw,light = lights_sw)
normalized_hessian_sw312 = sens.normalized_hessian(u0 = ics_sw312,light = lights_sw312)
normalized_hessian_sjl = sens.normalized_hessian(u0 = ics_sjl,light = lights_sjl)
normalized_hessian_ss = sens.normalized_hessian(u0 = ics_ss,light = lights_ss)
normalized_hessian_dark = sens.normalized_hessian(u0 = ics_dark,light = lights_dark)

# generate figures for the hessian results--overall sensitivity figure 
norm_2 = np.zeros(6)
norm_2[0] = np.linalg.norm(normalized_hessian_rl)
norm_2[1]  = np.linalg.norm(normalized_hessian_sw)
norm_2[2] = np.linalg.norm(normalized_hessian_sw312)
norm_2[3] = np.linalg.norm(normalized_hessian_sjl)
norm_2[4] = np.linalg.norm(normalized_hessian_ss)
norm_2[5] = np.linalg.norm(normalized_hessian_dark)

# visualize 
plt.bar(np.arange(6),norm_2)
plt.xlabel('light schedule')
plt.ylabel('frobenius norm of hessian')
plt.xticks(np.arange(6), ('Reg', 'SW', 'SW312', 'SJL', 'SS','Dark'))
plt.title('Frobenius norm of the hessian')
plt.show()

Next, we generate the components of figure 3, a one-at-a-time sensitivity analysis. 

In [None]:
# read in the mcmc results 
import pandas as pd

# read in the data 
mcmc_runs = pd.read_csv('mcmc_new_run.tsv', header = None, delimiter = '\t').to_numpy()
# list: tau, K, gamma, A1, A2, BetaL1, BetaL2, sigma, p, I0, alpha0, Beta1, ent_angle, cost, dd_period
mcmc_param_list = ["tau", "K", "gamma","A1", "A2", "BetaL1", "BetaL2", "sigma", "p", "I0", "alpha0", "Beta1", "ent_angle", "cost", "dd_period"]
model_param_list = ["tau", "K", "gamma", "Beta1", "A1", "A2", "BetaL1", "BetaL2", "sigma", "G", "alpha_0", "delta", "p", "I0"]

# initialize the figure 
fig, axes = plt.subplots(2, 3,figsize=(18, 9.5))
num_iter = 10 # number of iterations for each parameter

# find the model outputs with default parameters
model_states_all_rl_def = sens.step_all_n(ics_rl, lights_rl, params, dt)
model_states_all_sw_def = sens.step_all_n(ics_sw, lights_sw, params, dt)
model_states_all_sw312_def = sens.step_all_n(ics_sw312, lights_sw312, params, dt)
model_states_all_sjl_def = sens.step_all_n(ics_sjl, lights_sjl, params, dt)
model_states_all_ss_def = sens.step_all_n(ics_ss, lights_ss, params, dt)
model_states_all_dark_def = sens.step_all_n(ics_dark, lights_dark, params, dt)

# initialize the output measurements with perturbed parameters
norm_diff_rl = np.zeros([len(params),num_iter])
norm_diff_sw = np.zeros([len(params),num_iter])
norm_diff_sw312 = np.zeros([len(params),num_iter])
norm_diff_sjl = np.zeros([len(params),num_iter])
norm_diff_ss = np.zeros([len(params),num_iter])
norm_diff_dark = np.zeros([len(params),num_iter])

# norm difference measurement between model states 
def norm_diff_states(model_states_default, model_states_perturbed):
    
    x1 = model_states_perturbed[:,0] * jnp.cos(model_states_perturbed[:,1])
    y1 = model_states_perturbed[:,0] * jnp.sin(model_states_perturbed[:,1])
    x2 = model_states_default[:,0] * jnp.cos(model_states_default[:,1])
    y2 = model_states_default[:,0] * jnp.sin(model_states_default[:,1])
    norm_diff = np.mean((x1 - x2) ** 2 + (y1 - y2) ** 2)
    
    return norm_diff

# loop through the parameters and iterations 
for j in range(len(params)):
    
    print(j)
    
    for m in range(num_iter):
        
        param_name = model_param_list[j]
        rand_percent_vec = 0.96+m*0.01
        params_perturb = params
        
        if param_name in mcmc_param_list: 
            
            index = mcmc_param_list.index(param_name)
            med_val = np.quantile(mcmc_runs[:,index],0.50) 
            def_val = params[j] 
            shift = med_val - def_val
            params_perturb = params_perturb.at[j].set(np.quantile(mcmc_runs[:,index],(m+0.5)/num_iter) - shift) # simplify
            
        else:
            
            params_perturb = params_perturb.at[j].set(params[j]*rand_percent_vec)
        
        # RL schedule 
        model_states_all_new_rl = sens.step_all_n(ics_rl, lights_rl, params_perturb, dt)
        axes[0][0].plot(ts/24,model_states_all_new_rl[:,0]*np.cos(model_states_all_new_rl[:,1]),'k',label = 'RL', alpha = 0.2)
        norm_diff_rl[j,m] = norm_diff_states(model_states_all_rl_def, model_states_all_new_rl)
        
        # SW schedule 
        model_states_all_new_sw = sens.step_all_n(ics_sw, lights_sw, params_perturb, dt)
        axes[0][1].plot(ts/24,model_states_all_new_sw[:,0]*np.cos(model_states_all_new_sw[:,1]),'k',label = 'SW', alpha = 0.2)
        norm_diff_sw[j,m] = norm_diff_states(model_states_all_sw_def, model_states_all_new_sw)
        
        # SW312 schedule 
        model_states_all_new_sw312 = sens.step_all_n(ics_sw312, lights_sw312, params_perturb, dt)
        axes[0][2].plot(ts/24,model_states_all_new_sw312[:,0]*np.cos(model_states_all_new_sw312[:,1]),'k',label = 'SW312', alpha = 0.2)
        norm_diff_sw312[j,m] = norm_diff_states(model_states_all_sw312_def, model_states_all_new_sw312)
        
        # SJL schedule
        model_states_all_new_sjl = sens.step_all_n(ics_sjl, lights_sjl, params_perturb, dt)
        axes[1][0].plot(ts/24,model_states_all_new_sjl[:,0]*np.cos(model_states_all_new_sjl[:,1]),'k',label = 'SJL', alpha = 0.2)
        norm_diff_sjl[j,m] = norm_diff_states(model_states_all_sjl_def, model_states_all_new_sjl)
        
        # SS schedule
        model_states_all_new_ss = sens.step_all_n(ics_ss, lights_ss, params_perturb, dt)
        axes[1][1].plot(ts/24,model_states_all_new_ss[:,0]*np.cos(model_states_all_new_ss[:,1]),'k',label = 'SS', alpha = 0.2)
        norm_diff_ss[j,m] = norm_diff_states(model_states_all_ss_def, model_states_all_new_ss)
        
        # Dark schedule
        model_states_all_new_dark = sens.step_all_n(ics_dark, lights_dark, params_perturb, dt)
        axes[1][2].plot(ts/24,model_states_all_new_dark[:,0]*np.cos(model_states_all_new_dark[:,1]),'k',label = 'Dark', alpha = 0.2)
        norm_diff_dark[j,m] = norm_diff_states(model_states_all_dark_def, model_states_all_new_dark)

# plot default parameter state outputs 
axes[0][0].plot(ts/24,model_states_all_rl_def[:,0]*np.cos(model_states_all_rl_def[:,1]),'b',label = 'RL')
axes[0][1].plot(ts/24,model_states_all_sw_def[:,0]*np.cos(model_states_all_sw_def[:,1]),'b',label = 'SW')
axes[0][2].plot(ts/24,model_states_all_sw312_def[:,0]*np.cos(model_states_all_sw312_def[:,1]),'b',label = 'SW312')
axes[1][0].plot(ts/24,model_states_all_sjl_def[:,0]*np.cos(model_states_all_sjl_def[:,1]),'b',label = 'SJL')
axes[1][1].plot(ts/24,model_states_all_ss_def[:,0]*np.cos(model_states_all_ss_def[:,1]),'b',label = 'SS')
axes[1][2].plot(ts/24,model_states_all_dark_def[:,0]*np.cos(model_states_all_dark_def[:,1]),'b',label = 'Dark')
plt.show()

norm_diff_mean = [np.mean(norm_diff_rl), np.mean(norm_diff_sw), np.mean(norm_diff_sw312), np.mean(norm_diff_sjl), np.mean(norm_diff_ss), np.mean(norm_diff_dark)]
num_schedules = 6

print('Mean norm difference for the six light schedules is given by')
print(norm_diff_mean)

plt.bar(range(num_schedules),norm_diff_mean)
plt.xlabel('light schedule')
plt.ylabel('mean norm difference')
plt.xticks(range(num_schedules),['RL','SW','SW312','SJL','SS','Dark'])
plt.show()

import seaborn as sns 
labels = ['tau', 'K', 'gamma', 'Beta1', 'A1', 'A2', 'BetaL1',
                  'BetaL2', 'sigma', 'G', 'alpha_0', 'delta', 'p', 'I0']
param_sens2 = np.zeros([num_schedules,len(params)])
param_sens2[0,:] = np.mean(norm_diff_rl, axis = 1)
param_sens2[1,:] = np.mean(norm_diff_sw, axis = 1)
param_sens2[2,:] = np.mean(norm_diff_sw312, axis = 1)
param_sens2[3,:] = np.mean(norm_diff_sjl, axis = 1)
param_sens2[4,:] = np.mean(norm_diff_ss, axis = 1)
param_sens2[5,:] = np.mean(norm_diff_dark, axis = 1)
plt.figure(figsize = (15,8))
ax = sns.heatmap(param_sens2,annot = True)
ax.set_yticklabels(['RL','SW','SW312','SJL','SS','Dark'])
ax.set_xticklabels(labels)
plt.title('Parameter Perturbations')
plt.show()

We generate the real data examples (old figure 4), applying the hessian-based sensitivity analysis to real data examples. 

In [None]:
# function for generating initial conditions 
def ics_individual_schedules(final_state_diff, convergence_val, ics, lights, params):
    u0 = ics
    count = 0
    while final_state_diff > convergence_val and count < 50:

        # simulate the model and extract the final time as the initial condition
        count = count + 1
        statesfinal = sens.step_n(u0 = u0, light = lights, params = params, dt = 0.10) # final state value
        final_state_diff = abs(statesfinal[0] - u0[0]) + abs(np.mod(statesfinal[1] - u0[1] + np.pi,2*np.pi) - np.pi)
        #print(final_state_diff)
        u0 = statesfinal
    return u0

# generate the initial conditions for the individual schedules 
convergence_val = 10**(-3)
final_state_diff = 100

# filepath where real data is stored 
filepath = "/Users/calebmayer/Documents/MATLAB/Examples/matlab/ArrayIndexingGSExample/covid/data_s"

# define parameters 
count = 0
num_subjects = 10
start_pt = 1
num_days = 7
convergence_val = 10**(-3)
final_state_diff = 100
ics = jnp.array([0.70,0.0,0.0])
activity_scaling_factor = 10
dt = 0.1

# initialize record/list of hessian matrices 
hessianVal_rl = []

# iterate through the selected files 
for j in range(start_pt, start_pt + num_subjects):
    
    filename = os.listdir(filepath)[j] # num_subjects files for now--take the steps ones 
    print(j)
    
    if 'steps' in filename:

        data = pd.read_csv(filepath + '/' + filename)
        data_np = pd.DataFrame.to_numpy(data)
        
        # convert time to hours 
        data_np[:,0] = data_np[:,0] - data_np[0,0]
        data_np[:,0] = data_np[:,0]*24

        ts = data_np[:,0]
        lights = data_np[:,1]
        lights = lights[ts < num_days*24]
        ts = ts[ts < num_days*24]
        
        # resample the data 
        ts_new = np.arange(ts[0], ts[-1], dt)
        lights_new = np.zeros([len(ts_new)])
        
        for m in range(len(ts_new)-1):
            
            index1 = ts > ts_new[m]
            index2 = ts <= ts_new[m+1]
            index = np.logical_and(index1,index2) 
            light_vals = lights[index]
            lights_new[m] = np.sum(light_vals)*activity_scaling_factor # adjust this as needed
            
        # generate new initial conditions for each individual 
        ics_looped = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_new, params)
        
        # apply the hessian method to the rescaled activity data and new ics 
        hessianVal_rl.append(sens.normalized_hessian(u0 = ics_looped, light=lights_new)) # use ending value as new initial condition 
        count = count+1

# choose the example subjects 
ind1 = 0
ind2 = 1
ind3 = 2

# compute the frobenius norm of the hessian, as an overall sensitivity metric 
print('Overall sensitivity of the example individuals are given by')
print(np.linalg.norm(hessianVal_rl[ind1]))
print(np.linalg.norm(hessianVal_rl[ind2]))
print(np.linalg.norm(hessianVal_rl[ind3]))

# find the principal eigenvector for the three example individuals 
evals, evecs = np.linalg.eig(hessianVal_rl[ind1])
evec1 = abs(evecs[:,np.argmax(evals)])
plt.bar(range(len(evecs)),abs(evecs[:,np.argmax(evals)]))
plt.title('Mishra et al., subject %i' % ind1)
plt.ylabel('magnitude')
plt.xlabel('component')
#plt.savefig('figures_8_29/real_data_ex0.svg')
plt.show()

evals, evecs = np.linalg.eig(hessianVal_rl[ind2])
evec2 = abs(evecs[:,np.argmax(evals)])
plt.bar(range(len(evecs)),abs(evecs[:,np.argmax(evals)]))
plt.title('Mishra et al., subject %i' % ind2)
plt.ylabel('magnitude')
plt.xlabel('component')
#plt.savefig('figures_8_29/real_data_ex1.svg')
plt.show()

evals, evecs = np.linalg.eig(hessianVal_rl[ind3])
evec3 = abs(evecs[:,np.argmax(evals)])
plt.bar(range(len(evecs)),abs(evecs[:,np.argmax(evals)]))
plt.title('Mishra et al., subject %i' % ind3)
plt.ylabel('magnitude')
plt.xlabel('component')
#plt.savefig('figures_8_29/real_data_ex2.svg')
plt.show()

# read in the stored eigenvalues, from clean_figure2.ipynb
%store -r evecs_rl
%store -r evecs_sw
%store -r evecs_sw312
%store -r evecs_sjl
%store -r evecs_ss
%store -r evecs_dark

num_schedules = 6
num_examples = 3

# compute the dot product between principal eigenvectors of real and synthetic schedules 
dot_mat_subject_all = np.zeros([num_examples,num_schedules])
dot_mat_subject_all[0,0] = np.dot(np.abs(evec1),np.abs(evecs_rl))
dot_mat_subject_all[0,1] = np.dot(np.abs(evec1),np.abs(evecs_sw))
dot_mat_subject_all[0,2] = np.dot(np.abs(evec1),np.abs(evecs_sw312))
dot_mat_subject_all[0,3] = np.dot(np.abs(evec1),np.abs(evecs_sjl))
dot_mat_subject_all[0,4] = np.dot(np.abs(evec1),np.abs(evecs_ss))
dot_mat_subject_all[0,5] = np.dot(np.abs(evec1),np.abs(evecs_dark))

dot_mat_subject_all[1,0] = np.dot(np.abs(evec2),np.abs(evecs_rl))
dot_mat_subject_all[1,1] = np.dot(np.abs(evec2),np.abs(evecs_sw))
dot_mat_subject_all[1,2] = np.dot(np.abs(evec2),np.abs(evecs_sw312))
dot_mat_subject_all[1,3] = np.dot(np.abs(evec2),np.abs(evecs_sjl))
dot_mat_subject_all[1,4] = np.dot(np.abs(evec2),np.abs(evecs_ss))
dot_mat_subject_all[1,5] = np.dot(np.abs(evec2),np.abs(evecs_dark))

dot_mat_subject_all[2,0] = np.dot(np.abs(evec3),np.abs(evecs_rl))
dot_mat_subject_all[2,1] = np.dot(np.abs(evec3),np.abs(evecs_sw))
dot_mat_subject_all[2,2] = np.dot(np.abs(evec3),np.abs(evecs_sw312))
dot_mat_subject_all[2,3] = np.dot(np.abs(evec3),np.abs(evecs_sjl))
dot_mat_subject_all[2,4] = np.dot(np.abs(evec3),np.abs(evecs_ss))
dot_mat_subject_all[2,5] = np.dot(np.abs(evec3),np.abs(evecs_dark))

# generate the normalized heatmap figure 
dot_mat_subject_all_norm = (dot_mat_subject_all - np.mean(dot_mat_subject_all))/np.std(dot_mat_subject_all)
ax = sns.heatmap(dot_mat_subject_all_norm, annot = True)
ax.set_xticklabels(['RL','SW','SW312','SJL','SS','Dark'])
ax.set_yticklabels(['Ex. 1', 'Ex. 2', 'Ex. 3'])

We finally compute figure 5, a parameter recovery analysis. 

In [None]:
# parameters to define noisy light 
mu = 0
sigma = 100

states_rl_noisy_light = recov.noisy_light(mu, sigma, lights_rl, ts)
states_sw_noisy_light = recov.noisy_light(mu, sigma, lights_sw, ts)
states_sw312_noisy_light = recov.noisy_light(mu, sigma, lights_sw312, ts)
states_ss_noisy_light = recov.noisy_light(mu, sigma, lights_ss, ts)
states_sjl_noisy_light = recov.noisy_light(mu, sigma, lights_sjl, ts)
states_dark_noisy_light = recov.noisy_light(mu, sigma, lights_dark, ts)

num_iter = 20

# initialize the results 
state_diff_rl = np.zeros([num_iter, len(params)])
state_diff_sw = np.zeros([num_iter, len(params)])
state_diff_sw312 = np.zeros([num_iter, len(params)])
state_diff_ss = np.zeros([num_iter, len(params)])
state_diff_sjl = np.zeros([num_iter, len(params)])
state_diff_dark = np.zeros([num_iter, len(params)])

# loop through the parameters and iterations 
for j in range(len(params)):
    print(j)
    for m in range(num_iter):
        
        param_name = model_param_list[j]
        rand_percent_vec = 0.96+(m-1)*0.005 #0.96+m*0.01
        params_perturb = params
        
        if param_name in mcmc_param_list: 
            
            index = mcmc_param_list.index(param_name)
            med_val = np.quantile(mcmc_runs[:,index],0.50) 
            def_val = params[j] 
            shift = med_val - def_val
            params_perturb = params_perturb.at[j].set(np.quantile(mcmc_runs[:,index],(m+1)/num_iter) - shift) # simplify
            
        else:
            
            params_perturb = params_perturb.at[j].set(params[j]*rand_percent_vec)
        
        states_rl_params = recov.perturbed_params(ics_rl, params_perturb, lights_rl, ts)
        states_sw_params = recov.perturbed_params(ics_sw, params_perturb, lights_sw, ts)
        states_sw312_params = recov.perturbed_params(ics_sw312, params_perturb, lights_sw312, ts)
        states_ss_params = recov.perturbed_params(ics_ss, params_perturb, lights_ss, ts)
        states_sjl_params = recov.perturbed_params(ics_sjl, params_perturb, lights_sjl, ts)
        states_dark_params = recov.perturbed_params(ics_dark, params_perturb, lights_dark, ts)
        state_diff_rl[m,j] = recov.loss_recovery(states_rl_noisy_light, states_rl_params,ts)
        state_diff_sw[m,j] = recov.loss_recovery(states_sw_noisy_light, states_sw_params,ts)
        state_diff_sw312[m,j] = recov.loss_recovery(states_sw312_noisy_light, states_sw312_params,ts)
        state_diff_ss[m,j] = recov.loss_recovery(states_ss_noisy_light, states_ss_params,ts)
        state_diff_sjl[m,j] = recov.loss_recovery(states_sjl_noisy_light, states_sjl_params,ts)
        state_diff_dark[m,j] = recov.loss_recovery(states_dark_noisy_light, states_dark_params,ts)

num_schedules = 6

# generate the heatmap
heatmap_v1 = np.zeros([num_schedules,len(params)])
heatmap_v1[0,:] = np.nanargmin(state_diff_rl, axis = 0) # check the axis
heatmap_v1[1,:] = np.nanargmin(state_diff_sw, axis = 0)
heatmap_v1[2,:] = np.nanargmin(state_diff_sw312, axis = 0)
heatmap_v1[3,:] = np.nanargmin(state_diff_sjl, axis = 0)
heatmap_v1[4,:] = np.nanargmin(state_diff_ss, axis = 0)
heatmap_v1[5,:] = np.nanargmin(state_diff_dark, axis = 0)

sns.heatmap(heatmap_v1)

# choose the three parameters to examine 
index1 = 0 # tau
index2 = 1 # K
index3 = 10 # alpha_0

# generate the example figures 
x_vals = np.linspace(1/num_iter, 1, num_iter)
plt.plot(x_vals,state_diff_rl[:,index1])
plt.show()

plt.plot(x_vals,state_diff_sw[:,index2])
plt.show()

plt.plot(x_vals,state_diff_sw312[:,index3])
plt.show()

# generate the average recovery error for each parameter 
mean_param = np.zeros([num_schedules,len(params)])
mean_param[0,:] = np.nanmean(state_diff_rl, axis = 0)
mean_param[1,:] = np.nanmean(state_diff_sw, axis = 0)
mean_param[2,:] = np.nanmean(state_diff_sw312, axis = 0)
mean_param[3,:] = np.nanmean(state_diff_sjl, axis = 0)
mean_param[4,:] = np.nanmean(state_diff_ss, axis = 0)
mean_param[5,:] = np.nanmean(state_diff_dark, axis = 0)
plt.bar(range(len(params)),np.nanmean(mean_param, axis = 0))
plt.show()