In [None]:
import jax

jax.config.update("jax_enable_x64", True)

from typing import Callable

import diffrax
import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, ArrayLike
from tqdm.auto import tqdm as tq
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments as environments
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers


In [None]:
environment = examples.FibrosisEnvironment()
state = environment.init()

control = controls.LambdaControl(lambda x: jnp.ones((2,)))
seq = environment.integrate(control, state)


In [None]:
def train_with_integral(integral: ArrayLike):
    environment = examples.FibrosisEnvironment()
    control = controls.InterpolationControl(
        2, 101, 0.0, 100.0, control=jnp.ones((101, 2))
    )
    _constraints = [constraints.NonNegativeConstantIntegralConstraint(integral)]
    solver = solvers.DirectSolver()
    rewards = lambda x: -jnp.mean(jnp.log(x[..., :2]))
    key = jax.random.PRNGKey(1234)

    reward, control = trainers.solve_optimal_control_problem(
        environment, rewards, _constraints, solver, control, 128, key
    )

    return reward, control.control


batched_train_with_integral = jax.vmap(
    train_with_integral, in_axes=(0,), out_axes=(0, 0)
)

In [None]:
reward_fn = lambda x: -jnp.mean(jnp.log(x[..., :2]))

In [None]:
concentrations = 10 ** jnp.linspace(-2, 1, 10)
integrals = (
    jnp.stack(jnp.meshgrid(concentrations, concentrations), axis=-1).reshape(-1, 2)
    * 101
)

rewards, _controls = batched_train_with_integral(integrals)

In [None]:
def get_trajectory(
    environment: environments.AbstractEnvironment,
    state: environments.EnvironmentState,
    control: controls.AbstractControl,
    key: jax.random.KeyArray,
) -> Array:
    trajectory = environment.integrate(control, state, key)
    return trajectory

def is_treatment_successfull(final_state: Array) -> ArrayLike:
    return jnp.where(jnp.sum(jnp.abs(final_state[..., :2]), axis=-1) < 1e-1, 1.0, 0.0)

In [None]:
# Trajectories of optimized controls

environment = examples.FibrosisEnvironment()
state = environment.init()
key = jax.random.PRNGKey(1234)


optimal_trajectories = []
optimal_rewards = []
for i in trange(_controls.shape[0]):
    control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=_controls[i])
    
    trajectory = get_trajectory(environment, state, control, key)
    reward = reward_fn(trajectory)

    optimal_trajectories.append(trajectory)
    optimal_rewards.append(reward)


In [None]:
# Trajectories of constant controls

environment = examples.FibrosisEnvironment()
state = environment.init()
key = jax.random.PRNGKey(1234)


constant_trajectories = []
constant_rewards = []
for i in trange(_controls.shape[0]):
    constraint = constraints.NonNegativeConstantIntegralConstraint(integrals[i])
    control_signal = constraint.project(jnp.ones((101, 2)))
    control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=control_signal)

    trajectory = get_trajectory(environment, state, control, key)
    reward = reward_fn(trajectory)

    constant_trajectories.append(trajectory)
    constant_rewards.append(reward)


In [None]:
# Seperatrix trajectories

control = controls.LambdaControl(lambda _: jnp.zeros((2,)))
environment = examples.FibrosisEnvironment()
key = jax.random.PRNGKey(1234)

y0s = 10 ** jnp.linspace(0, 7, 100)

lambda1 = 0.9
lambda2 = 0.8
mu1 = 0.3
mu2 = 0.3
K = 1e6
gamma = 2
beta3 = 240 * 1440
beta1 = 470 * 1440
beta2 = 70 * 1440
alpha1 = 940 * 1440
alpha2 = 510 * 1440
k1 = 6 * 1e8
k2 = 6 * 1e8

seperatrix_trajectories = []
for i in trange(100):
    for j in range(100):
        M = y0s[i]
        F = y0s[j]

        C = -0.5 * (alpha1 / gamma * M + k2 - beta1 / gamma * F) + jnp.sqrt(
            0.25 * (alpha1 / gamma * M + k2 - beta1 / gamma * F) ** 2
            + beta1 * k2 / gamma * F
        )
        P = 0.5 * (beta2 / gamma * M + (beta3 - alpha2) / gamma * F - k1) + jnp.sqrt(
            0.25 * (k1 - beta2 / gamma * M - (beta3 - alpha2) / gamma * F) ** 2
            + (beta2 * M + beta3 * F) * k1 / gamma
        )

        y0 = jnp.stack((F, M, C, P), axis=-1)

        trajectory = environment._integrate(
            0.0, 300.0, y0, control, False, diffrax.SaveAt(t1=True)
        ).ys[-1]

        seperatrix_trajectories.append(trajectory)


