In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

import diffrax
import matplotlib.pyplot as plt
import scipy.io

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples

In [None]:
rep = scipy.io.loadmat(
    "/home/lena/master-thesis/repos/optimal-control/data/Rep_PLE_1_2_3_4_5_6_7_Nf50_mv0_1_0_1_0_1_1_0_1_ct0_I12_s32_Ch1498.mat"
)
data = scipy.io.loadmat(
    "/home/lena/master-thesis/repos/optimal-control/data/Repository_data_210919.mat"
)

In [None]:
for i, d in enumerate(data["couples"][0, -1]):
    print(i, d.shape)

In [None]:
couple = data["couples"][0, -1]

x0_idx = 11
k_idx = 9

x0 = couple[x0_idx].flatten()
k = couple[k_idx].flatten()

h_sg_idx = 4
k_sg_tha_idx = 5

h_sg = k[h_sg_idx]
k_sg_tha = k[k_sg_tha_idx]

p_eif2a_idx = 1

def det_ThaKin_ld_C1_G0_1K_wP_kd_wp_67BF35A2(t, x, args):
    k, u = args

    # ODE
    dx = [None] * 10

    u = u(t)
    a = [x[1] ** k[4] / (k[5] ** k[4] + x[1] ** k[4])]  # Tr_inh

    dx[0] = (
        -k[0] * x[0]
        - (k[1] * u[0] / (k[2] + u[0]) * x[0] / (k[3] + x[0]))
        + k[10] * x[1] * x[3]
        + k[11] * x[1]
    )  # eIF2a
    dx[1] = (
        k[0] * x[0]
        + (k[1] * u[0] / (k[2] + u[0]) * x[0] / (k[3] + x[0]))
        - k[10] * x[1] * x[3]
        - k[11] * x[1]
    )  # p_eIF2a
    dx[2] = k[6] * x[9] - (k[7] * x[2])  # m_GADD34
    dx[3] = k[8] * x[2] - (k[9] * x[3])  # GADD34
    dx[4] = -k[12] * x[4] * a[0] + (k[13] * x[9])  # Pr_tot
    dx[5] = k[12] * x[4] * a[0] - (k[12] * x[5])  # Pr_delay_1
    dx[6] = k[12] * x[5] - (k[12] * x[6])  # Pr_delay_2
    dx[7] = k[12] * x[6] - (k[12] * x[7])  # Pr_delay_3
    dx[8] = k[12] * x[7] - (k[12] * x[8])  # Pr_delay_4
    dx[9] = k[12] * x[8] - (k[13] * x[9])  # Pr_delay_5

    return jnp.stack(dx, axis=-1)


def f_sg(p_eif2a, h_sg, k_sg):
    return p_eif2a**h_sg / (k_sg**h_sg + p_eif2a**h_sg)

In [None]:
def integrate(control: controls.AbstractControl, x0, t1):
    terms = diffrax.ODETerm(det_ThaKin_ld_C1_G0_1K_wP_kd_wp_67BF35A2)
    solver = diffrax.Kvaerno5()
    stepsize_controller = diffrax.PIDController(
        rtol=1e-8, atol=1e-8, pcoeff=0.3, icoeff=0.3
    )

    sol = diffrax.diffeqsolve(
        terms=terms,
        solver=solver,
        t0=0.0,
        t1=t1,
        dt0=0.1,
        y0=x0,
        args=(k, control),
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, t1, 1024)),
        stepsize_controller=stepsize_controller,
    )

    return sol

In [None]:
control = controls.LambdaControl(lambda t: jnp.full((1,), 0.0))
sol = integrate(control, x0, 10*24*60)

s0 = sol.ys[-1]

In [None]:
control = controls.LambdaControl(lambda t: jnp.full((1,), 100.0))
sol = integrate(control, s0, 10*60)

In [None]:
plt.figure()
plt.plot(sol.ts, sol.ys)
plt.show()

plt.figure()
plt.plot(sol.ts, f_sg(sol.ys[:, p_eif2a_idx], h_sg, k_sg_tha))
plt.show()