In [None]:
%load_ext autoreload
%autoreload 2

# Imports

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import diffrax as dfx
from typing import List

from functools import partial
import os
import sys

jax.config.update('jax_platform_name', 'gpu')

if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


np.random.seed(0)
jax.devices()

In [None]:
from synbio_morpher.srv.parameter_prediction.simulator import make_piecewise_stepcontrol
from synbio_morpher.utils.misc.type_handling import flatten_listlike
from synbio_morpher.utils.modelling.physical import eqconstant_to_rates, equilibrium_constant_reparameterisation
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from synbio_morpher.utils.modelling.solvers import get_diffrax_solver, make_stepsize_controller
from synbio_morpher.utils.results.analytics.timeseries import calculate_adaptation, compute_peaks, compute_adaptability_full


# Set up test circuits

In [None]:
def make_species_bound(species_unbound):
    return sorted(set(flatten_listlike([['-'.join(sorted([x, y])) for x in species_unbound] for y in species_unbound])))


# RNA circuit settings
species_unbound = ['RNA_0', 'RNA_1', 'RNA_2']
species_bound = make_species_bound(species_unbound)
species = species_unbound + species_bound
species_signal = ['RNA_0']
species_output = ['RNA_2']
species_nonsignal = [s for s in species_unbound if s not in species_signal]
idxs_signal = np.array([species.index(s) for s in species_signal])
idxs_output = np.array([species.index(s) for s in species_output])
idxs_unbound = np.array([species.index(s) for s in species_unbound])
idxs_bound = np.array([species.index(s) for s in species_bound])
signal_onehot = np.array([1 if s in idxs_signal else 0 for s in np.arange(len(species))])

# Initial parameters
n_circuits = 1000
n_circuits_display = 30
k_a = 0.00150958097
N0 = 200
y00 = np.array([[N0, N0, N0, 0, 0, 0, 0, 0, 0]]).astype(np.float32)
y00 = np.repeat(y00, repeats=n_circuits, axis=0)

# Dynamic Simulation parameters
signal_target = 2
t0 = 0
t1 = 200
ts = np.linspace(t0, t1, 500)
dt0 = 0.0005555558569638981
dt1_factor = 5
dt1 = dt0 * dt1_factor
max_steps = 16**4 * 10
use_sensitivity_func1 = False
sim_method = 'Dopri5'
stepsize_controller = 'adaptive'

# MC parameters
total_steps = 10

# Reactions
energies = np.random.rand(n_circuits, len(species_unbound), len(species_unbound))
energies = np.interp(energies, (energies.min(), energies.max()), (-25, 0))
energies[np.tril_indices(len(species_unbound))] = energies[np.triu_indices(len(species_unbound))]
eqconstants = jax.vmap(equilibrium_constant_reparameterisation)(energies, y00[:, idxs_unbound])
forward_rates, reverse_rates = eqconstant_to_rates(eqconstants, k_a)
forward_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], forward_rates)))
reverse_rates = np.array(list(map(lambda r: r[np.triu_indices(len(species_unbound))], reverse_rates)))

inputs = np.array([
    [2, 0, 0, 0, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 0, 0, 0],
    [1, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 2, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 2, 0, 0, 0, 0, 0, 0],
], dtype=np.float64)
outputs = np.array([
    [0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1],
], dtype=np.float64)

# Initialise simulations

In [None]:
sim_func = jax.jit(jax.vmap(
                partial(bioreaction_sim_dfx_expanded,
                        t0=t0, t1=t1, dt0=dt0,
                        forward_rates=forward_rates,
                        inputs=inputs,
                        outputs=outputs,
                        solver=get_diffrax_solver(
                            sim_method),
                        saveat=dfx.SaveAt(
                            ts=jnp.linspace(t0, t1, 500)),  # int(np.min([500, t1-t0]))))
                        stepsize_controller=make_stepsize_controller(t0, t1, dt0, dt1,
                                                                     choice=stepsize_controller)
                        )))
sol_steady_states = jax.vmap(bioreaction_sim_dfx_expanded)(y00, reverse_rates)

y01 = np.array(sol_steady_states.ys[:, -1])
y01[:, np.array(idxs_signal)] = y01[:, np.array(idxs_signal)] * signal_target
sol_signal = jax.vmap(bioreaction_sim_dfx_expanded)(y01, reverse_rates)

In [None]:
adaptability, sensitivity, precision = jax.vmap(partial(compute_adaptability_full, idx_sig=idxs_signal[0], use_sensitivity_func1=use_sensitivity_func1))(
    sol_steady_states.ys, sol_signal.ys)

sensitivity = np.array(sensitivity)
precision = np.array(precision)

# Monte Carlo iterations

In [None]:


