In [1]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
from synbio_morpher.srv.io.manage.script_manager import script_preamble
from synbio_morpher.srv.parameter_prediction.simulator import RawSimulationHandling, make_piecewise_stepcontrol
from synbio_morpher.utils.results.analytics.timeseries import generate_analytics
from synbio_morpher.utils.common.setup import prepare_config, expand_config, expand_model_config
from synbio_morpher.utils.data.data_format_tools.common import load_json_as_dict
from synbio_morpher.utils.results.analytics.naming import get_true_interaction_cols
from synbio_morpher.utils.misc.numerical import symmetrical_matrix_length
from synbio_morpher.utils.misc.type_handling import flatten_listlike, get_unique
from synbio_morpher.utils.modelling.deterministic import bioreaction_sim_dfx_expanded
from bioreaction.model.data_tools import construct_model_fromnames
from bioreaction.model.data_containers import BasicModel, QuantifiedReactions
from bioreaction.simulation.manager import simulate_steady_states
from functools import partial

from scipy.cluster.vq import whiten
from scipy.special import factorial
from sklearn.manifold import TSNE
import os
import sys
import numpy as np
import haiku as hk
import jax
import diffrax as dfx

from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle

from datetime import datetime
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

jax.config.update('jax_platform_name', 'cpu')

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

__package__ = os.path.basename(module_path)


jax.devices()

np.random.seed(0)
rng = jax.random.PRNGKey(0)

I0000 00:00:1707258073.418828  500006 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
xla_bridge.py:backends():513: Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA INFO
xla_bridge.py:backends():513: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory INFO


In [11]:
from src.models.vae import CVAE, sample_z, init_data, VAE_fn
from src.models.mlp import MLP
from src.models.shared import arrayise
from src.losses.losses import loss_wrapper, compute_accuracy_regression, mse_loss
from src.utils.data_preprocessing import drop_duplicates_keep_first_n
from src.utils.math import make_symmetrical_matrix_from_sequence_nojax

# Load Data

In [12]:
fn = '../data/processed/ensemble_mutation_effect_analysis/2023_07_17_105328/tabulated_mutation_info.csv'
fn_test_data = '../data/raw/ensemble_mutation_effect_analysis/2023_10_03_204819/tabulated_mutation_info.csv'
data = pd.read_csv(fn, index_col=0)

# Config

In [13]:
# Architecture
HIDDEN_SIZE = 32 # 64
ENC_LS = 64
DEC_LS = 64
NUM_ENC_LAYERS = 3
NUM_DEC_LAYERS = 3
enc_layers = [ENC_LS] * NUM_ENC_LAYERS # [128, 128, 64, 64] # 
dec_layers = [DEC_LS] * NUM_DEC_LAYERS # [64, 64, 128, 128] # 


BATCH_SIZE = 128
N_BATCHES = 1200
TOTAL_DS = BATCH_SIZE * N_BATCHES
MAX_TOTAL_DS = TOTAL_DS
TRAIN_SPLIT = 0.8
SCALE_X = False
USE_X_LOGSCALE = True
X_TYPE = 'binding_rates_dissociation' #  'energies' #
LEARNING_RATE = 5e-4
LEARNING_RATE_SCHED = 'cosine_decay'
# LEARNING_RATE_SCHED = 'constant'
WARMUP_EPOCHS = 20
L2_REG_ALPHA = 0.01
EPOCHS = 1000
PRINT_EVERY = EPOCHS // 100
SEED = 1

INPUT_SPECIES = 'RNA_1'
USE_CATEGORICAL = False
target_circ_func = 'sensitivity'
input_concat_diffs = False
input_concat_axis = 0

# Training
USE_DROPOUT = False
USE_L2_REG = False
USE_WARMUP = True
loss_fn = partial(
    loss_wrapper, loss_f=mse_loss, use_l2_reg=USE_L2_REG) 
compute_accuracy = compute_accuracy_regression


rng = jax.random.PRNGKey(SEED)

## Init data

In [14]:
x, cond, x_scaling, x_unscaling, x_cols, df, filt, N_HEAD = init_data(data,
                                                                      BATCH_SIZE, INPUT_SPECIES, MAX_TOTAL_DS,
                                                                      SCALE_X, SEED, TOTAL_DS, USE_CATEGORICAL,
                                                                      USE_X_LOGSCALE, X_TYPE,
                                                                      input_concat_axis, input_concat_diffs, target_circ_func)

## Init Model

