In [None]:
import jax

jax.config.update("jax_enable_x64", True)
#jax.config.update("jax_check_tracer_leaks", True)

from functools import partial
from typing import Callable, Tuple

import diffrax
import equinox as eqx
import imageio.v3 as iio
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from jaxtyping import Array, ArrayLike
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments as environments
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

In [None]:
# What happends when max_steps is reached, throw=False and SaveAt not fully populated?


def ode(t, y, args):
    return -y


sol = diffrax.diffeqsolve(
    terms=diffrax.ODETerm(ode),
    solver=diffrax.Euler(),
    t0=0.0,
    t1=10.0,
    dt0=0.1,
    y0=jnp.asarray([1.0]),
    #saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 101)),
    saveat=diffrax.SaveAt(t1=True),
    max_steps=25,
    throw=False,
)

sol.ys, sol.ts

In [None]:
key = jax.random.PRNGKey(1234)
c = jax.random.normal(key, (10, 2))
t = jnp.linspace(0.0, 1.0, 1000)

In [None]:
controls.InterpolationControl.fast_interpolate_step(t, c, 0.1, 0.8)
controls.InterpolationControl.fast_interpolate_linear(t, c, 0.1, 0.8)

In [None]:
%timeit step = controls.InterpolationControl.fast_interpolate_step(t, c, 0.1, 0.8)
%timeit linear = controls.InterpolationControl.fast_interpolate_linear(t, c, 0.1, 0.8)

#%timeit linear2 = controls.InterpolationControl.interpolate(t, jnp.linspace(0.0, 1.0, c.shape[0]), c, "linear")


plt.figure()
plt.plot(t, step)
plt.plot(t, linear)
#plt.plot(t, linear2)
plt.show()

In [None]:
# Debugging a weird tracer leak


def debug_ode(t, y, args):
    return args(t)
    # return control(t)
    # return args


class DebugState(environments.EnvironmentState):
    y0: Array


class DebugEnvironment(environments.AbstractEnvironment):
    def init(self):
        return DebugState(jnp.ones(2))

    def integrate(self, control: controls.AbstractControl, state: DebugState) -> Array:
        sol = diffrax.diffeqsolve(
            terms=diffrax.ODETerm(debug_ode),  # partial(debug_ode, control=control)),
            solver=diffrax.Kvaerno5(),
            # solver=diffrax.ImplicitEuler(nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-5, atol=1e-5)),
            stepsize_controller=diffrax.PIDController(rtol=1e-5, atol=1e-5),
            t0=0.0,
            t1=10.0,
            dt0=1.0,
            y0=state.y0,
            args=control,
            saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 11)),
        )

        return sol.ys


environment = DebugEnvironment()
state = environment.init()


@jax.jit
@jax.vmap
@jax.grad
def solve(c: Array) -> Array:
    #control = controls.LambdaControl(lambda t: c) # This causes the tracer leak
    control = controls.InterpolationControl(
        2, 11, 0.0, 10.0, control=c.reshape(1, 2).repeat(11, 0)
    )
    ys = environment.integrate(control, state)

    return jnp.mean(ys)


In [None]:
solve(jnp.ones(2))
%timeit solve(jnp.ones(2)).block_until_ready()

In [None]:
solve(jnp.ones((10, 2)))
%timeit solve(jnp.ones((10, 2))).block_until_ready()

In [None]:
environment = examples.FibrosisEnvironment()
state = environment.init()
key = jax.random.PRNGKey(1234)

c = jnp.stack(
    jnp.meshgrid(jnp.linspace(0.1, 100.0, 16), jnp.linspace(0.1, 100.0, 16)), axis=-1
).reshape(-1, 2)

In [None]:
@jax.jit
def interp_fast(t: ArrayLike, c: Array, t0: float, t1: float) -> Array:  
    # Get indicies into array
    i = (t - t0) / (t1 - t0)
    i = jnp.floor(i * c.shape[0]).astype(jnp.int32)

    # Replace left oob indices
    i = jnp.where(i < 0, c.shape[0], i)

    # Gather array
    x = c.at[i].get(mode="fill", fill_value=0.0)
    return x

c1 = jnp.arange(101)
c2 = jnp.arange(101*2).reshape(2, 101).T
t_start = 0.0
t_end = 1.0
t = jnp.linspace(-0.1, 1.0, 11)

print(interp_fast(t, c1, t_start, t_end))
print(interp_fast(t, c2, t_start, t_end))
print(jax.make_jaxpr(interp_fast)(t, c1, t_start, t_end))
print(jax.make_jaxpr(interp_fast)(t, c2, t_start, t_end))

