In [3]:
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, jacfwd, jacrev
import jax.numpy as jnp

# import the needed modules 
from hessian_normalized import HessianCircadian
from hessian_normalized import Actogram
from hessian_normalized import ParameterRecovery
from lightschedules import RegularLight
from lightschedules import ShiftWorkLight
from lightschedules import ShiftWorkerThreeTwelves
from lightschedules import SocialJetLag
from lightschedules import SlamShift

sens=HessianCircadian()
params = sens.get_parameters_array()

# define the light schedules 
ndays = 7
intensity = 979
ts=np.arange(0, 24*ndays, 0.1)
lights_rl = jnp.array([RegularLight(t, Intensity = intensity) for t in ts ]) # define the light schedules 
lights_sw = jnp.array([ShiftWorkLight(t,Intensity = intensity) for t in ts ])
lights_sw312 = jnp.array([ShiftWorkerThreeTwelves(t,Intensity = intensity) for t in ts ])
lights_sjl = jnp.array([SocialJetLag(t,Intensity = intensity) for t in ts ])
lights_ss = jnp.array([SlamShift(t,Intensity = intensity) for t in ts ])
lights_dark = jnp.zeros([len(ts),])

In [4]:
def ics_individual_schedules(final_state_diff, convergence_val, ics, lights, params):
    u0 = ics
    count = 0
    while final_state_diff > convergence_val and count < 50:

        # simulate the model and extract the final time as the initial condition
        count = count + 1
        statesfinal = sens.step_n(u0 = u0, light = lights, params = params, dt = 0.10) # final state value
        final_state_diff = abs(statesfinal[0] - u0[0]) + abs(np.mod(statesfinal[1] - u0[1] + np.pi,2*np.pi) - np.pi)
        #print(final_state_diff)
        u0 = statesfinal
    return u0

# generate the initial conditions for the individual schedules 
convergence_val = 10**(-3)
final_state_diff = 100
ics = jnp.array([0.70,0.0,0.0])
ics_rl = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_rl, params) # look into mod 2pi phase 
ics_sw = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_sw, params)
ics_sw312 = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_sw312, params)
ics_sjl = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_sjl, params)
ics_ss = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_ss, params)
ics_dark = ics_individual_schedules(final_state_diff, convergence_val, ics, lights_dark, params)

In [5]:
# store the ics 
%store ics_rl
%store ics_sw
%store ics_sw312
%store ics_sjl
%store ics_ss
%store ics_dark

Stored 'ics_rl' (DeviceArray)
Stored 'ics_sw' (DeviceArray)
Stored 'ics_sw312' (DeviceArray)
Stored 'ics_sjl' (DeviceArray)
Stored 'ics_ss' (DeviceArray)
Stored 'ics_dark' (DeviceArray)
