# Imports

In [1]:

import os
import sys
import numpy as np
import jax
import jax.numpy as jnp
import diffrax as dfx
import dynesty
from functools import partial
# jax.config.update('jax_platform_name', 'gpu')


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


root_dir = '..'

ModuleNotFoundError: No module named 'dynesty'

In [None]:
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.misc.helper import vanilla_return
from synbio_morpher.utils.results.analytics.timeseries import get_precision, get_sensitivity, get_peaks
from bioreaction.simulation.manager import simulate_steady_states



# Simulation

In [None]:

def scale_rates(forward_rates, reverse_rates):
    rate_max = np.max([np.max(np.asarray(forward_rates)),
                        np.max(np.asarray(reverse_rates))])

    dt0 = 1 / (2 * rate_max)
    return dt0

In [None]:
def optimise_sp(s, p):
    s_lin = 1 / p
    return s - s_lin

In [None]:
def one_step_de_sim_expanded(spec_conc, inputs, outputs, forward_rates, reverse_rates):
    concentration_factors_in = jnp.prod(
        jnp.power(spec_conc, (inputs)), axis=1)
    concentration_factors_out = jnp.prod(
        jnp.power(spec_conc, (outputs)), axis=1)
    forward_delta = concentration_factors_in * forward_rates
    reverse_delta = concentration_factors_out * reverse_rates
    return (forward_delta - reverse_delta) @ (outputs - inputs)


def dummy_simfunc(
        y0, t0, t1, dt0,
        forward_rates,
        reverse_rates,
        inputs,
        outputs,
        threshold = 0.01
        # max_steps,
        # signal=vanilla_return, signal_onehot=1
    ):
    steps = int((t1 - t0) / dt0)
    y = np.zeros((steps, len(y0.squeeze())))
    time = np.arange(t0, t1, dt0)
    y[0] = y0
    for t in range(1, steps):
        yt = one_step_de_sim_expanded(
            spec_conc=y[t-1], inputs=inputs,
            outputs=outputs,
            forward_rates=forward_rates,
            reverse_rates=reverse_rates)
        y[t] = y[t-1] + yt * dt0
        if np.sum((np.abs(yt) - threshold ) * dt0) < 0:
            return y[:t], time[:t]
    return y, time

