In [None]:
%load_ext autoreload
%autoreload 2

# Visualise only the time series for ruggedness data

In [None]:
from synbio_morpher.utils.data.data_format_tools.common import load_json_as_dict
from evoscaper.scripts.init_from_hpos import init_from_hpos
from evoscaper.utils.evolution import calculate_ruggedness_core
from evoscaper.utils.math import arrayise, make_flat_triangle, make_batch_symmetrical_matrices
from evoscaper.utils.preprocess import make_datetime_str
from evoscaper.run.ruggedness import get_perturbations
from common import load_stitch_analytics, make_df_rugg, load_stitch_ys

import os
import json
import sys
import numpy as np
import jax

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

jax.config.update('jax_platform_name', 'cpu')


sns.set_style('whitegrid')
sns.set_context("notebook", font_scale=1.2)
custom_palette = sns.blend_palette(
    ['#ff9f9b', '#ffb482', '#fffea3', '#c1f6b8', '#a0e7e0', '#b9d3ee', '#d6a4ef', '#ff77a7'], n_colors=9)

# if __package__ is None:

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

__package__ = os.path.basename(module_path)


np.random.seed(0)
PRNG = jax.random.PRNGKey(0)

jax.devices()

# jupyter nbconvert --to notebook --execute 02_cvae_verify.ipynb --output=02_cvae_verify_2.ipynb --ExecutePreprocessor.timeout=-1

# Load

In [None]:
dir_src_rugg = os.path.join('data', 'ruggedness', '2025_07_05__15_10_09')

config_rugg = load_json_as_dict(os.path.join(dir_src_rugg, 'config.json'))
fn_df_hpos_loaded = config_rugg['fn_df_hpos_loaded'].replace('notebooks/', '')
hpos = pd.Series(load_json_as_dict(fn_df_hpos_loaded))
dir_src_nn = os.path.dirname(fn_df_hpos_loaded)
fn_saves = os.path.join(dir_src_nn, [i for i in os.listdir(
    dir_src_nn) if i.startswith('saves')][0])
idx_output = -1
idx_perturbations_og = -1

config = {'fn_saves': fn_saves,
          'dir_src_rugg': dir_src_rugg}
top_write_dir = os.path.join(
    'data', '21_visualise_rugged_ys', make_datetime_str())
os.makedirs(top_write_dir, exist_ok=True)
with open(os.path.join(top_write_dir, 'config.json'), 'w') as f:
    json.dump(config, f)
print('top_write_dir:', top_write_dir)

In [None]:
saves_loaded = load_json_as_dict(fn_saves)
(
    rng, rng_model, rng_dataset,
    config_norm_x, config_norm_y, config_filter, config_optimisation, config_dataset, config_training, config_model,
    data, x_cols, df,
    x, cond, y, x_train, cond_train, y_train, x_val, cond_val, y_val,
    total_ds, n_batches, BATCH_SIZE, x_datanormaliser, x_methods_preprocessing, y_datanormaliser, y_methods_preprocessing,
    params, encoder, decoder, model, h2mu, h2logvar, reparam
) = init_from_hpos(hpos)

params = saves_loaded[str(list(saves_loaded.keys())[-1])]['params']
params = arrayise(params)

config_bio = load_json_as_dict(config_dataset.filenames_verify_config)

In [None]:
all_fake_circuits = np.load(os.path.join(
    dir_src_rugg, 'all_fake_circuits.npy'))
n_species = all_fake_circuits.shape[-1]

all_sampled_cond = np.load(os.path.join(
    dir_src_rugg, 'sampled_cond', 'sampled_cond_0.npy'))

if config_rugg['perturb_once']:
    eps = np.load('data/ruggedness/2025_07_05__15_10_09/batch_104/eps.npy')
    # all_fake_circuits = jax.vmap(
    #     partial(create_perturbations, eps=eps))(all_fake_circuits)
    n_samples = all_fake_circuits.shape[0]
    all_fake_circuits, eps, n_perturbs = get_perturbations(
        make_flat_triangle(all_fake_circuits), config_rugg['eps_perc'], n_samples, config_rugg['n_perturbs'], config_rugg['resimulate_analytics'], config_rugg['perturb_once'])
    all_fake_circuits = make_batch_symmetrical_matrices(all_fake_circuits.reshape(-1, all_fake_circuits.shape[-1]), side_length=n_species)


