In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

In [None]:
import pickle
from typing import Tuple

import diffrax
import equinox as eqx
import jaxopt
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax
from jaxtyping import Array, PyTree, Scalar
from tqdm.auto import tqdm as tq
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.nn as nn
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers
from optimal_control.solvers.base import build_control

In [None]:
a4_inches = (8.3, 11.7)
plot_full_width = a4_inches[0]
plot_half_width = a4_inches[0] / 2
plot_third_width = a4_inches[0] / 3
plot_quarter_width = a4_inches[0] / 4

result_base_dir = "../thesis-results/stress"
plot_style = "seaborn-v0_8-paper"
plot_style_talk = "seaborn-v0_8-talk"

plot_styles = ["seaborn-v0_8-paper", "seaborn-v0_8-talk"]
plot_style_names = ["", "_talk"]

plot_shrink_factor = 0.9

plt.style.use(plot_style)

In [None]:
def show(save_postfix: str, save_prefix: str = None):
    if save_prefix is not None:
        plt.savefig(result_base_dir + save_prefix + save_postfix + ".png", bbox_inches="tight")
        plt.savefig(result_base_dir + save_prefix + save_postfix + ".svg", bbox_inches="tight")

    plt.show()

def styles(plot_fn):
    for style, name in zip(plot_styles, plot_style_names):
        with plt.style.context(style):
            plot_fn(name)

# Worst, best & average control

In [None]:
key = jax.random.PRNGKey(1234)

environment: examples.StressEnvironment = examples.StressEnvironment(
    "/home/lena/master-thesis/repos/optimal-control/data/Repository_data_210919.mat",
    use_updated_params=True,
)

environment_state = environment.init()

solver = solvers.DirectSolver(optimizer=optax.adam(learning_rate=3e-4))

key, subkey = jax.random.split(key)
control = controls.ImplicitTemporalControl(
    implicit_fn=nn.Siren(
        in_features=1, out_features=1, hidden_features=64, hidden_layers=2, key=subkey
    ),
    t_start=0.0,
    t_end=6.0 * 60.0,
    to_curve=True,
    curve_interpolation="step",
    curve_steps=12,
)

In [None]:
environment_state.s0

## Find average ligang concentration, which causes a peak of 50% SG activation

In [None]:
peak_log_conc = jnp.zeros(1)
peak_optim = optax.adam(learning_rate=1e-1)
optim_state = peak_optim.init(peak_log_conc)


@eqx.filter_jit
def train_step(peak_log_conc, optim_state):
    @jax.grad
    def loss_fn(peak_log_conc):
        ys, sg = environment.integrate(
            controls.LambdaControl(lambda _, c: c, data=10**peak_log_conc),
            environment_state,
            key,
            t1=12.0 * 60.0,
        )

        peak_sg_frac = jnp.max(sg)

        loss = (peak_sg_frac - 0.5) ** 2
        return loss

    grad = loss_fn(peak_log_conc)
    updates, optim_state = peak_optim.update(grad, optim_state, params=peak_log_conc)
    peak_log_conc = optax.apply_updates(peak_log_conc, updates)

    return peak_log_conc, optim_state


for i in trange(1024):
    peak_log_conc, optim_state = train_step(peak_log_conc, optim_state)

In [None]:
ys, sg = environment.integrate(
    controls.LambdaControl(lambda _, c: c, data=10**peak_log_conc),
    environment_state,
    key,
)

plt.figure()
plt.yscale("log")
plt.plot(ys[..., 1])
plt.show()

plt.figure()
plt.plot(sg)
plt.show()

## Average ligand scan

In [None]:
@jax.jit
def constant_sg_frac(ligand_log_conc):
    ys, sg = environment.integrate(
        controls.LambdaControl(lambda _, c: c, data=10**ligand_log_conc),
        environment_state,
        key,
        t1=6.0 * 60.0,
    )

    peak_sg_frac = jnp.max(sg)
    return peak_sg_frac


@jax.jit
def burst_sg_frac(ligand_log_conc):
    ys, sg = environment.integrate(
        controls.LambdaControl(
            lambda state, data: jnp.where(
                state["t"] < 30.0, data * 12, jnp.zeros_like(data)
            ),
            data=10**ligand_log_conc,
        ),
        environment_state,
        key,
        t1=6.0 * 60.0,
    )

    peak_sg_frac = jnp.max(sg)
    return peak_sg_frac

In [None]:
ligand_log_concs = jnp.linspace(0, 5, 1024)

constant_sg_fracs = jnp.asarray(
    [
        constant_sg_frac(ligand_log_conc.reshape(1))
        for ligand_log_conc in ligand_log_concs
    ]
)

burst_sg_fracs = jnp.asarray(
    [burst_sg_frac(ligand_log_conc.reshape(1)) for ligand_log_conc in ligand_log_concs]
)

In [None]:
10 ** ligand_log_concs[jnp.argmin(jnp.abs(constant_sg_fracs - 0.5))] * 6, 10 ** ligand_log_concs[jnp.argmin(jnp.abs(burst_sg_fracs - 0.5))] * 6

In [None]:
def plot_sg_fracs_ref(save_prefix: str = None):
    def plot(name):
        plt.figure(figsize=(plot_half_width, plot_third_width))
        plt.xlabel("Thapsigargin [nmol*h/l]")
        plt.ylabel("Peak Frac. Stressed Cells")
        plt.xscale("log")
        plt.plot(6 * 10**ligand_log_concs, constant_sg_fracs, label="Constant")
        plt.plot(6 * 10**ligand_log_concs, burst_sg_fracs, label="Bursted")
        plt.legend()
        plt.tight_layout()
        show("/sg_response_const_burst_tha"+name, save_prefix)

    styles(plot)

plot_sg_fracs_ref("/general")

## Optimize

In [None]:
@eqx.filter_jit
def optimize(
    target: Array,
    maximum: Array,
    peak_weight: Array,
    mean_weight: Array,
    num_steps: int,
    *,
    stepsize_controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
        atol=1e-5,
        rtol=1e-5,
        pcoeff=1.0,
        icoeff=1.0,
        dtmax=30,
    )
) -> Tuple[Array, controls.AbstractControl]:
    def reward_fn(args: PyTree) -> Scalar:
        ys, sg = args
        return mean_weight * jnp.mean(sg) + peak_weight * jnp.max(sg)

    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    optimized_reward, optimized_control = trainers.solve_optimal_control_problem(
        num_train_steps=num_steps,
        environment=environment,
        reward_fn=reward_fn,
        constraint_chain=constraint_chain,
        solver=solver,
        control=control,
        key=key,
        pbar_interval=8,
        integrate_kwargs=dict(
            t1=12.0 * 60.0,
            stepsize_controller=stepsize_controller,
            tha_lognormal_std=1.0,
            k_lognormal_std=0.01,
            s0_lognormal_std=0.01,
        ),
    )

    return optimized_reward, optimized_control


