# Setup

## Imports

In [None]:
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import time
from functools import partial
from typing import Tuple

import diffrax
import equinox as eqx
import evosax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from adaptive import Learner2D
from adaptive.notebook_integration import notebook_extension
from jaxtyping import Array, ArrayLike, PRNGKeyArray, PyTree, Scalar
from IPython.display import clear_output

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.nn as nn
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

In [None]:
notebook_extension()

In [None]:
a4_inches = (8.3, 11.7)
plot_full_width = a4_inches[0]
plot_half_width = a4_inches[0] / 2
plot_third_width = a4_inches[0] / 3
plot_quarter_width = a4_inches[0] / 4

result_base_dir = "../thesis-results/apoptosis"
plot_style = "seaborn-paper"

plot_shrink_factor = 0.9

plt.style.use(plot_style)

## General definitions

In [None]:
# Random key for reproducibility
key = jax.random.PRNGKey(1234)

# Initialize environment
environment: examples.ApoptosisEnvironment = examples.ApoptosisEnvironment(
    "../data/Initial_concentrations_CD95H_wtH.mat", [0, 500], 50
)
environment_state = environment.init()

eval_environment: examples.ApoptosisEnvironment = examples.ApoptosisEnvironment(
    "../data/Initial_concentrations_CD95H_wtH.mat", [500, 1000], -1, True
)
eval_environment_state = eval_environment.init()

# Build controller
key, subkey = jax.random.split(key)
control = controls.ImplicitTemporalControl(
    implicit_fn=nn.Siren(
        in_features=1, out_features=1, hidden_features=64, hidden_layers=2, key=subkey
    ),
    t_start=0.0,
    t_end=180.0,
    to_curve=True,
    curve_interpolation="linear",
    curve_steps=181,
)


# Define reward function
def proxy_reward_fn(
    args: Tuple[diffrax.Solution, Array], instantaneous: bool = False, norm: bool = True
):
    # Continuous fraction of tBID, clipped at the tBID-apoptosis threshold

    solution, thresh = args
    ys = solution.ys

    tBID_frac = ys[..., 12] / (ys[..., 3] + ys[..., 12])
    if norm:
        frac_norm = tBID_frac / thresh
        frac_clipped = jnp.clip(frac_norm, a_max=1.0)
    else:
        frac_clipped = jnp.clip(tBID_frac, a_max=thresh)

    if instantaneous:
        reward = jnp.sum(frac_clipped, axis=-1)
    else:
        reward = jnp.mean(jnp.sum(frac_clipped, axis=-1))

    return reward


def true_reward_fn(args: Tuple[diffrax.Solution, Array], instantaneous: bool = False):
    # Number of dead cells

    solution, thresh = args
    ys = solution.ys

    tBID_frac = ys[..., 12] / (ys[..., 3] + ys[..., 12])
    tBID_above = tBID_frac > thresh

    if instantaneous:
        reward = jnp.sum(tBID_above, axis=-1)
    else:
        reward = jnp.mean(jnp.sum(tBID_above, axis=-1))

    return reward


# Direct solver
direct_solver = solvers.DirectSolver(optax.adam(learning_rate=3e-4))


# ES Solver factory (to allow for init parameters)
def make_es_solver(init_control: controls.AbstractControl) -> solvers.ESSolver:
    evo_control_params = eqx.filter(init_control, eqx.is_array)
    evo_parameter_reshaper = evosax.ParameterReshaper(evo_control_params)
    evo_fitness_shaper = evosax.FitnessShaper(centered_rank=True, maximize=True)

    # """
    evo_strategy = evosax.OpenES(
        popsize=64,
        lrate_init=3e-4,
        lrate_limit=3e-4,
        sigma_init=1e-3,
        sigma_limit=1e-3,
        num_dims=len(evo_parameter_reshaper.flatten_single(evo_control_params)),
    )
    # """

    """
    evo_strategy = evosax.LES(
        popsize=64,
        num_dims=len(evo_parameter_reshaper.flatten_single(evo_control_params)),
        net_ckpt_path="../data/2023_03_les_v1.pkl",
    )
    """

    evo_strategy_params = evo_strategy.default_params

    evo_solver: solvers.ESSolver = solvers.ESSolver(
        evo_strategy, evo_strategy_params, evo_parameter_reshaper, evo_fitness_shaper
    )

    return evo_solver

