In [1]:
%load_ext autoreload
%autoreload 2

# Misc

In [2]:
import jax 
import jax.numpy as jnp
import numpy as np
import optax
from scipy.special import factorial
import diffrax as dfx
from functools import partial
from bioreaction.misc.misc import flatten_listlike

jax.config.update('jax_platform_name', 'gpu')




In [3]:
def one_step_de_sim_expanded(spec_conc, inputs, outputs, forward_rates, reverse_rates):
    concentration_factors_in = jnp.prod(
        jnp.power(spec_conc, (inputs)), axis=1)
    concentration_factors_out = jnp.prod(
        jnp.power(spec_conc, (outputs)), axis=1)
    forward_delta = concentration_factors_in * forward_rates
    reverse_delta = concentration_factors_out * reverse_rates
    return (forward_delta - reverse_delta) @ (outputs - inputs)


# ODE Terms
def bioreaction_sim_expanded(t, y,
                             args,
                             inputs, outputs,
                             #  signal, signal_onehot: jnp.ndarray,
                             forward_rates=None, reverse_rates=None):
    return one_step_de_sim_expanded(
        spec_conc=y, inputs=inputs,
        outputs=outputs,
        forward_rates=forward_rates,
        reverse_rates=reverse_rates)  # + signal(t) * signal_onehot


def bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
                                 inputs, outputs, forward_rates, reverse_rates,
                                 solver=dfx.Tsit5(),
                                 saveat=dfx.SaveAt(
                                     t0=True, t1=True, steps=True),
                                 max_steps=16**5,
                                 stepsize_controller=dfx.ConstantStepSize()):
    term = dfx.ODETerm(
        partial(bioreaction_sim_expanded,
                inputs=inputs, outputs=outputs,
                forward_rates=forward_rates.squeeze(), reverse_rates=reverse_rates.squeeze()
                )
    )
    return dfx.diffeqsolve(term, solver,
                           t0=t0, t1=t1, dt0=dt0,
                           y0=y0.squeeze(),
                           saveat=saveat, max_steps=max_steps,
                           stepsize_controller=stepsize_controller)


def f(reverse_rates, y0, t0, t1, dt0,
      inputs, outputs, forward_rates,
      output_idxs):
    s = bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
                                     inputs, outputs, forward_rates, reverse_rates,
                                     saveat=dfx.SaveAt(ts=np.linspace(t0, t1, 100)))
    ys = s.ys
    # ys = s.ys[:np.argmax(s.ts >= np.inf), :]
    # ts = s.ts[:np.argmax(s.ts >= np.inf)]

    cost = jnp.sum(jnp.abs(jnp.max(ys, axis=0)[
        output_idxs] - ys[-1, output_idxs]))
    return cost

In [4]:
n_species = 4
n_circuits = 5
n_reactions = np.sum(np.arange(n_species+1))
tot_species = n_species + n_reactions
signal_idxs = np.array([0])
output_idxs = np.array([1, 2])
y0 = np.array([200.0] * n_species + [0.0] * n_reactions) # [None, :] * np.ones((n_circuits, 1))
t0, t1, dt0 = 0, 20, 0.001
inputs, outputs = np.zeros((n_reactions, tot_species)), np.zeros(
    (n_reactions, tot_species))
inds = flatten_listlike([[(i, j) for j in range(i, n_species)]
                        for i in range(n_species)])
for i in range(n_reactions):
    inputs[i, inds[i][0]] += 1
    inputs[i, inds[i][1]] += 1
    outputs[i, i+n_species] += 1
forward_rates, reverse_rates = np.random.rand(
    n_reactions), np.random.rand(n_circuits, n_reactions)

f_t = partial(f, y0=y0, t0=t0, t1=t1, dt0=dt0, inputs=inputs, outputs=outputs, forward_rates=forward_rates, output_idxs=output_idxs)

In [5]:
cost = jax.vmap(f_t)(reverse_rates)


I0000 00:00:1704707362.368089  317070 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [6]:
cost

Array([389.75922, 387.81216, 393.3399 , 393.44666, 390.6364 ], dtype=float32)

In [7]:
l2_reg_alpha = 0.1
learning_rate = 0.001
warmup_epochs = 30
epochs = 1000
n_batches = 10
warmup_fn = optax.linear_schedule(
    init_value=0., end_value=learning_rate,
    transition_steps=warmup_epochs * n_batches)
cosine_epochs = max(epochs - warmup_epochs, 1)
cosine_fn = optax.cosine_decay_schedule(
    init_value=0.01, decay_steps=1000) #, alpha=l2_reg_alpha)
schedule_fn = optax.join_schedules(
    schedules=[warmup_fn, cosine_fn],
    boundaries=[warmup_epochs * n_batches])
optimiser = optax.sgd(learning_rate=schedule_fn)
optimiser_state = optimiser.init(reverse_rates[0])

In [8]:
def f_scan(inp, c):
    reverse_rates, optimiser_state = inp
    c, grads = jax.vmap(jax.value_and_grad(f_t, argnums=0))(reverse_rates)
    updates, optimiser_state = optimiser.update(grads, optimiser_state)
    reverse_rates = optax.apply_updates(reverse_rates, updates)
    return [reverse_rates, optimiser_state], c

cs = np.zeros(10)
reverse_rates, cs = jax.lax.scan(f_scan, init=[reverse_rates, optimiser_state], xs=cs)


In [9]:
cs

Array([[389.75922, 387.81216, 393.3399 , 393.44666, 390.6364 ],
       [389.75922, 387.81216, 393.3399 , 393.44666, 390.6364 ],
       [389.7574 , 387.8116 , 393.33676, 393.44116, 390.6349 ],
       [389.75375, 387.81052, 393.3304 , 393.4302 , 390.63184],
       [389.74835, 387.80884, 393.32098, 393.41394, 390.62723],
       [389.7411 , 387.80664, 393.30847, 393.39246, 390.62122],
       [389.73206, 387.80392, 393.29288, 393.3661 , 390.61362],
       [389.7214 , 387.80063, 393.27432, 393.33508, 390.60452],
       [389.709  , 387.797  , 393.25287, 393.29977, 390.5939 ],
       [389.6949 , 387.79272, 393.22858, 393.26044, 390.5818 ],
       [389.67923, 387.7878 , 393.20154, 393.21747, 390.56805],
       [389.6619 , 387.78238, 393.17188, 393.1713 , 390.5528 ],
       [389.64313, 387.77643, 393.1397 , 393.12222, 390.53595],
       [389.6229 , 387.76996, 393.1051 , 393.07056, 390.51764],
       [389.60114, 387.76294, 393.06827, 393.01675, 390.49762],
       [389.57806, 387.75543, 393.02917,