In [None]:
def R(B11, B12, B13, B22, B23, B33):
    unbound_species = ['RNA_0', 'RNA_1', 'RNA_2']
    species = ['RNA_0', 'RNA_1', 'RNA_2', 'RNA_0-0', 'RNA_0-1', 'RNA_0-2', 'RNA_1-1', 'RNA_1-2', 'RNA_2-2']
    signal_species = 'RNA_0'
    output_species = 'RNA_1'
    s_idx = species.index(signal_species)
    output_idx = species.index(output_species)
    
    signal_target = 2
    k = 0.00150958097
    N0 = 100
    
    # Amounts
    y00 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]])
    
    # Reactions
    inputs = np.array([
        [2, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 0, 0, 0, 0, 0, 0],
    ])
    outputs = np.array([
        [0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1],
    ])
    
    # Rates
    reverse_rates = np.array([[B11, B12, B13, B22, B23, B33]])
    forward_rates = np.ones_like(reverse_rates) * k
    
    # Sim params
    t0 = 0
    t1 = 100
    dt0 = scale_rates(forward_rates, reverse_rates)
    max_steps = 16**4 * 10
    print('\n\nInput N:', y00)
    print('Input B:', reverse_rates)
    sim_func = jax.jit(partial(bioreaction_sim_dfx_expanded,
        t0=t0, t1=t1, dt0=dt0,
        signal=vanilla_return, signal_onehot=1,
        forward_rates=forward_rates,
        inputs=inputs,
        outputs=outputs,
        solver=dfx.Tsit5(),
        saveat=dfx.SaveAt(
            ts=np.linspace(t0, t1, 500)),  # int(np.min([500, self.t1-self.t0]))))
        max_steps=max_steps
        ))
    # sim_func = partial(dummy_simfunc,
    #     t0=t0, t1=t1, dt0=dt0,
    #     forward_rates=forward_rates,
    #     inputs=inputs,
    #     outputs=outputs)
    
    y0, t = simulate_steady_states(y0=y00, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    # y0, t = sim_func(y0=y00, reverse_rates=reverse_rates)
    y0 = np.array(y0.squeeze()[-1, :]).reshape(y00.shape)
    
    # Signal
    signal_onehot = np.zeros_like(y00) 
    signal_onehot[s_idx] = 1
    
    y0s = y0 * signal_onehot * signal_target + y0 * (signal_onehot == 0)
    y, t = simulate_steady_states(y0, total_time=t1-t0, sim_func=sim_func, t0=t0, t1=t1, threshold=0.1, reverse_rates=reverse_rates)
    # y, t = sim_func(y0=y0s, reverse_rates=reverse_rates)
    y = np.concatenate([y0s, y.squeeze()[:-1, :]], axis=0)
    y1 = np.array(y[-1, :])
    
    print('Output:', y1)
    
    # Calculate R
    peaks = get_peaks(initial_steady_states=y0, final_steady_states=y1,
                      maxa=y.max(axis=0), mina=y.min(axis=0)) 
    
    s = get_sensitivity(
        signal_idx=s_idx, peaks=peaks, starting_states=y0
    )
    p = get_precision(
            starting_states=y0,
            steady_states=y1,
            signal_0=y0[s_idx],
            signal_1=y1[s_idx].squeeze())
    print('Sensitivity:', s)
    print('Precision:', p)
    
    r = optimise_sp(
        s=s.squeeze()[output_idx], p=p.squeeze()[output_idx]
    )
    
    return r

### Check that R is working

In [None]:

dummy_B = np.array([
    899.999500, 899.999500, 0.090935, 899.999500, 0.000126, 899.999500])

r = R(*dummy_B)

r



Input N: [[100 100 100   0   0   0   0   0   0]]
Input B: [[8.999995e+02 8.999995e+02 9.093500e-02 8.999995e+02 1.260000e-04
  8.999995e+02]]
Done:  0:00:01.049104
Done:  0:00:01.023504


Input N: [100 100 100   0   0   0   0   0   0]
Input B: [[8.999995e+02 8.999995e+02 9.093500e-02 8.999995e+02 1.260000e-04
  8.999995e+02]]
Done:  0:00:00.865641
Done:  0:00:00.545558
Done:  0:00:00.545659
Done:  0:00:00.860932
Output: [[9.4636719e+01 8.6227551e+00 3.2730794e+00 1.5022194e-02 1.3687421e-03
  5.3562503e+00 1.2471246e-04 9.1370132e+01 1.7969280e-05]]
Sensitivity: [[1. 1. 1. 1. 1. 1. 1. 1. 1.]]
Precision: [[1. 1. 1. 1. 1. 1. 1. 1. 1.]]


ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 3 dimension(s)

# Log likelihood

In [None]:
ndim = 6

def loglike(B):
    
    L = - 1 / (R(*B) + 0.0001)
    
    print(L)
    
    return L


def ptform(u):
    R_max = 800
    x = R_max * u + 100
    
    return x

# Sampling

In [None]:

sampler = dynesty.NestedSampler(loglike, ptform, ndim)
sampler.run_nested()
sresults = sampler.results
from dynesty import plotting as dyplot
# initialize figure
import matplotlib.pyplot as plt
fig, axes = plt.subplots(6, 13, figsize=(25, 10))
axes = axes.reshape((6, 13))  # reshape axes

# add white space
[a.set_frame_on(False) for a in axes[:, 6]]
[a.set_xticks([]) for a in axes[:, 6]]
[a.set_yticks([]) for a in axes[:, 6]]

# plot initial run (res1; left)
fg, ax = dyplot.cornerpoints(sresults, cmap='plasma', truths=np.zeros(ndim),
                             kde=False, fig=(fig, axes[:, :2]))