In [None]:
# Training helper
@eqx.filter_jit
def train_with_integral(
    control: controls.AbstractControl,
    solver: solvers.AbstractSolver,
    target_integral: Array,
    reward_fn=proxy_reward_fn,
    num_steps: int = 1024,
) -> Tuple[Scalar, controls.AbstractControl]:
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.NonNegativeConstantIntegralConstraint(target=target_integral)
        ]
    )

    opt_reward, opt_control = trainers.solve_optimal_control_problem(
        num_train_steps=num_steps,
        environment=environment,
        reward_fn=reward_fn,
        constraint_chain=constraint_chain,
        solver=solver,
        control=control,
        key=key,
        pbar_interval=8,
        integrate_kwargs=dict(vmap="inner"),
    )

    return opt_reward, opt_control

@eqx.filter_jit
def evaluate_with_integral(
    control: controls.AbstractControl,
    target_integral: Array,
    environment: examples.ApoptosisEnvironment,
    state: examples.ApoptosisState,
) -> Tuple[PyTree, controls.AbstractControl]:
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.NonNegativeConstantIntegralConstraint(target=target_integral)
        ]
    )

    constrained_control, _ = solvers.build_control(control, constraint_chain)

    solution = environment.integrate(constrained_control, state, key, vmap="inner")
    return solution, constrained_control

# Run

## Comparison between grad, es, and rewards

### Train networks

In [None]:
pretrained_reward, pretrained_control = train_with_integral(
    control,
    direct_solver,
    jnp.asarray([1.0]),
    reward_fn=proxy_reward_fn,
    num_steps=1024,
)

In [None]:
finetuned_reward, finetuned_control = train_with_integral(
    pretrained_control,
    make_es_solver(pretrained_control),
    jnp.asarray([1.0]),
    reward_fn=true_reward_fn,
    num_steps=256,
)

In [None]:
only_es_reward, only_es_control = train_with_integral(
    control,
    make_es_solver(pretrained_control),
    jnp.asarray([1.0]),
    reward_fn=true_reward_fn,
    num_steps=1024 + 256,
)

### Evaluate solutions

In [None]:
(pretrain_solution, tBID_thresh), pretrain_constrained_control = evaluate_with_integral(
    pretrained_control, jnp.asarray([1.0]), eval_environment, eval_environment_state
)

(
    finetuned_solution,
    tBID_thresh,
), finetune_constrained_control = evaluate_with_integral(
    finetuned_control, jnp.asarray([1.0]), eval_environment, eval_environment_state
)

(
    only_es_solution,
    tBID_thresh,
), only_es_constrained_control = evaluate_with_integral(
    only_es_control, jnp.asarray([1.0]), eval_environment, eval_environment_state
)

In [None]:
pretrain_proxy_reward = proxy_reward_fn(
    (pretrain_solution, tBID_thresh), instantaneous=True
)
pretrain_true_reward = true_reward_fn(
    (pretrain_solution, tBID_thresh), instantaneous=True
)

finetune_proxy_reward = proxy_reward_fn(
    (finetuned_solution, tBID_thresh), instantaneous=True
)
finetune_true_reward = true_reward_fn(
    (finetuned_solution, tBID_thresh), instantaneous=True
)

only_es_proxy_reward = proxy_reward_fn(
    (only_es_solution, tBID_thresh), instantaneous=True
)
only_es_true_reward = true_reward_fn(
    (only_es_solution, tBID_thresh), instantaneous=True
)

