In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

In [None]:
import pickle
from typing import Tuple

import diffrax
import equinox as eqx
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax
from jaxtyping import Array, PyTree, Scalar
from tqdm.auto import tqdm as tq
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.nn as nn
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers
from optimal_control.solvers.base import build_control

In [None]:
a4_inches = (8.3, 11.7)
plot_full_width = a4_inches[0]
plot_half_width = a4_inches[0] / 2
plot_third_width = a4_inches[0] / 3
plot_quarter_width = a4_inches[0] / 4

result_base_dir = "../thesis-results/fibrosis"
plot_style = "seaborn-paper"

plot_shrink_factor = 0.9

plt.style.use(plot_style)

# Worst, best & average control

In [None]:
key = jax.random.PRNGKey(1234)

environment: examples.StressEnvironment = examples.StressEnvironment(
    "/home/lena/master-thesis/repos/optimal-control/data/Repository_data_210919.mat",
    use_updated_params=True,
)

environment_state = environment.init()

solver = solvers.DirectSolver(optimizer=optax.adam(learning_rate=3e-4))

key, subkey = jax.random.split(key)
control = controls.ImplicitTemporalControl(
    implicit_fn=nn.Siren(
        in_features=1, out_features=1, hidden_features=64, hidden_layers=2, key=subkey
    ),
    t_start=0.0,
    t_end=6.0 * 60.0,
    to_curve=True,
    curve_interpolation="step",
    curve_steps=12,
)

## Find average ligang concentration, which causes a peak of 50% SG activation

In [None]:
peak_log_conc = jnp.zeros(1)
peak_optim = optax.adam(learning_rate=1e-1)
optim_state = peak_optim.init(peak_log_conc)


@eqx.filter_jit
def train_step(peak_log_conc, optim_state):
    @jax.grad
    def loss_fn(peak_log_conc):
        ys, sg = environment.integrate(
            controls.LambdaControl(lambda _, c: c, data=10**peak_log_conc),
            environment_state,
            None,
            t1=12.0 * 60.0,
        )

        peak_sg_frac = jnp.max(sg)

        loss = (peak_sg_frac - 0.5) ** 2
        return loss

    grad = loss_fn(peak_log_conc)
    updates, optim_state = peak_optim.update(grad, optim_state, params=peak_log_conc)
    peak_log_conc = optax.apply_updates(peak_log_conc, updates)

    return peak_log_conc, optim_state


for i in trange(1024):
    peak_log_conc, optim_state = train_step(peak_log_conc, optim_state)

In [None]:
ys, sg = environment.integrate(
    controls.LambdaControl(lambda _, c: c, data=10**peak_log_conc),
    environment_state,
    None,
)

plt.figure()
plt.yscale("log")
plt.plot(ys[..., 1])
plt.show()

plt.figure()
plt.plot(sg)
plt.show()

## Average ligand scan

In [None]:
@jax.jit
def constant_sg_frac(ligand_log_conc):
    ys, sg = environment.integrate(
        controls.LambdaControl(lambda _, c: c, data=10**ligand_log_conc),
        environment_state,
        None,
        t1=12.0 * 60.0,
    )

    peak_sg_frac = jnp.max(sg)
    return peak_sg_frac


@jax.jit
def burst_sg_frac(ligand_log_conc):
    ys, sg = environment.integrate(
        controls.LambdaControl(
            lambda state, data: jnp.where(
                state["t"] < 30.0, data * 12, jnp.zeros_like(data)
            ),
            data=10**ligand_log_conc,
        ),
        environment_state,
        None,
        t1=12.0 * 60.0,
    )

    peak_sg_frac = jnp.max(sg)
    return peak_sg_frac

In [None]:
ligand_log_concs = jnp.linspace(0, 5, 1024)

constant_sg_fracs = jnp.asarray(
    [
        constant_sg_frac(ligand_log_conc.reshape(1))
        for ligand_log_conc in ligand_log_concs
    ]
)

burst_sg_fracs = jnp.asarray(
    [burst_sg_frac(ligand_log_conc.reshape(1)) for ligand_log_conc in ligand_log_concs]
)

plt.figure()
plt.xlabel("CD95L")
plt.ylabel("Peak SG frac.")
plt.xscale("log")
plt.plot(10**ligand_log_concs, constant_sg_fracs, label="Constant CD95L")
plt.plot(10**ligand_log_concs, burst_sg_fracs, label="Bursted CD95L")
plt.legend()
plt.show()

## Optimize

