In [1]:
%load_ext autoreload
%autoreload 2

# Run VAE models systematically

## Imports

In [2]:
import pandas as pd
import numpy as np
import itertools


## Create table of all VAE model training settings

Parameters for:
- Biological dataset generation
- Training data
    - Input
    - Output 
- Model architecture
- Training hyperparameters

In [3]:
'USE_SIGMOID_DECODER'.lower()

'use_sigmoid_decoder'

### Initial parameters

In [4]:
['PREP_X_CATEGORICAL_ONEHOT'.lower(),
'PREP_Y_CATEGORICAL_ONEHOT'.lower(),
'PREP_X_CATEGORICAL_NBINS'.lower(),
'PREP_Y_CATEGORICAL_NBINS'.lower()]

['prep_x_categorical_onehot',
 'prep_y_categorical_onehot',
 'prep_x_categorical_nbins',
 'prep_y_categorical_nbins']

In [5]:
hpos_architecture = {
    'seed_arch': 1,
    'hidden_size': 32,
    'enc_layers': [64, 64, 64],
    'dec_layers': [64, 64, 64],
    'model': 'CVAE',
    'use_sigmoid_decoder': False,
    'init_enc': 'HeNormal',
    'init_dec': 'HeNormal',
    'init_model_with_random': True,
    'activation': 'leaky_relu',
}


hpos_training = {
    'batch_size': 128,
    'epochs': 2000,
    'learning_rate': 1e-1,
    'learning_rate_sched': 'cosine_decay',
    'loss_func': 'mse_loss',
    'use_dropout': False,
    'dropout_rate': 0.1,
    'use_l2_reg': False,
    'l2_reg_alpha': 0.01,
    'use_kl_div': True,
    'kl_weight': 0.00025,  # inspired by https://github.com/elttaes/VAE-MNIST-Haiku-Jax/blob/main/cVAE_mnist.ipynb
    'use_warmup': True,
    'warmup_epochs': 20
}
hpos_training['print_every'] = hpos_training['epochs'] // 100

hpos_optimization = {
    'seed_opt': 1,
    'opt_method': 'adam',
    'opt_metric': 'mean_absolute_error',
    'opt_mode': 'min',
    'opt_patience': 100,
    'opt_factor': 0.5,
    'opt_min_lr': 1e-6,
    'opt_min_delta': 1e-4
}

hpos_dataset = {
    'seed_dataset': 1,
    'include_diffs': False,
    'objective_col': 'adaptability',
    'output_species': ['RNA_2'],
    # 'total_ds': None,   # TO BE RECORDED
    'total_ds_max': 3e6,
    'train_split': 0.8,
    'x_type': 'energies',
    # XY filtering:
    'filt_x_nans': True,
    'filt_y_nans': True,
    'filt_sensitivity_nans': True,
    'filt_precision_nans': True,
    'filt_n_same_x_max': 100,
    'filt_n_same_x_max_bins': 500,
    # XY preprocessing:
    'prep_x_standardise': False,
    'prep_y_standardise': False,
    'prep_x_min_max': False,
    'prep_y_min_max': False,
    'prep_x_robust_scaling': True,
    'prep_y_robust_scaling': True,
    'prep_x_log': False,
    'prep_y_log': False,
    'prep_x_categorical': False,
    'prep_y_categorical': False,
    'prep_x_categorical_onehot': False,
    'prep_y_categorical_onehot': False,
    'prep_x_categorical_n_bins': 10,
    'prep_y_categorical_n_bins': 10,
    'prep_x_categorical_method': 'quantile',
    'prep_y_categorical_method': 'quantile',
    'prep_x_negative': False,
    'prep_y_negative': False,
}

hpos_biological = {
    'filenames_train_config': ['EvoScaper/data/raw/summarise_simulation/2024_12_05_210221/ensemble_config.json'], 
    'filenames_train_table': ['EvoScaper/data/raw/summarise_simulation/2024_12_05_210221/tabulated_mutation_info.csv'],
    'filenames_verify_config': ['EvoScaper/data/raw/summarise_simulation/2024_11_21_160955/ensemble_config.json'], 
    'filenames_verify_table': ['EvoScaper/data/raw/summarise_simulation/2024_11_21_160955/tabulated_mutation_info.csv'],
    'n_species': 3,
    'sequence_length': 20,
    'signal_function': 'step_function',
    'signal_target': 2,
    'starting_copynumbers_input': [200],
    'starting_copynumbers_output': [200],
    'starting_copynumbers_other': [200],
    'association_binding_rate': 1000000,
    'include_prod_deg': False,
}

info_to_be_recorded = {
    'filename_saved_model': 'TO_BE_RECORDED',
    'total_ds': 'TO_BE_RECORDED',
    'n_batches': 'TO_BE_RECORDED',
    'R2_train': 'TO_BE_RECORDED',
    'R2_test': 'TO_BE_RECORDED',
    'conditionality_fidelity': 'TO_BE_RECORDED',
    'n_layers_enc': 'TO_BE_RECORDED',
    'n_layers_dec': 'TO_BE_RECORDED',
}



