In [10]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv
%load_ext autoreload
%autoreload 2

In [2]:
dataset = 'BMMC'
root = '/scratch/blaauw_root/blaauw1/gyichen'
adata = anndata.read_h5ad(f'{root}/data/BMMC_pp.h5ad')
adata.obs["clusters"] = adata.obs['celltype.l2'].to_numpy()

In [None]:
# Uncomment this if data has not been preprocessed
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20, compute_umap=True)
#adata.write_h5ad(f'{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['SPINK2', 'AZU1', 'MPO', 'LYZ', 'CD74', 'HBB']

# Vanilla VAE

In [4]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2023)
np.random.seed(2023)

vanilla_vae = vv.VanillaVAE(adata, tmax=20, device='cuda:0')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 716 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 121, test iteration: 240
********* Early Stop Triggered at epoch 104. *********
*********              Finished. Total Time =   0 h :  5 m : 14 s             *********
Final: Train ELBO = 3577.028,           Test ELBO = 3569.452


# VeloVAE

In [11]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)

vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0')

vae.train(adata, gene_plot=gene_plot, plot=False, figure_path=figure_path)
vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 717 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.40, 0.38482108460541176), (0.60, 0.7297303195535904)
(0.32, 0.29062388893377694), (0.68, 0.730120876261714)
KS-test result: [0. 1. 0.]
Initial induction: 1381, repression: 619/2000
Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Learning Rate based on Data Sparsity: 0.0005
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 121, test iteration: 240
*********       Stage 1: Early Stop Triggered at epoch 340.       *********
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.001
Average Set Size: 598
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 1: Early Stop Triggered at epoch 497.       *********
Change in noise variance: 0.43300408124923706
*********             Velocity Refinement Round 2              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.001
Average Set Size: 598
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 2: Early Stop Triggered at epoch 569.       *********
Change in noise variance: 0.00032681203447282314
Change in x0: 0.05324282213418404
*********             Velocity Refinement Round 3              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.001
Average Set Size: 598
Finished. Actual Time:   0 h :  0 m : 26 s
*********       Round 3: Early Stop Triggered at epoch 736.       *********
Change in noise variance: 0.0
Change in x0: 0.030318180037273743
*********             Velocity Refinement Round 4              *********
Stage 2: Early Stop Triggered at round 3.
*********              Finished. Total Time =   0 h : 28 m : 16 s             *********
Final: Train ELBO = 7323.688,	Test ELBO = 7285.974


# Full VB

In [12]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)

full_vb = vv.VAE(adata, 
                 tmax=20, 
                 dim_z=5, 
                 device='cuda:0',
                 full_vb=True)

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 717 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.40, 0.38482108460541176), (0.60, 0.7297303195535904)
(0.32, 0.29062388893377694), (0.68, 0.730120876261714)
KS-test result: [0. 1. 0.]
Initial induction: 1381, repression: 619/2000
Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Learning Rate based on Data Sparsity: 0.0005
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 121, test iteration: 240
*********       Stage 1: Early Stop Triggered at epoch 131.       *********
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 1: Early Stop Triggered at epoch 233.       *********
Change in noise variance: 0.40584495663642883
*********             Velocity Refinement Round 2              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 2: Early Stop Triggered at epoch 323.       *********
Change in noise variance: 0.0006064710323698819
Change in x0: 0.2946686547118299
*********             Velocity Refinement Round 3              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 3: Early Stop Triggered at epoch 379.       *********
Change in noise variance: 0.0
Change in x0: 0.2552834472536671
*********             Velocity Refinement Round 4              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 4: Early Stop Triggered at epoch 397.       *********
Change in noise variance: 0.0
Change in x0: 0.2252053959281619
*********             Velocity Refinement Round 5              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 5: Early Stop Triggered at epoch 433.       *********
Change in noise variance: 0.0
Change in x0: 0.18697603875055835
*********             Velocity Refinement Round 6              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 6: Early Stop Triggered at epoch 475.       *********
Change in noise variance: 0.0
Change in x0: 0.1536832272016473
*********             Velocity Refinement Round 7              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 7: Early Stop Triggered at epoch 535.       *********
Change in noise variance: 0.0
Change in x0: 0.12904914004331405
*********             Velocity Refinement Round 8              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 28 s
*********       Round 8: Early Stop Triggered at epoch 565.       *********
Change in noise variance: 0.0
Change in x0: 0.09339124194214013
*********             Velocity Refinement Round 9              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 28 s
*********       Round 9: Early Stop Triggered at epoch 573.       *********
Change in noise variance: 0.0
Change in x0: 0.07036224069542298
*********             Velocity Refinement Round 10              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 10: Early Stop Triggered at epoch 593.       *********
Change in noise variance: 0.0
Change in x0: 0.058001221541610366
*********             Velocity Refinement Round 11              *********
Cell-wise KNN Estimation.


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 566
Finished. Actual Time:   0 h :  0 m : 27 s
*********       Round 11: Early Stop Triggered at epoch 609.       *********
Change in noise variance: 0.0
Change in x0: 0.05360666947224338
*********             Velocity Refinement Round 12              *********
Stage 2: Early Stop Triggered at round 11.
*********              Finished. Total Time =   0 h : 31 m : 15 s             *********
Final: Train ELBO = 7039.742,	Test ELBO = 6997.415


# Train a Branching ODE

In [None]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z')

brode.print_weight()

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

# Evaluation

In [13]:
cluster_edges = [('HSC','LMPP'),
                 ('LMPP','GMP'),
                 ('GMP','CD14 Mono'),
                 ('CD14 Mono','CD16 Mono'),
                 ('Prog DC','cDc2'),
                 ('Prog B 1','Prog B 2'), 
                 ('Prog MK','Prog RBC')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE', 'VeloVAE', 'FullVB'],
                 ['vanilla', 'velovae', 'fullvb'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(2,3),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)

Computing velocity embedding using scVelo
computing velocity graph (using 7/32 cores)


  0%|          | 0/22122 [00:00<?, ?cells/s]

KeyboardInterrupt: 