# Imports

In [148]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [149]:

import os
import sys
import numpy as np
import jax
import jax.numpy as jnp
import diffrax as dfx
from functools import partial
import dynesty
jax.config.update('jax_platform_name', 'cpu')


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


root_dir = '..'

In [150]:
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.misc.helper import vanilla_return
from synbio_morpher.utils.results.analytics.timeseries import get_precision, get_sensitivity, get_peaks, generate_analytics
from bioreaction.simulation.manager import simulate_steady_states



# Simulation

In [151]:

def scale_rates(forward_rates, reverse_rates, cushioning: int = 4):
    rate_max = np.max([np.max(np.asarray(forward_rates)),
                        np.max(np.asarray(reverse_rates))])

    dt0 = 1 / (cushioning * rate_max)
    return dt0

In [152]:
def optimise_sp(s, p):
    s_lin = 1 / p
    return s - s_lin

In [153]:
def one_step_de_sim_expanded(spec_conc, inputs, outputs, forward_rates, reverse_rates):
    concentration_factors_in = jnp.prod(
        jnp.power(spec_conc, (inputs)), axis=1)
    concentration_factors_out = jnp.prod(
        jnp.power(spec_conc, (outputs)), axis=1)
    forward_delta = concentration_factors_in * forward_rates
    reverse_delta = concentration_factors_out * reverse_rates
    return (forward_delta - reverse_delta) @ (outputs - inputs)


def dummy_simfunc(
        y0, t0, t1, dt0,
        forward_rates,
        reverse_rates,
        inputs,
        outputs,
        threshold = 0.01
        # max_steps,
        # signal=vanilla_return, signal_onehot=1
    ):
    steps = int((t1 - t0) / dt0)
    y = np.zeros((steps, len(y0.squeeze())))
    time = np.arange(t0, t1, dt0)
    y[0] = y0
    for t in range(1, steps):
        yt = one_step_de_sim_expanded(
            spec_conc=y[t-1], inputs=inputs,
            outputs=outputs,
            forward_rates=forward_rates,
            reverse_rates=reverse_rates)
        y[t] = y[t-1] + yt * dt0
        if np.sum((np.abs(yt) - threshold ) * dt0) < 0:
            return y[:t], time[:t]
    return y, time

In [154]:
def compute_analytics(y, t, labels, signal_onehot):
    y = np.swapaxes(y, 0, 1)
    
    analytics_func = partial(
        generate_analytics, time=t, labels=labels,
        signal_onehot=signal_onehot, signal_time=0,
        ref_circuit_data=None)
    return analytics_func(data=y, time=t, labels=labels)