In [6]:
df_hpos = pd.concat([pd.DataFrame.from_dict(hpos, orient='index').T for hpos in [hpos_architecture, hpos_training, hpos_optimization, hpos_dataset]], axis=1)
assert df_hpos.columns.duplicated().sum() == 0, 'Change some column names, there are duplicates'
basic_setting = df_hpos.copy()
df_hpos

Unnamed: 0,seed_arch,hidden_size,enc_layers,dec_layers,model,use_sigmoid_decoder,init_enc,init_dec,init_model_with_random,activation,...,prep_x_categorical,prep_y_categorical,prep_x_categorical_onehot,prep_y_categorical_onehot,prep_x_categorical_n_bins,prep_y_categorical_n_bins,prep_x_categorical_method,prep_y_categorical_method,prep_x_negative,prep_y_negative
0,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,False,10,10,quantile,quantile,False,False


### All parameters

In [7]:

hpos_to_vary_from_og = {
    'total_ds_max': [1e4, 5e4, 1e5, 5e5, 1e6, 5e6],
    'seed_arch': [1, 2, 3, 4, 5],
}
hpos_to_vary_together = {
    'hidden_size': [16, 32, 64, 128, 256, 512],
    'objective_col': ['adaptability', 'sensitivity_wrt_species-6'],
    'x_type': ['energies', 'binding_rates_dissociation'],
    'learning_rate': [1e-2, 1e-3, 1e-4],
}

df_hpos.loc[df_hpos['objective_col'] == 'sensitivity_wrt_species-6', 'prep_y_logscale'] = True

In [8]:
for h, v in hpos_to_vary_from_og.items():
    df_hpos = pd.concat([df_hpos] + [basic_setting.assign(**{h: vv}) for vv in v], ignore_index=True)
df_hpos

Unnamed: 0,seed_arch,hidden_size,enc_layers,dec_layers,model,use_sigmoid_decoder,init_enc,init_dec,init_model_with_random,activation,...,prep_y_categorical,prep_x_categorical_onehot,prep_y_categorical_onehot,prep_x_categorical_n_bins,prep_y_categorical_n_bins,prep_x_categorical_method,prep_y_categorical_method,prep_x_negative,prep_y_negative,prep_y_logscale
0,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
1,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
2,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
3,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
4,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
5,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
6,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
7,1,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
8,2,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,
9,3,32,"[64, 64, 64]","[64, 64, 64]",CVAE,False,HeNormal,HeNormal,True,leaky_relu,...,False,False,False,10,10,quantile,quantile,False,False,


In [9]:
keys_vary_together = sorted(hpos_to_vary_together.keys())
for v in itertools.product(*[hpos_to_vary_together[h] for h in keys_vary_together]):
    curr = basic_setting.assign(**{h: vv for h, vv in zip(keys_vary_together, v)})
    df_hpos = pd.concat([df_hpos, curr], ignore_index=True)
print('All good if these are equal: ', len(df_hpos), len(list(itertools.product(*[hpos_to_vary_together[h] for h in keys_vary_together]))) + np.sum([len(v) for v in hpos_to_vary_from_og.values()]) + 1)

All good if these are equal:  84 84


# Use table to create dataset for training

In [None]:
hpos = df_hpos.iloc[0].to_dict()

fn = '../data/raw/summarise_simulation/2024_11_21_144918/tabulated_mutation_info.csv'
# fn = '../data/raw/summarise_simulation/2024_11_21_160955/tabulated_mutation_info.csv'
# fn = '../data/raw/summarise_simulation/2024_12_05_210221/tabulated_mutation_info.csv'
data = pd.read_csv(fn).iloc[:100]
len(data)

177619

In [11]:
from evoscaper.utils.preprocess import make_xcols
from common import init_data
from evoscaper.utils.tuning import make_settings


X_COLS = make_xcols(data, hpos['x_type'], hpos['include_diffs'])
x_norm_settings, y_norm_settings, filter_settings = make_settings(
    hpos, keys_dataset=sorted(hpos_dataset.keys()))
df, x, cond, TOTAL_DS, N_BATCHES, x_datanormaliser, x_methods_preprocessing, y_datanormaliser, y_methods_preprocessing = init_data(
    data, X_COLS, hpos['objective_col'],
    hpos['output_species'], hpos['total_ds_max'], hpos['batch_size'], hpos['seed_dataset'],
    x_norm_settings, y_norm_settings, filter_settings)

  df.loc[:, OBJECTIVE_COL] = df[OBJECTIVE_COL].apply(


# Set up model

In [None]:
model_fn = partial(VAE_fn, enc_layers=enc_layers, dec_layers=dec_layers, decoder_head=x.shape[-1], 
                   HIDDEN_SIZE=HIDDEN_SIZE, decoder_activation_final=jax.nn.sigmoid if USE_SIGMOID_DECODER else jax.nn.leaky_relu, 
                   enc_init=ENC_INIT, dec_init=DEC_INIT, activation=get_activation_fn(ACTIVATION))
model_t = hk.multi_transform(model_fn)
dummy_x = jax.random.normal(PRNG, x.shape)
dummy_cond = jax.random.normal(PRNG, cond.shape)
params = model_t.init(PRNG, dummy_x, dummy_cond, deterministic=False)
