### DynaVelo Demo

#### Usage
Two adata files are required: adata_rna and adata_atac. The first adata contains preprocessed RNA expression values and RNA velocity estimates from scVelo. As RNA velocities are prone to error and sensitive to gene sets, we have to first check if the overall trajectories make biological sense. The second adata contains TF motif accessibility z-scores from chromVAR. The shape of adata_rna is [n_cells, n_genes], and the shape of adata_atac is [n_cells, n_tfs].

#### Load data

In [None]:
import scanpy as sc

dataset_name = 'Ben'
#sample_list = ['WT-3', 'WT-13', 'Icn2Het-2', 'Icn2Het-10', 'SpenHet-1-2', 'SpenHet-15', 'SpenHet-Icn2Het-8', 'SpenHet-Icn2Het-12']
sample_name = 'Icn2Het-2'

adata_rna = sc.read_h5ad(f"/mnt/storage/Ben_data/analysis/outs/RNAMatrix/dynavelo/RNA_Matrix_{sample_name}_AiBC.h5ad")
adata_atac = sc.read_h5ad(f"/mnt/storage/Ben_data/analysis/outs/MotifMatrix/dynavelo/MotifMatrix_{sample_name}_AiBC.h5ad")

assert all(adata_rna.obs_names == adata_atac.obs_names)

In [19]:
adata_rna.obs['fine.celltype'].value_counts()

Prememory_Memory    5098
Prememory_Naive      742
CC_Rec               269
Plasma_cell           67
Recycling             40
CB_S_G2M              35
CB_Rec_Sphase         10
Name: fine.celltype, dtype: int64

In [20]:
adata_rna.obs['anno_clusters'].value_counts()

3_AiBC    2843
0_AiBC    2066
7_AiBC     762
9_AiBC     557
8_AiBC      33
Name: anno_clusters, dtype: int64

In [21]:
adata_atac.obs['fine.celltype'].value_counts()

Prememory_Memory    5098
Prememory_Naive      742
CC_Rec               269
Plasma_cell           67
Recycling             40
CB_S_G2M              35
CB_Rec_Sphase         10
Name: fine.celltype, dtype: int64

#### Build datasets and dataloaders
We first build Pytorch datasets and dataloaders from adata_rna and adata_atac to be used for training. MultiomeDataset is used to create the custom datasets. We randomly allocate 10% of the cells for the test and the rest for the training dataset.

In [22]:
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import sys

sys.path.append("..")
sys.path.append("../dynavelo")
from dynavelo.models import MultiomeDataset
from dynavelo.models import DynaVelo

# set seed
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# set gpu
gpu = 1
device = torch.device('cuda:' + str(gpu) if torch.cuda.is_available() else 'cpu')

# datasets and dataloaders
dataset = MultiomeDataset(adata_rna, adata_atac, use_weights=False)
N_test = int(0.1 * len(dataset))
idx_random = np.random.permutation(len(dataset))
idx_test = idx_random[:N_test]
idx_train = idx_random[N_test:]
dataset_train = MultiomeDataset(adata_rna[idx_train], adata_atac[idx_train], use_weights=False)
dataset_test = MultiomeDataset(adata_rna[idx_test], adata_atac[idx_test], use_weights=False)

dataloader_train = DataLoader(dataset_train, batch_size=256, shuffle=True, num_workers=0, drop_last=True)
dataloader_test = DataLoader(dataset_test, batch_size=256, shuffle=False, num_workers=0)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=0)

In [23]:
torch.cuda.is_available()

True

#### Initiate DynaVelo model

In [24]:
# model
model = DynaVelo(x_dim=adata_rna.shape[1], y_dim=adata_atac.shape[1], device=device, dataset_name=dataset_name, sample_name=sample_name).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_params = sum([np.prod(p.size()) for p in model.parameters()])
print('number of parameters func: ', n_params)
for name, param in model.named_parameters():
    print(name, param.shape, param.dtype)
    print(name, param)

