In [1]:
import scanpy as sc
import numpy as np
import sys
import torch
import scvelo as scv
from scipy.sparse import csr_matrix

SEED = 2024
np.random.seed(SEED)

sys.path.append('/home/liyr/Benchmark/VeloVAE-master')
import velovae as vv

In [2]:
adata = sc.read_h5ad("adata/redeem_young.h5ad")
print(adata)

adata.layers['spliced'] = csr_matrix(adata.layers['spliced'])
adata.layers['unspliced'] = csr_matrix(adata.layers['unspliced'])

AnnData object with n_obs × n_vars = 9144 × 2000
    obs: 'nCount_RNA', 'nFeature_RNA', 'nCount_ATAC', 'nFeature_ATAC', 'nCount_SCT', 'nFeature_SCT', 'SCT.weight', 'ATAC.weight', 'seurat_clusters', 'Sig.HSC1', 'Sig.Prog1', 'Sig.EarlyE1', 'Sig.LateE1', 'Sig.ProMono1', 'Sig.Mono1', 'Sig.ncMono1', 'Sig.cDC1', 'Sig.pDC1', 'Sig.ProB1', 'Sig.PreB1', 'Sig.B1', 'Sig.Plasma1', 'Sig.T1', 'Sig.CTL1', 'Sig.NK1', 'STD.CellType', 'STD_Cat', 'STD_Cat2', 'Sample', 'MitoCoverage', 'ClonalGroup', 'ClonalGroup.Prob', 'nCount_spliced', 'nFeature_spliced', 'nCount_unspliced', 'nFeature_unspliced', 'nCount_ambiguous', 'nFeature_ambiguous', 'CellType', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts', 'velocity_self_transition'
    var: 'name', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_gamma', 'velocity_qreg_ratio', 'velocity_r2', 'velocity_genes'
    uns: 'CellType_colors', 'STD.CellType_colors', 'neighbors', 'umap', 'velocity_gra

In [3]:
torch.manual_seed(SEED)
np.random.seed(SEED)
vv.preprocess(adata, n_gene=2000)
vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=len(set(adata.obs['CellType'])), 
             device='cuda:0')

Filtered out 4 cells with low counts.
Filtered out 1602 genes that are detected 10 counts (shared).
Skip filtering by dispersion since number of variables are less than `n_top_genes`.
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
Keep raw unspliced/spliced count data.
Estimating ODE parameters...


100%|██████████| 398/398 [00:03<00:00, 103.25it/s]


Detected 69 velocity genes.
Estimating the variance...


100%|██████████| 398/398 [00:00<00:00, 2445.82it/s]

Initialization using the steady-state and dynamical models.





Reinitialize the regular ODE parameters based on estimated global latent time.


100%|██████████| 398/398 [00:00<00:00, 967.40it/s]


3 clusters detected based on gene co-expression.
(0.43, 0.363907213095749), (0.57, 0.758439269382994)
KS-test result: [1. 1. 0.]
Initial induction: 294, repression: 104/398


In [4]:
print(vae)

<velovae.model.vae.VAE object at 0x2b5427d7ca90>


In [5]:
config = {
    # You can change any hyperparameters here!
    # 'learning_rate': 1e-3,
    # 'learning_rate_ode': 2e-3,
    # 'learning_rate_post': 1e-3
}
vae.train(adata,
          config=config,
          plot=False,
          #gene_plot=gene_plot,
          #figure_path=figure_path,
          embed='umap')

Learning Rate based on Data Sparsity: 0.0005
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 50, test iteration: 98
*********       Stage 1: Early Stop Triggered at epoch 428.       *********
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1             *********


100%|██████████| 9140/9140 [00:06<00:00, 1474.21it/s]


Percentage of Invalid Sets: 0.031
Average Set Size: 185
*********     Round 1: Early Stop Triggered at epoch 583.    *********
Change in noise variance: 0.2670
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 872.    *********
Change in noise variance: 0.0008
Change in x0: 0.2679
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 945.    *********
Change in noise variance: 0.0000
Change in x0: 0.1989
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 1065.    *********
Change in noise variance: 0.0000
Change in x0: 0.1174
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 1156.    *********
Change in noise variance: 0.0000
Change in x0: 0.1095
*********             Velocity Refinement Round 6 

In [6]:
vae.save_model("sup/velovae", 'encoder_vae', 'decoder_vae')
vae.save_anndata(adata, 'vae', './adata', file_name="veloVAE.h5ad")

In [8]:
adataVAE = sc.read_h5ad("adata/veloVAE.h5ad")
print(adataVAE)

AnnData object with n_obs × n_vars = 9140 × 398
    obs: 'nCount_RNA', 'nFeature_RNA', 'nCount_ATAC', 'nFeature_ATAC', 'nCount_SCT', 'nFeature_SCT', 'SCT.weight', 'ATAC.weight', 'seurat_clusters', 'Sig.HSC1', 'Sig.Prog1', 'Sig.EarlyE1', 'Sig.LateE1', 'Sig.ProMono1', 'Sig.Mono1', 'Sig.ncMono1', 'Sig.cDC1', 'Sig.pDC1', 'Sig.ProB1', 'Sig.PreB1', 'Sig.B1', 'Sig.Plasma1', 'Sig.T1', 'Sig.CTL1', 'Sig.NK1', 'STD.CellType', 'STD_Cat', 'STD_Cat2', 'Sample', 'MitoCoverage', 'ClonalGroup', 'ClonalGroup.Prob', 'nCount_spliced', 'nFeature_spliced', 'nCount_unspliced', 'nFeature_unspliced', 'nCount_ambiguous', 'nFeature_ambiguous', 'CellType', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts', 'velocity_self_transition', 'n_genes', 'vae_time', 'vae_std_t', 'vae_t0'
    var: 'name', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_gamma', 'velocity_qreg_ratio', 'velocity_r2', 'velocity_genes', 'init_mode', 'w_init', 'vae_alpha', 'va

In [9]:
adataVAE.layers['vae_velocity']

array([[-7.60706469e-04, -1.26679884e-02, -1.85303350e-03, ...,
         9.06906293e-03,  3.16479801e-03,  3.45506505e-03],
       [-3.71925070e-03, -1.43751596e-02,  6.36338107e-02, ...,
         1.35679236e-02,  2.28931386e-03, -2.30208309e-03],
       [-1.87303641e-03, -2.80953341e-02,  6.67509708e-03, ...,
        -6.54340899e-03, -5.49355672e-04,  2.28924584e-02],
       ...,
       [-2.39410242e-04,  1.05378161e-01, -5.01172030e-03, ...,
        -2.09856474e-02,  2.78696572e-04, -7.42593750e-04],
       [-1.02914735e-02,  1.54561760e+00, -1.43869430e-02, ...,
        -8.22640457e-03,  3.41502945e-03,  1.31913089e-03],
       [-5.67826331e-03,  1.84767751e-02, -5.95292266e-03, ...,
         2.45077560e-02,  4.74749088e-03,  1.15889808e-03]])