In [None]:
import jax

jax.config.update("jax_enable_x64", True)

from functools import partial
from typing import Callable, Tuple

import diffrax
import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Array, ArrayLike
from tqdm.auto import tqdm as tq
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments as environments
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers


In [None]:
# Forward integration benchmark

environment = examples.FibrosisEnvironment()
state = environment.init()
key = jax.random.PRNGKey(1234)


@jax.jit
@jax.vmap
def integrate_benchmark(control_value: Array) -> Array:
    control = controls.LambdaControl(lambda t, args: args, control_value)
    seq = environment._integrate(
        t0=0.0,
        t1=300.0,
        y0=jnp.asarray([1e6, 1e6, 0.0, 0.0]),
        control=control,
        inflammation_pulse=False,
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 300.0, 301)),
        early_stopping=False,
    )

    return seq


integrate_benchmark(jnp.zeros((4, 2)))

In [None]:
integrate_benchmark(jnp.zeros((2**16, 2)))


In [None]:
def reward_fn(x: Array):
    x = jnp.where(jnp.isposinf(x), 0.0, x)
    x = jnp.clip(x[..., :2], a_min=1e2, a_max=None)
    x = -jnp.mean(jnp.log(x))

    return x


concentrations = [10 ** jnp.linspace(-2, 1, 10), 10 ** jnp.linspace(-2, 1, 10)]
integrals = jnp.stack(jnp.meshgrid(*concentrations), axis=-1).reshape(-1, 2) * 101


In [None]:
environment = examples.FibrosisEnvironment()
solver = solvers.DirectSolver()
key = jax.random.PRNGKey(1234)

# reward_fn = lambda x: -jnp.mean(jnp.log(jnp.clip(x[..., :2], a_min=1e2, a_max=None)))

control = controls.InterpolationControl(2, 101, 0.0, 100.0, method="linear")


def train_with_integral(
    integral: ArrayLike,
    environment: environments.AbstractEnvironment,
    solver: solvers.AbstractSolver,
    reward_fn: Callable[[ArrayLike], Array],
    control: controls.AbstractControl,
    key: jax.random.KeyArray,
) -> Tuple[ArrayLike, Array]:
    _constraints = [constraints.NonNegativeConstantIntegralConstraint(integral)]

    reward, control = trainers.solve_optimal_control_problem(
        environment, reward_fn, _constraints, solver, control, 1024, key
    )

    return reward, control.control


batched_train_with_integral = jax.jit(
    jax.vmap(
        partial(
            train_with_integral,
            environment=environment,
            solver=solver,
            reward_fn=reward_fn,
            control=control,
            key=key,
        ),
        in_axes=(0,),
        out_axes=(0, 0),
    )
)


In [None]:
rewards, _controls = batched_train_with_integral(integrals)


In [None]:
jnp.savez("../data/fibrosis.npz", reward_array=rewards, control_array=_controls)


In [None]:
data_dict = jnp.load("../data/fibrosis.npz")
rewards = jnp.asarray(data_dict["reward_array"])
_controls = jnp.asarray(data_dict["control_array"])


In [None]:
def get_trajectory(
    environment: environments.AbstractEnvironment,
    state: environments.EnvironmentState,
    control: controls.AbstractControl,
    key: jax.random.KeyArray,
) -> Array:
    trajectory = environment.integrate(control, state, key)
    return trajectory


def is_treatment_successfull(final_state: Array) -> ArrayLike:
    return jnp.where(jnp.sum(jnp.abs(final_state[..., :2]), axis=-1) < 1e-1, 1.0, 0.0)


In [None]:
# Trajectories of optimized controls

environment = examples.FibrosisEnvironment()
state = environment.init()
key = jax.random.PRNGKey(1234)


optimal_trajectories = []
optimal_rewards = []
for i in trange(_controls.shape[0]):
    control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=_controls[i])

    trajectory = get_trajectory(environment, state, control, key)
    reward = reward_fn(trajectory)

    optimal_trajectories.append(trajectory)
    optimal_rewards.append(reward)


In [None]:
# Trajectories of constant controls

environment = examples.FibrosisEnvironment()
state = environment.init()
key = jax.random.PRNGKey(1234)