In [None]:
print(jnp.mean(pretrain_true_reward), jnp.mean(finetune_true_reward))
print(pretrain_true_reward[-1], finetune_true_reward[-1])

### Save solutions

In [None]:
eqx.tree_serialise_leaves(result_base_dir + "/reward_comparison/pretrain_control.eqx", pretrained_control)
eqx.tree_serialise_leaves(result_base_dir + "/reward_comparison/finetune_control.eqx", finetuned_control)
eqx.tree_serialise_leaves(result_base_dir + "/reward_comparison/es_control.eqx", only_es_control)

### Plot solutions

In [None]:
ts = pretrain_solution.ts

fig, ax = plt.subplots(2, 1, sharex=True, figsize=(plot_half_width, plot_half_width))

ax[0].set_ylabel("True Reward")
ax[0].plot(ts, pretrain_true_reward, label="Proxy")
ax[0].plot(ts, finetune_true_reward, linestyle="--", label="Proxy -> True")
ax[0].plot(ts, only_es_true_reward, label="True")
ax[0].legend()

ax[1].set_xlabel("Time [min.]")
ax[1].set_ylabel("Proxy Reward")
ax[1].plot(ts, pretrain_proxy_reward, label="Proxy")
ax[1].plot(ts, finetune_proxy_reward, linestyle="--", label="Proxy -> True")
ax[1].plot(ts, only_es_proxy_reward, label="True")

plt.savefig(
    result_base_dir + "/reward_comparison/proxy_and_true_reward.png",
    bbox_inches="tight",
)
plt.savefig(
    result_base_dir + "/reward_comparison/proxy_and_true_reward.svg",
    bbox_inches="tight",
)

plt.show()

In [None]:
ts = pretrain_solution.ts
p_cs = jax.vmap(pretrain_constrained_control)(ts) 
f_cs = jax.vmap(finetune_constrained_control)(ts) 
es_cs = jax.vmap(only_es_constrained_control)(ts) 

In [None]:
plt.figure(figsize=(plot_half_width, plot_half_width))

plt.xlabel("Time [min.]")
plt.ylabel("CD95L [ng/ml]")
plt.plot(ts, p_cs, label="Proxy")
plt.plot(ts, f_cs, linestyle="--", label="Proxy -> True")
plt.plot(ts, es_cs, label="True")
plt.legend()

plt.savefig(
    result_base_dir + "/reward_comparison/proxy_and_true_controls.png",
    bbox_inches="tight",
)
plt.savefig(
    result_base_dir + "/reward_comparison/proxy_and_true_controls.svg",
    bbox_inches="tight",
)

plt.show()

### Show loss landscapes

In [None]:
# Evaluate region around optimum


@eqx.filter_jit
def sample_direction(network: PyTree, key: PRNGKeyArray) -> PyTree:
    jax_params = eqx.filter(network, eqx.is_array)
    reshaper = evosax.ParameterReshaper(eqx.filter(network, eqx.is_array))
    flat_params = reshaper.flatten_single(jax_params)
    noise = jax.random.normal(key, flat_params.shape, flat_params.dtype)
    direction = noise / jnp.sqrt(jnp.sum(jnp.square(noise)))

    return direction