In [None]:
analytics_rugg = load_stitch_analytics(dir_src_rugg, last_idx=idx_output)
for k in analytics_rugg.keys(): 
    analytics_rugg[k] = analytics_rugg[k].reshape(
        -1, n_perturbs, analytics_rugg[k].shape[-1] if analytics_rugg[k].ndim > 1 else 1)

In [None]:
n_samples = all_fake_circuits.reshape(-1, n_perturbs, *all_fake_circuits.shape[1:]).shape[0]
n_interactions = make_flat_triangle(all_fake_circuits[0]).shape[-1]
n_perturbs = n_interactions + config_rugg['resimulate_analytics']
eps = config_rugg['eps_perc'] * np.abs(all_fake_circuits).max()

ruggedness = {}
for analytic in analytics_rugg.keys():
    ruggedness[analytic] = calculate_ruggedness_core(analytics_rugg, None, analytic,
                                                     config_rugg['resimulate_analytics'], n_samples, n_perturbs, eps)

if config_rugg['resimulate_analytics']:
    # n_max = n_samples * n_perturbs
    analytics_og = {k: np.array(v).reshape(
        n_samples, n_perturbs, -1)[:, -1, :] for k, v in analytics_rugg.items()}
else:
    analytics_og = {}

k_rugg = 'Log ruggedness (adaptation)'

ruggedness[k_rugg] = np.log10(ruggedness['Log sensitivity'])

In [None]:
n_interactions = make_flat_triangle(all_fake_circuits).shape[-1]
n_perturbs = config_bio['n_perturbs'] if not config_rugg['perturb_once'] else n_interactions + config_rugg['resimulate_analytics']

ys_out, ts = load_stitch_ys(dir_src_rugg, idx_output, len(all_fake_circuits))
assert len(ys_out) == len(all_fake_circuits), 'Probably stitched ys together wrong'

In [None]:
i = 0
i = i + 1
plt.figure(figsize=(10, 5))
for ii in range(n_perturbs):
    ti = len(ts) // 10
    is_og = ii == (n_perturbs - 1)
    plt.plot(ts[:ti], ys_out.reshape(-1, n_perturbs, *ys_out.shape[1:])[i, ii, :ti],
             color='b' if is_og else 'r',
             alpha=1 if is_og else 0.5)
plt.title(f'Circuit {i}')
plt.xlabel('Time')
plt.ylabel('Signal')


In [None]:
df_rugg = make_df_rugg(analytics_og, ruggedness, idx_output, all_sampled_cond, 
                 y_datanormaliser, y_methods_preprocessing, config_dataset, k_rugg)

for k in ['adaptation', 'Log ruggedness (adaptation)']:
    df_rugg[f'{k} norm'] = y_datanormaliser.create_chain_preprocessor(
        y_methods_preprocessing)(df_rugg[f'{k}'].values, col=k, use_precomputed=True)
    data[f'{k} norm'] = y_datanormaliser.create_chain_preprocessor(
        y_methods_preprocessing)(data[f'{k}'].values, col=k, use_precomputed=True)

data['Log ruggedness (adaptation) bin'] = pd.cut(
    data['Log ruggedness (adaptation)'], bins=10)
data['Log ruggedness (adaptation) bin'] = data['Log ruggedness (adaptation) bin'].apply(
    lambda x: x.mid).astype(float).round(2)

sampled_rugg = all_sampled_cond[...,
                                config_dataset.objective_col.index(k_rugg)]
rugg_k = ruggedness[k_rugg]
sampled_rugg = y_datanormaliser.create_chain_preprocessor_inverse(
    y_methods_preprocessing)(sampled_rugg, col=k_rugg)


In [None]:
# ys_out.reshape(n_samples, n_perturbs, ys_out.shape[-2], ys_out.shape[-1]).shape
filt = df_rugg[df_rugg['Log ruggedness (adaptation)'] < 0].index.to_numpy()
assert ys_out.ndim == 3, 'Expected ys_out to be 3D: (n_samples, time_steps, species out)'
ys_p = ys_out.reshape(n_samples, n_perturbs, *ys_out.shape[1:])[
    filt][:, idx_perturbations_og, ...].reshape(-1, *ys_out.shape[1:])

In [None]:
plt.figure(figsize=(8, 6))
for i in range(800):
    sns.lineplot(ys_p[i, :, -1], legend=False, alpha=0.1, color='b')