In [2]:
%load_ext autoreload
%autoreload 2

# Ruggedness metric

Questions
1. What is the topological stability of a circuit?
2. What is the ruggedness of each circuit generated by the model?

In [22]:
import jax.numpy as jnp
import numpy as np
import os
import pandas as pd
import jax
from synbio_morpher.utils.data.data_format_tools.common import load_json_as_dict
from synbio_morpher.utils.results.analytics.naming import get_true_interaction_cols
from evoscaper.utils.preprocess import make_datetime_str
from evoscaper.scripts.init_from_hpos import init_from_hpos
from evoscaper.scripts.verify import verify, setup_model, make_rates, prep_sim, sim, prep_cfg, make_batch_symmetrical_matrices
from evoscaper.utils.math import arrayise


In [4]:
use_loaded = True
top_write_dir = os.path.join('data', '07_ruggedness', make_datetime_str())
os.makedirs(top_write_dir, exist_ok=True)

hpos = pd.Series(load_json_as_dict('data/01_cvae/2025_01_21__15_09_53/hpos_all.json'))
fn_saves = os.path.join('weight_saves', '01_cvae',
                        'saves_2025_01_17__16_01_57_sens_no_cat')

In [5]:
saves_loaded = load_json_as_dict(fn_saves)

(
    rng, rng_model, rng_dataset,
    config_norm_x, config_norm_y, config_filter, config_optimisation, config_dataset, config_training, config_model,
    data, x_cols, df,
    x, cond, y, x_train, cond_train, y_train, x_val, cond_val, y_val,
    total_ds, n_batches, BATCH_SIZE, x_datanormaliser, x_methods_preprocessing, y_datanormaliser, y_methods_preprocessing,
    params, encoder, decoder, model, h2mu, h2logvar, reparam
) = init_from_hpos(hpos)

params = arrayise(saves_loaded[str(list(saves_loaded.keys())[-1])]['params'])



xla_bridge.py:backends():900: Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO
xla_bridge.py:backends():900: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO


In [18]:
config_bio = load_json_as_dict(config_dataset.filenames_verify_config)
config_bio_u = config_bio['base_configs_ensemble']['generate_species_templates']
config_bio_u.update(config_bio['base_configs_ensemble']['mutation_effect_on_interactions_signal'])
input_species = data[data['sample_name'].notna()]['sample_name'].unique()
config_bio = prep_cfg(config_bio_u, input_species)

In [19]:
if use_loaded:
    analytics = load_json_as_dict(
        'data/02_cvae_verify/2025_01_17__16_33_03/analytics.json')
    fake_circuits = np.load('data/02_cvae_verify/2025_01_17__16_33_03/fake_circuits.npy')
    ts = np.load('data/02_cvae_verify/2025_01_17__16_33_03/ts.npy')
    y0m = np.load('data/02_cvae_verify/2025_01_17__16_33_03/y0m.npy')
    ys = np.load('data/02_cvae_verify/2025_01_17__16_33_03/ys.npy')
else:
    (
        analytics, ys, ts, y0m, y00s, ts0, fake_circuits, reverse_rates, model_brn, qreactions, ordered_species, input_species, z, sampled_cond
    ) = verify(params=params,
               rng=rng,
               decoder=decoder,
               df=df,
               cond=np.array([-0.1, 1.1]),
               config_bio=config_bio,
               config_norm_y=config_norm_y,
               config_dataset=config_dataset,
               config_model=config_model,
               x_datanormaliser=x_datanormaliser,
               x_methods_preprocessing=x_methods_preprocessing,
               y_datanormaliser=y_datanormaliser,
               output_species=config_dataset.output_species,
               signal_species=config_dataset.signal_species,
               input_species=data[data['sample_name'].notna()
                                  ]['sample_name'].unique(),
               n_to_sample=int(hpos['eval_n_to_sample']),
               visualise=False,
               top_write_dir=top_write_dir,
               return_relevant=True,
               impose_final_range=(df[get_true_interaction_cols(df, config_dataset.x_type, remove_symmetrical=True, num_species=3)].min().min(),
                                   df[get_true_interaction_cols(df, config_dataset.x_type, remove_symmetrical=True, num_species=3)].max().max()))

In [23]:
def calculate_robustness(interactions, eps, analytic, input_species, config_dataset, config_bio):
    
    perturbations = jax.vmap(create_perturbations)(interactions, eps)
    
    analytics, ys, ts, y0m, y00s, ts0 = simulate_perturbations(perturbations, config_dataset, config_bio, input_species)
    perturbations_output = jnp.array(analytics[analytic])
    
    robustness = calculate_robustness_from_perturbations(perturbations_output, interactions, eps)
    
    return robustness, analytics, ys, ts, y0m, y00s, ts0
    
    
def create_perturbations(interactions, eps):
    
    interactions_expanded = jnp.ones((len(interactions), len(interactions))) * interactions
    
    perturbations = interactions_expanded + jnp.eye(len(interactions_expanded), len(interactions_expanded)) * eps
            
    return perturbations
    
    
    
def simulate_perturbations(interactions, config_dataset, config_bio, input_species):
    
    interactions_reshaped = make_batch_symmetrical_matrices(
        interactions.reshape(-1, interactions.shape[-1]), side_length=len(input_species))

    model_brn, qreactions, ordered_species, postproc = setup_model(
        interactions_reshaped, config_bio, input_species)

    forward_rates, reverse_rates = make_rates(
        config_dataset.x_type, interactions_reshaped, postproc)

    (signal_onehot, signal_target, y00, t0, t1, dt0, dt1, stepsize_controller, save_steps, max_steps, forward_rates, reverse_rates) = prep_sim(
        config_dataset.signal_species, qreactions, interactions_reshaped, config_bio, forward_rates, reverse_rates)

    analytics, ys, ts, y0m, y00s, ts0 = sim(y00, forward_rates[0], reverse_rates,
                                            qreactions,
                                            signal_onehot, signal_target,
                                            t0, t1, dt0, dt1,
                                            save_steps, max_steps,
                                            stepsize_controller)
    
    return analytics, ys, ts, y0m, y00s, ts0
    
    
def calculate_robustness_from_perturbations(perturbations_output, interactions, eps):
    
    dp = (perturbations_output - interactions) / eps
    
    robustness = jnp.sqrt(jnp.sum(jnp.square(dp)))
            
    return robustness

In [24]:
fake_circuits_f = fake_circuits.reshape(np.prod(fake_circuits.shape[:-1]), -1)
eps = 1e-3 * fake_circuits_f

robustness = calculate_robustness(fake_circuits_f, eps=eps, analytic='sensitivity_wrt_species-6',
                                  input_species=input_species, config_dataset=config_dataset, config_bio=config_bio)

Steady states found. Now calculating signal response


E0121 16:53:23.245316  252149 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 103.00MiB (108000000B) on device ordinal 0


XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 103.00MiB (108000000B) on device ordinal 0