@eqx.filter_jit
def scan_2d(
    control: controls.AbstractControl, x: Array, y: Array, xs: Array, ys: Array
):
    jax_params, jax_static = eqx.partition(control, eqx.is_array)
    reshaper = evosax.ParameterReshaper(jax_params)
    source_params = reshaper.flatten_single(jax_params)
    offsets = y[None, None] * ys[:, None, None] + x[None, None] * xs[None, :, None]
    offsets = offsets.reshape(-1, len(source_params))
    grid_params = source_params + offsets

    # """
    grid_controls = eqx.combine(reshaper.reshape(grid_params), jax_static)

    (
        grid_solutions,
        grid_tbid_thresholds,
    ), _ = eqx.filter_vmap(
        partial(
            evaluate_with_integral,
            target_integral=jnp.asarray([1.0]),
            environment=eval_environment,
            state=eval_environment_state,
        )
    )(grid_controls)
    # """

    """
    grid_controls = eqx.combine(reshaper.reshape(grid_params), jax_static)
    grid_solutions, grid_tBID_thresholds = jax.lax.map(
        partial(
            evaluate_with_integral,
            target_integral=jnp.asarray([1.0]),
            environment=eval_environment,
            state=eval_environment_state,
        ),
        grid_controls,
    )
    """

    """
    grid_solutions = []
    grid_tbid_thresholds = []
    for params in grid_params:
        grid_control = eqx.combine(reshaper.reshape_single(params), jax_static)
        grid_solution, grid_tbid_threshold = evaluate_with_integral(
            control=grid_control,
            target_integral=jnp.asarray([1.0]),
            environment=eval_environment,
            state=eval_environment_state,
        )

        grid_solutions.append(grid_solutions)
        grid_tbid_thresholds.append(grid_tbid_thresholds)

    grid_solutions = jax.tree_map(
        lambda x: jnp.stack(x, axis=0) if eqx.is_array(x[0]) else x[0],
        grid_solutions[0],
        grid_solutions[1:],
    )
    grid_tbid_thresholds = jnp.stack(grid_tbid_thresholds, axis=0)
    """

    return grid_solutions, grid_tbid_thresholds


@eqx.filter_jit
def sample_control(
    control: controls.AbstractControl, x_dir: Array, y_dir: Array, x: Scalar, y: Scalar
) -> controls.AbstractControl:
    jax_params, jax_static = eqx.partition(control, eqx.is_array)
    reshaper = evosax.ParameterReshaper(jax_params)
    source_params = reshaper.flatten_single(jax_params)

    offset = y_dir * y + x_dir * x
    offset_params = source_params + offset
    offset_params = reshaper.reshape_single(offset_params)
    offset_control = eqx.combine(offset_params, jax_static)

    return offset_control


def adaptive_2d(
    control: controls.AbstractControl,
    x_dir: Array,
    y_dir: Array,
    bounds: Tuple[float],
    reward_fn,
    max_points: int = 1204,
):
    learner = Learner2D(lambda _: 0.0, bounds=bounds)
    learner.stack_size = 1

    last_plt_time = time.time()
    plt_interval = 15

    while learner.npoints < max_points:
        try:
            points, _ = learner.ask(1)
            point = points[0]
            x, y = point

            offset_control = sample_control(
                control, x_dir, y_dir, jnp.float_(x), jnp.float_(y)
            )
            offset_solution, _ = evaluate_with_integral(
                offset_control,
                jnp.asarray([1.0]),
                eval_environment,
                eval_environment_state,
            )
            offset_reward = reward_fn(offset_solution)
            learner.tell(point, float(offset_reward))

            if time.time() - last_plt_time >= plt_interval:
                clear_output(wait=True)
                display(learner.plot(tri_alpha=0.25))
                last_plt_time = time.time()
        except KeyboardInterrupt:
            break

    return learner

In [None]:
xkey, ykey = jax.random.split(key)

x = sample_direction(pretrained_control, xkey)
y = sample_direction(pretrained_control, ykey)

In [None]:
# This consumes lots of RAM

grid_solutions, grid_tBID_thresholds = scan_2d(
    pretrained_control, x, y, jnp.linspace(-1.5, 1.5, 16), jnp.linspace(-1.5, 1.5, 16)
)

es_grid_solutions, es_grid_tBID_thresholds = scan_2d(
    only_es_control, x, y, jnp.linspace(-1.5, 1.5, 16), jnp.linspace(-1.5, 1.5, 16)
)

In [None]:
pretrain_proxy_learner = adaptive_2d(
    pretrained_control,
    x,
    y,
    [(-1.5, 1.5), (-1.5, 1.5)],
    reward_fn=proxy_reward_fn,
    max_points=1024,
)