constant_trajectories = []
constant_rewards = []
for i in trange(_controls.shape[0]):
    constraint = constraints.NonNegativeConstantIntegralConstraint(integrals[i])
    control_signal = constraint.project(jnp.ones((101, 2)))
    control = controls.InterpolationControl(2, 101, 0.0, 100.0, control=control_signal)

    trajectory = get_trajectory(environment, state, control, key)
    reward = reward_fn(trajectory)

    constant_trajectories.append(trajectory)
    constant_rewards.append(reward)


In [None]:
# Seperatrix trajectories

control = controls.LambdaControl(lambda _: jnp.zeros((2,)))
environment = examples.FibrosisEnvironment()
key = jax.random.PRNGKey(1234)

y0s = 10 ** jnp.linspace(0, 7, 100)

lambda1 = 0.9
lambda2 = 0.8
mu1 = 0.3
mu2 = 0.3
K = 1e6
gamma = 2
beta3 = 240 * 1440
beta1 = 470 * 1440
beta2 = 70 * 1440
alpha1 = 940 * 1440
alpha2 = 510 * 1440
k1 = 6 * 1e8
k2 = 6 * 1e8

seperatrix_trajectories = []
for i in trange(100):
    for j in range(100):
        M = y0s[i]
        F = y0s[j]

        C = -0.5 * (alpha1 / gamma * M + k2 - beta1 / gamma * F) + jnp.sqrt(
            0.25 * (alpha1 / gamma * M + k2 - beta1 / gamma * F) ** 2
            + beta1 * k2 / gamma * F
        )
        P = 0.5 * (beta2 / gamma * M + (beta3 - alpha2) / gamma * F - k1) + jnp.sqrt(
            0.25 * (k1 - beta2 / gamma * M - (beta3 - alpha2) / gamma * F) ** 2
            + (beta2 * M + beta3 * F) * k1 / gamma
        )

        y0 = jnp.stack((F, M, C, P), axis=-1)

        trajectory = environment._integrate(
            0.0, 300.0, y0, control, False, diffrax.SaveAt(t1=True)
        ).ys[-1]

        seperatrix_trajectories.append(trajectory)


In [None]:
# Just load it

import scipy.io

seperatrix_array = scipy.io.loadmat("../data/Separatrix_array_F06_M07.mat")


In [None]:
def plt_reward_grid(plt_rewards, x=10, y=10):
    with plt.style.context("seaborn-paper"):
        plt.figure(figsize=(5, 5))
        plt.xlabel("aPDGF int.")
        plt.ylabel("aCSF1 int.")
        plt.imshow(
            plt_rewards.reshape(x, y),
            extent=(0.1, 2.0, 0.1, 2.0),
            origin="lower",
            aspect="equal",
            cmap="inferno",
        )
        plt.colorbar(fraction=0.0457, pad=0.04, label="Reward")
        # plt.savefig("../figures/fibrosis_opt_reward.png", bbox_inches="tight")
        # plt.savefig("../figures/fibrosis_opt_reward.svg", bbox_inches="tight")
        plt.show()


In [None]:
# Plot optimal reward grid

plt_reward_grid(jnp.stack(optimal_rewards, axis=0))
plt_reward_grid(
    is_treatment_successfull(jnp.stack(optimal_trajectories, axis=0)[:, -1])
)


In [None]:
# Plot constant reward grid

plt_reward_grid(jnp.stack(constant_rewards, axis=0))
plt_reward_grid(
    is_treatment_successfull(jnp.stack(constant_trajectories, axis=0)[:, -1])
)


In [None]:
# Plot difference reward grid

plt_reward_grid(
    jnp.stack(optimal_rewards, axis=0) - jnp.stack(constant_rewards, axis=0)
)


In [None]:
# Plot seperatrix

with plt.style.context("seaborn-paper"):
    x = np.logspace(
        seperatrix_array["lims_F"][0, 0],
        seperatrix_array["lims_F"][0, 1],
        seperatrix_array["tsteps"][0, 0],
    )
    y = np.logspace(
        seperatrix_array["lims_M"][0, 0],
        seperatrix_array["lims_M"][0, 1],
        seperatrix_array["tsteps"][0, 0],
    )

    plt.figure(figsize=(5, 5))
    plt.xlabel("F")
    plt.ylabel("M")
    plt.xscale("log")
    plt.yscale("log")
    plt.pcolor(x, y, 1 - seperatrix_array["S"], cmap="Greys", vmin=0.0, vmax=3.0)
    plt.show()