def evaluate(control: controls.AbstractControl, target: Array, maximum: Array):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    mod_control = eqx.tree_at(
        lambda pytree: pytree.curve.nodes,
        constrained_control,
        jnp.concatenate(
            (constrained_control.curve.nodes[1:2], constrained_control.curve.nodes[1:]),
            axis=0,
        )
        + (constrained_control.curve.nodes[0] - constrained_control.curve.nodes[1])
        / 12,
    )

    ys, sg = environment.integrate(
        constrained_control,
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
    )
    ys_mod, sg_mod = environment.integrate(
        mod_control,
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
    )

    ts = jnp.linspace(0.0, 12 * 60, 12 * 60 + 1)
    cs = jax.vmap(constrained_control)(ts)
    cs_mod = jax.vmap(mod_control)(ts)

    print(jnp.mean(cs), jnp.mean(cs_mod))

    plt.figure()
    plt.yscale("log")
    plt.plot(ys[..., 1])  # p_EIF2a
    plt.plot(ys_mod[..., 1])
    plt.show()

    plt.figure()
    plt.plot(ts[:128], sg[:128])
    plt.plot(ts[:128], sg_mod[:128])

    # Means
    plt.axhline(jnp.mean(sg), c="tab:blue", linestyle="--")
    plt.axhline(jnp.mean(sg_mod), c="tab:orange", linestyle="--")

    # Peaks
    plt.axhline(jnp.mean(sg), c="tab:blue", linestyle="-.")
    plt.axhline(jnp.mean(sg_mod), c="tab:orange", linestyle="-.")

    plt.show()

    plt.figure()
    plt.plot(ts, sg)
    plt.plot(ts, sg_mod)

    # Means
    plt.axhline(jnp.mean(sg), c="tab:blue", linestyle="--")
    plt.axhline(jnp.mean(sg_mod), c="tab:orange", linestyle="--")

    # Peaks
    plt.axhline(jnp.mean(sg), c="tab:blue", linestyle="-.")
    plt.axhline(jnp.mean(sg_mod), c="tab:orange", linestyle="-.")

    plt.show()

    plt.figure()
    plt.plot(ts, cs)
    plt.plot(ts, cs_mod)
    plt.show()

In [None]:
maximum = jnp.asarray([1200.0])
target_conc = 10**peak_log_conc * 12
print(target_conc)

In [None]:
target_conc / 12

In [None]:
10**peak_log_conc * 720

In [None]:
min_peak_reward, min_peak_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(-1.0),
    mean_weight=jnp.float_(0.0),
    num_steps=1024 * 8,
    stepsize_controller=diffrax.PIDController(
        atol=1e-8,
        rtol=1e-8,
        pcoeff=1.0,
        icoeff=1.0,
        dtmax=30,
    ),
)

In [None]:
max_peak_reward, max_peak_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(1.0),
    mean_weight=jnp.float_(0.0),
    num_steps=1024 * 7,
    stepsize_controller=diffrax.PIDController(
        atol=1e-8,
        rtol=1e-8,
        pcoeff=1.0,
        icoeff=1.0,
        dtmax=30,
    ),
)

In [None]:
min_mean_reward, min_mean_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(0.0),
    mean_weight=jnp.float_(-1.0),
    #num_steps=1024 * 3,
    num_steps=128*3,
    stepsize_controller=diffrax.PIDController(
        atol=1e-8,
        rtol=1e-8,
        pcoeff=1.0,
        icoeff=1.0,
        dtmax=30,
    ),
)

In [None]:
max_mean_reward, max_mean_control = optimize(
    target=target_conc,
    maximum=maximum,
    peak_weight=jnp.float_(0.0),
    mean_weight=jnp.float_(1.0),
    num_steps=1024 * 8,
    #num_steps=512,
    stepsize_controller=diffrax.PIDController(
        atol=1e-8,
        rtol=1e-8,
        pcoeff=1.0,
        icoeff=1.0,
        dtmax=30,
    ),
)

In [None]:
evaluate(min_peak_control, target=target_conc, maximum=maximum)
#evaluate(max_peak_control, target=target_conc, maximum=maximum)
#evaluate(min_mean_control, target=target_conc, maximum=maximum)
evaluate(max_mean_control, target=target_conc, maximum=maximum)


In [None]:
# 1200 nM stock

In [None]:
def get_conc(control: controls.AbstractControl, target: Array, maximum: Array):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    return constrained_control.curve.nodes

print(get_conc(min_peak_control, target=target_conc, maximum=maximum).flatten())
print(get_conc(max_mean_control, target=target_conc, maximum=maximum).flatten())

In [None]:
10**peak_log_conc

In [None]:
# Total drug and medium needed

max_conc = 1.2  # uM
mean_conc = 0.3281  # uM
flush_volume = 400  # uL
num_flushings = 12
initial_drug_flush = 100

print(f"Drug vol. {mean_conc / max_conc * flush_volume * num_flushings + initial_drug_flush} uL")
print(
    f"Medium vol. {(2 + (1 - mean_conc / max_conc) * num_flushings) * flush_volume} uL"
)

# Pretty Plots

In [None]:
#controller_folder = "controllers_final" # With randomized Tha, slight k, s0 randomization
controller_folder = "controllers" # Without randomization

## Save

In [None]:
eqx.tree_serialise_leaves(
    result_base_dir + f"/{controller_folder}/min_peak_control.eqx", min_peak_control
)
eqx.tree_serialise_leaves(
    result_base_dir + f"/{controller_folder}/max_peak_control.eqx", max_peak_control
)
eqx.tree_serialise_leaves(
    result_base_dir + f"/{controller_folder}/min_mean_control.eqx", min_mean_control
)
eqx.tree_serialise_leaves(
    result_base_dir + f"/{controller_folder}/max_mean_control.eqx", max_mean_control
)

## Load

In [None]:
min_peak_control = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder}/min_peak_control.eqx", control
)
max_peak_control = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder}/max_peak_control.eqx", control
)
min_mean_control = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder}/min_mean_control.eqx", control
)
max_mean_control = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder}/max_mean_control.eqx", control
)

## Evaluate

In [None]:
@eqx.filter_jit
def evaluate(
    control: controls.AbstractControl,
    target: Array,
    maximum: Array,
    tha_mult: Scalar = jnp.float_(1.0),
):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    constrained_control = eqx.tree_at(
        lambda pytree: pytree.curve.nodes,
        constrained_control,
        constrained_control.curve.nodes * tha_mult,
    )

    mod_control = eqx.tree_at(
        lambda pytree: pytree.curve.nodes,
        constrained_control,
        jnp.clip(
            jnp.concatenate(
                (
                    constrained_control.curve.nodes[1:2],
                    constrained_control.curve.nodes[1:],
                ),
                axis=0,
            )
            + (constrained_control.curve.nodes[0] - constrained_control.curve.nodes[1])
            / 12,
            a_min=0.0,
        ),
    )

    ys, sg = environment.integrate(
        constrained_control,
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
    )
    ys_mod, sg_mod = environment.integrate(
        mod_control,
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
    )

    ts = jnp.linspace(0.0, 12 * 60, 12 * 60 + 1)
    cs = jax.vmap(constrained_control)(ts)
    cs_mod = jax.vmap(mod_control)(ts)

    return {
        "ys": ys,
        "sg": sg,
        "cs": cs,
        "ys_mod": ys_mod,
        "sg_mod": sg_mod,
        "cs_mod": cs_mod,
        "ts": ts,
    }

In [None]:
min_peak_eval = evaluate(min_peak_control, target=target_conc, maximum=maximum)
max_peak_eval = evaluate(max_peak_control, target=target_conc, maximum=maximum)
min_mean_eval = evaluate(min_mean_control, target=target_conc, maximum=maximum)
max_mean_eval = evaluate(max_mean_control, target=target_conc, maximum=maximum)

## Plot

In [None]:
def get_mod(kv: dict) -> dict:
    return {k.replace("_mod", ""): v for k, v in kv.items() if "_mod" in k or k == "ts"}

In [None]:
# p-eIF2a & GADD34 -> d/dt p-eIF2a Matrix


def ddt(tot_eIF2a, frac_p_eIF2a, GADD34, k, u):
    y = [tot_eIF2a * (1 - frac_p_eIF2a), tot_eIF2a * frac_p_eIF2a, jnp.nan, GADD34]

    dy = (
        k[0] * y[0]
        + (k[1] * u[0] / (k[2] + u[0]) * y[0] / (k[3] + y[0]))
        - k[10] * y[1] * y[3]
        - k[11] * y[1]
    )  # p_eIF2a

    return dy