pretrain_true_learner = adaptive_2d(
    pretrained_control,
    x,
    y,
    [(-1.5, 1.5), (-1.5, 1.5)],
    reward_fn=true_reward_fn,
    max_points=1024,
)

In [None]:
es_proxy_learner = adaptive_2d(
    only_es_control,
    x,
    y,
    [(-1.5, 1.5), (-1.5, 1.5)],
    reward_fn=proxy_reward_fn,
    max_points=1024,
)

es_true_learner = adaptive_2d(
    only_es_control,
    x,
    y,
    [(-1.5, 1.5), (-1.5, 1.5)],
    reward_fn=true_reward_fn,
    max_points=1024,
)

#### Save learners

In [None]:
pretrain_proxy_learner.save(result_base_dir + "/reward_comparison/pretrain_proxy_learner.pickle")
pretrain_true_learner.save(result_base_dir + "/reward_comparison/pretrain_true_learner.pickle")
es_proxy_learner.save(result_base_dir + "/reward_comparison/es_proxy_learner.pickle")
es_true_learner.save(result_base_dir + "/reward_comparison/es_true_learner.pickle")

#### Plot from vmapped grids

In [None]:
grid_proxy_rewards = eqx.filter_vmap(proxy_reward_fn)((grid_solutions, grid_tBID_thresholds))
grid_proxy_no_norm_rewards = eqx.filter_vmap(partial(proxy_reward_fn, norm=False))((grid_solutions, grid_tBID_thresholds))
grid_true_rewards = eqx.filter_vmap(true_reward_fn)((grid_solutions, grid_tBID_thresholds))

es_grid_proxy_rewards = eqx.filter_vmap(proxy_reward_fn)((grid_solutions, grid_tBID_thresholds))
es_grid_proxy_no_norm_rewards = eqx.filter_vmap(partial(proxy_reward_fn, norm=False))((grid_solutions, grid_tBID_thresholds))
es_grid_true_rewards = eqx.filter_vmap(true_reward_fn)((grid_solutions, grid_tBID_thresholds))

In [None]:
plt.figure()
plt.scatter([0.0], [0.0], c="red", marker="x")
plt.imshow(
    grid_proxy_rewards.reshape(16, 16), cmap="magma", extent=(-1.5, 1.5, -1.5, 1.5)
)
plt.colorbar()
plt.show()

plt.figure()
plt.scatter([0.0], [0.0], c="red", marker="x")
plt.imshow(
    grid_proxy_no_norm_rewards.reshape(16, 16), cmap="magma", extent=(-1.5, 1.5, -1.5, 1.5)
)
plt.colorbar()
plt.show()

plt.figure()
plt.scatter([0.0], [0.0], c="red", marker="x")
plt.imshow(
    grid_true_rewards.reshape(16, 16), cmap="magma", extent=(-1.5, 1.5, -1.5, 1.5)
)
plt.colorbar()
plt.show()

#### Plot from adaptive grids

In [None]:
pretrain_proxy_reward_grid = pretrain_proxy_learner.interpolated_on_grid()
pretrain_true_reward_grid = pretrain_true_learner.interpolated_on_grid()
es_proxy_reward_grid = es_proxy_learner.interpolated_on_grid()
es_true_reward_grid = es_true_learner.interpolated_on_grid()

In [None]:
from resize_right import resize, interp_methods


def plot_reward_grid(grid_x, grid_y, grid_value, label: str, filepath: str):
    plt.figure(
        figsize=(
            plot_half_width * plot_shrink_factor,
            plot_half_width * plot_shrink_factor,
        )
    )
    plt.xlabel("X Offset [a.u.]")
    plt.ylabel("Y Offset [a.u.]")
    plt.imshow(
        grid_value, cmap="magma", extent=(grid_x[0], grid_x[-1], grid_y[0], grid_y[-1])
    )
    cbar = plt.colorbar(fraction=0.04575, pad=0.04)
    cbar.set_label(label)
    plt.savefig(result_base_dir + filepath + ".png", bbox_inches="tight")
    plt.savefig(result_base_dir + filepath + ".svg", bbox_inches="tight")
    plt.show()