In [None]:
@eqx.filter_jit
def optimize(
    target: Array,
    maximum: Array,
    peak_weight: Array,
    mean_weight: Array,
    num_steps: int,
    *,
    stepsize_controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
        atol=1e-5,
        rtol=1e-5,
        pcoeff=1.0,
        icoeff=1.0,
        dtmax=30,
    )
) -> Tuple[Array, controls.AbstractControl]:
    def reward_fn(args: PyTree) -> Scalar:
        ys, sg = args
        return mean_weight * jnp.mean(sg) + peak_weight * jnp.max(sg)

    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    optimized_reward, optimized_control = trainers.solve_optimal_control_problem(
        num_train_steps=num_steps,
        environment=environment,
        reward_fn=reward_fn,
        constraint_chain=constraint_chain,
        solver=solver,
        control=control,
        key=key,
        pbar_interval=8,
        integrate_kwargs=dict(t1=12.0 * 60.0, stepsize_controller=stepsize_controller),
    )

    return optimized_reward, optimized_control


def evaluate(control: controls.AbstractControl, target: Array, maximum: Array):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    mod_control = eqx.tree_at(
        lambda pytree: pytree.curve.nodes,
        constrained_control,
        jnp.concatenate(
            (constrained_control.curve.nodes[1:2], constrained_control.curve.nodes[1:]),
            axis=0,
        )
        + (constrained_control.curve.nodes[0] - constrained_control.curve.nodes[1])
        / 12,
    )

    ys, sg = environment.integrate(
        constrained_control, environment_state, None, t1=12.0 * 60.0
    )
    ys_mod, sg_mod = environment.integrate(
        mod_control, environment_state, None, t1=12.0 * 60.0
    )

    ts = jnp.linspace(0.0, 12 * 60, 1024)
    cs = jax.vmap(constrained_control)(ts)
    cs_mod = jax.vmap(mod_control)(ts)

    print(jnp.mean(cs), jnp.mean(cs_mod))

    plt.figure()
    plt.yscale("log")
    plt.plot(ys[..., 1])  # p_EIF2a
    plt.plot(ys_mod[..., 1])
    plt.show()

    plt.figure()
    plt.plot(ts[:128], sg[:128])
    plt.plot(ts[:128], sg_mod[:128])

    # Means
    plt.axhline(jnp.mean(sg), c="tab:blue", linestyle="--")
    plt.axhline(jnp.mean(sg_mod), c="tab:orange", linestyle="--")

    # Peaks
    plt.axhline(jnp.mean(sg), c="tab:blue", linestyle="-.")
    plt.axhline(jnp.mean(sg_mod), c="tab:orange", linestyle="-.")

    plt.show()

    plt.figure()
    plt.plot(ts, cs)
    plt.plot(ts, cs_mod)
    plt.show()

In [None]:
maximum = jnp.asarray([1200.0])
target_conc = 10**peak_log_conc * 12
print(target_conc)

In [None]:
min_peak_reward, min_peak_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(-1.0),
    mean_weight=jnp.float_(0.0),
    num_steps=1024 * 64,
)

max_peak_reward, max_peak_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(1.0),
    mean_weight=jnp.float_(0.0),
    num_steps=1024 * 64,
)

In [None]:
min_mean_reward, min_mean_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(0.0),
    mean_weight=jnp.float_(-1.0),
    num_steps=1024 * 2,
)

In [None]:
max_mean_reward, max_mean_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(0.0),
    mean_weight=jnp.float_(1.0),
    num_steps=1024*4,
    stepsize_controller = diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        )
)

In [None]:
#evaluate(min_peak_control, target=target_conc, maximum=maximum)
#evaluate(max_peak_control, target=target_conc, maximum=maximum)
#evaluate(min_mean_control, target=target_conc, maximum=maximum)
evaluate(max_mean_control, target=target_conc, maximum=maximum)


In [None]:
# 1200 nM stock

In [None]:
def get_conc(control: controls.AbstractControl, target: Array, maximum: Array):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    return constrained_control.curve.nodes

get_conc(max_mean_control, target=target_conc, maximum=maximum).flatten()

In [None]:
10**peak_log_conc

In [None]:
# Total drug and medium needed

max_conc = 1.2  # uM
mean_conc = 0.3281  # uM
flush_volume = 400  # uL
num_flushings = 12
initial_drug_flush = 100

print(f"Drug vol. {mean_conc / max_conc * flush_volume * num_flushings + initial_drug_flush} uL")
print(
    f"Medium vol. {(2 + (1 - mean_conc / max_conc) * num_flushings) * flush_volume} uL"
)