# u @ dy/dt = 0
"""
tot_eIF2a = environment_state.s0[0] + environment_state.s0[1]
frac_p_eIF2a = jnp.linspace(0.0, 0.1, 64)
GADD34 = jnp.geomspace(1e0, 1e1, 64)

frac_p_eIF2a, GADD34 = jnp.meshgrid(frac_p_eIF2a, GADD34)
u = jnp.zeros_like(GADD34)

zero_fn = lambda x: ddt(
    tot_eIF2a, frac_p_eIF2a, GADD34, environment_state.k, jnp.exp(x)
)

u = jaxopt.Broyden(zero_fn, maxiter=10000).run(u)[0]
"""


def at_u(u):
    tot_eIF2a = environment_state.s0[0] + environment_state.s0[1]
    frac_p_eIF2a = jnp.linspace(0.05, 0.95, 256)
    GADD34 = jnp.geomspace(1e0, 1e2, 256)

    frac_p_eIF2a, GADD34 = jnp.meshgrid(frac_p_eIF2a, GADD34)

    dydt = ddt(tot_eIF2a, frac_p_eIF2a, GADD34, environment_state.k, [u])

    log_dydt = dydt / (tot_eIF2a * frac_p_eIF2a)
    vabsmax = jnp.max(jnp.abs(log_dydt))

    plt.figure()
    plt.xlabel("Frac. p-eIF2a")
    plt.ylabel("GADD34")
    plt.imshow(
        log_dydt,
        vmin=-vabsmax,
        vmax=vabsmax,
        cmap="RdBu",
        origin="lower",
        extent=(0, 1, 1, 2),
        aspect="auto",
        interpolation="antialiased",
    )
    cbar = plt.colorbar()
    cbar.set_label("p-EIF2a log-derivative")
    """
    plt.quiver(
        frac_p_eIF2a[::4, ::4],
        GADD34[::4, ::4],
        log_dydt[::4, ::4],
        jnp.zeros_like(log_dydt[::4, ::4]),
        angles="xy",
    )
    """
    plt.show()


def steady_states(save_prefix: str = None):
    def eIF2a_ode(t, y, args):
        k, u = args
        u = [u]

        dy = [0] * 4
        dy[0] = (
            -k[0] * y[0]
            - (k[1] * u[0] / (k[2] + u[0]) * y[0] / (k[3] + y[0]))
            + k[10] * y[1] * y[3]
            + k[11] * y[1]
        )  # eIF2a
        dy[1] = (
            k[0] * y[0]
            + (k[1] * u[0] / (k[2] + u[0]) * y[0] / (k[3] + y[0]))
            - k[10] * y[1] * y[3]
            - k[11] * y[1]
        )  # p_eIF2a

        return jnp.stack(dy, axis=-1)

    def f_sg(p_eif2a: Array, h_sg: Scalar, k_sg: Scalar) -> Array:
        return p_eif2a**h_sg / (k_sg**h_sg + p_eif2a**h_sg)

    Tha = jnp.geomspace(1e-2, 1e6, 64)
    GADD34 = jnp.geomspace(1e-4, 1e4, 64)

    Tha, GADD34 = jnp.meshgrid(Tha, GADD34)

    eIF2a_0 = environment_state.s0[0]
    p_eIF2a_0 = environment_state.s0[1]

    # eIF2a_0 = jnp.full_like(Tha, environment_state.s0[0])
    # p_eIF2a_0 = jnp.full_like(Tha, environment_state.s0[1])

    # y0 = jnp.stack([eIF2a_0, p_eIF2a_0, 0, GADD34])
    # args = (environment_state.k, Tha)

    @jax.vmap
    def solve_with(Tha, GADD34):
        y0 = jnp.stack([eIF2a_0, p_eIF2a_0, 0, GADD34])
        args = (environment_state.k, Tha)

        solution = diffrax.diffeqsolve(
            terms=diffrax.ODETerm(eIF2a_ode),
            solver=diffrax.Kvaerno5(),
            t0=0.0,
            t1=jnp.inf,
            dt0=None,
            y0=y0,
            args=args,
            max_steps=None,
            stepsize_controller=diffrax.PIDController(
                rtol=1e-8, atol=1e-8, pcoeff=1.0, icoeff=1.0, dtmax=30.0
            ),
            discrete_terminating_event=diffrax.SteadyStateEvent(),
        )

        return solution.ys[-1]

    ys = solve_with(Tha.flatten(), GADD34.flatten())
    p_eIF2a_0_steady_state = ys[..., 1]
    sg_steady_state = f_sg(
        p_eIF2a_0_steady_state, environment_state.k[4], environment_state.k[5]
    )

    def plot(name):
        fig, ax = plt.subplots(
            1,
            1,
            figsize=(
                plot_half_width * plot_shrink_factor,
                plot_half_width * plot_shrink_factor,
            ),
        )

        im = ax.imshow(
            p_eIF2a_0_steady_state.reshape(64, 64),
            cmap="magma",
            norm="log",
            extent=(-2, 6, -4, 4),
            origin="lower",
            interpolation="bilinear",
        )
        cbar = plt.colorbar(im, ax=ax, fraction=0.04575, pad=0.04)
        cbar.set_label(r"p-eIF2$\mathrm{\alpha}$ [nmol/l]")

        ax.set_xlabel(r"Thapsigargin [nmol/l]")
        ax.set_ylabel(r"GADD34 [nmol/l]")
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: r"$10^{%d}$" % x))
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: r"$10^{%d}$" % x))

        show("/p_eIF2a_steady_state"+name, save_prefix)

    styles(plot)

    def plot(name):
        fig, ax = plt.subplots(
            1,
            1,
            figsize=(
                plot_half_width * plot_shrink_factor,
                plot_half_width * plot_shrink_factor,
            ),
        )

        im = ax.imshow(
            sg_steady_state.reshape(64, 64),
            cmap="magma",
            extent=(-2, 6, -4, 4),
            origin="lower",
            interpolation="bilinear",
        )
        cbar = plt.colorbar(im, ax=ax, fraction=0.04575, pad=0.04)
        cbar.set_label("Frac. Stressed Cells")

        ax.set_xlabel(r"Thapsigargin [nmol/l]")
        ax.set_ylabel(r"GADD34 [nmol/l]")
        ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: r"$10^{%d}$" % x))
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: r"$10^{%d}$" % x))

        show("/sg_steady_state"+name, save_prefix)

    styles(plot)


steady_states("/general")

In [None]:
def sg_response_curve(save_prefix: str = None, talk: bool = False):
    def f_sg(p_eif2a: Array, h_sg: Scalar, k_sg: Scalar) -> Array:
        return p_eif2a**h_sg / (k_sg**h_sg + p_eif2a**h_sg)

    def inv_hill(y, h, k):
        x = (-((y - 1) * k ** (-h)) / y) ** (-1 / h)
        return x

    print(environment_state.k[4], environment_state.k[5])
    print(inv_hill(0.1, environment_state.k[4], environment_state.k[5]))
    print(inv_hill(0.9, environment_state.k[4], environment_state.k[5]))

    # p_eIF2a = jnp.geomspace(2e2, 2e3, num=1024)
    p_eIF2a = jnp.linspace(250, 1250, num=1024)
    sg = f_sg(p_eIF2a, environment_state.k[4], environment_state.k[5])

    plt.figure(figsize=(plot_half_width, plot_half_width * 2 / 3))

    # plt.xscale("log")
    plt.xlabel(r"p-eIF2$\mathrm{\alpha}$ [nmol/l]")
    plt.ylabel("Frac. Stressed Cells")
    # plt.xlabel("p-eIF2a [umol]")

    plt.plot(p_eIF2a, sg)

    plt.axhline(0.1, c="black", linestyle="--")
    plt.axhline(0.9, c="black", linestyle="--")
    plt.axhline(0.5, c="black", linestyle="--")
    plt.axvline(
        inv_hill(0.1, environment_state.k[4], environment_state.k[5]),
        c="black",
        linestyle="--",
    )
    plt.axvline(
        inv_hill(0.9, environment_state.k[4], environment_state.k[5]),
        c="black",
        linestyle="--",
    )
    plt.axvline(
        inv_hill(0.5, environment_state.k[4], environment_state.k[5]),
        c="black",
        linestyle="--",
    )

    if talk:
        show("/sg_response_curve_talk", save_prefix)
    else:
        show("/sg_response_curve", save_prefix)