In [155]:
def R(B11, B12, B13, B22, B23, B33):
    unbound_species = ['RNA_0', 'RNA_1', 'RNA_2']
    species = ['RNA_0', 'RNA_1', 'RNA_2', 'RNA_0-0', 'RNA_0-1', 'RNA_0-2', 'RNA_1-1', 'RNA_1-2', 'RNA_2-2']
    signal_species = ['RNA_0']
    output_species = ['RNA_1']
    s_idxs = [species.index(s) for s in signal_species]
    output_idxs = [species.index(s) for s in output_species]
    signal_onehot = np.array([1 if s in [species.index(ss) for ss in signal_species] else 0 for s in np.arange(len(species))])
    
    signal_target = 2
    k = 0.00150958097
    N0 = 200
    
    # Amounts
    y00 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]])
    
    # Reactions
    inputs = np.array([
        [2, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 0, 0, 0, 0, 0, 0],
    ])
    outputs = np.array([
        [0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1],
    ])
    
    # Rates
    reverse_rates = np.array([[B11, B12, B13, B22, B23, B33]])
    forward_rates = np.ones_like(reverse_rates) * k
    
    # Sim params
    t0 = 0
    t1 = 100
    dt0 = scale_rates(forward_rates, reverse_rates, cushioning=4)
    max_steps = 16**4 * 10
    print('\n\nInput N:', y00)
    print('Input B:', reverse_rates)
    sim_func = jax.jit(partial(bioreaction_sim_dfx_expanded,
        t0=t0, t1=t1, dt0=dt0,
        signal=vanilla_return, signal_onehot=1,
        forward_rates=forward_rates,
        inputs=inputs,
        outputs=outputs,
        solver=dfx.Tsit5(),
        saveat=dfx.SaveAt(
            ts=np.linspace(t0, t1, 500)),  # int(np.min([500, self.t1-self.t0]))))
        max_steps=max_steps
        ))
    
    y0, t = simulate_steady_states(y0=y00, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    y0 = np.array(y0.squeeze()[-1, :]).reshape(y00.shape)
    
    # Signal
    
    y0s = y0 * ((signal_onehot == 0) * 1) + y00 * signal_target * signal_onehot
    y, t = simulate_steady_states(y0s, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    y = np.concatenate([y0, y.squeeze()[:-1, :]], axis=0)
    y1 = np.array(y[-1, :])
        
    print('Output:', y1)
    
    analytics = compute_analytics(y, t, labels=np.arange(y.shape[-1]), signal_onehot=signal_onehot)
    
    s = analytics['sensitivity_wrt_species-0']
    p = analytics['precision_wrt_species-0']
    print(f'Sensitivity {s_idxs[0]}:', s[tuple(s_idxs)])
    print(f'Precision {s_idxs[0]}:', p[tuple(s_idxs)])
    
    r = optimise_sp(
        s=s.squeeze()[tuple(output_idxs)], p=p.squeeze()[tuple(output_idxs)]
    )
    
    return r

### Check that R is working

In [156]:

dummy_B = np.array([0.000164, 899.999500, 0.000114, 899.999500, 899.9995, 899.9995], dtype=np.float64)

# Min sensitivity:
# toy_mRNA_circuit_11232	RNA_2_m15-0
# For [0.000164, 899.999500, 0.000114, 899.999500, 899.9995, 899.9995] :
# s = 2.442697e-07


# Max sensitivity:
# For [899.999500, 899.999500, 0.090935, 899.999500, 0.000126, 899.999500], it should be
# p = 


r = R(*dummy_B)



Input N: [[200 200 200   0   0   0   0   0   0]]
Input B: [[1.640000e-04 8.999995e+02 1.140000e-04 8.999995e+02 8.999995e+02
  8.999995e+02]]
Done:  0:00:01.609192
Done:  0:00:01.617003
Output: [2.6023965e+00 1.9975958e+02 1.7810986e+01 2.0855815e+02 8.7195769e-04
 1.8209904e+02 6.6931315e-02 5.9677567e-03 5.3209940e-04]
Sensitivity 0: [1.]
Precision 0: [165.4514]


# Log likelihood

In [157]:
ndim = 6

def loglike(B):
    
    L = - 1 / (R(*B) + 0.0001)
    
    print(L)
    
    return L


def ptform(u):
    R_max = 1
    R_min = 0.0001
    x = R_max * u + R_min
    
    return x

# Sampling

In [158]:

sampler = dynesty.NestedSampler(loglike, ptform, ndim)
sampler.run_nested()
sresults = sampler.results
from dynesty import plotting as dyplot
# initialize figure
import matplotlib.pyplot as plt
fig, axes = plt.subplots(6, 13, figsize=(25, 10))
axes = axes.reshape((6, 13))  # reshape axes

# add white space
[a.set_frame_on(False) for a in axes[:, 6]]
[a.set_xticks([]) for a in axes[:, 6]]
[a.set_yticks([]) for a in axes[:, 6]]

# plot initial run (res1; left)
fg, ax = dyplot.cornerpoints(sresults, cmap='plasma', truths=np.zeros(ndim),
                             kde=False, fig=(fig, axes[:, :2]))



Input N: [[200 200 200   0   0   0   0   0   0]]
Input B: [[0.87067947 0.45007486 0.07680774 0.27877479 0.88274225 0.33653068]]
Done:  0:00:00.582056
Done:  0:00:00.515412
Output: [197.17014    77.86359    37.4245     67.403175   51.492893  145.02681
  32.830044    4.983254    6.2826653]
Sensitivity 0: [1.]
Precision 0: [2.6140387]
-109.03312


Input N: [[200 200 200   0   0   0   0   0   0]]
Input B: [[0.48784045 0.3240048  0.62954757 0.72352542 0.26148844 0.9335126 ]]
Done:  0:00:00.516630
Done:  0:00:00.512430
Output: [186.99042   73.78293   92.075775 108.197556  64.2806    41.285015
  11.358325  39.21981   13.709668]
Sensitivity 0: [1.]
Precision 0: [3.2202015]
-1435.7252


Input N: [[200 200 200   0   0   0   0   0   0]]
Input B: [[0.22794678 0.5296043  0.14960344 0.26403175 0.51293016 0.84148191]]
Done:  0:00:00.514960
Done:  0:00:00.525805
Output: [140.89642   79.55304   68.89348  131.46893   31.949316  97.94732
  36.183754  16.129957   8.514672]
Sensitivity 0: [1.]
Precision 