In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import jax.random as random
import pickle

## Set this to disable JAX from preallocating memory
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

from datasets import *
from training_setup import *
from training import *
from utils import *
from analysis_utils import *

In [2]:
DATA_PATH="./"
SAVE_PATH="./"
MODEL_PATH="./"

In [3]:
df = pd.read_csv(f"{DATA_PATH}/A549_EMT_heldout_1d_tp_normalized.csv", index_col=0)
total_timepoints = len(df["time"].unique())
timepoint_map={k:v for k,v in zip(np.arange(total_timepoints), sorted(df["time"].unique()))}
rev_timepoint_map={k:v for k,v in zip(sorted(df["time"].unique()), np.arange(total_timepoints))}
df["time"] = df["time"].map(rev_timepoint_map)

train_tps_idx = [0,1,3,4]
val_tps_idx = [2]

train_tps = train_tps_idx
val_tps = val_tps_idx

timepoints_sorted =  sorted(df["time"].unique().tolist())
    
train_data = df[df["time"].isin(train_tps)].iloc[:,:]
val_data = df[df["time"].isin(val_tps)].iloc[:,:]

z_score=True

if z_score:
    train_data = z_score_norm(train_data)
    val_data = z_score_norm(val_data)

train_data.shape, val_data.shape, train_data["time"].unique(), val_data["time"].unique()

((2345, 2001), (788, 2001), array([3, 4, 1, 0]), array([2]))

In [4]:
df["time"].unique()

array([3, 4, 1, 0, 2])

In [20]:
#tuned params
train=True

vae_latent_dim=50
vae_enc_hidden_dim=vae_dec_hidden_dim=[256,128]
vae_input_dim=train_data.iloc[:,:-1].shape[-1]
t_dim=16
lr=1e-4
vae_lr=1e-4
ferryman_lr=1e-4

dec_hidden_dim=200
td_sched=2
vae_epochs=100
epochs=10
num_sde=10
paths_reuse=5
steps_num=100
batch_size=64
vae_batch_size=64
death_importance_rate=100
ferryman_hidden_dim=64
vae_t_dim=8
f_val=2


if train:
    
    train_dataset = Input_Dataset(x=train_data, meta=None, meta_celltype_column=None, splitting_births_frac=0.9,
                            steps_num=steps_num, val_split=False, death_importance_rate=death_importance_rate, f_val=f_val)

    ts = Training_Setup(dataset=train_dataset, dataset_name="EMT", hidden_dim=[dec_hidden_dim, dec_hidden_dim], 
                        dec_hidden_size=[dec_hidden_dim,dec_hidden_dim,1],vae_epochs=vae_epochs, 
                        epochs=epochs,  num_sde=num_sde, paths_reuse=paths_reuse, reality_coefficient=0.2, 
                        ipf_mask_dead=True, t_dim=t_dim, batch_size=batch_size, vae_batch_size=vae_batch_size,
                        vae_input_dim=vae_input_dim, vae_enc_hidden_dim=vae_enc_hidden_dim, vae_dec_hidden_dim=vae_dec_hidden_dim,
                        vae_latent_dim=vae_latent_dim,ferryman_hidden_dim=[ferryman_hidden_dim, ferryman_hidden_dim])

    tr = Trainer(dataset=train_dataset, ts=ts,key=random.PRNGKey(0), lr=lr, vae_lr=vae_lr, ferryman_lr=ferryman_lr, ferryman_coeff=1)

    tr_model = tr.train(td_schedule=[1]*td_sched, project_name="EMT")

    train_recon, train_latent, val_recon, val_latent = get_model_latents(train_data, val_data, ts, tr_model)
 

else:
    train_dataset = Input_Dataset(x=train_data, meta=None, meta_celltype_column=None, splitting_births_frac=0.9,
                            steps_num=steps_num, val_split=False, death_importance_rate=death_importance_rate, f_val=f_val)

    ts = Training_Setup(dataset=train_dataset, dataset_name="EMT", hidden_dim=[dec_hidden_dim, dec_hidden_dim], 
                        dec_hidden_size=[dec_hidden_dim,dec_hidden_dim,1],vae_epochs=vae_epochs, 
                        epochs=epochs,  num_sde=num_sde, paths_reuse=paths_reuse, reality_coefficient=0.2, 
                        ipf_mask_dead=True, t_dim=t_dim, batch_size=batch_size, vae_batch_size=vae_batch_size,
                        vae_input_dim=vae_input_dim, vae_enc_hidden_dim=vae_enc_hidden_dim, vae_dec_hidden_dim=vae_dec_hidden_dim,
                        vae_latent_dim=vae_latent_dim,ferryman_hidden_dim=[ferryman_hidden_dim, ferryman_hidden_dim])

    tr = Trainer(dataset=train_dataset, ts=ts,key=random.PRNGKey(0), lr=lr, vae_lr=vae_lr, ferryman_lr=ferryman_lr, ferryman_coeff=1)

    with open(f"{MODEL_PATH}/train_emt_1d_heldout_z_scored.pkl","rb") as f:
        model_params = pickle.load(f)
        
    tr_model = tr
    tr_model.vae_params = model_params["vae_params"]
    tr_model.training_setup.state[1]["forward"] = model_params["forward"]
    tr_model.training_setup.state[1]["backward"] = model_params["backward"]
    tr_model.training_setup.state[1]["ferryman"] = model_params["ferryman"]
    train_recon, train_latent, val_recon, val_latent = get_model_latents(train_data, val_data, ts, tr_model)
    train_dataset.update_data_info(train_latent, None)
    ts.sde.killer = train_dataset.killing_function()

In [32]:
# Calculate prediction performance
# 1. Predict 5 trajectories from trained model
predictions, predictions_all = get_predictions(train_data, val_data, train_tps, val_tps,
                              train_dataset,train_latent,timepoints_sorted,tr_model,ts,
                              vae_input_dim, vae_dec_hidden_dim,vae_latent_dim, t_0_orig=train_tps[0])
simulations = predictions["simulations"]
#2. Get average wasserstein-2 distance on each validation timepoint
for t in val_tps:
    perf_df = get_metrics(simulations, val_data[val_data["time"]==t].values[:,:-1],t)
    print ("t =",t,":",perf_df.mean()['w2'],"+-", perf_df.std()['w2'])

w2    45.120541
dtype: float64 w2    0.030677
dtype: float64


In [None]:
save_predictions_model_params=False
if save_predictions_model_params:
    with open(f"{SAVE_PATH}/predicted_trajectories.pkl","wb") as f:
        pickle.dump(predictions,f,protocol=pickle.HIGHEST_PROTOCOL)

    model_params = {
    "vae_params": tr_model.vae_params,
    "forward": tr_model.training_setup.state[1]["forward"],
    "backward": tr_model.training_setup.state[1]["backward"],
    "ferryman": tr_model.training_setup.state[1]["ferryman"],
    "config": tr_model.get_model_configs(),
    }
    with open(f"{SAVE_PATH}/model_params.pkl","wb") as f:
        pickle.dump(model_params,f,protocol=pickle.HIGHEST_PROTOCOL)