# Setup

## Imports

In [None]:
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

from typing import Tuple

import diffrax
import equinox as eqx
import evosax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from jaxtyping import Array, ArrayLike, Scalar

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.nn as nn
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

## General definitions

In [None]:
# Random key for reproducibility
key = jax.random.PRNGKey(1234)

# Initialize environment
environment: examples.ApoptosisEnvironment = examples.ApoptosisEnvironment(
    "../data/Initial_concentrations_CD95H_wtH.mat", [0, 500], 50
)
environment_state = environment.init()

# Build controller
key, subkey = jax.random.split(key)
control = controls.ImplicitTemporalControl(
    implicit_fn=nn.Siren(
        in_features=1, out_features=2, hidden_features=64, hidden_layers=2, key=subkey
    ),
    t_start=0.0,
    t_end=180.0,
    to_curve=True,
    curve_interpolation="linear",
    curve_steps=181,
)


# Define reward function
def reward_fn(args: Tuple[diffrax.Solution, Array]):
    solution, thresh = args
    ys = solution.ys

    reward = jnp.mean(
        jnp.clip(
            ys[..., 12] / (ys[..., 3] + ys[..., 12]),
            a_min=None,
            a_max=thresh.reshape(-1, 1),
        )
    )
    return reward


# Make direct and ES solver
direct_solver = solvers.DirectSolver(optax.adam(learning_rate=3e-4))

evo_control_params = eqx.filter(control, eqx.is_array)
evo_parameter_reshaper = evosax.ParameterReshaper(evo_control_params)
evo_fitness_shaper = evosax.FitnessShaper(centered_rank=True, maximize=True)

evo_strategy = evosax.OpenES(
    popsize=64,
    num_dims=len(evo_parameter_reshaper.flatten_single(evo_control_params)),
)
evo_strategy_params = evo_strategy.params_strategy

evo_solver = solvers.ESSolver(
    evo_strategy, evo_strategy_params, evo_parameter_reshaper, evo_fitness_shaper
)

In [None]:
# Training helper
@eqx.filter_jit
def train_with_integral(
    control: controls.AbstractControl,
    solver: solvers.AbstractSolver,
    target_integral: Array,
) -> Tuple[Scalar, controls.AbstractControl]:
    constraint_chain = constraints.ConstraintChain(
        transformations=[
            constraints.NonNegativeConstantIntegralConstraint(target=target_integral)
        ]
    )

    opt_reward, opt_control = trainers.solve_optimal_control_problem(
        num_train_steps=1024,
        environment=environment,
        reward_fn=reward_fn,
        constraint_chain=constraint_chain,
        solver=solver,
        control=control,
        key=key,
        pbar_interval=8,
    )

    return opt_reward, opt_control

# Run

## Single run train and finetune

In [None]:
direct_reward, direct_control = train_with_integral(
    control, direct_solver, jnp.asarray([1.0])
)