In [1]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import diffrax as dfx
from typing import List

from functools import partial
import os
import sys

jax.config.update('jax_platform_name', 'gpu')

if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


np.random.seed(0)
jax.devices()

[cuda(id=0)]

In [3]:
from synbio_morpher.srv.parameter_prediction.simulator import make_piecewise_stepcontrol
from synbio_morpher.utils.misc.type_handling import flatten_listlike
from synbio_morpher.utils.modelling.physical import eqconstant_to_rates, equilibrium_constant_reparameterisation
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.modelling.solvers import get_diffrax_solver, make_stepsize_controller
from synbio_morpher.utils.results.analytics.timeseries import calculate_adaptation, compute_peaks, compute_adaptability_full


# Set up test circuits

In [4]:
def make_species_bound(species_unbound):
    return sorted(set(flatten_listlike([['-'.join(sorted([x, y])) for x in species_unbound] for y in species_unbound])))


# RNA circuit settings
species_unbound = ['RNA_0', 'RNA_1', 'RNA_2']
species_bound = make_species_bound(species_unbound)
species = species_unbound + species_bound
species_signal = ['RNA_0']
species_output = ['RNA_2']
species_nonsignal = [s for s in species_unbound if s not in species_signal]
idxs_signal = np.array([species.index(s) for s in species_signal])
idxs_output = np.array([species.index(s) for s in species_output])
idxs_unbound = np.array([species.index(s) for s in species_unbound])
idxs_bound = np.array([species.index(s) for s in species_bound])
signal_onehot = np.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])

# Dynamic Simulation parameters
k_a = 0.00150958097
signal_target = 2
t0 = 0
t1 = 200
ts = np.linspace(t0, t1, 500)
dt0 = 0.0005555558569638981
dt1_factor = 5
dt1 = dt0 * dt1_factor
max_steps = 16**4 * 10
use_sensitivity_func1 = False
sim_method = 'Dopri5'
stepsize_controller = 'adaptive'

# MC parameters
total_steps = 10
total_samples = 100
choose_max = 20
mutation_scale = 0.1
n_circuits_display = 30
N0 = 200
y00 = np.array([[N0] * len(species_unbound) + [0] * len(species_bound)]).astype(np.float32)
y00 = np.repeat(y00, repeats=total_samples, axis=0)

# Reactions
energies = np.random.rand(total_samples, len(species_unbound), len(species_unbound))
energies = np.interp(energies, (energies.min(), energies.max()), (-25, 0))
energies[np.tril_indices(len(species_unbound))] = energies[np.triu_indices(len(species_unbound))]
eqconstants = jax.vmap(equilibrium_constant_reparameterisation)(energies, y00[:, idxs_unbound])
forward_rates, reverse_rates = eqconstant_to_rates(eqconstants, k_a)
forward_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], forward_rates)))
reverse_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], reverse_rates)))

inputs = np.array([
    [2, 0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0],
    [1, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 2, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 2, 0, 0, 0, 0, 0, 0],
], dtype=np.float64)
outputs = np.array([
    [0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1],
], dtype=np.float64)

# Initialise simulations

In [5]:
sim_func = jax.jit(jax.vmap(
    partial(bioreaction_sim_dfx_expanded,
            t0=t0, t1=t1, dt0=dt0,
            forward_rates=forward_rates,
            inputs=inputs,
            outputs=outputs,
            solver=get_diffrax_solver(
                sim_method),
            saveat=dfx.SaveAt(
                ts=jnp.linspace(t0, t1, 500)),  # int(np.min([500, t1-t0]))))
            stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1,
                                                         choice=stepsize_controller)
            )))
# sol_steady_states = jax.vmap(bioreaction_sim_dfx_expanded)(y00, reverse_rates)
# y01 = np.array(sol_steady_states.ys[:, -1])
# y01[:, np.array(idxs_signal)] = y01[:, np.array(idxs_signal)] * signal_target
# sol_signal = jax.vmap(bioreaction_sim_dfx_expanded)(y01, reverse_rates)

In [6]:
# adaptability, sensitivity, precision = jax.vmap(partial(compute_adaptability_full, idx_sig=idxs_signal[0], use_sensitivity_func1=use_sensitivity_func1))(
#     sol_steady_states.ys, sol_signal.ys)
# sensitivity = np.array(sensitivity)
# precision = np.array(precision)

# Monte Carlo iterations

In [7]:
def choose_next(params, sol, idxs_signal, use_sensitivity_func1: bool, choose_max: int, n_samples_per_parent: int, mutation_scale=0.1):
    """ Choose circuits with highest adaptability """

    sol_steady_states, sol_signal = sol
    adaptability, sensitivity, precision = compute_adaptability_full(
        sol_steady_states.ys, sol_signal.ys, idxs_signal, use_sensitivity_func1)
    idxs_next = jnp.argsort(adaptability)[-choose_max:]
    # summary_data = {}
    # summary_data['adaptability'] = adaptability
    return params[idxs_next], adaptability


def mutate(parents: jnp.ndarray, n_samples_per_parent, mutation_scale):
    # Generate mutated samples from each parent
    mutated = jax.tree_util.tree_map(
        lambda x: x + mutation_scale * np.random.randn(n_samples_per_parent, *x.shape), parents)
    return mutated


def simulate(y00, reverse_rates):
    sol_steady_states = jax.vmap(
        bioreaction_sim_dfx_expanded)(y00, reverse_rates)
    y01 = np.array(sol_steady_states.ys[:, -1])
    y01[:, np.array(idxs_signal)] = y01[:, np.array(
        idxs_signal)] * signal_target
    sol_signal = jax.vmap(bioreaction_sim_dfx_expanded)(y01, reverse_rates)
    return sol_steady_states, sol_signal

In [8]:
gen_y = reverse_rates
n_samples_per_parent = total_samples//choose_max
params_all = np.zeros((total_steps, total_samples, *reverse_rates.shape))
adaptability_all = np.zeros((total_steps, total_samples))

for step in range(total_steps):

    print(f'\n\nStarting step {step+1} out of {total_steps}\n\n')

    sol = simulate(y00, gen_y)
    next_starting, adaptability = choose_next(params=gen_y, sol=sol, idxs_signal=idxs_signal,
                                              use_sensitivity_func1=use_sensitivity_func1, choose_max=choose_max, n_samples_per_parent=n_samples_per_parent)
    gen_z = mutate(next_starting, n_samples_per_parent, mutation_scale)
    params_all[step] = gen_z
    adaptability_all[step] = adaptability



Starting step 1 out of 10




TypeError: bioreaction_sim_dfx_expanded() missing 6 required positional arguments: 't1', 'dt0', 'inputs', 'outputs', 'forward_rates', and 'reverse_rates'