In [None]:
import jax

jax.config.update("jax_enable_x64", True)
#jax.config.update("jax_check_tracer_leaks", True)

import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments as environments
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers


In [None]:
# vmap test

def f(x):
    return jnp.dot(x, x**2), jnp.dot(x, -x)

vf = jax.vmap(f, in_axes=(0,), out_axes=(0, 0))
vf(jnp.arange(16).reshape(4, 4))

In [None]:
# Control test

control = jnp.sin(
    jnp.stack(
        (jnp.linspace(0.0, 10.0, 100), jnp.linspace(0.0, 5.0, 100) + 0.33), axis=-1
    )
)
t = jnp.linspace(-10.0, 120.0, 1000)

linear_control = controls.InterpolationControl(2, 100, 0.0, 100.0, "step", control)

control_signal = linear_control(t)

plt.figure()
plt.plot(t, control_signal)
plt.show()


In [None]:
# Environment test

environment = examples.FibrosisEnvironment()
environment_state = environment.init()

print(environment_state)

control = controls.InterpolationControl(
    2, 100, 0.0, 100.0, "linear", jnp.ones((100, 2))
)
sequence = environment.integrate(control, environment_state)

print(sequence)


In [None]:
# Training test

environment = examples.FibrosisEnvironment()
control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=jnp.ones((101, 2)))
_constraints = [constraints.NonNegativeConstantIntegralConstraint(1.0 * 101)]
solver = solvers.DirectSolver()
rewards = lambda x: -jnp.mean(jnp.log(x[..., :2]))

reward, control = trainers.solve_optimal_control_problem(
    environment, rewards, _constraints, solver, control, 1024
)


In [None]:
# Make grid

from jaxtyping import ArrayLike

def train_with_integral(integral: ArrayLike):
    environment = examples.FibrosisEnvironment()
    control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=jnp.ones((101, 2)))
    _constraints = [constraints.NonNegativeConstantIntegralConstraint(integral)]
    solver = solvers.DirectSolver()
    rewards = lambda x: -jnp.mean(jnp.log(x[..., :2]))

    reward, control = trainers.solve_optimal_control_problem(
        environment, rewards, _constraints, solver, control, 1024
    )

    return reward, control

rewards = []
_controls = []
for i in jnp.linspace(0.1, 2.0, 10):
    for j in jnp.linspace(0.1, 2.0, 10):
        reward, control = train_with_integral(jnp.asarray([i, j]))
        rewards.append(reward)
        _controls.append(control)

In [None]:
# Batch training test

from jaxtyping import ArrayLike

def train_with_integral(integral: ArrayLike):
    environment = examples.FibrosisEnvironment()
    control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=jnp.ones((101, 2)))
    _constraints = [constraints.NonNegativeConstantIntegralConstraint(integral)]
    solver = solvers.DirectSolver()
    rewards = lambda x: -jnp.mean(jnp.log(x[..., :2]))

    reward, control = trainers.solve_optimal_control_problem(
        environment, rewards, _constraints, solver, control, 1024
    )

    return reward, control.control

batched_train_with_integral = jax.vmap(train_with_integral, in_axes=(0,), out_axes=(0, 0))
integrals = jnp.stack(jnp.meshgrid(jnp.linspace(0.1, 2.0, 10), jnp.linspace(0.1, 2.0, 10)), axis=-1).reshape(-1, 2) * 101
rewards, _controls = batched_train_with_integral(integrals)

In [None]:
# Plot reward grid
with plt.style.context("seaborn-paper"):
    plt.figure(figsize=(5,5))
    plt.xlabel("aPDGF int.")
    plt.ylabel("aCSF1 int.")
    plt.imshow(rewards.reshape(10, 10), extent=(0.1, 2.0, 0.1, 2.0), origin="lower", aspect="equal", cmap="inferno")
    plt.colorbar(fraction=0.0457, pad=0.04, label="Reward")
    plt.savefig("../figures/fibrosis_opt_reward.png", bbox_inches="tight")
    plt.savefig("../figures/fibrosis_opt_reward.svg", bbox_inches="tight")
    plt.show()

In [None]:
# Plot dosage curve grid

with plt.style.context("seaborn-paper"):
    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].plot(_controls[(9 - i)*10+j])
    #ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()

In [None]:
control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=_controls[-1])

t = jnp.linspace(-10.0, 120.0, 1000)
control_signal = control(t)

plt.figure(figsize=(25, 5))
plt.plot(t, control_signal)
plt.show()


In [None]:
env_seq = environment.integrate(control, environment.init())

plt.figure()
plt.plot(jnp.linspace(0.0, 100.0, 101), env_seq)
plt.show()