with plt.style.context(plot_style):
    sg_response_curve("/general")

with plt.style.context(plot_style_talk):
    sg_response_curve("/general", talk=True)

In [None]:
def plot_control_with_response(eval_data, comp_eval_data=None, save_prefix: str = None, name: str = None):
    def plot(style_name):
        fig, ax = plt.subplots(
            4, 1, sharex=True, figsize=(plot_half_width, plot_half_width * 1.5)
        )
        ax[-1].set_xlabel("Time [min]")

        ax[0].set_ylabel("Frac. Stressed")
        ax[0].plot(eval_data["ts"], eval_data["sg"], label="Trajectory")
        if comp_eval_data is not None:
            ax[0].plot(comp_eval_data["ts"], comp_eval_data["sg"], linestyle="--", label="Mod. Trajectory")

        ax[1].set_yscale("log")
        ax[1].set_ylabel(r"p-eIF2$\mathrm{\alpha}$ [nmol/l]")
        ax[1].plot(eval_data["ts"], eval_data["ys"][:, 1])
        if comp_eval_data is not None:
            ax[1].plot(comp_eval_data["ts"], comp_eval_data["ys"][:, 1], linestyle="--")

        ax[2].set_yscale("log")
        ax[2].set_ylabel(r"GADD34 [nmol/l]")
        ax[2].plot(eval_data["ts"], eval_data["ys"][:, 3])
        if comp_eval_data is not None:
            ax[2].plot(comp_eval_data["ts"], comp_eval_data["ys"][:, 3], linestyle="--")

        ax[3].set_ylabel(r"Tha [nmol/l]")
        ax[3].plot(eval_data["ts"], eval_data["cs"])
        if comp_eval_data is not None:
            ax[3].plot(comp_eval_data["ts"], comp_eval_data["cs"], linestyle="--")

        ax[0].legend()
        plt.tight_layout()
        show(f"/{name}_control_response"+style_name, save_prefix)

    styles(plot)

#response_folder = "/responses_final"
response_folder = "/responses"

plot_control_with_response(min_peak_eval, get_mod(min_peak_eval), response_folder, "min_peak")
plot_control_with_response(max_peak_eval, get_mod(max_peak_eval), response_folder, "max_peak")
plot_control_with_response(min_mean_eval, get_mod(min_mean_eval), response_folder, "min_mean")
plot_control_with_response(max_mean_eval, get_mod(max_mean_eval), response_folder, "max_mean")

In [None]:
def save_response(eval, mod_eval, folder, name):
    np.savez_compressed(result_base_dir + folder + "/eval_data_" + name, **eval)
    np.savez_compressed(result_base_dir + folder + "/mod_eval_data_" + name, **mod_eval)

save_response(min_peak_eval, get_mod(min_peak_eval), response_folder, "min_peak")
save_response(max_peak_eval, get_mod(max_peak_eval), response_folder, "max_peak")
save_response(min_mean_eval, get_mod(min_mean_eval), response_folder, "min_mean")
save_response(max_mean_eval, get_mod(max_mean_eval), response_folder, "max_mean")

In [None]:
plt.figure(figsize=(plot_third_width, plot_third_width * 2 / 3))
plt.xlabel("Time [min]")
plt.ylabel("Thapsigargin [umol]")
plt.plot(min_peak_eval["ts"], min_peak_eval["cs"])
plt.plot(min_mean_eval["ts"], min_mean_eval["cs"], linestyle="--")
plt.show()

plt.figure(figsize=(plot_third_width, plot_third_width * 2 / 3))
plt.xlabel("Time [min]")
plt.ylabel("Thapsigargin [umol]")
plt.plot(max_peak_eval["ts"], max_peak_eval["cs"])
plt.plot(max_mean_eval["ts"], max_mean_eval["cs"], linestyle="--")
plt.show()

In [None]:
# Trajectories over different multipliers

vmap_eval_fn = eqx.filter_vmap(evaluate, in_axes=(None, None, None, 0), out_axes=0)

In [None]:
tha_mults = jnp.linspace(0.0, 2.0, 128)
min_peak_vmap_eval_data = vmap_eval_fn(min_peak_control, target_conc, maximum, tha_mults)
max_peak_vmap_eval_data = vmap_eval_fn(max_peak_control, target_conc, maximum, tha_mults)
min_mean_vmap_eval_data = vmap_eval_fn(min_mean_control, target_conc, maximum, tha_mults)
max_mean_vmap_eval_data = vmap_eval_fn(max_mean_control, target_conc, maximum, tha_mults)

In [None]:
def plot_vmap_traj(eval_data, tha_mults):
    import matplotlib.cm

    cmap = matplotlib.cm.ScalarMappable(cmap="magma")
    mapped_colors = cmap.to_rgba(tha_mults)

    #fig, ax = plt.subplots(1, 1)
    #for i, mult in enumerate(tha_mults):
    #    ax.plot(eval_data["sg"][i], c=mapped_colors[i])
    #plt.show()

    fig, ax = plt.subplots(1, 1)
    ax.imshow(eval_data["sg"], cmap="magma")
    plt.show()

plot_vmap_traj(min_peak_vmap_eval_data, tha_mults)
plot_vmap_traj(max_peak_vmap_eval_data, tha_mults)
plot_vmap_traj(min_mean_vmap_eval_data, tha_mults)
plot_vmap_traj(max_mean_vmap_eval_data, tha_mults)

In [None]:
min_peak_experiment_data = jnp.asarray(
    [
        0.0,
        0.0,
        0.0,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.02173913,
        0.02173913,
        0.04347826,
        0.04347826,
        0.04347826,
        0.04347826,
        0.02173913,
        0.06521739,
        0.06521739,
        0.06521739,
        0.04347826,
        0.23913043,
        0.34782609,
        0.32608696,
        0.36956522,
        0.34782609,
        0.23913043,
        0.08695652,
        0.04347826,
        0.02173913,
        0.02173913,
        0.02173913,
        0.0,
        0.0,
        0.04347826,
        0.02173913,
        0.0,
        0.0,
        0.0,
        0.02173913,
        0.0,
        0.02173913,
        0.02173913,
        0.02173913,
        0.02173913,
        0.02173913,
    ]
)

max_mean_experiment_data = jnp.asarray(
    [
        0.0,
        0.0,
        0.08695652,
        0.47826087,
        0.7826087,
        0.82608696,
        0.86956522,
        0.91304348,
        0.82608696,
        0.86956522,
        0.7826087,
        0.69565217,
        0.69565217,
        0.52173913,
        0.39130435,
        0.26086957,
        0.17391304,
        0.17391304,
        0.13043478,
    ]
)

In [None]:
def match_traj(pred_sg, exp_sg):
    exp_sg = exp_sg[1:]  # First sample is pre-treatment
    pred_sg = pred_sg[:, : len(exp_sg) * 15 : 15]  # Each sample is 15 min

    l2_error = jnp.sum(jnp.square(pred_sg - exp_sg), axis=1)
    min_idx = jnp.argmin(l2_error)

    return min_idx, l2_error[min_idx]