def choose_next(batch: list, data_writer, distance_func, choose_max: int = 4, target_species: List[str] = ['RNA_1', 'RNA_2'], use_diversity: bool = False):
    
    def make_data(batch, batch_analytics, target_species: List[str]):
        d = pd.DataFrame(
            data=np.concatenate(
                [
                    np.asarray([c.name for c in batch])[:, None],
                    np.asarray([c.subname for c in batch])[:, None]
                ], axis=1
            ),
            columns=['Name', 'Subname']
        )
        d['Circuit Obj'] = batch
        species_names = [s.name for s in batch[0].model.species]
        t_idxs = {s: species_names.index(s) for s in species_names if s in target_species}
        for t in target_species:
            t_idx = t_idxs[t]
            d[f'Sensitivity species-{t}'] = np.asarray([b['sensitivity_wrt_species-6'][t_idx] for b in batch_analytics])
            d[f'Precision species-{t}'] = np.asarray([b['precision_wrt_species-6'][t_idx] for b in batch_analytics])
            d[f'Overshoot species-{t}'] = np.asarray([b['overshoot'][t_idx] for b in batch_analytics])
            
            rs = d[d['Subname'] == 'ref_circuit']
            d[f'Parent Sensitivity species-{t}'] = jax.tree_util.tree_map(lambda n: rs[rs['Name'] == n][f'Sensitivity species-{t}'].iloc[0], d['Name'].to_list())
            d[f'Parent Precision species-{t}'] = jax.tree_util.tree_map(lambda n: rs[rs['Name'] == n][f'Precision species-{t}'].iloc[0], d['Name'].to_list())
        
            d[f'dS species-{t}'] = np.asarray([b['sensitivity_wrt_species-6_diff_to_base_circuit'][t_idx] for b in batch_analytics])
            d[f'dP species-{t}'] = np.asarray([b['precision_wrt_species-6_diff_to_base_circuit'][t_idx] for b in batch_analytics])
            # d[f'dS species-{t}'] = d[f'Sensitivity species-{t}'] - d[f'Parent Sensitivity species-{t}']
            # d[f'dP species-{t}'] = d[f'Precision species-{t}'] - d[f'Parent Precision species-{t}']
            
            # d[f'Diag Distance species-{t}'] = distance_func(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy())
            d[f'SP Prod species-{t}'] = sp_prod(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy(), 
                                                sp_factor=1, #(d[f'Precision species-{t}'] / d[f'Sensitivity species-{t}']).max(), 
                                                s_weight=0) #np.log(d[f'Precision species-{t}']) / d[f'Sensitivity species-{t}'])
            d[f'Log Distance species-{t}'] = np.array(log_distance(s=d[f'Sensitivity species-{t}'].to_numpy(), p=d[f'Precision species-{t}'].to_numpy()))
            # d[f'SP and distance species-{t}'] = np.log( np.power(d[f'Log Distance species-{t}'], dist_weight) * np.log(d[f'SP Prod species-{t}']))
            d[f'SP and distance species-{t}'] = d[f'Sensitivity species-{t}'] * d[f'Log Distance species-{t}']
            
        return d
    
    def select_next(data_1, choose_max, t, use_diversity: bool):
        # filt = (data_1[f'dS species-{t}'] >= 0) & (data_1[f'dP species-{t}'] >= 0) & (
        #     data_1[f'Sensitivity species-{t}'] >= data_1[data_1['Subname'] == 'ref_circuit'][f'Sensitivity species-{t}'].min()) & (
        #         data_1[f'Precision species-{t}'] >= data_1[data_1['Subname'] == 'ref_circuit'][f'Precision species-{t}'].min())
        
        data_1['Diversity selection'] = False
        circuits_chosen = data_1.sort_values(
            by=[f'SP and distance species-{t}', f'Log Distance species-{t}', f'SP Prod species-{t}', 'Name', 'Subname'], ascending=False)['Circuit Obj'].iloc[:choose_max].to_list()
        prev_circuits = data_1[data_1['Subname'] == 'ref_circuit']
        keep_n = int(0.7 * choose_max)
        if use_diversity and all([c in prev_circuits for c in circuits_chosen]) and (len(data_1) >= keep_n):
            _, circuits_chosen = select_next(data_1[data_1['Circuit Obj'].isin(prev_circuits[:keep_n])], choose_max, t)
            data_1['Diversity selection'] = data_1['Circuit Obj'].isin(circuits_chosen)
        
        data_1['Next selected'] = data_1['Circuit Obj'].isin(circuits_chosen)
        return data_1, circuits_chosen
        
    def get_batch_analytics(batch, data_writer):
        batch_analytics = []
        for c in batch:
            if c.subname == 'ref_circuit':
                batch_analytics.append(
                    load_json_as_dict(os.path.join(data_writer.top_write_dir, c.name, 'report_signal.json')))
            else:
                batch_analytics.append(
                    load_json_as_dict(os.path.join(data_writer.top_write_dir, c.name, 'mutations', c.subname, 'report_signal.json'))
                )
        batch_analytics = jax.tree_util.tree_map(lambda x: np.float64(x), batch_analytics)
        return batch_analytics
    
    batch_analytics = get_batch_analytics(batch, data_writer)
    data_1 = make_data(batch, batch_analytics, target_species)
    
    t = target_species[0]
    # circuits_chosen = data_1[(data_1[f'dS species-{t}'] >= 0) & (data_1[f'dP species-{t}'] >= 0)].sort_values(by=[f'Sensitivity species-{t}', f'Precision species-{t}'], ascending=False)['Circuit Obj'].iloc[:choose_max].to_list()
    data_1, circuits_chosen = select_next(data_1, choose_max, t, use_diversity)
    return circuits_chosen, data_1


def mutate(circuits):
    return circuits


In [None]:

for step in range(total_steps):
    
    print(f'\n\nStarting step {step+1} out of {total_steps}\n\n')

    batch = mutate(starting, evolver, algorithm=config['mutations_args']['algorithm'])
    batch = simulate(batch, modeller, config)
    starting, summary_data = choose_next(batch=expanded_batchs, data_writer=data_writer, distance_func=distance_func, 
                                         choose_max=choose_max, target_species=target_species, use_diversity=config.get('use_diversity', False))
    starting = process_for_next_run(starting, data_writer=data_writer)
    