In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Conditional VAE for genetic circuits

This notebook follows the previous VAE notebook very closely, but implementing a conditional VAE instead. Loosely following [this blog post](https://agustinus.kristia.de/techblog/2016/12/17/conditional-vae/).

## Imports 

In [16]:
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform

from synbio_morpher.utils.data.data_format_tools.common import load_json_as_dict
from synbio_morpher.utils.results.analytics.naming import get_true_interaction_cols
from synbio_morpher.utils.data.data_format_tools.common import write_json
from synbio_morpher.utils.misc.string_handling import prettify_keys_for_label
from functools import partial

from sklearn.metrics import r2_score  
import os
import sys
import numpy as np
import haiku as hk
import jax

from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle
                
import wandb

from datetime import datetime
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

jax.config.update('jax_platform_name', 'cpu')


# if __package__ is None:

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

__package__ = os.path.basename(module_path)


jax.devices()

[CpuDevice(id=0)]

In [17]:
from src.models.vae import VAE
from src.models.mlp import MLP
from src.models.shared import arrayise
from src.losses.losses import loss_wrapper, compute_accuracy_regression
from src.utils.data_preprocessing import drop_duplicates_keep_first_n
from src.utils.optimiser import make_optimiser
from src.utils.train import train

## Load data

In [18]:
fn = '../data/processed/ensemble_mutation_effect_analysis/2023_07_17_105328/tabulated_mutation_info.csv'
fn_test_data = '../data/raw/ensemble_mutation_effect_analysis/2023_10_03_204819/tabulated_mutation_info.csv'
data = pd.read_csv(fn)
try:
    data.drop(columns=['Unnamed: 0'], inplace=True)
except:
    pass

## Hyperparameters

In [None]:
# Architecture
HIDDEN_SIZE = 32
NUM_ENC_LAYERS = 3
NUM_DEC_LAYERS = 3


BATCH_SIZE = 128
N_BATCHES = 1200
TOTAL_DS = BATCH_SIZE * N_BATCHES
MAX_TOTAL_DS = TOTAL_DS
TRAIN_SPLIT = 0.8
SCALE_X = False
LEARNING_RATE = 5e-4
LEARNING_RATE_SCHED = 'cosine_decay'
# LEARNING_RATE_SCHED = 'constant'
WARMUP_EPOCHS = 20
L2_REG_ALPHA = 0.01
EPOCHS = 2
PRINT_EVERY = EPOCHS // 1000
SEED = 1
target_circ_func = 'sensitivity_wrt_species-6'
input_concat_diffs = False
input_concat_axis = 0

USE_DROPOUT = False
USE_L2_REG = False
USE_WARMUP = True

loss_fn = partial(
    loss_fn, loss_type='mse', use_l2_reg=USE_L2_REG) 
compute_accuracy = compute_accuracy_regression

subtask = '_test'
save_path = str(datetime.now()).split(' ')[0].replace(
    '-', '_') + '__' + str(datetime.now()).split(' ')[-1].split('.')[0].replace(':', '_') + '_saves' + subtask
save_path = os.path.join('weight_saves', '09_vae', save_path)

rng = jax.random.PRNGKey(SEED)

# Initialise

## Init data

In [None]:
filt = data['sample_name'] == data['sample_name'].unique()[0]

# Balance the dataset
df = drop_duplicates_keep_first_n(data[filt], get_true_interaction_cols(
    data, 'energies', remove_symmetrical=True), n=100)

TOTAL_DS = np.min([TOTAL_DS, MAX_TOTAL_DS, len(df)])
N_BATCHES = TOTAL_DS // BATCH_SIZE
TOTAL_DS = N_BATCHES * BATCH_SIZE

x_cols = [get_true_interaction_cols(data, 'energies', remove_symmetrical=True)]

x = [df[i].iloc[:TOTAL_DS].values[:, :, None] for i in x_cols]
x = np.concatenate(x, axis=input_concat_axis+1).squeeze()

y = x

x, y = shuffle(x, y, random_state=SEED)

N_HEAD = x.shape[-1]


if x.shape[0] < TOTAL_DS:
    print(
        f'WARNING: The filtered data is not as large as the requested total dataset size: {x.shape[0]} vs. requested {TOTAL_DS}')
    
if SCALE_X:
    xscaler, yscaler = MinMaxScaler(), MinMaxScaler()
    x = xscaler.fit_transform(x)
    y = xscaler.fit_transform(y)

In [None]:
x_cols = [get_true_interaction_cols(data, 'energies', remove_symmetrical=True)]
if input_concat_diffs:
    x_cols = x_cols + \
        [[f'{i}_diffs' for i in get_true_interaction_cols(
            data, k, remove_symmetrical=True)]]

x = [df[i].iloc[:TOTAL_DS].values[:, :, None] for i in x_cols]
x = np.concatenate(x, axis=input_concat_axis+1).squeeze()

y = df[target_circ_func].iloc[:TOTAL_DS].to_numpy()

if USE_CATEGORICAL:
    y_map = {k: numerical_resolution for k in np.arange(int(f'{y[y != 0].min():.0e}'.split(
        'e')[1])-1, np.max([int(f'{y.max():.0e}'.split('e')[1])+1, 0 + 1]))}
    y_map[-6] = 1
    y_map[-5] = 1
    y_map[-4] = 4
    y_map[-3] = 2
    y_map[-1] = 3
    y = jax.tree_util.tree_map(partial(
        vectorized_convert_to_scientific_exponent, numerical_resolution=y_map), y)
    y = np.interp(y, sorted(np.unique(y)), np.arange(
        len(sorted(np.unique(y))))).astype(int)
else:
    zero_log_replacement = -10.0
    y = np.where(y != 0, np.log10(y), zero_log_replacement)

x, y = shuffle(x, y, random_state=SEED)

N_HEAD = len(np.unique(y)) if USE_CATEGORICAL else 1


if x.shape[0] < TOTAL_DS:
    print(
        f'WARNING: The filtered data is not as large as the requested total dataset size: {x.shape[0]} vs. requested {TOTAL_DS}')

## Init model

In [None]:
enc_layers = [64] * NUM_ENC_LAYERS
dec_layers = [64] * NUM_DEC_LAYERS

def VAE_fn(enc_layers: list, dec_layers: list, call_kwargs: dict = {}):
    encoder = MLP(layer_sizes=enc_layers, n_head=dec_layers[0], use_categorical=False, name='encoder')
    decoder = MLP(layer_sizes=dec_layers, n_head=x.shape[-1], use_categorical=False, name='decoder')
    model = VAE(encoder=encoder, decoder=decoder, embed_size=HIDDEN_SIZE)
    
    def init(x: np.ndarray, deterministic: bool):
        h = model.encoder(x)

        mu = model.h2mu(h)
        logvar = model.h2logvar(h)
        z = model.reparameterize(mu, logvar, hk.next_rng_key(), deterministic)

        y = model.decoder(z)
        
        return y
        
    return init, (encoder, decoder, model) #model(x, **call_kwargs)

model_fn = partial(VAE_fn, enc_layers=enc_layers, dec_layers=dec_layers, call_kwargs={'key': rng})
# model = hk.transform(model_fn)
model_t = hk.multi_transform(model_fn)
params = model_t.init(rng, x[:2], deterministic=False)


  unscaled = jax.random.truncated_normal(
  param = init(shape, dtype)


In [None]:
encoder, decoder, model = model_t.apply

In [None]:
h = encoder(params, rng, x)

In [None]:
model(params, rng, x, key=rng)

Array([[ 1.2200505 ,  0.56759536, -0.17559402,  0.38360912,  0.91138595,
         0.9836584 ],
       [ 0.93897206,  0.39088452, -0.07145118,  0.36647606,  1.1440954 ,
         0.79225576],
       [ 0.9203401 ,  0.44798788, -0.15981165,  0.42806238,  0.79410064,
         0.8846526 ],
       ...,
       [ 1.0310407 ,  0.7607218 , -0.15593469,  0.427303  ,  1.2034185 ,
         0.8617574 ],
       [ 0.7840745 ,  0.49800652,  0.03180542,  0.27787185,  0.6302647 ,
         0.5866081 ],
       [ 1.2556453 ,  0.40852782, -0.12452329,  0.28039137,  1.1396322 ,
         0.5093984 ]], dtype=float32)

## Init optimiser

In [None]:
optimiser = make_optimiser(LEARNING_RATE_SCHED, LEARNING_RATE,
                           EPOCHS, L2_REG_ALPHA, USE_WARMUP, WARMUP_EPOCHS, N_BATCHES)
optimiser_state = optimiser.init(x)

# Train

In [None]:
# [i_batch, xy, Batches, *content]
        
x = x.reshape(N_BATCHES, 1, BATCH_SIZE, x.shape[-1])
y = y.reshape(N_BATCHES, 1, BATCH_SIZE, y.shape[-1])

x_train, y_train = x[:int(TRAIN_SPLIT * N_BATCHES)], y[:int(TRAIN_SPLIT * N_BATCHES)]
x_val, y_val = x[int(TRAIN_SPLIT * N_BATCHES):], y[int(TRAIN_SPLIT * N_BATCHES):]
xy_train = np.concatenate([x_train, y_train], axis=1)

In [None]:
params, saves = train(params, rng, model, xy_train, x_val, y_val, optimiser, optimiser_state,
                      l2_reg_alpha=L2_REG_ALPHA, epochs=EPOCHS,
                      loss_fn=loss_wrapper, compute_accuracy=compute_accuracy_regression,
                      save_every=PRINT_EVERY, include_params_in_saves=False) 

[autoreload of src.losses.losses failed: Traceback (most recent call last):
  File "/home/wadh6511/Kode/env_evo/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/home/wadh6511/Kode/env_evo/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/usr/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/home/wadh6511/Kode/EvoScaper/src/losses/losses.py", line 75, in <module>
    params, rng, model: MLP, x: Float[Array, "batch num_interactions"], y: Int[Array, " batch n_head"]
NameError: name 'MLP' is not defined
]


NameError: name 'loss_wrapper' is not defined