# Imports

In [86]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [87]:

import os
import sys
import numpy as np
import jax
import jax.numpy as jnp
import diffrax as dfx
from functools import partial
import dynesty
jax.config.update('jax_platform_name', 'cpu')


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


root_dir = '..'

In [88]:
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.misc.helper import vanilla_return
from synbio_morpher.utils.results.analytics.timeseries import get_precision, get_sensitivity, get_peaks, generate_analytics
from bioreaction.simulation.manager import simulate_steady_states



# Simulation

In [89]:

def scale_rates(forward_rates, reverse_rates):
    rate_max = np.max([np.max(np.asarray(forward_rates)),
                        np.max(np.asarray(reverse_rates))])

    dt0 = 1 / (2 * rate_max)
    return dt0

In [90]:
def optimise_sp(s, p):
    s_lin = 1 / p
    return s - s_lin

In [91]:
def one_step_de_sim_expanded(spec_conc, inputs, outputs, forward_rates, reverse_rates):
    concentration_factors_in = jnp.prod(
        jnp.power(spec_conc, (inputs)), axis=1)
    concentration_factors_out = jnp.prod(
        jnp.power(spec_conc, (outputs)), axis=1)
    forward_delta = concentration_factors_in * forward_rates
    reverse_delta = concentration_factors_out * reverse_rates
    return (forward_delta - reverse_delta) @ (outputs - inputs)


def dummy_simfunc(
        y0, t0, t1, dt0,
        forward_rates,
        reverse_rates,
        inputs,
        outputs,
        threshold = 0.01
        # max_steps,
        # signal=vanilla_return, signal_onehot=1
    ):
    steps = int((t1 - t0) / dt0)
    y = np.zeros((steps, len(y0.squeeze())))
    time = np.arange(t0, t1, dt0)
    y[0] = y0
    for t in range(1, steps):
        yt = one_step_de_sim_expanded(
            spec_conc=y[t-1], inputs=inputs,
            outputs=outputs,
            forward_rates=forward_rates,
            reverse_rates=reverse_rates)
        y[t] = y[t-1] + yt * dt0
        if np.sum((np.abs(yt) - threshold ) * dt0) < 0:
            return y[:t], time[:t]
    return y, time

In [92]:
def compute_analytics(y, t, labels, signal_onehot):
    y = np.swapaxes(y, 0, 1)
    
    analytics_func = partial(
        generate_analytics, time=t, labels=labels,
        signal_onehot=signal_onehot, signal_time=0,
        ref_circuit_data=None)
    return analytics_func(data=y, time=t, labels=labels)