min_peak_idx, min_peak_l2 = match_traj(min_peak_vmap_eval_data["sg"], min_peak_experiment_data)
max_mean_idx, max_mean_l2 = match_traj(max_mean_vmap_eval_data["sg"], max_mean_experiment_data)

In [None]:
tha_mults[min_peak_idx], tha_mults[max_mean_idx]

In [None]:
def plot_match(pred_sg, exp_sg):
    exp_sg = exp_sg[1:]  # First sample is pre-treatment
    exp_ts = jnp.linspace(0.0, len(exp_sg) * 15, len(exp_sg), endpoint=False)

    fig, ax = plt.subplots(1, 1)
    ax.scatter(exp_ts, exp_sg)
    ax.plot(pred_sg)
    plt.show()

plot_match(min_peak_vmap_eval_data["sg"][min_peak_idx], min_peak_experiment_data)
plot_match(max_mean_vmap_eval_data["sg"][max_mean_idx], max_mean_experiment_data)

## Distributions

In [None]:
@eqx.filter_jit
def eval_control_with_mult(
    control: controls.AbstractControl,
    target: Array,
    maximum: Array,
    tha_mult: Scalar = jnp.float_(1.0),
):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    constrained_control = eqx.tree_at(
        lambda pytree: pytree.curve.nodes,
        constrained_control,
        constrained_control.curve.nodes * tha_mult,
    )

    ys, sg = environment.integrate(
        constrained_control,
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
    )

    return sg

from functools import partial
eval_base_fn = partial(eval_control_with_mult, target=target_conc, maximum=maximum)

In [None]:
def dist_for_control(control: controls.AbstractControl):
    eval_fn = partial(eval_base_fn, control=control)
    eval_fn = jax.jit(eval_fn)
    #eval_fn = lambda c: eval_fn(tha_mult=c)
    #eval_fn = jax.vmap(eval_fn)

    #key = jax.random.PRNGKey(1234)
    #tha_mults = jnp.exp(jax.random.normal(key, (1024,)) * 1.0)

    tha_mults = jnp.geomspace(0.1, 10.0, 128)

    #sgs = eval_fn(tha_mults)
    sgs = []
    for mult in tha_mults:
        sgs.append(eval_fn(tha_mult=mult))

    return sgs

min_peak_sg = dist_for_control(min_peak_control)
max_mean_sg = dist_for_control(max_mean_control)

In [None]:
jnp.geomspace(0.1, 10.0, 128)[:90]

In [None]:
with plt.style.context(
    {
        "axes.prop_cycle": plt.cycler(
            "color", plt.cm.magma(np.linspace(0, 1, len(min_peak_sg)))
        )
    }
):
    plt.figure()
    plt.plot(jnp.asarray(min_peak_sg).T[:, :])
    plt.show()

with plt.style.context(
    {
        "axes.prop_cycle": plt.cycler(
            "color", plt.cm.magma(np.linspace(0, 1, len(max_mean_sg)))
        )
    }
):
    plt.figure()
    plt.plot(jnp.asarray(max_mean_sg).T[:, :])
    plt.show()

## Calibration Experiment

In [None]:
# For calibration experiment

@jax.jit
def eval_constant(tha: Scalar):
    ys, sg = environment.integrate(
        controls.LambdaControl(lambda _: jnp.full((1,), tha)),
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
    )

    #return ys, sg
    return sg

@jax.jit
def match_traj(pred_sg, exp_sg):
    exp_sg = exp_sg
    pred_sg = pred_sg[:, : len(exp_sg) * 15 : 15]  # Each sample is 15 min

    l2_error = jnp.sum(jnp.square(pred_sg - exp_sg), axis=1)
    min_idx = jnp.argmin(l2_error)

    return l2_error, min_idx

In [None]:
max_factor = 1e2
tha_min = 1e2
tha_max = 1e3

tha_amounts = jnp.geomspace(tha_min / max_factor, tha_max * max_factor, 1024)
tha_sg = jax.jit(jax.vmap(eval_constant))(tha_amounts)

In [None]:
tha_amounts

In [None]:
with plt.style.context(
    {
        "axes.prop_cycle": plt.cycler(
            "color", plt.cm.magma(np.linspace(0, 1, tha_sg.shape[0]))
        )
    }
):
    plt.figure()
    plt.plot(tha_sg.T)
    plt.show()

In [None]:
data_300nM = [0.04081632653061224,0.3469387755102041,0.6122448979591837,0.6530612244897959,0.7755102040816326,0.673469387755102,0.6530612244897959,0.4897959183673469,0.4489795918367347,0.3673469387755102,0.1836734693877551,0.08163265306122448,0.08163265306122448,0.10204081632653061,0.0,0.04081632653061224,0.04081632653061224,0.02040816326530612,0.02040816326530612,0.0,0.02040816326530612,0.02040816326530612,0.0,0.0,0.0,0.02040816326530612,0.02040816326530612,0.0,0.02040816326530612,0.0,0.0,0.0,0.0,0.02040816326530612,0.0,0.0,0.0,0.0,0.0]
data_1000nM = [0.6379310344827587,0.8620689655172413,0.9310344827586207,0.9310344827586207,0.9137931034482759,0.896551724137931,0.8793103448275862,0.8275862068965517,0.6896551724137931,0.5862068965517241,0.5172413793103449,0.3275862068965517,0.20689655172413793,0.1896551724137931,0.1206896551724138,0.05172413793103448,0.034482758620689655,0.0,0.034482758620689655,0.034482758620689655,0.034482758620689655,0.034482758620689655,0.05172413793103448,0.06896551724137931,0.034482758620689655,0.05172413793103448,0.034482758620689655,0.05172413793103448,0.06896551724137931,0.06896551724137931,0.05172413793103448,0.06896551724137931,0.06896551724137931,0.06896551724137931,0.06896551724137931,0.06896551724137931,0.08620689655172414,0.034482758620689655,0.05172413793103448]


In [None]:
# Radius 50
# atol 2.5

