In [1]:
%load_ext autoreload
%autoreload 2

# Conditional VAE for genetic circuits

This notebook follows the previous VAE notebook very closely, but implementing a conditional VAE instead. Loosely following [this blog post](https://agustinus.kristia.de/techblog/2016/12/17/conditional-vae/).

## Imports 

In [2]:
# %env XLA_PYTHON_CLIENT_ALLOCATOR=platform

from synbio_morpher.utils.data.data_format_tools.common import load_json_as_dict
from synbio_morpher.utils.results.analytics.naming import get_true_interaction_cols
from synbio_morpher.utils.data.data_format_tools.common import write_json
from synbio_morpher.utils.misc.string_handling import prettify_keys_for_label
from functools import partial

from sklearn.metrics import r2_score  
import os
import sys
import numpy as np
import haiku as hk
import jax
import optax

from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle
                
import wandb

from datetime import datetime
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

jax.config.update('jax_platform_name', 'gpu')


# if __package__ is None:

module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

__package__ = os.path.basename(module_path)


jax.devices()

I0000 00:00:1702549660.546714  476195 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


[gpu(id=0)]

In [3]:
from src.models.vae import CVAE
from src.models.mlp import MLP
from src.models.shared import arrayise
from src.losses.losses import loss_wrapper, compute_accuracy_regression, mse_loss, accuracy_regression
from src.utils.data_preprocessing import drop_duplicates_keep_first_n
from src.utils.optimiser import make_optimiser
from src.utils.train import make_saves



## Load data

In [4]:
fn = '../data/processed/ensemble_mutation_effect_analysis/2023_07_17_105328/tabulated_mutation_info.csv'
fn_test_data = '../data/raw/ensemble_mutation_effect_analysis/2023_10_03_204819/tabulated_mutation_info.csv'
data = pd.read_csv(fn)
try:
    data.drop(columns=['Unnamed: 0'], inplace=True)
except:
    pass

## Hyperparameters

In [5]:
# Architecture
HIDDEN_SIZE = 32
NUM_ENC_LAYERS = 3
NUM_DEC_LAYERS = 3


BATCH_SIZE = 128
N_BATCHES = 1200
TOTAL_DS = BATCH_SIZE * N_BATCHES
MAX_TOTAL_DS = TOTAL_DS
TRAIN_SPLIT = 0.8
SCALE_X = False
LEARNING_RATE = 5e-4
LEARNING_RATE_SCHED = 'cosine_decay'
# LEARNING_RATE_SCHED = 'constant'
WARMUP_EPOCHS = 20
L2_REG_ALPHA = 0.01
EPOCHS = 5
PRINT_EVERY = EPOCHS // 1000
SEED = 1
USE_CATEGORICAL = False
target_circ_func = 'sensitivity_wrt_species-6'
input_concat_diffs = False
input_concat_axis = 0

# Training
USE_DROPOUT = False
USE_L2_REG = False
USE_WARMUP = False

loss_fn = partial(
    loss_wrapper, loss_f=mse_loss, use_l2_reg=USE_L2_REG) 
compute_accuracy = compute_accuracy_regression

subtask = '_test'
save_path = str(datetime.now()).split(' ')[0].replace(
    '-', '_') + '__' + str(datetime.now()).split(' ')[-1].split('.')[0].replace(':', '_') + '_saves' + subtask
save_path = os.path.join('weight_saves', '09_vae', save_path)

rng = jax.random.PRNGKey(SEED)

# Initialise

## Init data

In [6]:
filt = data['sample_name'] == data['sample_name'].unique()[0]

# Balance the dataset
df = drop_duplicates_keep_first_n(data[filt], get_true_interaction_cols(
    data, 'energies', remove_symmetrical=True), n=100)

TOTAL_DS = np.min([TOTAL_DS, MAX_TOTAL_DS, len(df)])
N_BATCHES = TOTAL_DS // BATCH_SIZE
TOTAL_DS = N_BATCHES * BATCH_SIZE

x_cols = [get_true_interaction_cols(data, 'energies', remove_symmetrical=True)]
if input_concat_diffs:
    k = 'energies'
    x_cols = x_cols + \
        [[f'{i}_diffs' for i in get_true_interaction_cols(
            data, k, remove_symmetrical=True)]]

x = [df[i].iloc[:TOTAL_DS].values[:, :, None] for i in x_cols]
x = np.concatenate(x, axis=input_concat_axis+1).squeeze()

if SCALE_X:
    xscaler = MinMaxScaler()
    x = xscaler.fit_transform(x)

In [7]:
cond = df[target_circ_func].iloc[:TOTAL_DS].to_numpy()

if USE_CATEGORICAL:
    from src.utils.math import convert_to_scientific_exponent

    vectorized_convert_to_scientific_exponent = np.vectorize(
        convert_to_scientific_exponent)
    numerical_resolution = 2
    cond_map = {k: numerical_resolution for k in np.arange(int(f'{cond[cond != 0].min():.0e}'.split(
        'e')[1])-1, np.max([int(f'{cond.max():.0e}'.split('e')[1])+1, 0 + 1]))}
    cond_map[-6] = 1
    cond_map[-5] = 1
    cond_map[-4] = 4
    cond_map[-3] = 2
    cond_map[-1] = 3
    cond = jax.tree_util.tree_map(partial(
        vectorized_convert_to_scientific_exponent, numerical_resolution=cond_map), cond)
    cond = np.interp(cond, sorted(np.unique(cond)), np.arange(
        len(sorted(np.unique(cond))))).astype(int)
else:
    zero_log_replacement = -10.0
    cond = np.where(cond != 0, np.log10(cond), zero_log_replacement)

cond = cond[:, None]
N_HEAD = x.shape[-1]

x, cond = shuffle(x, cond, random_state=SEED)

if x.shape[0] < TOTAL_DS:
    print(
        f'WARNING: The filtered data is not as large as the requested total dataset size: {x.shape[0]} vs. requested {TOTAL_DS}')

In [8]:
np.concatenate([x, cond], axis=1).shape

(153600, 7)

## Init model

In [9]:
enc_layers = [64] * NUM_ENC_LAYERS
dec_layers = [64] * NUM_DEC_LAYERS

def VAE_fn(enc_layers: list, dec_layers: list, call_kwargs: dict = {}):
    encoder = MLP(layer_sizes=enc_layers, n_head=dec_layers[0], use_categorical=False, name='encoder')
    decoder = MLP(layer_sizes=dec_layers, n_head=x.shape[-1], use_categorical=False, name='decoder')
    model = CVAE(encoder=encoder, decoder=decoder, embed_size=HIDDEN_SIZE)
    
    def init(x: np.ndarray, cond: np.ndarray, deterministic: bool):
        h = model.encoder(np.concatenate([x, cond], axis=-1))

        mu = model.h2mu(h)
        logvar = model.h2logvar(h)
        z = model.reparameterize(mu, logvar, hk.next_rng_key(), deterministic)
        z_cond = np.concatenate([z, cond], axis=-1)

        y = model.decoder(z_cond)
        return y
        
    return init, (encoder, decoder, model) #model(x, **call_kwargs)

model_fn = partial(VAE_fn, enc_layers=enc_layers, dec_layers=dec_layers, call_kwargs={'key': rng})
# model = hk.transform(model_fn)
model_t = hk.multi_transform(model_fn)
params = model_t.init(rng, x, cond, deterministic=False)


  unscaled = jax.random.truncated_normal(
  param = init(shape, dtype)


In [10]:
encoder, decoder, model = model_t.apply

In [11]:
h = encoder(params, rng, np.concatenate([x, cond], axis=-1))

In [12]:
model(params, rng, x, cond)

Array([[ 0.82784826,  0.4404419 ,  0.12714294,  0.22062692,  1.0385995 ,
         0.64580923],
       [ 1.2034786 ,  0.40335447, -0.02307892,  0.4063047 ,  1.0215791 ,
         0.7701513 ],
       [ 1.2043823 ,  0.52515024,  0.05376959,  0.4019768 ,  0.81173486,
         0.94905376],
       ...,
       [ 0.86559814,  0.39743975,  0.2708708 ,  0.36654013,  0.9355792 ,
         0.8020615 ],
       [ 1.1065849 ,  0.11027247,  0.09968159,  0.5402143 ,  0.9874739 ,
         0.98084295],
       [ 1.1596351 ,  0.5367333 ,  0.05196697,  0.2566617 ,  0.9819162 ,
         0.93189126]], dtype=float32)

## Init optimiser

In [13]:
optimiser = make_optimiser(LEARNING_RATE_SCHED, LEARNING_RATE,
                           EPOCHS, L2_REG_ALPHA, USE_WARMUP, WARMUP_EPOCHS, N_BATCHES)
optimiser_state = optimiser.init(x)

# Train

In [14]:
# [i_batch, Batches, *content]

x = x.reshape(N_BATCHES, BATCH_SIZE, x.shape[-1])
cond = cond.reshape(N_BATCHES, BATCH_SIZE, cond.shape[-1])
y = x.reshape(N_BATCHES, BATCH_SIZE, x.shape[-1])

x_train, cond_train, y_train = x[:int(TRAIN_SPLIT * N_BATCHES)], cond[:int(TRAIN_SPLIT * N_BATCHES)], y[:int(TRAIN_SPLIT * N_BATCHES)]
x_val, cond_val, y_val = x[int(TRAIN_SPLIT * N_BATCHES):], cond[int(TRAIN_SPLIT * N_BATCHES):], y[int(TRAIN_SPLIT * N_BATCHES):]

In [15]:
def train_step(params, x, y, cond, optimiser_state, model, rng, use_l2_reg, l2_reg_alpha, optimiser, loss_fn):

    pred_y = model(params, rng, x, cond)
    loss, grads = jax.value_and_grad(loss_fn)(
        params, pred_y, y, use_l2_reg=use_l2_reg, l2_reg_alpha=l2_reg_alpha)

    updates, optimiser_state = optimiser.update(grads, optimiser_state)
    params = optax.apply_updates(params, updates)

    return params, optimiser_state, loss, grads


def eval_step(params, rng, model, x, y, cond, use_l2_reg, l2_reg_alpha, loss_fn, compute_accuracy):
    """ Return the average of loss and accuracy on validation data """
    pred_y = model(params, rng, x, cond)
    loss = loss_fn(params, pred_y, y, use_l2_reg=use_l2_reg, l2_reg_alpha=l2_reg_alpha)
    acc = compute_accuracy(pred_y, y)
    return acc, loss


def run_batches(params, model, rng,
                x_batch, y_batch, cond_batch,
                use_l2_reg, l2_reg_alpha, optimiser, optimiser_state, loss_fn):

    f_train_step = partial(train_step, model=model, rng=rng,
                           use_l2_reg=use_l2_reg, l2_reg_alpha=l2_reg_alpha, optimiser=optimiser,
                           loss_fn=loss_fn)

    def f(carry, inp):

        params, optimiser_state = carry[0], carry[1]
        x_batch, y_batch, cond_batch = inp[0], inp[1], inp[2]

        params, optimiser_state, loss, grads = f_train_step(
            params, x_batch, y_batch, cond_batch, optimiser_state)
        return (params, optimiser_state), (loss, grads)

    # for x_batch, y_batch in xy_train:
    (params, optimiser_state), (train_loss, grads) = jax.lax.scan(
        f, (params, optimiser_state), (x_batch, y_batch, cond_batch))
    return params, optimiser_state, train_loss, grads


def train(params, rng, model,
          x_train, cond_train, y_train, x_val, cond_val, y_val,
          optimiser, optimiser_state,
          use_l2_reg, l2_reg_alpha, epochs,
          loss_fn, compute_accuracy,
          save_every, include_params_in_saves):

    def f(carry, _):
        params, optimiser_state = carry[0], carry[1]

        params, optimiser_state, train_loss, grads = run_batches(
            params, model, rng, x_train, y_train, cond_train, use_l2_reg, l2_reg_alpha, optimiser, optimiser_state, loss_fn)

        val_acc, val_loss = eval_step(
            params, rng, model, x_val, y_val, cond_val, use_l2_reg, l2_reg_alpha, loss_fn, compute_accuracy)

        return (params, optimiser_state), (params, grads, train_loss, val_loss, val_acc)

    saves = {}
    for e in range(epochs):
        (params, optimiser_state), (params_stack, grads, train_loss,
                                    val_loss, val_acc) = f((params, optimiser_state), None)

        if np.mod(e, save_every) == 0:
            saves[e] = make_saves(
                train_loss, val_loss, val_acc, include_params_in_saves, params_stack, grads)
            print(
                f'Epoch {e} / {epochs} -\t\t Train loss: {np.mean(train_loss)}\tVal loss: {val_loss}\tVal accuracy: {val_acc}')

    saves[list(saves.keys())[-1]]['params'] = params
    return params, saves

In [17]:
params, saves = train(params, rng, model, 
                      x_train, cond_train, y_train, x_val, cond_val, y_val, 
                      optimiser, optimiser_state,
                      use_l2_reg=USE_L2_REG, l2_reg_alpha=L2_REG_ALPHA, epochs=EPOCHS,
                      loss_fn=loss_fn, compute_accuracy=accuracy_regression,
                      save_every=PRINT_EVERY, include_params_in_saves=False) 

  if np.mod(e, save_every) == 0:


Epoch 0 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 1 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 2 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 3 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 4 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 5 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 6 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 7 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
Epoch 8 / 30 -		 Train loss: 19.778484344482422	Val loss: 19.715730667114258	Val accuracy: 0.07190755754709244
E