In [None]:
seperatrix_trajectories[0]


In [None]:
is_treatment_successfull(jnp.stack(seperatrix_trajectories, axis=0))


In [None]:
# Plot dosage curve grid

with plt.style.context("seaborn-paper"):
    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].set_yscale("log")
            ax[i, j].plot(np.clip(_controls[(9 - i) * 10 + j], a_min=1e-2, a_max=None))
    # ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    # plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    # plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()


In [None]:
# Plot MF trajectories

with plt.style.context("seaborn-paper"):
    x = np.logspace(
        seperatrix_array["lims_F"][0, 0],
        seperatrix_array["lims_F"][0, 1],
        seperatrix_array["tsteps"][0, 0],
    )
    y = np.logspace(
        seperatrix_array["lims_M"][0, 0],
        seperatrix_array["lims_M"][0, 1],
        seperatrix_array["tsteps"][0, 0],
    )

    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].set_xscale("log")
            ax[i, j].set_yscale("log")
            # ax[i, j].set_xlabel("F")
            # ax[i, j].set_ylabel("M")
            ax[i, j].pcolor(
                x, y, 1 - seperatrix_array["S"], cmap="Greys", vmin=0.0, vmax=3.0
            )
            ax[i, j].plot(
                optimal_trajectories[(9 - i) * 10 + j][..., 0],
                optimal_trajectories[(9 - i) * 10 + j][..., 1],
            )

    # ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    # plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    # plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()


In [None]:
# Plot MF trajectories

with plt.style.context("seaborn-paper"):
    x = np.logspace(
        seperatrix_array["lims_F"][0, 0],
        seperatrix_array["lims_F"][0, 1],
        seperatrix_array["tsteps"][0, 0],
    )
    y = np.logspace(
        seperatrix_array["lims_M"][0, 0],
        seperatrix_array["lims_M"][0, 1],
        seperatrix_array["tsteps"][0, 0],
    )

    fig, ax = plt.subplots(10, 10, figsize=(10, 10), sharex=True, sharey=True)
    for i in range(10):
        for j in range(10):
            ax[i, j].set_xscale("log")
            ax[i, j].set_yscale("log")
            # ax[i, j].set_xlabel("F")
            # ax[i, j].set_ylabel("M")
            ax[i, j].pcolor(
                x, y, 1 - seperatrix_array["S"], cmap="Greys", vmin=0.0, vmax=3.0
            )
            ax[i, j].plot(
                constant_trajectories[(9 - i) * 10 + j][..., 0],
                constant_trajectories[(9 - i) * 10 + j][..., 1],
            )

    # ax_outer = plt.axes([0.1,0.1,2.0,2.0], facecolor=(1,1,1,0))

    # plt.savefig("../figures/fibrosis_opt_traj.png", bbox_inches="tight")
    # plt.savefig("../figures/fibrosis_opt_traj.svg", bbox_inches="tight")
    plt.show()


# Adaptive sampling

## Optimal control with adaptively sampled total drug dosages

In [None]:
import jax

jax.config.update("jax_enable_x64", True)

import time
from functools import partial
from typing import Callable, Tuple

import diffrax
import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from adaptive import DataSaver, Learner2D, notebook_extension
from IPython.display import clear_output
from jaxtyping import Array, ArrayLike
from tqdm.auto import tqdm as tq
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments as environments
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

notebook_extension()

In [None]:
key = jax.random.PRNGKey(1234)

environment = examples.FibrosisEnvironment()
state = environment.init()
solver = solvers.DirectSolver()


#control = controls.InterpolationControl(2, 101, 0.0, 100.0, method="linear")
key, subkey = jax.random.split(key)
control = controls.ImplicitControl(controls.Siren(1, 2, 32, 2, subkey), 0.0, 100.0)