data_300nM = [0.0,0.20408163265306123,0.6122448979591837,0.5510204081632653,0.5714285714285714,0.5918367346938775,0.5714285714285714,0.40816326530612246,0.40816326530612246,0.22448979591836735,0.10204081632653061,0.08163265306122448,0.04081632653061224,0.02040816326530612,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
data_1000nM = [0.1896551724137931,0.6206896551724138,0.7413793103448276,0.6724137931034483,0.6896551724137931,0.6724137931034483,0.6551724137931034,0.6379310344827587,0.5,0.39655172413793105,0.27586206896551724,0.2413793103448276,0.13793103448275862,0.10344827586206896,0.05172413793103448,0.017241379310344827,0.0,0.0,0.0,0.017241379310344827,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.017241379310344827,0.017241379310344827,0.017241379310344827,0.017241379310344827,0.034482758620689655,0.017241379310344827,0.034482758620689655,0.034482758620689655,0.017241379310344827,0.017241379310344827,0.0]

In [None]:
# Radius 50
# atol 2.0

data_300nM = [0.0,0.4489795918367347,0.7551020408163265,0.7959183673469388,0.7755102040816326,0.7755102040816326,0.7142857142857143,0.5918367346938775,0.46938775510204084,0.32653061224489793,0.14285714285714285,0.12244897959183673,0.08163265306122448,0.04081632653061224,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.04081632653061224,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]
data_1000nM = [0.41379310344827586,0.7931034482758621,0.8793103448275862,0.8448275862068966,0.8793103448275862,0.7586206896551724,0.7758620689655172,0.7068965517241379,0.603448275862069,0.43103448275862066,0.39655172413793105,0.2413793103448276,0.15517241379310345,0.10344827586206896,0.06896551724137931,0.05172413793103448,0.017241379310344827,0.0,0.017241379310344827,0.017241379310344827,0.017241379310344827,0.0,0.017241379310344827,0.034482758620689655,0.0,0.0,0.017241379310344827,0.034482758620689655,0.034482758620689655,0.017241379310344827,0.034482758620689655,0.017241379310344827,0.034482758620689655,0.034482758620689655,0.034482758620689655,0.034482758620689655,0.017241379310344827,0.017241379310344827,0.017241379310344827]

In [None]:
data_300nM = jnp.asarray(data_300nM)
data_1000nM = jnp.asarray(data_1000nM)

In [None]:
l2_error_300nM, min_idx_300nM = match_traj(tha_sg, data_300nM)
l2_error_1000nM, min_idx_1000nM = match_traj(tha_sg, data_1000nM)

In [None]:
print(tha_amounts[min_idx_300nM], tha_amounts[min_idx_1000nM])
print(300/tha_amounts[min_idx_300nM], 1000/tha_amounts[min_idx_1000nM])

In [None]:
plt.figure()
plt.xscale("log")
plt.plot(tha_amounts, l2_error_300nM)
plt.plot(tha_amounts, l2_error_1000nM)
plt.show()

In [None]:
time = jnp.arange(tha_sg.shape[1])

plt.figure()
plt.plot(time, tha_sg[min_idx_300nM])
plt.plot(time[: data_300nM.shape[0] * 15 : 15], data_300nM)
plt.show()

plt.figure()
plt.plot(time, tha_sg[min_idx_1000nM])
plt.plot(time[: data_300nM.shape[0] * 15 : 15], data_1000nM)
plt.show()

In [None]:
def plot_cal_fit(save_prefix: str = None):
    def plot(name):
        plt.figure(figsize=(plot_half_width, plot_third_width))
        
        plt.xscale("log")
        plt.yscale("log")
        plt.xlabel("Tha Activity Factor")
        plt.ylabel("L2-Error")
        
        plt.plot(1 / (300 / tha_amounts), l2_error_300nM, label=r"300 [nmol/l]")
        plt.plot(1 / (1000 / tha_amounts), l2_error_1000nM, label=r"1000 [nmol/l]")
        plt.axvline(1 / (300 / tha_amounts[min_idx_300nM]), c="tab:blue")
        plt.axvline(1 / (1000 / tha_amounts[min_idx_1000nM]), c="tab:orange", linestyle="dashed")
        
        plt.legend()
        plt.tight_layout()
        show("/cal_l2"+name, save_prefix=save_prefix)

    styles(plot)

    def plot(name):
        time = jnp.arange(tha_sg.shape[1])

        fig, ax = plt.subplots(2, 1, sharex=True, figsize=(plot_half_width, plot_half_width))
        ax[0].plot(time[: data_300nM.shape[0] * 15 :], tha_sg[min_idx_300nM][: data_300nM.shape[0] * 15 :], label="Best Fit")
        ax[0].plot(time[: data_300nM.shape[0] * 15 : 15], data_300nM, label="Measured")
        ax[0].legend()
        ax[1].plot(time[: data_1000nM.shape[0] * 15 :], tha_sg[min_idx_1000nM][: data_1000nM.shape[0] * 15 :])
        ax[1].plot(time[: data_1000nM.shape[0] * 15 : 15], data_1000nM)
        
        ax[1].set_xlabel("Time [min]")
        ax[1].set_ylabel("Frac. Stressed")
        #ax[0].set_ylabel("Frac. Stressed")

        plt.tight_layout()
        show("/cal_fits"+name, save_prefix=save_prefix)

    styles(plot)

plot_cal_fit("/calibration")

## Comparison Plots between randomized and non-randomized control curves

In [None]:
min_peak_det_data = np.load(result_base_dir + "/responses/eval_data_min_peak.npz")
max_peak_det_data = np.load(result_base_dir + "/responses/eval_data_max_peak.npz")
min_mean_det_data = np.load(result_base_dir + "/responses/eval_data_min_mean.npz")
max_mean_det_data = np.load(result_base_dir + "/responses/eval_data_max_mean.npz")

min_peak_stc_data = np.load(result_base_dir + "/responses_final/eval_data_min_peak.npz")
max_peak_stc_data = np.load(result_base_dir + "/responses_final/eval_data_max_peak.npz")
min_mean_stc_data = np.load(result_base_dir + "/responses_final/eval_data_min_mean.npz")
max_mean_stc_data = np.load(result_base_dir + "/responses_final/eval_data_max_mean.npz")

In [None]:
def plot_control_with_response_comparison(data, comp_data, mod_data, save_prefix: str = None, name: str = None):
    def plot(style_name):
        fig, ax = plt.subplots(
            4, 1, sharex=True, figsize=(plot_half_width, plot_half_width * 1.5)
        )
        ax[-1].set_xlabel("Time [min]")

        ax[0].set_ylabel("Frac. Stressed")
        ax[0].plot(data["ts"], data["sg"], label="Stochastic")
        ax[0].plot(mod_data["ts"], mod_data["sg"], linestyle="--", label="Mod. Stoch.")
        ax[0].plot(comp_data["ts"], comp_data["sg"], label="Deterministic")
        ax[0].legend()

        ax[1].set_yscale("log")
        ax[1].set_ylabel(r"p-eIF2$\mathrm{\alpha}$ [nmol/l]")
        ax[1].plot(data["ts"], data["ys"][:, 1])
        ax[1].plot(mod_data["ts"], mod_data["ys"][:, 1], linestyle="--")
        ax[1].plot(comp_data["ts"], comp_data["ys"][:, 1])

        ax[2].set_yscale("log")
        ax[2].set_ylabel(r"GADD34 [nmol/l]")
        ax[2].plot(data["ts"], data["ys"][:, 3])
        ax[2].plot(mod_data["ts"], mod_data["ys"][:, 3], linestyle="--")
        ax[2].plot(comp_data["ts"], comp_data["ys"][:, 3])

        ax[3].set_ylabel(r"Tha [nmol/l]")
        ax[3].plot(data["ts"], data["cs"])
        ax[3].plot(mod_data["ts"], mod_data["cs"], linestyle="--")
        ax[3].plot(comp_data["ts"], comp_data["cs"])

        plt.tight_layout()
        show(f"/{name}_control_response"+style_name, save_prefix)

    styles(plot)

plot_control_with_response_comparison(min_peak_stc_data, min_peak_det_data, get_mod(min_peak_stc_data), "/comparisons", "min_peak")
plot_control_with_response_comparison(max_peak_stc_data, max_peak_det_data, get_mod(max_peak_stc_data), "/comparisons", "max_peak")
plot_control_with_response_comparison(min_mean_stc_data, min_mean_det_data, get_mod(min_mean_stc_data), "/comparisons", "min_mean")
plot_control_with_response_comparison(max_mean_stc_data, max_mean_det_data, get_mod(max_mean_stc_data), "/comparisons", "max_mean")

In [None]:
controller_folder_det = "controllers"
controller_folder_stc = "controllers_final"

min_peak_control_det = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_det}/min_peak_control.eqx", control
)
max_peak_control_det = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_det}/max_peak_control.eqx", control
)
min_mean_control_det = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_det}/min_mean_control.eqx", control
)
max_mean_control_det = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_det}/max_mean_control.eqx", control
)

