# Imports

In [1]:

import os
import sys
import numpy as np
import jax
import jax.numpy as jnp
import diffrax as dfx
import dynesty
from functools import partial
# jax.config.update('jax_platform_name', 'gpu')


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


root_dir = '..'

In [2]:
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.misc.helper import vanilla_return
from synbio_morpher.utils.results.analytics.timeseries import get_precision, get_sensitivity, get_peaks
from bioreaction.simulation.manager import simulate_steady_states



# Simulation

In [3]:

def scale_rates(forward_rates, reverse_rates):
    rate_max = np.max([np.max(np.asarray(forward_rates)),
                        np.max(np.asarray(reverse_rates))])

    dt0 = 1 / (2 * rate_max)
    return dt0

In [4]:
16**4 * 10

655360

In [5]:
def optimise_sp(s, p):
    s_lin = 1 / p
    return s - s_lin

In [6]:


def one_step_de_sim_expanded(spec_conc, inputs, outputs, forward_rates, reverse_rates):
    concentration_factors_in = jnp.prod(
        jnp.power(spec_conc, (inputs)), axis=1)
    concentration_factors_out = jnp.prod(
        jnp.power(spec_conc, (outputs)), axis=1)
    forward_delta = concentration_factors_in * forward_rates
    reverse_delta = concentration_factors_out * reverse_rates
    return (forward_delta - reverse_delta) @ (outputs - inputs)


def dummy_simfunc(
        y0, t0, t1, dt0,
        forward_rates,
        reverse_rates,
        inputs,
        outputs,
        threshold = 0.1
        # max_steps,
        # signal=vanilla_return, signal_onehot=1
    ):
    steps = int((t1 - t0) / dt0)
    y = np.zeros((steps, len(y0.squeeze())))
    time = np.linspace(t0, t1, dt0)
    y[0] = y0
    for t in range(1, steps):
        yt = one_step_de_sim_expanded(
            spec_conc=y[t-1], inputs=inputs,
            outputs=outputs,
            forward_rates=forward_rates,
            reverse_rates=reverse_rates)
        y[t] += yt / dt0
        if yt / dt0 < threshold:
            return y[:t], time[:t]
    return y, time

In [11]:
def R(B11, B12, B13, B22, B23, B33):
    unbound_species = ['RNA_0', 'RNA_1', 'RNA_2']
    species = ['RNA_0', 'RNA_1', 'RNA_2', 'RNA_0-0', 'RNA_0-1', 'RNA_0-2', 'RNA_1-1', 'RNA_1-2', 'RNA_2-2']
    signal_species = 'RNA_0'
    output_species = 'RNA_1'
    s_idx = species.index(signal_species)
    output_idx = species.index(output_species)
    
    k = 0.001
    N0 = 100
    
    # Amounts
    y0 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]])
    
    # Reactions
    inputs = np.array([
        [2, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 0, 0, 0, 0, 0, 0],
    ])
    outputs = np.array([
        [0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1],
    ])
    
    # Rates
    reverse_rates = np.array([[B11, B12, B13, B22, B23, B33]])
    forward_rates = np.ones_like(reverse_rates) * k
    
    # Sim params
    t0 = 0
    t1 = 100
    dt0 = scale_rates(forward_rates, reverse_rates)
    max_steps = 16**4 * 10
    print('Input N:', y0)
    print('Input B:', reverse_rates)
    print('Time steps expected:', (t1 - t0) / (dt0))
    print('Max time steps set:', max_steps)
    # sim_func = partial(bioreaction_sim_dfx_expanded,
    #     t0=t0, t1=t1, dt0=dt0,
    #     signal=vanilla_return, signal_onehot=1,
    #     forward_rates=forward_rates,
    #     inputs=inputs,
    #     outputs=outputs,
    #     solver=dfx.Tsit5(),
    #     saveat=dfx.SaveAt(
    #         # t0=True, t1=True),
    #         ts=np.linspace(t0, t1, 500)),  # int(np.min([500, self.t1-self.t0]))))
    #         # ts=np.interp(np.logspace(0, 2, num=500), [1, np.power(10, 2)], [self.t0, self.t1])),  # Save more points early in the sim
    #     # stepsize_controller=self.make_stepsize_controller(choice='piecewise')
    #     max_steps=max_steps
    #     )
    sim_func = partial(dummy_simfunc,
        t0=t0, t1=t1, dt0=dt0,
        forward_rates=forward_rates,
        inputs=inputs,
        outputs=outputs)
    
    # y, t = simulate_steady_states(y0, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    y, t = sim_func(y0=y0, reverse_rates=reverse_rates)
    y1 = y[-1, :]
    
    print('\n\nOutput:', y)
    
    # Calculate R
    peaks = get_peaks(initial_steady_states=y0, final_steady_states=y1,
                      maxa=y1.max(axis=1), mina=y1.min(axis=1)) 
    
    s = get_sensitivity(
        signal_idx=s_idx, peaks=peaks, starting_states=y0
    )[output_idx], 
    p = get_precision(
            starting_states=y0,
            steady_states=y1,
            signal_0=y0[s_idx],
            signal_1=y1[s_idx].squeeze())
    print('Peaks:', peaks)
    print('Sensitivity:', s)
    print('Precision:', p)
    
    r = optimise_sp(
        s=s, p=p
    )
    
    return r