In [None]:
# Fibrosis benchmark

def ode(t, y, args):
    return args(t)
    #return 1


def reward_fn(x: Array) -> ArrayLike:
    x = jnp.where(jnp.isposinf(x), 0.0, x)
    x = jnp.clip(x[..., :2], a_min=1e2, a_max=None)
    x = -jnp.mean(jnp.log(x))

    return x

@jax.jit
#@eqx.filter_jit
#@partial(jax.vmap, in_axes=(0, None, None))
@jax.vmap
@jax.grad
def solve(
    c: Array,
    #environment: environments.AbstractEnvironment,
    #state: environments.EnvironmentState,
    # key: jax.random.KeyArray,
    # reward_fn: Callable[[Array], ArrayLike],
) -> Array:
    # def control_fn(t, c):
    #    return c

    control = controls.LambdaControl(lambda t, c: c, c)
    #control = controls.LambdaControl(lambda t: c)
    # ys = environment.integrate(control, state, key)

    #control = controls.InterpolationControl(
    #    2, 101, 0.0, 100.0, control=c.reshape(1, 2).repeat(101, 0)
    #)

    sol = environment._integrate(
        0.0,
        200.0,
        state.y0,
        control,
        False,
        diffrax.SaveAt(ts=jnp.linspace(0.0, 200.0, 201)),
        False,
    )#.ys

    return jnp.mean(reward_fn(sol.ys))
    #return sol.stats["num_steps"]

    """
    ys = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(ode),
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=10.0,
        dt0=1.0,
        y0=jnp.zeros(2),
        args=controls.LambdaControl(lambda t: c),
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 11)),
    ).ys

    return reward_fn(ys)
    """


# j_solve = eqx.filter_jit(solve)
"""jv_solve = jax.jit(
    jax.vmap(
        # jax.grad(
        partial(
            solve,
            environment=environment,
            state=state,
            key=key,
            reward_fn=reward_fn,
        )
        # )
    )
)"""


solve(c)#, environment, state)  # , environment, state, key, reward_fn)

# j_solve(c[0], environment, state, key, reward_fn, ode)
# %timeit jv_solve(c).block_until_ready()


In [None]:
%timeit solve(c).block_until_ready()

In [None]:
control = controls.InterpolationControl(
    2, 101, 0.0, 100.0, control=c[0].reshape(1, 2).repeat(101, 0)
)

sol = environment._integrate(
    0.0,
    200.0,
    state.y0,
    control,
    False,
    diffrax.SaveAt(ts=jnp.linspace(0.0, 200.0, 201)),
    False,
)

In [None]:
sol.ys

In [None]:
# Benchmarking test for vmapped solves


def ode(t, y, args):
    return -y


def solve1(y0):
    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(ode),
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=10.0,
        dt0=0.1,
        y0=y0,
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 11)),
    )

    return sol.ys


def solve2(y0, terms, solver, saveat):
    sol = diffrax.diffeqsolve(
        terms=terms,
        solver=solver,
        t0=0.0,
        t1=10.0,
        dt0=0.1,
        y0=y0,
        saveat=saveat,
    )

    return sol.ys


jv_solve1 = jax.jit(jax.vmap(solve1))
jv_solve2 = jax.jit(
    jax.vmap(
        partial(
            solve2,
            terms=diffrax.ODETerm(ode),
            solver=diffrax.Dopri5(),
            saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 11)),
        )
    )
)

%timeit jv_solve1(jnp.linspace(1.0, 10.0, 1024)).block_until_ready()
%timeit jv_solve2(jnp.linspace(1.0, 10.0, 1024)).block_until_ready()

%timeit jv_solve1(jnp.linspace(1.0, 10.0, 1024)).block_until_ready()
%timeit jv_solve2(jnp.linspace(1.0, 10.0, 1024)).block_until_ready()

In [None]:
# Test to see if terminating events still allow backprop (yes)


def ode(x):
    return -x


def cond_fn(state, **kwargs):
    return state.y[0] < 1e-3


def solve(y0):
    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(lambda t, y, args: ode(y)),
        solver=diffrax.Dopri5(),
        t0=0.0,
        t1=10.0,
        dt0=0.1,
        y0=y0,
        saveat=diffrax.SaveAt(ts=jnp.linspace(0.0, 10.0, 11)),
        discrete_terminating_event=diffrax.DiscreteTerminatingEvent(cond_fn),
    )

    return jnp.mean(sol.ys)


jit_solve = jax.jit(jax.value_and_grad(solve))
ys, y0_grad = jit_solve(jnp.asarray([1.0]))

print(ys, y0_grad)
