In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Run VAE models systematically

## Imports

In [22]:
import pandas as pd
import itertools
import numpy as np
import os
from evoscaper.scripts.cvae_scan import main as cvae_scan
from evoscaper.utils.preprocess import make_datetime_str
from synbio_morpher.utils.data.data_format_tools.common import write_json
from bioreaction.misc.misc import flatten_listlike

# jupyter nbconvert --to notebook --execute 03_cvae_multi.ipynb --output=03_cvae_multi_2.ipynb --ExecutePreprocessor.timeout=-1


## Create table of all VAE model training settings

Parameters for:
- Biological dataset generation
- Training data
    - Input
    - Output 
- Model architecture
- Training hyperparameters

### Initial parameters

In [23]:
data_dir = '../data'

hpos_architecture = {
    'seed_arch': 1,
    'hidden_size': 16,
    'enc_ls': 32,
    'dec_ls': 32,
    'num_enc_layers': 2,
    'num_dec_layers': 2,
    'factor_expanding_ls': 1,
    'factor_contracting_ls': 1,
    'model': 'CVAE',
    'use_sigmoid_decoder': False,
    'enc_init': 'HeNormal',
    'dec_init': 'HeNormal',
    'init_model_with_random': True,
    'activation': 'leaky_relu',
}

hpos_training = {
    'seed_train': 1,
    'batch_size': 256,
    'epochs': 100,
    'patience': 500,
    'threshold_early_val_acc': 0.995,
    'learning_rate': 1e-2,
    'loss_func': 'mse',
    'accuracy_func': 'accuracy_regression',
    'use_dropout': False,
    'dropout_rate': 0.1,
    'use_l2_reg': False,
    'l2_reg_alpha': 5e-2,
    'use_kl_div': True,
    # inspired by https://github.com/elttaes/VAE-MNIST-Haiku-Jax/blob/main/cVAE_mnist.ipynb
    'kl_weight': 2.5e-4,
    'use_grad_clipping': False,
    'use_contrastive_loss': False,
    'temperature': 1.5,
    'contrastive_func': 'info_nce',
    'threshold_similarity': 0.9,
    'power_factor_distance': 3
}
hpos_training['print_every'] = hpos_training['epochs'] // 50

hpos_optimization = {
    'seed_opt': 1,
    'opt_method': 'adam',
    'opt_min_lr': 1e-6,
    'opt_min_delta': 1e-4,
    'learning_rate_sched': 'cosine_decay',
    'use_warmup': True,
    'warmup_epochs': 20,
}

hpos_dataset = {
    'seed_dataset': 1,
    'include_diffs': False,
    # 'objective_col': ('Log sensitivity', 'Log precision'),
    # 'objective_col': 'adaptation',
    'objective_col': ('Log sensitivity',),
    'output_species': ('RNA_2',),
    'signal_species': ('RNA_0',),
    'filenames_train_config': f'{data_dir}/raw/summarise_simulation/2024_12_05_210221/ensemble_config.json',
    'filenames_train_table': f'{data_dir}/raw/summarise_simulation/2024_12_05_210221/tabulated_mutation_info.csv',
    'filenames_verify_config': f'{data_dir}/raw/summarise_simulation/2024_11_21_160955/ensemble_config.json',
    'filenames_verify_table': f'{data_dir}/raw/summarise_simulation/2024_11_21_160955/tabulated_mutation_info.csv',
    'use_test_data': False,
    # 'total_ds': None,   # TO BE RECORDED
    'total_ds_max': 5e6,
    'train_split': 0.8,
    'x_type': 'energies',
    # XY filtering:
    'filt_x_nans': True,
    'filt_y_nans': True,
    'filt_sensitivity_nans': True,
    'filt_precision_nans': True,
    'filt_n_same_x_max': 1,
    'filt_n_same_x_max_bins': 50,
    'filt_response_time_high': True,
    'filt_response_time_perc_max': 0.8,
    # XY preprocessing:
    'prep_x_standardise': False,
    'prep_y_standardise': False,
    'prep_x_min_max': True,
    'prep_y_min_max': True,
    'prep_x_robust_scaling': True,
    'prep_y_robust_scaling': True,
    'prep_x_logscale': False,
    'prep_y_logscale': False,
    'prep_x_categorical': False,
    'prep_y_categorical': False,
    'prep_x_categorical_onehot': False,
    'prep_y_categorical_onehot': False,
    'prep_x_categorical_n_bins': 10,
    'prep_y_categorical_n_bins': 10,
    'prep_x_categorical_method': 'quantile',
    'prep_y_categorical_method': 'quantile',
    'prep_x_negative': True,
    'prep_y_negative': False
}

hpos_biological = {
    'n_species': 3,
    'sequence_length': 20,
    'signal_function': 'step_function',
    'signal_target': 2,
    'starting_copynumbers_input': 200,
    'starting_copynumbers_output': 200,
    'starting_copynumbers_other': 200,
    'association_binding_rate': 1000000,
    'include_prod_deg': False,
}