In [None]:
# Just load it

import scipy.io

seperatrix_array = scipy.io.loadmat("../data/Separatrix_array_F06_M07.mat")

In [None]:
seperatrix_array

In [None]:
def plt_reward_grid(plt_rewards, x, y):
    with plt.style.context("seaborn-paper"):  
        plt.figure(figsize=(5,5))
        plt.xlabel("aPDGF int.")
        plt.ylabel("aCSF1 int.")
        plt.imshow(plt_rewards.reshape(x, y), extent=(0.1, 2.0, 0.1, 2.0), origin="lower", aspect="equal", cmap="inferno")
        plt.colorbar(fraction=0.0457, pad=0.04, label="Reward")
        #plt.savefig("../figures/fibrosis_opt_reward.png", bbox_inches="tight")
        #plt.savefig("../figures/fibrosis_opt_reward.svg", bbox_inches="tight")
        plt.show()

In [None]:
# Plot optimal reward grid

plt_reward_grid(jnp.stack(optimal_rewards, axis=0))
plt_reward_grid(is_treatment_successfull(jnp.stack(optimal_trajectories, axis=0)[:, -1]))

In [None]:
# Plot constant reward grid

plt_reward_grid(jnp.stack(constant_rewards, axis=0))
plt_reward_grid(is_treatment_successfull(jnp.stack(constant_trajectories, axis=0)[:, -1]))

In [None]:
# Plot difference reward grid

plt_reward_grid(jnp.stack(optimal_rewards, axis=0) - jnp.stack(constant_rewards, axis=0))

In [None]:
# Plot seperatrix

with plt.style.context("seaborn-paper"):
    x = np.logspace(seperatrix_array["lims_F"][0, 0], seperatrix_array["lims_F"][0, 1], seperatrix_array["tsteps"][0,0])
    y = np.logspace(seperatrix_array["lims_M"][0, 0], seperatrix_array["lims_M"][0, 1], seperatrix_array["tsteps"][0,0])

    plt.figure(figsize=(5,5))
    plt.xlabel("F")
    plt.ylabel("M")
    plt.xscale("log")
    plt.yscale("log")
    plt.pcolor(x, y, 1 - seperatrix_array["S"], cmap="Greys", vmin=0.0, vmax=3.0)
    plt.show()

In [None]:
seperatrix_trajectories[0]

In [None]:
is_treatment_successfull(jnp.stack(seperatrix_trajectories, axis=0))

In [None]:
# Plot dosage curve grid

with plt.style.context("seaborn-paper"):
    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].plot(_controls[(9 - i)*10+j])
    #ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    #plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    #plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()

In [None]:
# Plot MF trajectories

with plt.style.context("seaborn-paper"):
    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].set_xscale("log")
            ax[i, j].set_yscale("log")
            ax[i, j].plot(optimal_trajectories[(9 - i)*10+j][..., 0], optimal_trajectories[(9 - i)*10+j][..., 1])
    #ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    #plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    #plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()

In [None]:
# Plot MF trajectories

with plt.style.context("seaborn-paper"):
    x = np.logspace(seperatrix_array["lims_F"][0, 0], seperatrix_array["lims_F"][0, 1], seperatrix_array["tsteps"][0,0])
    y = np.logspace(seperatrix_array["lims_M"][0, 0], seperatrix_array["lims_M"][0, 1], seperatrix_array["tsteps"][0,0])
    
    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].set_xscale("log")
            ax[i, j].set_yscale("log")
            #ax[i, j].set_xlabel("F")
            #ax[i, j].set_ylabel("M")
            ax[i, j].pcolor(x, y, 1 - seperatrix_array["S"], cmap="Greys", vmin=0.0, vmax=3.0)
            ax[i, j].plot(constant_trajectories[(9 - i)*10+j][..., 0], constant_trajectories[(9 - i)*10+j][..., 1])

    #ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    #plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    #plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()