In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import jax.random as random
import pickle
import seaborn as sns 
from statannot import add_stat_annotation
import scipy.stats as stats
import umap
from matplotlib.pyplot import cm

from datasets import *
from training_setup import *
from training import *
from utils import *
from analysis_utils import *
from evaluate import *

In [2]:
# load dataset
DATA_PATH="/media/sayalialatkar/T9/Sayali/UDSB/UDSB/BroadInstitute_Zebrafish/processed"
SAVE_PATH="./"

# load pre-trained model parameters
MODEL_PATH="/media/sayalialatkar/T9/Sayali/UDSB/UDSB/src_vae_fbsde/resultsv1.2/params/zebrafish" 

In [3]:
df = pd.read_csv(f"{DATA_PATH}/zebrafish_cells_all_timepoints_2000genes.csv", index_col=0)
total_timepoints = len(df["time"].unique())
timepoint_map={k:v for k,v in zip(np.arange(total_timepoints), df["time"].unique())}
rev_timepoint_map={k:v for k,v in zip(df["time"].unique(), np.arange(total_timepoints))}
df["time"] = df["time"].map(rev_timepoint_map)

train_tps_idx = [ 0.,  1.,  2.,  3.,  5.,  7.,  9., 10., 11.]
val_tps_idx = [4., 6., 8.]

train_tps = train_tps_idx
val_tps = val_tps_idx

timepoints_sorted =  sorted(df["time"].unique().tolist())
    
train_data = df[df["time"].isin(train_tps)].iloc[:,:]
val_data = df[df["time"].isin(val_tps)].iloc[:,:]

train_data.shape, val_data.shape, train_data["time"].unique(), val_data["time"].unique()

((23472, 2001),
 (15259, 2001),
 array([ 0,  1,  2,  3,  5,  7,  9, 10, 11]),
 array([4, 6, 8]))

In [5]:
z_score=False

if z_score:
    train_data = z_score_norm(train_data)
    val_data = z_score_norm(val_data)

In [21]:
train=False

# VAE hyperparameters
vae_latent_dim=10
vae_enc_hidden_dim=vae_dec_hidden_dim=[256,128]
vae_input_dim=train_data.iloc[:,:-1].shape[-1]
vae_batch_size=64
vae_lr= 1e-3
vae_epochs=50 # set vae_epochs to 1 to avoid pre-training
vae_t_dim=8

# SB hyperparameters
td_sched = 2
epochs=10   
num_sde=10
paths_reuse=5
steps_num=100
batch_size=256
dec_hidden_dim=200
t_dim=16
lr= 1e-3
f_val=2
ferryman_lr= 1e-3
ferryman_hidden_dim=64
death_importance_rate = 1

if train:
    train_dataset = Dataset(x=train_data, meta=None, meta_celltype_column=None, splitting_births_frac=0.9,
                            steps_num=steps_num, val_split=False, death_importance_rate=death_importance_rate, f_val=f_val)

    ts = Training_Setup(dataset=train_dataset, dataset_name="Zebrafish", hidden_dim=[dec_hidden_dim,dec_hidden_dim,], 
                        dec_hidden_size=[dec_hidden_dim, dec_hidden_dim, 1],vae_epochs=vae_epochs, 
                        epochs=epochs,  num_sde=num_sde, paths_reuse=paths_reuse, reality_coefficient=0.2, 
                        ipf_mask_dead=True, t_dim=t_dim, batch_size=batch_size, vae_batch_size=vae_batch_size,
                        vae_input_dim=vae_input_dim, vae_enc_hidden_dim=vae_enc_hidden_dim, vae_dec_hidden_dim=vae_dec_hidden_dim,
                        vae_latent_dim=vae_latent_dim, ferryman_hidden_dim=[ferryman_hidden_dim, ferryman_hidden_dim,])

    tr = Trainer(dataset=train_dataset, ts=ts,key=random.PRNGKey(0), lr=lr, vae_lr=vae_lr, ferryman_lr=ferryman_lr, ferryman_coeff=1)

    tr_model = tr.train(td_schedule=[1]*td_sched, project_name="Zebrafish")

    train_recon, train_latent, val_recon, val_latent = get_model_latents(train_data, val_data, ts, tr_model)
