In [None]:
import jax

jax.config.update("jax_enable_x64", True)

from typing import Tuple

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

In [None]:
# Training

key = jax.random.PRNGKey(1234)

environment = examples.ApoptosisEnvironment(
    "../data/Initial_concentrations_CD95H_wtH.mat", [0, 500], 50
)
control = controls.InterpolationControl(1, 181, 0.0, 180.0)
solver = solvers.DirectSolver()
_constraints = [
    constraints.NonNegativeConstantIntegralConstraint(jnp.asarray([2.5 * 181]))
]

def reward_fn(args: Tuple[Array, Array]):
    ys, thresh = args
    reward = jnp.mean(jnp.clip(ys[..., 12] / (ys[..., 3] + ys[..., 12]), a_min=None, a_max=thresh.reshape(-1, 1)))
    return reward

rewards = reward_fn
key = jax.random.PRNGKey(1234)

reward, control = trainers.solve_optimal_control_problem(
    environment, rewards, _constraints, solver, control, 1024, key
)


In [None]:
state = environment.init()
ys, thresh = environment.integrate(control, state, key)

In [None]:
state.x0[..., -1]

In [None]:
frac = ys[..., 12] / (ys[..., 3] + ys[..., 12])
#thresh = state.x0[..., -1] * 1.4897

plt.figure()
for t in thresh:
    plt.axhline(t, c="black")
plt.plot(frac.T)
plt.show()

plt.figure()
plt.plot(control.control)
plt.show()