In [None]:
from functools import partial
from typing import Dict, List

import jax
import jax.numpy as jnp
import jaxopt
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax import jacfwd, jacrev, lax
from jax.experimental.ode import odeint
from tqdm.auto import trange

# from odeint import odeint

# jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)


# Integration Test

In [None]:
def k_helper(t):
    k = [None] * 14

    k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
    k[1] = 0.8  # lambda2=0.8/day
    k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
    k[3] = 1e6  # carrying capacity: 10^6 cells
    k[4] = 2  # growth factor degradation: gamma=2/day
    k[5] = (
        240 * 1440
    )  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
    k[6] = (
        470 * 1440
    )  # beta1=470 molecules/cell/min                                ---- beta_1
    k[7] = (
        70 * 1440
    )  # beta2=70 molecules/cell/min                                 ---- beta_2
    k[8] = (
        940 * 1440
    )  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
    k[9] = (
        510 * 1440
    )  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
    k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
    k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
    k[12] = 140 * 1440 if t < 4.0 else 0  # 120 inflammation pulse
    k[13] = 1e6

    return k


@jax.jit
def ode(x, t):
    k = [None] * 14

    k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
    k[1] = 0.8  # lambda2=0.8/day
    k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
    k[3] = 1e6  # carrying capacity: 10^6 cells
    k[4] = 2  # growth factor degradation: gamma=2/day
    k[5] = (
        240 * 1440
    )  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
    k[6] = (
        470 * 1440
    )  # beta1=470 molecules/cell/min                                ---- beta_1
    k[7] = (
        70 * 1440
    )  # beta2=70 molecules/cell/min                                 ---- beta_2
    k[8] = (
        940 * 1440
    )  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
    k[9] = (
        510 * 1440
    )  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
    k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
    k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
    k[12] = jnp.where(t < 4, 140 * 1440, 0)  # 120 inflammation pulse
    k[13] = 1e6

    dx0 = x[0] * (k[0] * x[3] / (k[10] + x[3]) * (1 - x[0] / k[3]) - k[2])  # Fibrobasts
    dx1 = x[1] * (k[1] * x[2] / (k[11] + x[2]) - k[2]) + k[12]  # Mph
    dx2 = k[6] * x[0] - k[8] * x[1] * x[2] / (k[11] + x[2]) - k[4] * x[2]  # CSF
    dx3 = (
        k[7] * x[1] + k[5] * x[0] - k[9] * x[0] * x[3] / (k[10] + x[3]) - k[4] * x[3]
    )  # PDGF

    return dx0, dx1, dx2, dx3


In [None]:
t = jnp.linspace(0.0, 300.0, 300)
y = odeint(ode, jnp.asarray([1.0, 1.0, 0.0, 0.0]), t)


In [None]:
for i in range(4):
    plt.figure()
    plt.yscale("log")
    plt.plot(t, y[..., i])
    plt.show()


# Direct Optimization

In [None]:
@jax.jit
def ode(x, t):
    k = [None] * 14

    k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
    k[1] = 0.8  # lambda2=0.8/day
    k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
    k[3] = 1e6  # carrying capacity: 10^6 cells
    k[4] = 2  # growth factor degradation: gamma=2/day
    k[5] = (
        240 * 1440
    )  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
    k[6] = (
        470 * 1440
    )  # beta1=470 molecules/cell/min                                ---- beta_1
    k[7] = (
        70 * 1440
    )  # beta2=70 molecules/cell/min                                 ---- beta_2
    k[8] = (
        940 * 1440
    )  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
    k[9] = (
        510 * 1440
    )  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
    k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
    k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
    k[12] = jnp.where(t < 4, 140 * 1440, 0)  # 120 inflammation pulse
    k[13] = 1e6

    dx0 = x[0] * (k[0] * x[3] / (k[10] + x[3]) * (1 - x[0] / k[3]) - k[2])  # Fibrobasts
    dx1 = x[1] * (k[1] * x[2] / (k[11] + x[2]) - k[2]) + k[12]  # Mph
    dx2 = k[6] * x[0] - k[8] * x[1] * x[2] / (k[11] + x[2]) - k[4] * x[2]  # CSF
    dx3 = (
        k[7] * x[1] + k[5] * x[0] - k[9] * x[0] * x[3] / (k[10] + x[3]) - k[4] * x[3]
    )  # PDGF

    return dx0, dx1, dx2, dx3


# Get initial condition (ON fixed point / inflammation)
t = jnp.linspace(0.0, 300.0, 300)
y0 = jnp.asarray([1.0, 1.0, 0.0, 0.0])
y_pre = odeint(ode, y0, t)

y1 = lax.stop_gradient(y_pre[-1])