### Check that R is working

In [12]:
dummy_B = np.array([0.1, 0.1, 800, 30, 0.1, 0.1])
r = R(*dummy_B)

r

Input N: [[100 100 100   0   0   0   0   0   0]]
Input B: [[1.e-01 1.e-01 8.e+02 3.e+01 1.e-01 1.e-01]]
Time steps expected: 160000.0
Max time steps set: 655360


In [None]:
dfghj

# Log likelihood

In [None]:
ndim = 6

def loglike(B):
    
    L = - 1 / R(*B)
    
    print(L)
    
    return L


def ptform(u):
    R_max = 800
    x = R_max * u + 100
    
    return x

# Sampling

In [None]:

sampler = dynesty.NestedSampler(loglike, ptform, ndim)
sampler.run_nested()
sresults = sampler.results
from dynesty import plotting as dyplot
# initialize figure
import matplotlib.pyplot as plt
fig, axes = plt.subplots(6, 13, figsize=(25, 10))
axes = axes.reshape((6, 13))  # reshape axes

# add white space
[a.set_frame_on(False) for a in axes[:, 6]]
[a.set_xticks([]) for a in axes[:, 6]]
[a.set_yticks([]) for a in axes[:, 6]]

# plot initial run (res1; left)
fg, ax = dyplot.cornerpoints(sresults, cmap='plasma', truths=np.zeros(ndim),
                             kde=False, fig=(fig, axes[:, :2]))

Input N: [[100 100 100   0   0   0   0   0   0]]
Input B: [[185.40290543 105.15853949 107.57711556 742.60044632 368.11860316
  112.56012704]]
Time steps expected: 148520.08926350862
Max time steps set: 655360


Done:  0:00:41.584489


Output: [[[1.0000000e+02 1.0000000e+02 1.0000000e+02 ... 0.0000000e+00
   0.0000000e+00 0.0000000e+00]]

 [[9.9705696e+01 9.9851479e+01 9.9703972e+01 ... 1.3426222e-02
   2.7044520e-02 8.8316157e-02]]

 [[9.9705696e+01 9.9851479e+01 9.9703972e+01 ... 1.3426222e-02
   2.7044520e-02 8.8316157e-02]]

 ...

 [[9.9705696e+01 9.9851479e+01 9.9703972e+01 ... 1.3426222e-02
   2.7044520e-02 8.8316157e-02]]

 [[9.9705696e+01 9.9851479e+01 9.9703972e+01 ... 1.3426222e-02
   2.7044520e-02 8.8316157e-02]]

 [[9.9705696e+01 9.9851479e+01 9.9703972e+01 ... 1.3426222e-02
   2.7044520e-02 8.8316157e-02]]]
[[[nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  ...
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]]

 [[nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  ...
  [nan nan nan ... nan nan nan]
  [nan nan nan ... nan nan nan]
  [nan nan nan