number of parameters func:  495865
encoder_net_x.0.weight torch.Size([50, 2631]) torch.float32
encoder_net_x.0.weight Parameter containing:
tensor([[-0.0001,  0.0105, -0.0160,  ..., -0.0036,  0.0012, -0.0064],
        [-0.0071, -0.0168,  0.0011,  ...,  0.0045,  0.0072, -0.0139],
        [ 0.0153,  0.0139, -0.0059,  ...,  0.0020, -0.0052,  0.0025],
        ...,
        [ 0.0146, -0.0186,  0.0096,  ...,  0.0194, -0.0164, -0.0175],
        [-0.0137, -0.0035, -0.0165,  ..., -0.0160, -0.0098,  0.0137],
        [-0.0064, -0.0171, -0.0091,  ..., -0.0033, -0.0129, -0.0105]],
       device='cuda:1', requires_grad=True)
encoder_net_x.0.bias torch.Size([50]) torch.float32
encoder_net_x.0.bias Parameter containing:
tensor([-9.7715e-03,  1.2772e-02,  1.8205e-02, -1.1239e-02, -1.7011e-02,
         8.4122e-03,  1.1611e-02, -2.6513e-03, -1.4230e-02,  1.9178e-02,
        -1.4118e-02, -5.2934e-03, -1.0111e-02, -2.4669e-03, -1.4569e-02,
         1.0251e-02,  6.7808e-03, -6.1193e-03,  1.1878e-05,  1.8734e

#### Training the model

In [25]:
# train
model.fit(dataloader_train, dataloader_test, optimizer, max_epoch=200)

Epoch	loss_train	loss_test	nll_x	nll_y	loss_vel	loss_con	kl_z0	kl_t
1	50664.1216	43762.0564	38073.427	5426.6752	-0.3858	0.0104	4.0081	0.0073
2	39440.9824	35187.485	32190.6065	4597.216	-0.5479	0.0061	3.8126	0.0057
3	32725.2636	30225.7602	27655.5584	3921.3342	-0.5982	0.0075	4.5524	0.0032
4	28512.1097	27026.3698	25624.995	3329.9868	-0.6298	0.0044	4.3227	0.0027
5	25812.5805	24687.3325	24319.5096	2930.1483	-0.6544	0.0044	3.9356	0.0022
6	23896.739	23108.1465	23309.7959	2654.1318	-0.6786	0.0043	3.8849	0.0019
7	22456.162	21742.8543	22247.257	2489.9742	-0.6952	0.0048	3.9084	0.0018
8	21373.6883	20760.8313	21358.9242	2509.9053	-0.7032	0.0029	3.8926	0.0023
9	20284.6537	19973.5837	20992.2283	2361.5884	-0.7156	0.0023	3.7473	0.0056
10	19595.7893	19067.0815	20794.8586	2044.6135	-0.7355	0.0052	3.5198	0.0108
11	18832.7779	18195.4052	20192.1587	1907.1738	-0.7428	0.0064	3.4274	0.0324
12	18456.561	18059.4926	19898.9275	1933.7117	-0.7582	0.0027	3.7336	0.0479
13	17886.7823	17610.4466	19933.9981	1745.8582	-0.

#### Loading the model

In [26]:
model.load(optimizer)

Loaded ckpt from ../checkpoints/Ben/Icn2Het-2/Ben_Icn2Het-2_DynaVelo_num_hidden_200_zxdim_50_zydim_50_k_z0_1000_k_t_1000_k_velocity_10000_k_consistency_10000_seed_0.pth


#### Predicting velocities and latent times
After a DynaVelo model is trained, we use it to predict the latent times, RNA velocities, and motif velocities for all cells, saving the results in the adata files. The mode evaluation-sample means that we sample from the learned posterior probabilities of initial points in the latent space and latent times of cells `n_samples` times, then report the mean and variance of the predicted velocities. The variance can represent uncertainty in the velocities.


In [27]:
# evaluate
model.mode = 'evaluation-sample'
adata_rna_pred, adata_atac_pred = model.evaluate(adata_rna, adata_atac, dataloader, n_samples=50)

n:  0
n:  1
n:  2
n:  3
n:  4
n:  5
n:  6
n:  7
n:  8
n:  9
n:  10
n:  11
n:  12
n:  13
n:  14
n:  15
n:  16
n:  17
n:  18
n:  19
n:  20
n:  21
n:  22
Repetitive time
n:  23
n:  24
n:  25
Repetitive time
Repetitive time
n:  26
n:  27
n:  28
n:  29
n:  30
n:  31
n:  32
n:  33
n:  34
n:  35
n:  36
n:  37
n:  38
n:  39
n:  40
Repetitive time
n:  41
n:  42
n:  43
n:  44
n:  45
n:  46
n:  47
n:  48
n:  49


#### Calculate Jacobian matrices
To learn dynamic and cell-state-specific gene regulatory networks (GRNs), we calculate the Jacobian matrices of the trained DynaVelo model. There are four types of Jacobian matrices:

(1) J_vx_x, which measures the partial effects of RNA expression on RNA velocity and has the shape [n_cells, n_genes, n_genes].<br> 
(2) J_vy_x, which measures the partial effects of RNA expression on motif velocity and has the shape [n_cells, n_tfs, n_genes].<br> 
(3) J_vx_y, which measures the partial effects of TF motif accessibility on RNA velocity and has the shape [n_cells, n_genes, n_tfs].<br> 
(4) J_vy_y, which measures the partial effects of TF motif accessibility on motif velocity and has the shape [n_cells, n_tfs, n_tfs].

Since the Jacobians are 3D dense tensors, they take a lot of memory, so we choose a subset of genes we are interested in and calculate the Jacobians for them. We include all the TFs in the subset of `genes_of_interest`, as we are interested in understanding how TFs regulate each other. The mode evaluation-fixed means that we use the mean of the latent times and initial points of the cells in the latent space without sampling.


In [28]:
# Jacobians
TFs = adata_atac.var['TF'].values
genes_of_interest = ['Spen', 'Notch2']
genes_of_interest = [g.capitalize() for g in genes_of_interest]
genes_of_interest = list(np.intersect1d(adata_rna_pred.var_names, np.union1d(TFs, genes_of_interest)))
genes_of_interest.sort()
print(len(genes_of_interest))
print(genes_of_interest)

184
['Ahr', 'Arid3a', 'Arid3b', 'Arid5a', 'Arnt', 'Arntl', 'Atf1', 'Atf2', 'Atf4', 'Atf7', 'Bach1', 'Bach2', 'Batf', 'Bcl6', 'Bhlhe40', 'Bhlhe41', 'Cebpg', 'Clock', 'Creb1', 'Creb3', 'Creb3l2', 'Crem', 'Ctcf', 'Cux1', 'E2f1', 'E2f2', 'E2f3', 'E2f4', 'E2f5', 'E2f7', 'E2f8', 'Ebf1', 'Egr2', 'Egr3', 'Elf1', 'Elf2', 'Elf4', 'Elk3', 'Elk4', 'Esr1', 'Esrra', 'Ets1', 'Etv3', 'Etv6', 'Fli1', 'Fos', 'Fosb', 'Foxj2', 'Foxj3', 'Foxk1', 'Foxk2', 'Foxo1', 'Foxo3', 'Foxp1', 'Gabpa', 'Gfi1', 'Gmeb1', 'Gmeb2', 'Grhl1', 'Hes1', 'Hif1a', 'Hinfp', 'Hltf', 'Hmbox1', 'Hsf1', 'Id2', 'Ikzf1', 'Irf1', 'Irf2', 'Irf3', 'Irf4', 'Irf5', 'Irf8', 'Irf9', 'Junb', 'Klf12', 'Klf13', 'Klf3', 'Klf4', 'Klf6', 'Klf8', 'Lef1', 'Lin54', 'Lyl1', 'Max', 'Maz', 'Mbd2', 'Mecp2', 'Mef2a', 'Mef2b', 'Mef2c', 'Mef2d', 'Mga', 'Mlxip', 'Mnt', 'Mtf1', 'Mxi1', 'Myb', 'Mybl1', 'Mybl2', 'Myc', 'Nfat5', 'Nfatc1', 'Nfatc2', 'Nfatc3', 'Nfe2l2', 'Nfia', 'Nfic', 'Nfkb1', 'Nfkb2', 'Nfya', 'Nfyb', 'Nfyc', 'Notch2', 'Nr1d2', 'Nr2c1', 'Nr2c2', 'N

In [None]:

model.mode = 'evaluation-fixed'
adata_rna_pred = model.calculate_jacobians(adata_rna_pred, adata_atac_pred, dataloader, genes_of_interest, epsilon = 1e-4)

Batch 1/25
Batch 2/25
Batch 3/25
Batch 4/25
Batch 5/25
Batch 6/25
Batch 7/25
Batch 8/25
Batch 9/25
Batch 10/25
Batch 11/25
Batch 12/25
Batch 13/25
Batch 14/25
Batch 15/25
Batch 16/25
Batch 17/25
Batch 18/25
Batch 19/25
Batch 20/25
Batch 21/25
Batch 22/25
Batch 23/25


#### In-silico gene perturbation
One of the useful applications of DynaVelo is to perform in-silico gene perturbations and observe the resulting changes in RNA and motif velocities. This approach helps us understand how perturbing a gene can alter cell trajectories. Such insights are invaluable for identifying optimal perturbation targets to restore lost functions in diseases where normal trajectories have been disrupted. The `perturbed_genes` list specifies the genes for in-silico perturbations. If a gene is a TF, both its RNA expression and motif accessibility are set to the minimum value observed across all cells; otherwise, only RNA expression is perturbed.

In [15]:
# In-silico gene perturbation
TFs = adata_atac.var['TF'].values
#idx_tfs_sub = np.where((adata_rna_pred[:, TFs].X.toarray()>0).sum(0)>1000)[0]
#TFs_sub = TFs[idx_tfs_sub]
perturbed_genes = ['Spen', 'Notch2']
perturbed_genes = [g.capitalize() for g in perturbed_genes]
perturbed_genes = list(np.intersect1d(adata_rna_pred.var_names, np.union1d(TFs, perturbed_genes)))
perturbed_genes.sort()

model.mode = 'evaluation-fixed'
adata_rna_pred = model.predict_perturbation(adata_rna_pred, adata_atac_pred, dataloader, perturbed_genes)

idx: 0 / perturbed gene: Ahr
idx: 1 / perturbed gene: Arid3a
idx: 2 / perturbed gene: Arid3b
idx: 3 / perturbed gene: Arid5a
idx: 4 / perturbed gene: Arnt
idx: 5 / perturbed gene: Arntl
idx: 6 / perturbed gene: Atf1
idx: 7 / perturbed gene: Atf2
idx: 8 / perturbed gene: Atf4
n_rep: 1
idx: 9 / perturbed gene: Atf7
idx: 10 / perturbed gene: Bach2
idx: 11 / perturbed gene: Batf
idx: 12 / perturbed gene: Bcl6
idx: 13 / perturbed gene: Bhlhe40
idx: 14 / perturbed gene: Bhlhe41
n_rep: 1
idx: 15 / perturbed gene: Cebpg
idx: 16 / perturbed gene: Clock
idx: 17 / perturbed gene: Creb1
idx: 18 / perturbed gene: Creb3l2
idx: 19 / perturbed gene: Crem
n_rep: 1
idx: 20 / perturbed gene: Ctcf
n_rep: 1
idx: 21 / perturbed gene: Cux1
idx: 22 / perturbed gene: E2f1
idx: 23 / perturbed gene: E2f2
idx: 24 / perturbed gene: E2f3
idx: 25 / perturbed gene: E2f4
idx: 26 / perturbed gene: E2f5
idx: 27 / perturbed gene: E2f7
idx: 28 / perturbed gene: E2f8
idx: 29 / perturbed gene: Ebf1
idx: 30 / perturbed gene:

#### Saving the adatas

In [17]:
adata_rna_pred.write_h5ad(f"/mnt/storage/Ben_data/analysis/outs/RNAMatrix/predicted/RNA_Matrix_Pred_{sample_name}.h5ad")
adata_atac_pred.write_h5ad(f"/mnt/storage/Ben_data/analysis/outs/MotifMatrix/predicted/Motif_Matrix_Pred_{sample_name}.h5ad")