In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Misc

In [74]:
import jax 
import jax.numpy as jnp
import numpy as np
import optax
from scipy.special import factorial
import diffrax as dfx
from functools import partial
from bioreaction.misc.misc import flatten_listlike

jax.config.update('jax_platform_name', 'cpu')


In [7]:
def one_step_de_sim_expanded(spec_conc, inputs, outputs, forward_rates, reverse_rates):
    concentration_factors_in = jnp.prod(
        jnp.power(spec_conc, (inputs)), axis=1)
    concentration_factors_out = jnp.prod(
        jnp.power(spec_conc, (outputs)), axis=1)
    forward_delta = concentration_factors_in * forward_rates
    reverse_delta = concentration_factors_out * reverse_rates
    return (forward_delta - reverse_delta) @ (outputs - inputs)


# ODE Terms
def bioreaction_sim_expanded(t, y,
                             args,
                             inputs, outputs,
                             #  signal, signal_onehot: jnp.ndarray,
                             forward_rates=None, reverse_rates=None):
    return one_step_de_sim_expanded(
        spec_conc=y, inputs=inputs,
        outputs=outputs,
        forward_rates=forward_rates,
        reverse_rates=reverse_rates)  # + signal(t) * signal_onehot


def bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
                                 inputs, outputs, forward_rates, reverse_rates,
                                 solver=dfx.Tsit5(),
                                 saveat=dfx.SaveAt(
                                     t0=True, t1=True, steps=True),
                                 max_steps=16**5,
                                 stepsize_controller=dfx.ConstantStepSize()):
    term = dfx.ODETerm(
        partial(bioreaction_sim_expanded,
                inputs=inputs, outputs=outputs,
                forward_rates=forward_rates.squeeze(), reverse_rates=reverse_rates.squeeze()
                )
    )
    return dfx.diffeqsolve(term, solver,
                           t0=t0, t1=t1, dt0=dt0,
                           y0=y0.squeeze(),
                           saveat=saveat, max_steps=max_steps,
                           stepsize_controller=stepsize_controller)

In [82]:
def f(y0, t0, t1, dt0,
      inputs, outputs, forward_rates, reverse_rates):
    s = bioreaction_sim_dfx_expanded(y0, t0, t1, dt0,
                                     inputs, outputs, forward_rates, reverse_rates)
    ys = s.ys[:, :np.argmax(s.ts >= np.inf), :]
    ts = s.ts[:, :np.argmax(s.ts >= np.inf), :]
    
    c = jnp.max(ys) - ys[-1]
    return c


n_species = 3
n_circuits = 5
n_reactions = int(factorial(n_species))
tot_species = n_species + n_reactions
y0 = np.array([200.0] * n_species + [0.0] * n_reactions)[None, :] * np.ones((n_circuits, 1))
t0, t1, dt0 = 0, 100, 0.001
inputs, outputs = np.zeros((n_reactions, tot_species)), np.zeros(
    (n_reactions, tot_species))
inds = flatten_listlike([[(i, j) for j in range(i, n_species)]
                        for i in range(n_species)])
for i in range(n_reactions):
    inputs[i, inds[i][0]] += 1
    inputs[i, inds[i][1]] += 1
    outputs[i, i+n_species] += 1
forward_rates, reverse_rates = np.random.rand(
    n_reactions), np.random.rand(n_circuits, n_reactions)

f_t = partial(f, t0=t0, t1=t1, dt0=dt0, inputs=inputs, outputs=outputs, forward_rates=forward_rates)
c = jax.vmap(f_t)(y0, reverse_rates=reverse_rates)

IndexError: Too many indices for array: 3 non-None/Ellipsis indices for dim 2.

In [72]:
c, grads = jax.value_and_grad(f_t)(y0, reverse_rates=reverse_rates)

l2_reg_alpha = 0.1
learning_rate = 0.001
warmup_epochs = 30
epochs = 1000
n_batches = 100
warmup_fn = optax.linear_schedule(
    init_value=0., end_value=learning_rate,
    transition_steps=warmup_epochs * n_batches)
cosine_epochs = max(epochs - warmup_epochs, 1)
cosine_fn = optax.cosine_decay_schedule(
    init_value=0.01, decay_steps=1000, alpha=l2_reg_alpha)
schedule_fn = optax.join_schedules(
    schedules=[warmup_fn, cosine_fn],
    boundaries=[warmup_epochs * n_batches])
optimiser = optax.sgd(learning_rate=schedule_fn)
optimiser_state = optimiser.init(reverse_rates)

updates, optimiser_state = optimiser.update(grads, optimiser_state)
params = optax.apply_updates(params, updates)


(Array(169.3576, dtype=float32),
 Array([5.3504956e-01, 3.1107688e-01, 2.6252866e-04,           nan,
                  nan,           nan,           nan,           nan,
                  nan], dtype=float32))