#def reward_fn(x: Array):
#    x = jnp.where(jnp.isposinf(x), 0.0, x)
#    x = jnp.clip(x[..., :2], a_min=1e2, a_max=None)
#    x = -jnp.mean(jnp.log(x))
#
#    return x

def reward_fn(x: Array):
    return x[-1, -1]


@jax.jit
def train_with_integral(
    integral: ArrayLike,
) -> Tuple[ArrayLike, Array]:
    _constraints = [constraints.NonNegativeConstantIntegralConstraint(integral)]

    reward, _control = trainers.solve_optimal_control_problem(
        environment, reward_fn, _constraints, solver, control, 512, key
    )

    #return reward, _control.control
    return reward, jnp.zeros(1)


@jax.jit
def evaluate_constant_dosage(
    integral: ArrayLike,
) -> ArrayLike:
    control = controls.LambdaControl(lambda t, args: args, data=integral)
    seq = environment.integrate(control, state, key)
    reward = reward_fn(seq)

    return reward

In [None]:
evaluate_constant_dosage(jnp.asarray([0.001, 0.001]))

In [None]:
reward, control = train_with_integral(jnp.asarray([1.0, 1.0])*101)

In [None]:
plt.figure()
plt.plot(constraints.NonNegativeConstantIntegralConstraint(jnp.asarray([1.0, 1.0])*101).transform(control))
#plt.plot(control)
plt.show()

In [None]:
learner = Learner2D(lambda x: 0, bounds=((-3, 3), (-3, 3)))

plot_timer = time.time()
plot_interval = 15
results = []
while True:
    x, _ = learner.ask(1)
    x = x[0]

    integral = jnp.asarray(x, dtype=jnp.float64)
    integral = 10**integral

    optimized_reward, control_points = train_with_integral(integral)
    constant_reward = evaluate_constant_dosage(integral)
    advantage = optimized_reward - constant_reward

    # learner.tell(x, float(reward))
    learner.tell(x, float(advantage))
    results.append(
        {
            "x": x,
            "integral": integral,
            "optimized_reward": optimized_reward,
            "constant_reward": constant_reward,
            "advantage": advantage,
            "control_points": control_points,
        }
    )

    if time.time() - plot_timer >= plot_interval:
        plot_timer = time.time()

        clear_output(wait=True)
        display(learner.plot(tri_alpha=0.25))

## Evaluation of adaptively sampled constant total drug dosages

In [None]:
environment = examples.FibrosisEnvironment()
state = environment.init()
solver = solvers.DirectSolver()
key = jax.random.PRNGKey(1234)


#def reward_fn(x: Array):
#    x = jnp.where(jnp.isposinf(x), 0.0, x)
#    x = jnp.clip(x[..., :2], a_min=1e2, a_max=None)
#    x = -jnp.mean(jnp.log(x))
#
#    return x

#def reward_fn(x: Array):
#    return x[-1, -1]

def reward_fn(x: Array):
    return jnp.where((x[-1, 0] < 1e2) | (x[-1, 1] < 1e2), 1.0, 0.0)


@jax.jit
def evaluate_constant_dosage(
    integral: ArrayLike,
) -> ArrayLike:
    control = controls.LambdaControl(lambda t, args: args, data=integral)
    seq = environment.integrate(control, state, key)
    reward = reward_fn(seq)

    return reward

In [None]:
evaluate_constant_dosage(jnp.asarray([1e-2, 1e-2]))

In [None]:
constant_dosage_learner = Learner2D(lambda x: 0, bounds=((-3, 3), (-3, 3)))

plot_timer = time.time()
plot_interval = 15
constant_dosage_results = []
for i in range(2500):
    x, _ = constant_dosage_learner.ask(1)
    x = x[0]

    integral = jnp.asarray(x, dtype=jnp.float64)
    integral = 10**integral
    reward = evaluate_constant_dosage(integral)

    constant_dosage_learner.tell(x, float(reward))
    constant_dosage_results.append(
        {
            "x": x,
            "integral": integral,
            "reward": reward,
        }
    )

    if time.time() - plot_timer >= plot_interval:
        plot_timer = time.time()

        clear_output(wait=True)
        display(constant_dosage_learner.plot(tri_alpha=0.1))

In [None]:
display(constant_dosage_learner.plot(tri_alpha=0.1))