hpos_eval = {
    'eval_n_to_sample': 2, #1e3
}

info_to_be_recorded = {
    'filename_saved_model': 'TO_BE_RECORDED',
    'total_ds': 'TO_BE_RECORDED',
    'n_batches': 'TO_BE_RECORDED',
    'R2_train': 'TO_BE_RECORDED',
    'R2_test': 'TO_BE_RECORDED',
    'mutual_information_conditionality': 'TO_BE_RECORDED',
    'n_layers_enc': 'TO_BE_RECORDED',
    'n_layers_dec': 'TO_BE_RECORDED',
    'run_successful': 'TO_BE_RECORDED',
    'info_early_stop': 'TO_BE_RECORDED',
    'error_msg': 'TO_BE_RECORDED',
}

hpos_all = {}
for d in [hpos_architecture, hpos_training, hpos_optimization, hpos_dataset, hpos_eval, info_to_be_recorded]:
    hpos_all.update(d)
    

In [24]:
df_hpos = pd.DataFrame.from_dict(hpos_all, orient='index').T
assert df_hpos.columns.duplicated().sum() == 0, 'Change some column names, there are duplicates'
basic_setting = df_hpos.copy()
df_hpos

Unnamed: 0,seed_arch,hidden_size,enc_ls,dec_ls,num_enc_layers,num_dec_layers,factor_expanding_ls,factor_contracting_ls,model,use_sigmoid_decoder,...,total_ds,n_batches,R2_train,R2_test,mutual_information_conditionality,n_layers_enc,n_layers_dec,run_successful,info_early_stop,error_msg
0,1,16,32,32,2,2,1,1,CVAE,False,...,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED,TO_BE_RECORDED


In [25]:
# for k, v in hpos_all.items():
#     if type(v) == tuple:
#         print(k, v)
#         df_hpos[k] = df_hpos[k].apply(lambda x: tuple(x))
df_hpos['filenames_train_config']        

0    ../data/raw/summarise_simulation/2024_12_05_21...
Name: filenames_train_config, dtype: object

### All parameters

In [26]:
hpos_to_vary_from_og = [{
    'hidden_size': np.arange(2, 32, 2),
    'learning_rate': [1e-1, 1e-2, 1e-3, 1e-4]
}]
hpos_to_vary_together = [{
    'objective_col': [('adaptation',), ('Log sensitivity',), ('Log sensitivity', 'Log precision')],
    'prep_y_categorical': [False, True],
    'use_kl_div': [True],
    'kl_weight': [5e-5, 1e-4, 2.5e-4, 4e-4, 5e-4],
    'threshold_early_val_acc': [0.995, 0.98, 0.96, 0.9],
},
    {
    'use_contrastive_loss': [True],
    'temperature': [0.1, 0.5, 1, 1.5, 2, 4, 8],
    'threshold_similarity': [0.95, 0.9, 0.7, 0.5, 0.3, 0.1],
    'power_factor_distance': [3, 4],
    'threshold_early_val_acc': [0.995, 0.9]
}
]

df_hpos.loc[df_hpos['objective_col'] ==
            'sensitivity_wrt_species-6', 'prep_y_logscale'] = True

In [27]:
def keep_equal(df):
    pairs = {
        'enc_ls': 'dec_ls',
        'num_enc_layers': 'num_dec_layers',
        'factor_expanding_ls': 'factor_contracting_ls',
    }
    for k1, k2 in pairs.items():
        if k1 in df.columns and k2 in df.columns:
            df[k2] = df[k1]
    return df


def add_combinatorial_keys(df_hpos, hpos_to_vary_together, basic_setting):
    keys_vary_together = sorted(hpos_to_vary_together.keys())
    for i, v in enumerate(itertools.product(*[hpos_to_vary_together[h] for h in keys_vary_together])):
        curr = basic_setting.assign(
            **{h: [vv] if type(vv) == tuple else vv for h, vv in zip(keys_vary_together, v)})
        df_hpos = pd.concat([df_hpos, curr], ignore_index=True)
    return df_hpos


def add_single_hpos(df_hpos, hpos_to_vary_from_og, basic_setting):
    for h, v in hpos_to_vary_from_og.items():
        try:
            df_hpos = pd.concat(
                [df_hpos] + [basic_setting.assign(**{h: vv}) for vv in v], ignore_index=True)
        except ValueError:
            for vv in v:
                b = basic_setting.copy()
                b.loc[0, h] = vv
                df_hpos = pd.concat([df_hpos, b], ignore_index=True)
    return df_hpos


def postproc(df_hpos):
    df_hpos = keep_equal(df_hpos)
    df_hpos.loc[df_hpos['x_type'] ==
                'binding_rates_dissociation', 'prep_x_negative'] = False
    df_hpos = df_hpos.drop_duplicates().reset_index(drop=True)
    return df_hpos