def plot_diff_grid(grid_x, grid_y, grid1_value, grid2_value, filepath: str):
    grid1_value = resize(
        grid1_value,
        out_shape=(64, 64),
        interp_method=interp_methods.linear,
        pad_mode="edge",
    )
    grid2_value = resize(
        grid2_value,
        out_shape=(64, 64),
        interp_method=interp_methods.linear,
        pad_mode="edge",
    )

    # """
    grid1_value = (grid1_value - grid1_value.min()) / (
        grid1_value.max() - grid1_value.min()
    )
    grid2_value = (grid2_value - grid2_value.min()) / (
        grid2_value.max() - grid2_value.min()
    )
    # """

    # While znorm leads to easy-to-interpret differences, the data isn't normally
    # distributed
    """
    grid1_value = (grid1_value - grid1_value.mean()) / grid1_value.std()
    grid2_value = (grid2_value - grid2_value.mean()) / grid2_value.std()
    """

    # grid1_value = grid1_value / grid1_value.mean()
    # grid2_value = grid2_value / grid2_value.mean()

    grid_diff = grid1_value - grid2_value
    vabs = np.abs(grid_diff).max()

    plt.figure(
        figsize=(
            plot_half_width * plot_shrink_factor,
            plot_half_width * plot_shrink_factor,
        )
    )
    plt.xlabel("X Offset [a.u.]")
    plt.ylabel("Y Offset [a.u.]")
    plt.imshow(
        grid_diff,
        cmap="RdBu",
        extent=(grid_x[0], grid_x[-1], grid_y[0], grid_y[-1]),
        vmin=-vabs,
        vmax=vabs,
    )
    cbar = plt.colorbar(fraction=0.04575, pad=0.04)
    cbar.set_label("Diff. between norm. rewards")
    plt.savefig(result_base_dir + filepath + ".png", bbox_inches="tight")
    plt.savefig(result_base_dir + filepath + ".svg", bbox_inches="tight")
    plt.show()


grid_x, grid_y = pretrain_proxy_reward_grid[:2]

plot_reward_grid(
    *pretrain_proxy_reward_grid,
    label="Proxy Reward",
    filepath="/reward_comparison/pretrain_proxy_reward_grid"
)
plot_reward_grid(
    *pretrain_true_reward_grid,
    label="True Reward",
    filepath="/reward_comparison/pretrain_true_reward_grid"
)

plot_reward_grid(
    *es_proxy_reward_grid,
    label="Proxy Reward",
    filepath="/reward_comparison/es_proxy_reward_grid"
)
plot_reward_grid(
    *es_true_reward_grid,
    label="True Reward",
    filepath="/reward_comparison/es_true_reward_grid"
)

plot_diff_grid(
    grid_x,
    grid_y,
    pretrain_proxy_reward_grid[2],
    pretrain_true_reward_grid[2],
    filepath="/reward_comparison/pretrain_diff_reward_grid",
)

plot_diff_grid(
    grid_x,
    grid_y,
    es_proxy_reward_grid[2],
    es_true_reward_grid[2],
    filepath="/reward_comparison/es_diff_reward_grid",
)

In [None]:
plt.figure()
plt.hist(
    resize(
        pretrain_proxy_reward_grid[2],
        out_shape=(64, 64),
        interp_method=interp_methods.linear,
    ).flatten(),
    bins=128,
)
plt.show()

plt.figure()
plt.scatter(
    resize(
        pretrain_proxy_reward_grid[2],
        out_shape=(64, 64),
        interp_method=interp_methods.linear,
    ).flatten(),
    resize(
        pretrain_true_reward_grid[2],
        out_shape=(64, 64),
        interp_method=interp_methods.linear,
    ).flatten(),
)
plt.show()