In [None]:
%load_ext autoreload
%autoreload 2

# Imports

In [None]:
import numpy as np
import plotly.graph_objects as go
import os
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import jax
from copy import deepcopy


jax.config.update('jax_platform_name', 'cpu')


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)
    

In [None]:
from synbio_morpher.scripts.parameter_based_simulation.run_parameter_based_simulation import make_interaction_matrices
from synbio_morpher.utils.common.testing.minimal_sim import mini_sim
from synbio_morpher.utils.data.data_format_tools.common import load_json_as_dict
from synbio_morpher.utils.misc.string_handling import prettify_keys_for_label
from synbio_morpher.utils.misc.numerical import make_symmetrical_matrix_from_sequence
from synbio_morpher.utils.parameter_inference.interpolation_grid import create_parameter_range

# Load analytic

In [None]:
sdir = '../data/parameter_based_simulation/2023_08_24_114212'
analytic_name = 'precision'
config = load_json_as_dict('../data/parameter_based_simulation/2023_08_24_114212/experiment.json')['config_params']
param_range = create_parameter_range(config['parameter_based_simulation'])


In [None]:
a = {}
for f in os.listdir(sdir):
    if f.endswith('npy'):
        a[f.split('.')[0]] = np.load(os.path.join(sdir, f))

# Find top circuits

In [None]:
def find_top_n_peaks(arrs: list, n: int):
    """ Find the top n peaks across the multi-dimensional arrays in arrs """
    top_peaks = []
    for ni in range(n):

        objective = arrs[0] / arrs[0].max()
        for arr in arrs[1:]:
            objective += arr / arr.max()
        top_peak = np.argmax(objective)
        top_peak_idx = np.unravel_index(top_peak, arr.shape)

        for i in range(len(arrs)):
            arrs[i][top_peak_idx] = 0

        # next_top_peaks = find_top_n_peaks(arrs, n=n-1)
        # if (np.sum(top_peak_idx) >= np.sum(next_top_peaks[-1]) + 1) & (
        #     np.sum(top_peak_idx) <= np.sum(next_top_peaks[-1]) - 1
        # ):
        top_peaks.append(top_peak_idx)
        # top_peaks = top_peaks + next_top_peaks

    return top_peaks


def filter_top_peaks(top_indices, radius: int):
    chosen = [0]
    for i, t in enumerate(top_indices[1:]):
        # if (np.sum(np.abs(np.array(t) - np.array(top_indices[i]))) > radius):
        # if np.sum(np.power(np.array(t) - np.array(top_indices[i]), 0.5)) > radius:
        if np.sum((np.array(t) - np.array(top_indices[i])) == 0) == 0:
            chosen.append(i+1)
    top_indices = [top_indices[i] for i in chosen]
    return top_indices

## Setting for best circuits

In [None]:
n_top = 10000
i_spec = 1
radius = 1

pp = np.where(a['precision_wrt_species-6'] > 1e1,
              1e1, a['precision_wrt_species-6'])
# Overshoot / signal diff
oo = np.where((a['max_amount'][0] - a['initial_steady_states'][0]) == 0,
              0, a['overshoot'] / (a['max_amount'][0] - a['initial_steady_states'][0]))
oo = np.where(oo < 1e-4, 0, oo)
# oo = np.where(a['overshoot'] < 1e-4, 0, np.where(a['initial_steady_states']
#               == 0, a['overshoot'], a['overshoot'] / a['initial_steady_states']))
# oo = np.where(oo < 1e-4, 0, np.where(a['initial_steady_states']
#               == 0, a['overshoot'], a['overshoot'] / a['initial_steady_states']))
arrs = [a['sensitivity_wrt_species-6'][i_spec], pp[i_spec], oo[i_spec]]
top_indices = find_top_n_peaks(deepcopy(arrs), n=n_top)
n_ind = len(top_indices)
n_prev = 0
while n_ind != n_prev:
    n_prev = len(top_indices)
    top_indices = filter_top_peaks(top_indices, radius=radius)
    n_ind = len(top_indices)

print(len(top_indices))
top_indices

In [None]:
print('senstivity \t overshoot \t precision')
for t in top_indices:
    print(a['sensitivity_wrt_species-6'][1][tuple(t)], '\t', a['overshoot'][1][tuple(t)], '\t', a['precision_wrt_species-6'][1][t])

In [None]:
config['simulation']


## Run mini sim


In [None]:
saves = {}
for i, ti in enumerate(top_indices):
    r, analytics, y, t = mini_sim(*param_range[np.array(ti)])
    saves[i] = {
        'analytics': analytics,
        'y': y,
        't': t
    }

In [None]:
n_rows = np.ceil(len(saves) / 6).astype(int)
n_cols = np.ceil(len(saves) / n_rows).astype(int)
plt.figure(figsize=(7*n_cols, 6*n_rows))
for i, run in saves.items():
    ax = plt.subplot(n_rows, n_cols, i+1)
    plt.plot(run['t'], run['y'][:, 7])
    plt.title('S = ' + str(run['analytics']['sensitivity_wrt_species-6'][7]) +
              ', O = ' + str(run['analytics']['overshoot'][7]))

    sensitivity = np.absolute(np.divide(
        np.divide(
            run['analytics']['min_amount'][7] -
            run['analytics']['initial_steady_states'][7],
            run['analytics']['initial_steady_states'][7]),
        np.divide(
            run['analytics']['max_amount'][6] -
            run['analytics']['initial_steady_states'][6],
            run['analytics']['initial_steady_states'][6])
    ))
    print(i, '\t', sensitivity, '\t', run['analytics']['overshoot'][7], '\t', run['analytics']['precision_wrt_species-6'][7])

In [None]:
plt.imshow(np.log10(param_range[sorted(top_indices)]))


idk

In [None]:

unbound_species = ['RNA_0', 'RNA_1', 'RNA_2']
species = ['RNA_0-0', 'RNA_0-1', 'RNA_0-2', 'RNA_1-1', 'RNA_1-2', 'RNA_2-2', 'RNA_0', 'RNA_1', 'RNA_2']
inds = [i for i, s in enumerate(species) if s in unbound_species]

analytics['sensitivity_wrt_species-6'][:]
analytics['precision_wrt_species-6']


In [None]:
num_species = 3
interaction_matrices, all_interaction_strength_choices = make_interaction_matrices(
    num_species=num_species, interaction_strengths=param_range,
    num_unique_interactions=6, starting_iteration=0, end_iteration=10)

for i, interaction_strength_choices in enumerate(all_interaction_strength_choices):
    idxs = [slice(0, num_species)] + [[strength_idx] for strength_idx in interaction_strength_choices]
    r, analytics = mini_sim(*interaction_matrices[i][np.triu_indices(num_species)])

    print(f'\nInteraction {i}:')
    print('prev', a['sensitivity_wrt_species-6'][tuple(idxs)])
    print('now', analytics['sensitivity_wrt_species-6'])
    
n_species = 3
r, analytics = mini_sim(*interaction_matrices[0][np.triu_indices(n_species)])


In [None]:
for interaction_strength_choices in all_interaction_strength_choices[:1]:
    idxs = [slice(0, n_species)] + [[strength_idx] for strength_idx in interaction_strength_choices]