for h in hpos_to_vary_from_og:
    df_hpos = add_single_hpos(df_hpos, h, basic_setting)
for h in hpos_to_vary_together:
    df_hpos = add_combinatorial_keys(df_hpos, h, basic_setting)

df_hpos = postproc(df_hpos)

# Reorder columns
cols_priority = list(set(flatten_listlike([list(h.keys(
)) for h in hpos_to_vary_from_og] + [list(h.keys()) for h in hpos_to_vary_together])))
df_hpos = df_hpos[cols_priority +
                  [c for c in df_hpos.columns if c not in cols_priority]]

df_hpos.reset_index(drop=True)
len(df_hpos)

305

# Use table to create dataset for training

In [28]:
# fn = '../data/raw/summarise_simulation/2024_11_21_144918/tabulated_mutation_info.csv'
# # fn = '../data/raw/summarise_simulation/2024_11_21_160955/tabulated_mutation_info.csv'
# # fn = '../data/raw/summarise_simulation/2024_12_05_210221/tabulated_mutation_info.csv'
# data = pd.read_csv(fn).iloc[:100]
# len(data)

In [29]:
df_hpos_main = df_hpos.iloc[111:]
i = 0
df_hpos_main.iloc[i]

use_contrastive_loss             False
learning_rate                     0.01
temperature                        1.5
use_kl_div                        True
kl_weight                       0.0004
                             ...      
n_layers_enc            TO_BE_RECORDED
n_layers_dec            TO_BE_RECORDED
run_successful          TO_BE_RECORDED
info_early_stop         TO_BE_RECORDED
error_msg               TO_BE_RECORDED
Name: 111, Length: 93, dtype: object

In [None]:
from bioreaction.misc.misc import load_json_as_dict

top_dir = os.path.join('data', '03_cvae_multi', make_datetime_str())
os.makedirs(top_dir, exist_ok=True)

# df_hpos2 = pd.DataFrame(load_json_as_dict(
#     'data/03_cvae_multi/2025_01_15__10_59_22/df_hpos_main.json'))
# df_hpos_main = df_hpos.iloc[df_hpos2[df_hpos2['run_successful'].isin(
#     [False, 'TO_BE_RECORDED'])].index]
df_hpos_main = df_hpos.iloc[111:]
for i in range(len(df_hpos_main)):
    hpos = df_hpos_main.reset_index().iloc[i]
    top_write_dir = os.path.join(top_dir, 'cvae_scan', f'hpo_{hpos["index"]}')
    # hpos['use_grad_clipping'] = True
    hpos = cvae_scan(hpos, top_write_dir=top_write_dir)
    # try:
    #     try:
    #         hpos = cvae_scan(hpos, top_write_dir=top_write_dir)
    #         hpos.loc['run_successful'] = True
    #         hpos.loc['error_msg'] = ''
    #     except Exception as e:
    #         print(e)
    #         if 'nan' in e.lower() and (hpos['use_grad_clipping'] == False):
    #             try:
    #                 hpos['use_grad_clipping'] = True
    #                 hpos = cvae_scan(hpos, top_write_dir=top_write_dir)
    #                 hpos.loc['run_successful'] = True
    #                 hpos.loc['error_msg'] = ''
    #             except Exception as e:
    #                 print(e)
    #                 hpos.loc['run_successful'] = False
    #                 hpos.loc['error_msg'] = str(e)
    #         else:
    #             hpos.loc['run_successful'] = False
    #             hpos.loc['error_msg'] = str(e)
    # except:
    #     hpos.loc['run_successful'] = False
    #     hpos.loc['error_msg'] = 'sys exit'

    df_hpos_main.iloc[i] = pd.Series(hpos) if type(
        hpos) == dict else hpos.drop('index')
    # df_hpos_main.loc[i] = pd.DataFrame.from_dict(hpos).drop('index')
    if not os.path.exists(top_dir):
        os.makedirs(top_dir)
    df_hpos_main.to_csv(os.path.join(top_dir, 'df_hpos_main.csv'))
    write_json(df_hpos_main.to_dict(), os.path.join(
        top_dir, 'df_hpos_main.json'), overwrite=True)

train.py:train():187: Epoch 0 / 100 -		 Train loss: 0.13625994324684143	Val loss: 0.08772104233503342	Val accuracy: 0.33725741505622864 INFO
train.py:train():187: Epoch 10 / 100 -		 Train loss: 0.006428841967135668	Val loss: 0.006101085804402828	Val accuracy: 0.9633186459541321 INFO
train.py:train():197: Early stopping triggered after 11 epochs:
Train loss: 0.006428841967135668
Val loss: 0.006101085804402828
Val accuracy: 0.9633186459541321


Training complete: 0:00:09.269009