else: # load pre-trained model parameters
    train_dataset = Dataset(x=train_data, meta=None, meta_celltype_column=None, splitting_births_frac=0.9,
                            steps_num=steps_num, val_split=False, death_importance_rate=death_importance_rate,f_val=f_val)

    ts = Training_Setup(dataset=train_dataset, dataset_name="Zebrafish", hidden_dim=[dec_hidden_dim,dec_hidden_dim,], 
                        dec_hidden_size=[dec_hidden_dim, dec_hidden_dim, 1],vae_epochs=vae_epochs, 
                        epochs=epochs,  num_sde=num_sde, paths_reuse=paths_reuse, reality_coefficient=0.2, 
                        ipf_mask_dead=True, t_dim=t_dim, batch_size=batch_size, vae_batch_size=vae_batch_size,
                        vae_input_dim=vae_input_dim, vae_enc_hidden_dim=vae_enc_hidden_dim, vae_dec_hidden_dim=vae_dec_hidden_dim,
                        vae_latent_dim=vae_latent_dim, ferryman_hidden_dim=[ferryman_hidden_dim, ferryman_hidden_dim,])

    tr = Trainer(dataset=train_dataset, ts=ts,key=random.PRNGKey(0), lr=lr, vae_lr=vae_lr, ferryman_lr=ferryman_lr, ferryman_coeff=1)
    with open(f"{MODEL_PATH}/train_zebrafish_4_6_8_heldout.pkl","rb") as f:
        model_params = pickle.load(f)
        
    tr_model = tr
    tr_model.vae_params = model_params["vae_params"]
    tr_model.training_setup.state[1]["forward"] = model_params["forward"]
    tr_model.training_setup.state[1]["backward"] = model_params["backward"]
    tr_model.training_setup.state[1]["ferryman"] = model_params["ferryman"]
    train_recon, train_latent, val_recon, val_latent = get_model_latents(train_data, val_data, ts, tr_model)
    train_dataset.update_data_info(train_latent, None)
    ts.sde.killer = train_dataset.killing_function()

In [22]:
# Calculate prediction performance
# 1. Predict 5 trajectories from trained model
predictions = get_predictions(train_data, val_data, train_tps, val_tps,
                              train_dataset,train_latent,timepoints_sorted,tr_model,ts,
                              vae_input_dim, vae_dec_hidden_dim,vae_latent_dim, t_0_orig=train_tps[0])
simulations = predictions["simulations"]
#2. Get average wasserstein-2 distance on each validation timepoint
for t in val_tps:
    perf_df = get_metrics(simulations, val_data[val_data["time"]==t].values[:,:-1],t)
    print ("t =",t,":",perf_df.mean()['w2'],"+-", perf_df.std()['w2'])

t = 4.0 : 32.2402889251709 +- 0.33241317029633105
t = 6.0 : 29.325314712524413 +- 0.07004028912342827
t = 8.0 : 31.615964126586913 +- 0.3445967923964806


In [251]:
save_predictions_model_params=False
if save_predictions_model_params:
    with open(f"{SAVE_PATH}/predicted_trajectories.pkl","wb") as f:
        pickle.dump(predictions,f,protocol=pickle.HIGHEST_PROTOCOL)

    model_params = {
    "vae_params": tr_model.vae_params,
    "forward": tr_model.training_setup.state[1]["forward"],
    "backward": tr_model.training_setup.state[1]["backward"],
    "ferryman": tr_model.training_setup.state[1]["ferryman"],
    "config": tr_model.get_model_configs(),
    }
    with open(f"{SAVE_PATH}/model_params.pkl","wb") as f:
        pickle.dump(model_params,f,protocol=pickle.HIGHEST_PROTOCOL)