### DynaVelo Demo

#### Usage
Two adata files are required: adata_rna and adata_atac. The first adata contains preprocessed RNA expression values and RNA velocity estimates from scVelo. As RNA velocities are prone to error and sensitive to gene sets, we have to first check if the overall trajectories make biological sense. The second adata contains TF motif accessibility z-scores from chromVAR. The shape of adata_rna is [n_cells, n_genes], and the shape of adata_atac is [n_cells, n_tfs].

#### Load data

In [14]:
import scanpy as sc

dataset_name = 'Ben'
#sample_list = ['WT-3', 'WT-13', 'Icn2Het-2', 'Icn2Het-10', 'SpenHet-1-2', 'SpenHet-15', 'SpenHet-Icn2Het-8', 'SpenHet-Icn2Het-12']
sample_name = 'WT-3'

adata_rna = sc.read_h5ad(f"/mnt/storage/Ben_data/analysis/outs/RNAMatrix/dynavelo/RNA_Matrix_{sample_name}_AiBC.h5ad")
adata_atac = sc.read_h5ad(f"/mnt/storage/Ben_data/analysis/outs/MotifMatrix/dynavelo/MotifMatrix_{sample_name}_AiBC.h5ad")

assert all(adata_rna.obs_names == adata_atac.obs_names)

In [15]:
adata_rna.obs['fine.celltype'].value_counts()

Prememory_Memory    5355
Prememory_Naive     1271
CC_Rec               191
CB_Rec_Sphase         19
Recycling             15
CB_S_G2M               6
CB_G2M                 5
Name: fine.celltype, dtype: int64

In [16]:
adata_rna.obs['anno_clusters'].value_counts()

3_AiBC    3616
0_AiBC    2776
9_AiBC     348
7_AiBC     121
8_AiBC       1
Name: anno_clusters, dtype: int64

In [17]:
adata_atac.obs['fine.celltype'].value_counts()

Prememory_Memory    5355
Prememory_Naive     1271
CC_Rec               191
CB_Rec_Sphase         19
Recycling             15
CB_S_G2M               6
CB_G2M                 5
Name: fine.celltype, dtype: int64

#### Build datasets and dataloaders
We first build Pytorch datasets and dataloaders from adata_rna and adata_atac to be used for training. MultiomeDataset is used to create the custom datasets. We randomly allocate 10% of the cells for the test and the rest for the training dataset.

In [18]:
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import sys

sys.path.append("..")
sys.path.append("../dynavelo")
from dynavelo.models import MultiomeDataset
from dynavelo.models import DynaVelo

# set seed
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# set gpu
gpu = 1
device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')

# datasets and dataloaders
dataset = MultiomeDataset(adata_rna, adata_atac, use_weights=False)
N_test = int(0.1 * len(dataset))
idx_random = np.random.permutation(len(dataset))
idx_test = idx_random[:N_test]
idx_train = idx_random[N_test:]
dataset_train = MultiomeDataset(adata_rna[idx_train], adata_atac[idx_train], use_weights=False)
dataset_test = MultiomeDataset(adata_rna[idx_test], adata_atac[idx_test], use_weights=False)

dataloader_train = DataLoader(dataset_train, batch_size=256, shuffle=True, num_workers=0, drop_last=True)
dataloader_test = DataLoader(dataset_test, batch_size=256, shuffle=False, num_workers=0)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=0)

In [19]:
torch.cuda.is_available()

True

#### Initiate DynaVelo model

In [20]:
# model
model = DynaVelo(x_dim=adata_rna.shape[1], y_dim=adata_atac.shape[1], device=device, dataset_name=dataset_name, sample_name=sample_name).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_params = sum([np.prod(p.size()) for p in model.parameters()])
print('number of parameters func: ', n_params)
for name, param in model.named_parameters():
    print(name, param.shape, param.dtype)
    print(name, param)