In [None]:
@jax.jit
def ode(x, t, u):
    k = {}

    k[0] = 0.9  # proliferation rates: lambda1=0.9/day,
    k[1] = 0.8  # lambda2=0.8/day
    k[2] = 0.3  # mu_1, mu_2, death rates: 0.3/day
    k[3] = 1e6  # carrying capacity: 10^6 cells
    k[4] = 2  # growth factor degradation: gamma=2/day
    k[5] = (
        240 * 1440
    )  # growth factor secretion rates: beta3=240 molecules/cell/min  ---- beta_3
    k[6] = (
        470 * 1440
    )  # beta1=470 molecules/cell/min                                ---- beta_1
    k[7] = (
        70 * 1440
    )  # beta2=70 molecules/cell/min                                 ---- beta_2
    k[8] = (
        940 * 1440
    )  # alpha1=940 molecules/cell/min, endocytosis rate CSF1       ---- alpha_1
    k[9] = (
        510 * 1440
    )  # alpha2=510 molecules/cell/min, endocytosis rate PDGF     ---- alpha_2
    k[10] = 6e8  # #binding affinities: k1=6x10^8 molecules (PDGF)     ---- k_1
    k[11] = 6e8  # k2=6x10^8 (CSF)                                   ---- k_2
    k[12] = 0  # 120 inflammation pulse
    k[13] = 1e6

    # Control
    vinterp = jax.vmap(jnp.interp, in_axes=(None, None, -1), out_axes=-1)

    u_at_t = vinterp(
        t,
        jnp.linspace(0.0, 100.0, 100 + 1),
        jnp.concatenate((jnp.zeros_like(u[:1]), u), axis=0),
    )

    # PDGF antibody
    k_ab = 1 * 1440  # 1 / (min * molecule)
    # pdgf_ab_deg = -k_ab * x[3] * u[jnp.int32(t)]
    pdgf_ab_deg = -k_ab * x[3] * u_at_t[0]
    # pdgf_ab_deg = 0

    # Cytostatic drug
    #k[0] = 0.9 * (1 - u_at_t[1] / (u_at_t[1] + 1.0))

    dx0 = x[0] * (k[0] * x[3] / (k[10] + x[3]) * (1 - x[0] / k[3]) - k[2])  # Fibrobasts
    dx1 = x[1] * (k[1] * x[2] / (k[11] + x[2]) - k[2]) + k[12]  # Mph
    dx2 = k[6] * x[0] - k[8] * x[1] * x[2] / (k[11] + x[2]) - k[4] * x[2]  # CSF
    dx3 = (
        pdgf_ab_deg
        + k[7] * x[1]
        + k[5] * x[0]
        - k[9] * x[0] * x[3] / (k[10] + x[3])
        - k[4] * x[3]
    )  # PDGF

    return jnp.array([dx0, dx1, dx2, dx3])

@jax.jit
def integrate(u, y0):
    t = jnp.linspace(0.0, 100.0, 10000)
    # y = odeint(ode, y0, t, u) # This uses the adjoint, which seems unstable here
    #y = odeint_rk4(ode, y0, t, u)  # Adjoint-free, but currently with RK4, not DOPRI5
    # y = odeint_backward_euler(ode, y0, t, u)  # Backward Euler
    y = odeint_trapezoidal_rule(ode, y0, t, u)  # Trapezoidal Rule

    return y


@jax.jit
def loss(u, y0):
    y = integrate(u, y0)
    l = jnp.mean(jnp.square(y[..., :2]))  # + jnp.mean(jnp.square(u))

    return l


grad = jax.grad(loss)


In [None]:
u = jnp.ones((300, 1))
%timeit l = loss(u, y_pre[-1]).block_until_ready()

In [None]:
u = jnp.zeros(100)
print(loss(u, y1))
# integrate(u, y1)
print(grad(u, y1))


In [None]:
u = jnp.zeros(10)
for i in trange(1024):
    lr = 1e-4

    u_grad = grad(u, y1)
    u = u - (lr / jnp.max(jnp.abs(u_grad))) * u_grad
    # u = u - lr * jnp.sign(u_grad)
    # print(jnp.max(jnp.abs(u_grad)))

    l2 = 0.0  # 1.0
    l1 = 0.0  # 0.1

    u = u - (lr * l2) * u  # L2
    u = jnp.where(jnp.abs(u) <= (lr * l1), 0.0, u - (lr * l1) * jnp.sign(u))  # L1


In [None]:
lr = 1e-2

optimizer = optax.adam(learning_rate=lr)
params = jnp.zeros((100, 2))
opt_state = optimizer.init(params)

params_history = []

for i in trange(1024 * 16):
    grads = jax.grad(loss)(params, y1)
    updates, opt_state = optimizer.update(grads, opt_state, params=params)
    params = optax.apply_updates(params, updates)

    l2 = 0.0  # 1.0
    l1 = 0.1  # 0.1

    u = params
    u = u - (lr * l2) * u  # L2
    u = jnp.where(jnp.abs(u) <= (lr * l1), 0.0, u - (lr * l1) * jnp.sign(u))  # L1
    params = u

    params_history.append(params.copy())

In [None]:
y = integrate(params, y1)
y = jnp.concatenate((y, integrate(jnp.zeros_like(params), y[-1])), axis=0)

for i in range(4):
    plt.figure()
    plt.yscale("log")
    plt.plot(np.linspace(0.0, 200.0, y.shape[0]),y[..., i])
    plt.show()

plt.figure()
vinterp = jax.vmap(jnp.interp, in_axes=(None, None, -1), out_axes=-1)
plt.plot(
    np.linspace(0.0, 100.0, 100),
    vinterp(
        jnp.linspace(0.0, 100.0, 100),
        jnp.linspace(0.0, 100.0, 100 + 1),
        jnp.concatenate((jnp.zeros_like(params_history[-1][:1]), params_history[-1]), axis=0),
    ),
)
plt.show()


In [None]:
for i in range(0, 1024, 128):
    plt.figure()
    vinterp = jax.vmap(jnp.interp, in_axes=(None, None, -1), out_axes=-1)
    plt.plot(
        np.linspace(0.0, 100.0, 100),
        vinterp(
            jnp.linspace(0.0, 100.0, 100),
            jnp.linspace(0.0, 100.0, 100 + 1),
            jnp.concatenate((jnp.zeros_like(params_history[i][:1]), params_history[i]), axis=0),
        ),
    )
    plt.show()

In [None]:
# Different trajectories with constant-integral PDGF + CSF1 antibodies
# Constant antibodies
# Optimized antibodies
# Plot in panels