min_peak_control_stc = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_stc}/min_peak_control.eqx", control
)
max_peak_control_stc = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_stc}/max_peak_control.eqx", control
)
min_mean_control_stc = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_stc}/min_mean_control.eqx", control
)
max_mean_control_stc = eqx.tree_deserialise_leaves(
    result_base_dir + f"/{controller_folder_stc}/max_mean_control.eqx", control
)

In [None]:
tha_range = jnp.geomspace(0.1, 10.0, num=256)
vmap_eval_fn = eqx.filter_vmap(evaluate, in_axes=(None, None, None, 0), out_axes=0)

min_peak_stc_vmap_data = vmap_eval_fn(min_peak_control_stc, target_conc, maximum, tha_range)
max_peak_stc_vmap_data = vmap_eval_fn(max_peak_control_stc, target_conc, maximum, tha_range)
min_mean_stc_vmap_data = vmap_eval_fn(min_mean_control_stc, target_conc, maximum, tha_range)
max_mean_stc_vmap_data = vmap_eval_fn(max_mean_control_stc, target_conc, maximum, tha_range)

min_peak_det_vmap_data = vmap_eval_fn(min_peak_control_det, target_conc, maximum, tha_range)
max_peak_det_vmap_data = vmap_eval_fn(max_peak_control_det, target_conc, maximum, tha_range)
min_mean_det_vmap_data = vmap_eval_fn(min_mean_control_det, target_conc, maximum, tha_range)
max_mean_det_vmap_data = vmap_eval_fn(max_mean_control_det, target_conc, maximum, tha_range)

In [None]:
def reward_fn(sg: Array, mean_weight: Scalar, peak_weight: Scalar) -> Scalar:
    return mean_weight * jnp.mean(sg) + peak_weight * jnp.max(sg)

reward_fn = jax.vmap(reward_fn, in_axes=(0, None, None))

In [None]:
jnp.argmax(reward_fn(min_peak_stc_vmap_data["sg"], 0.0, -1.0) - reward_fn(min_peak_det_vmap_data["sg"], 0.0, -1.0))

In [None]:
tha_range[76]

In [None]:
reward_fn(min_peak_stc_vmap_data["sg"], 0.0, -1.0)[76], reward_fn(min_peak_det_vmap_data["sg"], 0.0, -1.0)[76]

In [None]:
def plot(name):
    plt.figure(figsize=(plot_half_width, plot_third_width))
    plt.xlabel("Tha Activity")
    plt.ylabel("Reward")
    plt.xscale("log")
    plt.plot(tha_range, reward_fn(min_peak_stc_vmap_data["sg"], 0.0, -1.0), label="Stochastic")
    plt.plot(tha_range, reward_fn(min_peak_det_vmap_data["sg"], 0.0, -1.0), label="Deterministic")
    plt.legend()
    plt.tight_layout()
    show("/min_peak_stc_det_tha_range_reward"+name, "/comparisons")

styles(plot)


def plot(name):
    plt.figure(figsize=(plot_half_width, plot_third_width))
    plt.xlabel("Tha Activity")
    plt.ylabel("Reward")
    plt.xscale("log")
    plt.plot(tha_range, reward_fn(max_peak_stc_vmap_data["sg"], 0.0, 1.0), label="Stochastic")
    plt.plot(tha_range, reward_fn(max_peak_det_vmap_data["sg"], 0.0, 1.0), label="Deterministic")
    plt.legend()
    plt.tight_layout()
    show("/max_peak_stc_det_tha_range_reward"+name, "/comparisons")

styles(plot)


def plot(name):
    plt.figure(figsize=(plot_half_width, plot_third_width))
    plt.xlabel("Tha Activity")
    plt.ylabel("Reward")
    plt.xscale("log")
    plt.plot(tha_range, reward_fn(min_mean_stc_vmap_data["sg"], -1.0, 0.0), label="Stochastic")
    plt.plot(tha_range, reward_fn(min_mean_det_vmap_data["sg"], -1.0, 0.0), label="Deterministic")
    plt.legend()
    plt.tight_layout()
    show("/min_mean_stc_det_tha_range_reward"+name, "/comparisons")

styles(plot)


def plot(name):
    plt.figure(figsize=(plot_half_width, plot_third_width))
    plt.xlabel("Tha Activity")
    plt.ylabel("Reward")
    plt.xscale("log")
    plt.plot(tha_range, reward_fn(max_mean_stc_vmap_data["sg"], 1.0, 0.0), label="Stochastic")
    plt.plot(tha_range, reward_fn(max_mean_det_vmap_data["sg"], 1.0, 0.0), label="Deterministic")
    plt.legend()
    plt.tight_layout()
    show("/max_mean_stc_det_tha_range_reward"+name, "/comparisons")

styles(plot)

In [None]:
from jaxtyping import PRNGKeyArray


@eqx.filter_jit
def evaluate_stochastic(
    control: controls.AbstractControl,
    target: Array,
    maximum: Array,
    key: PRNGKeyArray,
    tha_mult: Scalar = jnp.float_(1.0),
):
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.LimitedRangeConstantIntegralConstraint(
                target=target, maximum=maximum
            )
        ]
    )

    constrained_control: controls.InterpolationCurveControl = build_control(
        control, constraint_chain
    )[0]

    constrained_control = eqx.tree_at(
        lambda pytree: pytree.curve.nodes,
        constrained_control,
        constrained_control.curve.nodes * tha_mult,
    )

    ys, sg = environment.integrate(
        constrained_control,
        environment_state,
        key,
        t1=12.0 * 60.0,
        stepsize_controller=diffrax.PIDController(
            atol=1e-8,
            rtol=1e-8,
            pcoeff=1.0,
            icoeff=1.0,
            dtmax=30,
        ),
        tha_lognormal_std=1.0,
        k_lognormal_std=0.01,
        s0_lognormal_std=0.01,
    )

    return sg

In [None]:
eval_key = jax.random.PRNGKey(1234)
num_samples = 1024 * 16

min_peak_stc_sgs = []
max_peak_stc_sgs = []
min_mean_stc_sgs = []
max_mean_stc_sgs = []

min_peak_det_sgs = []
max_peak_det_sgs = []
min_mean_det_sgs = []
max_mean_det_sgs = []

for i in trange(num_samples):
    eval_key, subkey = jax.random.split(eval_key, num=2)
    min_peak_stc_sgs.append(evaluate_stochastic(min_peak_control_stc, target_conc, maximum, subkey))
    
    eval_key, subkey = jax.random.split(eval_key, num=2)
    max_peak_stc_sgs.append(evaluate_stochastic(max_peak_control_stc, target_conc, maximum, subkey))

    eval_key, subkey = jax.random.split(eval_key, num=2)
    min_mean_stc_sgs.append(evaluate_stochastic(min_mean_control_stc, target_conc, maximum, subkey))

    eval_key, subkey = jax.random.split(eval_key, num=2)
    max_mean_stc_sgs.append(evaluate_stochastic(max_mean_control_stc, target_conc, maximum, subkey))
    
    eval_key, subkey = jax.random.split(eval_key, num=2)
    min_peak_det_sgs.append(evaluate_stochastic(min_peak_control_det, target_conc, maximum, subkey))
    
    eval_key, subkey = jax.random.split(eval_key, num=2)
    max_peak_det_sgs.append(evaluate_stochastic(max_peak_control_det, target_conc, maximum, subkey))

    eval_key, subkey = jax.random.split(eval_key, num=2)
    min_mean_det_sgs.append(evaluate_stochastic(min_mean_control_det, target_conc, maximum, subkey))

    eval_key, subkey = jax.random.split(eval_key, num=2)
    max_mean_det_sgs.append(evaluate_stochastic(max_mean_control_det, target_conc, maximum, subkey))