In [15]:
model_fn = partial(VAE_fn, enc_layers=enc_layers, dec_layers=dec_layers, decoder_head=x.shape[-1], HIDDEN_SIZE=HIDDEN_SIZE)
model_t = hk.multi_transform(model_fn)
params = model_t.init(rng, x, cond, deterministic=False)
encoder, decoder, model, h2mu, h2logvar, reparam = model_t.apply

  unscaled = jax.random.truncated_normal(
  param = init(shape, dtype)


In [16]:
fn_saves = os.path.join('weight_saves', '10_cvae', '2024_02_04__15_54_15_saves_test')
saves_loaded = load_json_as_dict(fn_saves)

p = saves_loaded[str(list(saves_loaded.keys())[-1])]['params']
p = arrayise(p)

# Load circuit simulations

In [20]:
exp_dir = 'data/tests/2024_02_06_144827'
analytics = load_json_as_dict(os.path.join(exp_dir, 'analytics.json'))
analytics.keys()


In [25]:
fake_circuits = np.load(os.path.join(exp_dir, 'fake_circuits.npy'))

In [31]:
n_to_sample = 10000
cond_splits = 10
sampled_cond = np.interp(np.random.rand(
    n_to_sample//cond_splits, cond.shape[-1]), [0, 1], np.linspace(cond.min(), cond.max(), cond_splits)[0:2])[None, :]
for i in range(1, cond_splits):
    sampled_cond0 = np.interp(np.random.rand(
        n_to_sample//cond_splits, cond.shape[-1]), [0, 1], np.linspace(cond.min(), cond.max(), cond_splits+1)[i:i+2])[None, :]
    sampled_cond = np.concatenate([sampled_cond, sampled_cond0], axis=0)
mu = np.random.normal(size=(n_to_sample, HIDDEN_SIZE))
logvar = np.random.normal(size=(n_to_sample, HIDDEN_SIZE))
z = sample_z(mu=mu, logvar=logvar, key=rng)
z = np.concatenate([z, sampled_cond.reshape(np.prod(sampled_cond.shape[:-1]), sampled_cond.shape[-1])], axis=-1)

num_species = symmetrical_matrix_length(fake_circuits.shape[-1])
input_species = [f'RNA_{i}' for i in range(num_species)]
fake_circuits_reshaped = np.array(list(map(partial(make_symmetrical_matrix_from_sequence_nojax, side_length=num_species), fake_circuits)))
for fn in x_unscaling:

    fake_circuits_reshaped = fn(fake_circuits_reshaped)
    
fake_circuits_reshaped = np.where(fake_circuits_reshaped > df[x_cols[0]].max().max(), df[x_cols[0]].max().max(), fake_circuits_reshaped)

# Visualise

## TSNE

In [33]:
h = encoder(p, rng, np.concatenate([x, cond], axis=-1))

mu = h2mu(p, rng, h)
logvar = h2logvar(p, rng, h)
z = reparam(p, rng, mu, logvar, rng, deterministic=True)
z_cond = np.concatenate([z, cond], axis=-1)

tsne_inp = whiten(z_cond[:10000])
        
n_components = 2
tsne = TSNE(n_components, perplexity=300, learning_rate=100, n_iter=500)
tsne_result = tsne.fit_transform(z_cond)
# tsne_result = tsne.fit_transform(fake_circuits)
tsne_result.shape

(153600, 2)

In [38]:
output_idxs = np.array([1, 2])

In [39]:
fig = plt.figure(figsize=(12, 5))
fig.subplots_adjust(wspace=0.6)
# for i in range(n_components):
#     for j in range(i, n_components):
tsne_result_df = pd.DataFrame({'TSNE 1': tsne_result[:, 0], 'TSNE 2': tsne_result[:, 1], 'VAE Conditioning input': sampled_cond.flatten()[:10000], 'Log10 Sensitivity': np.log10(analytics['sensitivity'][:, output_idxs][:10000])})
ax = plt.subplot(1,int(factorial(n_components)),1)
sns.scatterplot(x='TSNE 1', y='TSNE 2', hue='VAE Conditioning input', data=tsne_result_df, s=20, palette='viridis', alpha=1)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax = plt.subplot(1,int(factorial(n_components)),2)
sns.scatterplot(x='TSNE 1', y='TSNE 2', hue='Log10 Sensitivity', data=tsne_result_df, s=20, palette='viridis', alpha=1)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.suptitle('TSNE Generated circuits')

TypeError: list indices must be integers or slices, not tuple

<Figure size 1200x500 with 0 Axes>

### TSNE on real data

In [None]:
tsne_result2 = tsne.fit_transform(data[filt][x_cols[0]].iloc[:15000].to_numpy())

In [None]:
fig = plt.figure(figsize=(12, 5))
fig.subplots_adjust(wspace=0.6)
# for i in range(n_components):
#     for j in range(i, n_components):
tsne_result_df = pd.DataFrame({'TSNE 1': tsne_result2[:, 0], 'TSNE 2': tsne_result2[:, 1], 'Log10 Sensitivity': np.log10(data[filt][target_circ_func].iloc[:len(tsne_result2)])})
ax = plt.subplot(1,int(factorial(n_components)),1)
sns.scatterplot(x='TSNE 1', y='TSNE 2', hue='Log10 Sensitivity', data=tsne_result_df, s=20, palette='viridis', alpha=0.1)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.title('TSNE Training circuits')


# Motifs

In [None]:
sns.histplot(fake_circuits.flatten(), bins=50, element='step')

In [None]:
data[get_true_interaction_cols(data, 'eqconstants', remove_symmetrical=True, num_species=3)]