number of parameters func:  462947
encoder_net_x.0.weight torch.Size([50, 2419]) torch.float32
encoder_net_x.0.weight Parameter containing:
tensor([[-0.0002,  0.0109, -0.0167,  ..., -0.0104,  0.0076, -0.0173],
        [ 0.0101,  0.0017, -0.0170,  ...,  0.0203, -0.0168, -0.0158],
        [ 0.0128,  0.0202, -0.0196,  ...,  0.0028,  0.0170,  0.0004],
        ...,
        [ 0.0009, -0.0114,  0.0044,  ..., -0.0175,  0.0017,  0.0011],
        [ 0.0052, -0.0041,  0.0086,  ..., -0.0191,  0.0145,  0.0108],
        [ 0.0022, -0.0133,  0.0097,  ..., -0.0114,  0.0069,  0.0035]],
       device='cuda:1', requires_grad=True)
encoder_net_x.0.bias torch.Size([50]) torch.float32
encoder_net_x.0.bias Parameter containing:
tensor([-0.0050, -0.0060, -0.0015,  0.0012, -0.0017, -0.0003, -0.0019, -0.0171,
         0.0143,  0.0143, -0.0132,  0.0036,  0.0194,  0.0129, -0.0186,  0.0139,
        -0.0134,  0.0009,  0.0184,  0.0159,  0.0074,  0.0190,  0.0083,  0.0035,
        -0.0190, -0.0070, -0.0202, -0.0159, -0.

#### Training the model

In [21]:
# train
model.fit(dataloader_train, dataloader_test, optimizer, max_epoch=200)

Epoch	loss_train	loss_test	nll_x	nll_y	loss_vel	loss_con	kl_z0	kl_t
1	42260.5283	35126.9178	32353.439	3882.8097	-0.4268	0.0089	3.0625	0.0072
2	31762.3917	28251.0659	27633.8124	3055.6673	-0.5231	0.0054	2.7329	0.0056
3	26246.7608	23697.825	23470.7094	2493.5707	-0.5587	0.0036	3.2812	0.0038
4	22340.5836	20745.52	21145.7384	2099.1258	-0.5771	0.0034	3.2347	0.0036
5	19898.4334	18509.4917	19519.4799	1751.2543	-0.6106	0.0026	3.3137	0.0048
6	18004.2536	17005.6483	18611.5586	1635.7173	-0.6331	0.004	3.0447	0.0045
7	16830.1408	16188.6937	17974.3685	1477.8643	-0.6519	0.0026	3.2204	0.009
8	15765.5397	15113.8565	17296.9085	1426.4412	-0.663	0.003	2.9818	0.009
9	15040.0252	14690.6367	17065.0297	1353.3316	-0.6697	0.0099	2.8582	0.012
10	14527.1547	13866.0284	16191.5237	1300.0444	-0.6842	0.003	3.1637	0.0231
11	14027.9998	13928.7098	16648.9666	1205.4485	-0.6733	0.0034	2.7476	0.0249
INFO: Early stopping counter 1 of 10
12	13492.029	13364.6496	16393.6031	1132.1019	-0.6804	0.0028	2.5797	0.0356
13	13273.3741	12

#### Loading the model

In [26]:
model.load(optimizer)

Loaded ckpt from ../checkpoints/Ben/Icn2Het-2/Ben_Icn2Het-2_DynaVelo_num_hidden_200_zxdim_50_zydim_50_k_z0_1000_k_t_1000_k_velocity_10000_k_consistency_10000_seed_0.pth


#### Predicting velocities and latent times
After a DynaVelo model is trained, we use it to predict the latent times, RNA velocities, and motif velocities for all cells, saving the results in the adata files. The mode evaluation-sample means that we sample from the learned posterior probabilities of initial points in the latent space and latent times of cells `n_samples` times, then report the mean and variance of the predicted velocities. The variance can represent uncertainty in the velocities.


In [22]:
# evaluate
model.mode = 'evaluation-sample'
adata_rna_pred, adata_atac_pred = model.evaluate(adata_rna, adata_atac, dataloader, n_samples=50)

n:  0
n:  1
n:  2
n:  3
n:  4
n:  5
n:  6
n:  7
n:  8
n:  9
n:  10
n:  11
n:  12
n:  13
n:  14
n:  15
n:  16
n:  17
n:  18
n:  19
Repetitive time
n:  20
n:  21
n:  22
n:  23
Repetitive time
n:  24
n:  25
n:  26
n:  27
n:  28
n:  29
n:  30
n:  31
n:  32
n:  33
n:  34
n:  35
n:  36
n:  37
n:  38
n:  39
n:  40
n:  41
n:  42
n:  43
n:  44
n:  45
n:  46
n:  47
n:  48
n:  49


#### Calculate Jacobian matrices
To learn dynamic and cell-state-specific gene regulatory networks (GRNs), we calculate the Jacobian matrices of the trained DynaVelo model. There are four types of Jacobian matrices:

(1) J_vx_x, which measures the partial effects of RNA expression on RNA velocity and has the shape [n_cells, n_genes, n_genes].<br> 
(2) J_vy_x, which measures the partial effects of RNA expression on motif velocity and has the shape [n_cells, n_tfs, n_genes].<br> 
(3) J_vx_y, which measures the partial effects of TF motif accessibility on RNA velocity and has the shape [n_cells, n_genes, n_tfs].<br> 
(4) J_vy_y, which measures the partial effects of TF motif accessibility on motif velocity and has the shape [n_cells, n_tfs, n_tfs].

Since the Jacobians are 3D dense tensors, they take a lot of memory, so we choose a subset of genes we are interested in and calculate the Jacobians for them. We include all the TFs in the subset of `genes_of_interest`, as we are interested in understanding how TFs regulate each other. The mode evaluation-fixed means that we use the mean of the latent times and initial points of the cells in the latent space without sampling.


In [23]:
# Jacobians
TFs = adata_atac.var['TF'].values
genes_of_interest = ['Spen', 'Notch2']
genes_of_interest = [g.capitalize() for g in genes_of_interest]
genes_of_interest = list(np.intersect1d(adata_rna_pred.var_names, np.union1d(TFs, genes_of_interest)))
genes_of_interest.sort()
print(len(genes_of_interest))
print(genes_of_interest)

178
['Ahr', 'Arid3a', 'Arid3b', 'Arid5a', 'Arnt', 'Arntl', 'Atf1', 'Atf2', 'Atf4', 'Atf7', 'Bach2', 'Batf', 'Bcl6', 'Bhlhe40', 'Bhlhe41', 'Cebpg', 'Clock', 'Creb1', 'Creb3l2', 'Crem', 'Ctcf', 'Cux1', 'E2f1', 'E2f2', 'E2f3', 'E2f4', 'E2f5', 'E2f7', 'E2f8', 'Ebf1', 'Egr1', 'Egr2', 'Elf1', 'Elf2', 'Elf4', 'Elk3', 'Elk4', 'Esrra', 'Ets1', 'Etv3', 'Etv6', 'Fli1', 'Fos', 'Fosb', 'Foxj2', 'Foxj3', 'Foxk1', 'Foxk2', 'Foxo1', 'Foxo3', 'Foxp1', 'Gabpa', 'Gfi1', 'Gmeb1', 'Gmeb2', 'Hif1a', 'Hltf', 'Hmbox1', 'Id2', 'Ikzf1', 'Irf1', 'Irf2', 'Irf3', 'Irf4', 'Irf5', 'Irf8', 'Irf9', 'Junb', 'Klf13', 'Klf3', 'Klf4', 'Klf6', 'Klf8', 'Klf9', 'Lin54', 'Lyl1', 'Max', 'Mbd2', 'Mecp2', 'Mef2a', 'Mef2b', 'Mef2c', 'Mef2d', 'Mga', 'Mlxip', 'Mnt', 'Mtf1', 'Mxi1', 'Myb', 'Mybl1', 'Myc', 'Nfat5', 'Nfatc1', 'Nfatc2', 'Nfatc3', 'Nfe2l2', 'Nfia', 'Nfic', 'Nfkb1', 'Nfkb2', 'Nfya', 'Nfyb', 'Nfyc', 'Notch2', 'Nr1d2', 'Nr2c1', 'Nr2c2', 'Nr2f6', 'Nr3c1', 'Nr3c2', 'Nr4a1', 'Nr6a1', 'Nrf1', 'Patz1', 'Pax5', 'Pbx1', 'Pbx2', '

In [24]:

model.mode = 'evaluation-fixed'
adata_rna_pred = model.calculate_jacobians(adata_rna_pred, adata_atac_pred, dataloader, genes_of_interest, epsilon = 1e-4)

Batch 1/27
Batch 2/27
Batch 3/27
Batch 4/27
Batch 5/27
Batch 6/27
Batch 7/27
Batch 8/27
Batch 9/27
Batch 10/27
Batch 11/27
Batch 12/27
Batch 13/27
Batch 14/27
Batch 15/27
Batch 16/27
Batch 17/27
Batch 18/27
Batch 19/27
Batch 20/27
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_rep: 1
n_re

#### In-silico gene perturbation
One of the useful applications of DynaVelo is to perform in-silico gene perturbations and observe the resulting changes in RNA and motif velocities. This approach helps us understand how perturbing a gene can alter cell trajectories. Such insights are invaluable for identifying optimal perturbation targets to restore lost functions in diseases where normal trajectories have been disrupted. The `perturbed_genes` list specifies the genes for in-silico perturbations. If a gene is a TF, both its RNA expression and motif accessibility are set to the minimum value observed across all cells; otherwise, only RNA expression is perturbed.

In [25]:
# In-silico gene perturbation
TFs = adata_atac.var['TF'].values
#idx_tfs_sub = np.where((adata_rna_pred[:, TFs].X.toarray()>0).sum(0)>1000)[0]
#TFs_sub = TFs[idx_tfs_sub]
perturbed_genes = ['Spen', 'Notch2']
perturbed_genes = [g.capitalize() for g in perturbed_genes]
perturbed_genes = list(np.intersect1d(adata_rna_pred.var_names, np.union1d(TFs, perturbed_genes)))
perturbed_genes.sort()

model.mode = 'evaluation-fixed'
adata_rna_pred = model.predict_perturbation(adata_rna_pred, adata_atac_pred, dataloader, perturbed_genes)

n_rep: 1
idx: 0 / perturbed gene: Ahr
idx: 1 / perturbed gene: Arid3a
idx: 2 / perturbed gene: Arid3b
idx: 3 / perturbed gene: Arid5a
idx: 4 / perturbed gene: Arnt
idx: 5 / perturbed gene: Arntl
idx: 6 / perturbed gene: Atf1
idx: 7 / perturbed gene: Atf2
idx: 8 / perturbed gene: Atf4
idx: 9 / perturbed gene: Atf7
idx: 10 / perturbed gene: Bach2
idx: 11 / perturbed gene: Batf
idx: 12 / perturbed gene: Bcl6
idx: 13 / perturbed gene: Bhlhe40
idx: 14 / perturbed gene: Bhlhe41
idx: 15 / perturbed gene: Cebpg
idx: 16 / perturbed gene: Clock
n_rep: 1
idx: 17 / perturbed gene: Creb1
idx: 18 / perturbed gene: Creb3l2
idx: 19 / perturbed gene: Crem
idx: 20 / perturbed gene: Ctcf
idx: 21 / perturbed gene: Cux1
idx: 22 / perturbed gene: E2f1
idx: 23 / perturbed gene: E2f2
idx: 24 / perturbed gene: E2f3
idx: 25 / perturbed gene: E2f4
idx: 26 / perturbed gene: E2f5
idx: 27 / perturbed gene: E2f7
idx: 28 / perturbed gene: E2f8
n_rep: 1
idx: 29 / perturbed gene: Ebf1
idx: 30 / perturbed gene: Egr1
idx

#### Saving the adatas

In [26]:
adata_rna_pred.write_h5ad(f"/mnt/storage/Ben_data/analysis/outs/RNAMatrix/predicted/RNA_Matrix_Pred_{sample_name}.h5ad")
adata_atac_pred.write_h5ad(f"/mnt/storage/Ben_data/analysis/outs/MotifMatrix/predicted/Motif_Matrix_Pred_{sample_name}.h5ad")