In [None]:
min_peak_stc_sgs = jnp.stack(min_peak_stc_sgs, axis=0)
max_peak_stc_sgs = jnp.stack(max_peak_stc_sgs, axis=0)
min_mean_stc_sgs = jnp.stack(min_mean_stc_sgs, axis=0)
max_mean_stc_sgs = jnp.stack(max_mean_stc_sgs, axis=0)

min_peak_det_sgs = jnp.stack(min_peak_det_sgs, axis=0)
max_peak_det_sgs = jnp.stack(max_peak_det_sgs, axis=0)
min_mean_det_sgs = jnp.stack(min_mean_det_sgs, axis=0)
max_mean_det_sgs = jnp.stack(max_mean_det_sgs, axis=0)

In [None]:
np.savez_compressed(
    result_base_dir + "/comparisons/samples.npz",

    min_peak_stc_sgs=min_peak_stc_sgs,
    max_peak_stc_sgs=max_peak_stc_sgs,
    min_mean_stc_sgs=min_mean_stc_sgs,
    max_mean_stc_sgs=max_mean_stc_sgs,

    min_peak_det_sgs=min_peak_det_sgs,
    max_peak_det_sgs=max_peak_det_sgs,
    min_mean_det_sgs=min_mean_det_sgs,
    max_mean_det_sgs=max_mean_det_sgs,
)

In [None]:
stochastic_data = np.load(result_base_dir + "/comparisons/samples.npz")

min_peak_stc_sgs = stochastic_data["min_peak_stc_sgs"]
max_peak_stc_sgs = stochastic_data["max_peak_stc_sgs"]
min_mean_stc_sgs = stochastic_data["min_mean_stc_sgs"]
max_mean_stc_sgs = stochastic_data["max_mean_stc_sgs"]

min_peak_det_sgs = stochastic_data["min_peak_det_sgs"]
max_peak_det_sgs = stochastic_data["max_peak_det_sgs"]
min_mean_det_sgs = stochastic_data["min_mean_det_sgs"]
max_mean_det_sgs = stochastic_data["max_mean_det_sgs"]

In [None]:
min_peak_stc_rewards = reward_fn(min_peak_stc_sgs, 0.0, -1.0)
min_peak_det_rewards = reward_fn(min_peak_det_sgs, 0.0, -1.0)

max_peak_stc_rewards = reward_fn(max_peak_stc_sgs, 0.0, 1.0)
max_peak_det_rewards = reward_fn(max_peak_det_sgs, 0.0, 1.0)

min_mean_stc_rewards = reward_fn(min_mean_stc_sgs, -1.0, 0.0)
min_mean_det_rewards = reward_fn(min_mean_det_sgs, -1.0, 0.0)

max_mean_stc_rewards = reward_fn(max_mean_stc_sgs, 1.0, 0.0)
max_mean_det_rewards = reward_fn(max_mean_det_sgs, 1.0, 0.0)

In [None]:
print(jnp.mean(min_peak_det_rewards) / jnp.mean(min_peak_stc_rewards))

In [None]:
print(jnp.quantile(min_peak_stc_rewards, 0.25), jnp.quantile(min_peak_stc_rewards, 0.75))
print(jnp.quantile(min_peak_det_rewards, 0.25), jnp.quantile(min_peak_det_rewards, 0.75))

In [None]:
def plot_stc_det_hist_comp(stc_rewards, det_rewards, plot_name: str = None, save_prefix: str = None):
    def plot(style_name):
        plt.figure(figsize=(plot_half_width, plot_third_width))

        plt.xlabel("Reward")
        plt.ylabel("Counts")

        plt.hist(stc_rewards, bins=128, alpha=0.5, label="Stochastic")
        plt.hist(det_rewards, bins=128, alpha=0.5, label="Deterministic")

        plt.hist(stc_rewards, bins=128, linewidth=1.0, histtype="step", color="tab:blue")
        plt.hist(det_rewards, bins=128, linewidth=1.0, histtype="step", color="tab:orange")

        plt.axvline(jnp.mean(stc_rewards), c="tab:blue")
        #plt.axvline(jnp.median(min_peak_stc_rewards), c="tab:blue", linestyle="dotted")

        plt.axvline(jnp.mean(det_rewards), c="tab:orange")
        #plt.axvline(jnp.median(min_peak_det_rewards), c="tab:orange", linestyle="dotted")

        print(jnp.mean(stc_rewards), jnp.mean(det_rewards))
        print(jnp.std(stc_rewards), jnp.std(det_rewards))

        plt.legend()
        plt.tight_layout()
        show(f"/{plot_name}_stc_det_hist_comp" + style_name, save_prefix=save_prefix)

    styles(plot)

plot_stc_det_hist_comp(min_peak_stc_rewards, min_peak_det_rewards, "min_peak", "/comparisons")
plot_stc_det_hist_comp(max_peak_stc_rewards, max_peak_det_rewards, "max_peak", "/comparisons")
plot_stc_det_hist_comp(min_mean_stc_rewards, min_mean_det_rewards, "min_mean", "/comparisons")
plot_stc_det_hist_comp(max_mean_stc_rewards, max_mean_det_rewards, "max_mean", "/comparisons")

In [None]:
maxSGint_data = np.stack(
    [
        np.load("../measured-data/maxSGint_230923_summary_data.npz")[
            "frac_stressed_cells"
        ][:42],
        np.load("../measured-data/maxSGint_231003_summary_data.npz")[
            "frac_stressed_cells"
        ][:42],
    ],
    axis=-1,
)

minSGamp_data = np.stack(
    [
        np.load("../measured-data/minSGamp_230921_summary_data.npz")[
            "frac_stressed_cells"
        ][:42],
        np.load("../measured-data/minSGamp_230928_summary_data.npz")[
            "frac_stressed_cells"
        ][:42],
    ],
    axis=-1,
)

In [None]:
def l2_loss(measurement, theory):
    l2 = 0
    for i in range(measurement.shape[1]):
        l2 += jnp.mean(
            jnp.square(measurement[None, 1:, i] - theory[:, : 41 * 15 : 15]), axis=-1
        )

    return l2


minSGamp_loss = l2_loss(minSGamp_data, min_peak_stc_vmap_data["sg"])
maxSGint_loss = l2_loss(maxSGint_data, max_mean_stc_vmap_data["sg"])
total_loss = minSGamp_loss + maxSGint_loss

In [None]:
def plot_fit(save_prefix: str = None):
    def plot(name):
        plt.figure(figsize=(plot_half_width, plot_third_width))

        plt.yscale("log")
        plt.xscale("log")
        plt.xlabel("Tha Activity")
        plt.ylabel("L2 Loss")

        plt.plot(tha_range, minSGamp_loss, label="Min. Peak")
        plt.plot(tha_range, maxSGint_loss, label="Max. Integral")
        plt.plot(tha_range, total_loss, label="Total")

        plt.axvline(tha_range[jnp.argmin(minSGamp_loss)], c="tab:blue")
        plt.axvline(tha_range[jnp.argmin(maxSGint_loss)], c="tab:orange")
        plt.axvline(tha_range[jnp.argmin(total_loss)], c="tab:green")

        plt.legend()
        plt.tight_layout()

        show("/stc_tha_refit" + name, save_prefix=save_prefix)

    styles(plot)


plot_fit("/stochastic_calibration")

In [None]:
jnp.min(maxSGint_loss), jnp.argmin(maxSGint_loss), tha_range[jnp.argmin(maxSGint_loss)]

In [None]:
plt.figure()
plt.plot(min_peak_stc_vmap_data["sg"][45])
plt.show()

In [None]:
np.save("../measured-data/fit/min_peak_sgs.npy", min_peak_stc_vmap_data["sg"][90])
np.save("../measured-data/fit/max_mean_sgs.npy", max_mean_stc_vmap_data["sg"][90])