In [93]:
def R(B11, B12, B13, B22, B23, B33):
    unbound_species = ['RNA_0', 'RNA_1', 'RNA_2']
    species = ['RNA_0', 'RNA_1', 'RNA_2', 'RNA_0-0', 'RNA_0-1', 'RNA_0-2', 'RNA_1-1', 'RNA_1-2', 'RNA_2-2']
    signal_species = ['RNA_0']
    output_species = ['RNA_1']
    s_idxs = [species.index(s) for s in signal_species]
    output_idxs = [species.index(s) for s in output_species]
    signal_onehot = np.array([1 if s in [species.index(ss) for ss in signal_species] else 0 for s in np.arange(len(species))])
    
    signal_target = 2
    k = 0.00150958097
    N0 = 200
    
    # Amounts
    y00 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]])
    
    # Reactions
    inputs = np.array([
        [2, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 0, 0, 0, 0, 0, 0],
    ])
    outputs = np.array([
        [0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1],
    ])
    
    # Rates
    reverse_rates = np.array([[B11, B12, B13, B22, B23, B33]])
    forward_rates = np.ones_like(reverse_rates) * k
    
    # Sim params
    t0 = 0
    t1 = 100
    dt0 = scale_rates(forward_rates, reverse_rates)
    max_steps = 16**4 * 10
    print('\n\nInput N:', y00)
    print('Input B:', reverse_rates)
    sim_func = jax.jit(partial(bioreaction_sim_dfx_expanded,
        t0=t0, t1=t1, dt0=dt0,
        signal=vanilla_return, signal_onehot=1,
        forward_rates=forward_rates,
        inputs=inputs,
        outputs=outputs,
        solver=dfx.Tsit5(),
        saveat=dfx.SaveAt(
            ts=np.linspace(t0, t1, 500)),  # int(np.min([500, self.t1-self.t0]))))
        max_steps=max_steps
        ))
    # sim_func = partial(dummy_simfunc,
    #     t0=t0, t1=t1, dt0=dt0,
    #     forward_rates=forward_rates,
    #     inputs=inputs,
    #     outputs=outputs)
    
    y0, t = simulate_steady_states(y0=y00, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    # y0, t = sim_func(y0=y00, reverse_rates=reverse_rates)
    y0 = np.array(y0.squeeze()[-1, :]).reshape(y00.shape)
    
    # Signal
    
    y0s = y0 * signal_onehot * signal_target + y0 * (signal_onehot == 0)
    y, t = simulate_steady_states(y0, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    # y, t = sim_func(y0=y0s, reverse_rates=reverse_rates)
    y = np.concatenate([y0s, y.squeeze()[:-1, :]], axis=0)
    y1 = np.array(y[-1, :])
        
    print('Output:', y1)
    
    analytics = compute_analytics(y, t, labels=np.arange(y.shape[-1]), signal_onehot=signal_onehot)
    
    # Calculate R
    # peaks = get_peaks(initial_steady_states=y0.squeeze(), final_steady_states=y1.squeeze(),
    #                   maxa=y.max(axis=0), mina=y.min(axis=0)) 
    
    # s = get_sensitivity(
    #     signal_idx=s_idxs[0], peaks=peaks, starting_states=y0
    # )
    # p = get_precision(
    #         starting_states=y0,
    #         steady_states=y1,
    #         signal_0=y0[s_idxs[0]],
    #         signal_1=y1[s_idxs[0]].squeeze())
    
    s = analytics['sensitivity_wrt_species-0']
    p = analytics['precision_wrt_species-0']
    print(f'Sensitivity {s_idxs[0]}:', s)
    print(f'Precision {s_idxs[0]}:', p)
    
    r = optimise_sp(
        s=s.squeeze()[tuple(output_idxs)], p=p.squeeze()[tuple(output_idxs)]
    )
    
    return r, analytics

### Check that R is working

In [94]:

dummy_B = np.array([0.000164, 899.999500, 0.000114, 899.999500, 899.9995, 899.9995])

# Min sensitivity:
# toy_mRNA_circuit_11232	RNA_2_m15-0
# For [0.000164, 899.999500, 0.000114, 899.999500, 899.9995, 899.9995] :
# s = 2.442697e-07


# Max sensitivity:
# For [899.999500, 899.999500, 0.090935, 899.999500, 0.000126, 899.999500], it should be
# p = 


r, analytics = R(*dummy_B)

analytics



Input N: [[200 200 200   0   0   0   0   0   0]]
Input B: [[1.640000e-04 8.999995e+02 1.140000e-04 8.999995e+02 8.999995e+02
  8.999995e+02]]
Done:  0:00:01.097894
Done:  0:00:01.072107
Output: [1.87220320e-01 1.99802567e+02 9.74028778e+01 4.86558609e+01
 6.27434274e-05 1.02530769e+02 6.69601336e-02 3.26427743e-02
 1.59132145e-02]
Sensitivity 0: [[1.        ]
 [0.        ]
 [0.02770725]
 [0.02773285]
 [0.02289429]
 [0.02705233]
 [0.        ]
 [0.02770715]
 [0.05503044]]
Precision 0: [[ 1.0115796]
 [       inf]
 [36.091633 ]
 [36.058323 ]
 [43.67902  ]
 [36.965397 ]
 [       inf]
 [36.091763 ]
 [18.17176  ]]


{'first_derivative': Array([[-1.8510143e-01, -9.2548713e-02,  4.0978193e-06, ...,
          4.3138862e-06,  4.3064356e-06,  4.3213367e-06],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00, -1.3771057e-03, -2.7542114e-03, ...,
         -2.7427673e-03, -2.7465820e-03, -2.7465820e-03],
        ...,
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
          0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
        [ 0.0000000e+00, -4.6007335e-07, -9.2014670e-07, ...,
         -9.2014670e-07, -9.2014670e-07, -9.2014670e-07],
        [ 0.0000000e+00, -4.5541674e-07, -9.1176480e-07, ...,
         -8.9779496e-07, -8.9593232e-07, -8.9406967e-07]], dtype=float32),
 'initial_steady_states': Array([[3.7020287e-01],
        [1.9980257e+02],
        [9.8771217e+01],
        [4.9340031e+01],
        [6.2033323e-05],
        [1.0116243e+02],
        [6.6960134e-02],
        [3.3101346e-02],
     

In [95]:
asdfg

NameError: name 'asdfg' is not defined

# Log likelihood

In [None]:
ndim = 6

def loglike(B):
    
    L = - 1 / (R(*B) + 0.0001)
    
    print(L)
    
    return L


def ptform(u):
    R_max = 800
    x = R_max * u + 100
    
    return x

# Sampling

In [None]:

sampler = dynesty.NestedSampler(loglike, ptform, ndim)
sampler.run_nested()
sresults = sampler.results
from dynesty import plotting as dyplot
# initialize figure
import matplotlib.pyplot as plt
fig, axes = plt.subplots(6, 13, figsize=(25, 10))
axes = axes.reshape((6, 13))  # reshape axes

# add white space
[a.set_frame_on(False) for a in axes[:, 6]]
[a.set_xticks([]) for a in axes[:, 6]]
[a.set_yticks([]) for a in axes[:, 6]]

# plot initial run (res1; left)
fg, ax = dyplot.cornerpoints(sresults, cmap='plasma', truths=np.zeros(ndim),
                             kde=False, fig=(fig, axes[:, :2]))



Input N: [[100 100 100   0   0   0   0   0   0]]
Input B: [[231.57366282 179.11849832 516.61303364 498.78202482 248.64310843
  894.54357911]]


Done:  0:00:38.349748
Done:  0:00:38.290739
Output: [9.97572556e+01 9.97953262e+01 9.98767319e+01 6.48718253e-02
 8.39016438e-02 2.91138701e-02 3.01415827e-02 6.05137907e-02
 1.68338548e-02]
Sensitivity: [[1. 1. 1. 1. 1. 1. 1. 1. 1.]]
Precision: [[nan inf inf inf inf inf inf inf inf]]
-0.9999


Input N: [[100 100 100   0   0   0   0   0   0]]
Input B: [[267.27751787 547.53853283 172.8296925  529.01423898 311.04114566
  361.75569528]]
Done:  0:01:10.805630
Exception while calling loglikelihood function:
  params: [267.27751787 547.53853283 172.8296925  529.01423898 311.04114566
 361.75569528]
  args: []
  kwargs: {}
  exception:


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/dynesty/dynesty.py", line 910, in __call__
    return self.func(np.asarray(x).copy(), *self.args, **self.kwargs)
  File "/tmp/ipykernel_20484/820544128.py", line 5, in loglike
    L = - 1 / (R(*B) + 0.0001)
  File "/tmp/ipykernel_20484/3261121794.py", line 71, in R
    y, t = simulate_steady_states(y0, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
  File "/usr/lib/install_requirements/src/bioreaction/src/bioreaction/simulation/manager.py", line 45, in simulate_steady_states
    x_res = sim_func(y00, **sim_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py", line 253, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/usr/local/lib/python3.10/

KeyboardInterrupt: 

: 