In [1]:
%load_ext autoreload
%autoreload 2

In [85]:
import scanpy as sc
import torch
import lineagevi as linvi
import numpy as np
import pandas as pd

In [3]:
adata_path = '/Users/lgolinelli/git/lineageVI/notebooks/data/outputs/pancreas_2025.08.17_12.43.17/adata_with_velocity.h5ad'
model_path = '/Users/lgolinelli/git/lineageVI/notebooks/data/outputs/pancreas_2025.08.17_12.43.17/vae_velocity_model.pt'

adata = sc.read_h5ad(adata_path)

model = linvi.trainer.LineageVI(
    adata,
)

model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

LineageVI(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Linear(in_features=1805, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
    )
    (mean_layer): Linear(in_features=128, out_features=647, bias=True)
    (logvar_layer): Linear(in_features=128, out_features=647, bias=True)
  )
  (gene_decoder): MaskedLinearDecoder(
    (linear): Linear(in_features=647, out_features=1805, bias=True)
  )
  (velocity_decoder): VelocityDecoder(
    (shared_decoder): Sequential(
      (0): Linear(in_features=647, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
    )
    (gp_velocity_decoder): Linear(in_features=128, out_features=647, bias=True)
    (gene_velocity_decoder): Sequential(
      (0): Linear(in_features=128, out_features=5415, bias=True)
      (1): Softplus(beta=1.0, threshold=20.0)
    )
  )
)

# PERTURB GENES FUNCTION

In [6]:
adata.layers

Layers with keys: Ms, Mu, recon, spliced, unspliced, velocity, velocity_u

In [16]:
mu.shape

(1805,)

In [22]:
x.shape

torch.Size([])

In [34]:
mu.shape

(1, 1805)

In [None]:
cell_idx = 0
gene_idx = 0
spliced = True
unspliced = False
both = False
pert_type = 'overexpression'
mu_unperturbed = adata.layers['Mu'][cell_idx, :]
ms_unperturbed = adata.layers['Ms'][cell_idx, :]

mu_unperturbed = mu_unperturbed.reshape(1, mu_unperturbed.shape[0])
ms_unperturbed = ms_unperturbed.reshape(1, ms_unperturbed.shape[0])


mu_perturbed = mu_unperturbed.copy()
ms_perturbed = ms_unperturbed.copy()

perturb_value = 0 if pert_type == 'ko' else max(adata.layers['Ms'][:,gene_idx])

if unspliced:
    mu_perturbed[:,gene_idx]= perturb_value
if spliced:
    ms_perturbed[:,gene_idx]= perturb_value
if both:
    mu_perturbed[:,gene_idx]= perturb_value
    ms_perturbed[:,gene_idx]= perturb_value

mu_ms_unpert = np.concatenate([mu_unperturbed, ms_unperturbed], axis=1)
mu_ms_pert = np.concatenate([mu_perturbed, ms_perturbed], axis=1)

x_unpert = torch.tensor(mu_ms_unpert)
x_pert = torch.tensor(mu_ms_pert)

with torch.no_grad():
    model.first_regime = False
    out_unpert = model.forward(x_unpert)
    out_pert = model.forward(x_pert)

In [106]:
out_unpert.keys()

dict_keys(['recon', 'z', 'mean', 'logvar', 'velocity', 'velocity_gp', 'alpha', 'beta', 'gamma'])

In [142]:
recon_unpert = out_unpert['recon']
mean_unpert = out_unpert['mean']
logvar_unpert = out_unpert['logvar']
gp_velo_unpert = out_unpert['velocity_gp']
velo_concat_unpert = out_unpert['velocity']
velo_u_unpert, velo_unpert = np.split(velo_concat_unpert, 2, axis=1)
alpha_unpert = out_unpert['alpha']
beta_unpert = out_unpert['beta']
gamma_unpert = out_unpert['gamma']

In [143]:
recon_pert = out_pert['recon']
mean_pert = out_pert['mean']
logvar_pert = out_pert['logvar']
gp_velo_pert = out_pert['velocity_gp']
velo_concat_pert = out_pert['velocity']
velo_u_pert, velo_pert = np.split(velo_concat_pert, 2, axis=1)
alpha_pert = out_pert['alpha']
beta_pert = out_pert['beta']
gamma_pert = out_pert['gamma']

In [144]:
recon_unpert.shape, mean_unpert.shape, logvar_unpert.shape, gp_velo_pert.shape

(torch.Size([1, 1805]),
 torch.Size([1, 647]),
 torch.Size([1, 647]),
 torch.Size([1, 647]))

In [145]:
velo_u_pert.shape, velo_pert.shape, alpha_unpert.shape, beta_unpert.shape, gamma_unpert.shape

(torch.Size([1, 1805]),
 torch.Size([1, 1805]),
 torch.Size([1, 1805]),
 torch.Size([1, 1805]),
 torch.Size([1, 1805]))

In [152]:
recon_diff = recon_pert - recon_unpert
mean_diff = mean_pert - mean_unpert
logvar_diff = mean_pert - mean_unpert
gp_velo_diff = gp_velo_pert - gp_velo_unpert
velo_u_diff = velo_u_pert - velo_u_unpert
velo_diff = velo_pert - velo_unpert
alpha_diff = alpha_pert - alpha_unpert
beta_diff = beta_pert - beta_unpert
gamma_diff = gamma_pert = gamma_unpert

to_numpy = lambda x : x.cpu().numpy().reshape(x.size(1))

df_gp = pd.DataFrame({
    'terms' : adata.uns['terms'],
    'mean' : to_numpy(mean_diff),
    'abs_mean' : abs(to_numpy(mean_diff)),
    'logvar' : to_numpy(logvar_diff),
    'abs_logvar' : abs(to_numpy(logvar_diff)),
    'gp_velocity' : to_numpy(gp_velo_diff),
    'abs_gp_velocity' : abs(to_numpy(gp_velo_diff)),
})

df_genes = pd.DataFrame({
    'genes' : adata.var_names,
    'recon' : to_numpy(recon_diff),
    'recon' : abs(to_numpy(recon_diff)),
    'unspliced_velocity' : to_numpy(velo_u_diff),
    'abs_unspliced_velocity' : abs(to_numpy(velo_u_diff)),
    'velocity' : to_numpy(velo_diff),
    'abs_velocity' : abs(to_numpy(velo_diff)),
    'alpha' : to_numpy(alpha_diff),
    'abs_alpha' : abs(to_numpy(alpha_diff)),
    'beta' : to_numpy(beta_diff),
    'abs_beta' : abs(to_numpy(beta_diff)),
    'gamma' : to_numpy(gamma_diff),
    'abs_gamma' : abs(to_numpy(gamma_diff)),
})

#df.sort_values('diff_abs',  ascending=False)

In [153]:
df_gp.sort_values('abs_gp_velocity',  ascending=False)

Unnamed: 0,terms,mean,abs_mean,logvar,abs_logvar,gp_velocity,abs_gp_velocity
571,BLOOD_CHAD63_KH_AGE_18_50YO_HI,-0.000230,0.000230,-0.000230,0.000230,-0.093851,0.093851
537,FOXP3_TARGETS_UP,0.000105,0.000105,0.000105,0.000105,-0.084915,0.084915
221,CONVENTIONAL_CDC_VS_PLASMACYTO,0.000100,0.000100,0.000100,0.000100,0.078448,0.078448
273,ADULT_VS_FETAL_DN3_THYMOCYTE_U,0.000134,0.000134,0.000134,0.000134,0.077340,0.077340
441,ABNORMAL_PANCREAS_SIZE,-0.000247,0.000247,-0.000247,0.000247,-0.076436,0.076436
...,...,...,...,...,...,...,...
74,MAIN_FETAL_MEGAKARYOCYTES,-0.000440,0.000440,-0.000440,0.000440,0.000056,0.000056
594,THYROID_CARCINOMA_ANAPLASTIC_D,-0.000590,0.000590,-0.000590,0.000590,0.000055,0.000055
288,EARLY_THYMIC_PROGENITOR_VS_DN3,-0.000336,0.000336,-0.000336,0.000336,-0.000038,0.000038
369,NAIVE_VS_DAY15_LCMV_ARMSTRONG_,0.000262,0.000262,0.000262,0.000262,-0.000034,0.000034


In [154]:
df_gp.sort_values('abs_mean',  ascending=False)

Unnamed: 0,terms,mean,abs_mean,logvar,abs_logvar,gp_velocity,abs_gp_velocity
541,RB1_TARGETS_UP,1.691818e-03,1.691818e-03,1.691818e-03,1.691818e-03,0.047035,0.047035
24,FETAL_CEREBELLUM_VASCULAR_ENDO,-1.658440e-03,1.658440e-03,-1.658440e-03,1.658440e-03,-0.033839,0.033839
319,LUPUS_VS_HEALTHY_DONOR_BCELL_D,-1.629829e-03,1.629829e-03,-1.629829e-03,1.629829e-03,-0.047757,0.047757
478,MATURITY_ONSET_DIABETES_OF_THE,1.610756e-03,1.610756e-03,1.610756e-03,1.610756e-03,0.035594,0.035594
584,REGULATION_OF_BETA_CELL_DEVELO,1.539230e-03,1.539230e-03,1.539230e-03,1.539230e-03,0.017361,0.017361
...,...,...,...,...,...,...,...
370,NAIVE_VS_DAY15_LCMV_CONE13_EFF,-3.337860e-06,3.337860e-06,-3.337860e-06,3.337860e-06,0.064858,0.064858
504,ERBB2_BREAST_TUMORS_324_DN,7.152557e-07,7.152557e-07,7.152557e-07,7.152557e-07,-0.065867,0.065867
617,DN,-5.960464e-07,5.960464e-07,-5.960464e-07,5.960464e-07,-0.021874,0.021874
580,DEVELOPMENTAL_CELL_LINEAGES_OF,-4.768372e-07,4.768372e-07,-4.768372e-07,4.768372e-07,-0.035391,0.035391


In [155]:
df_genes.columns

Index(['genes', 'recon', 'unspliced_velocity', 'abs_unspliced_velocity',
       'velocity', 'abs_velocity', 'alpha', 'abs_alpha', 'beta', 'abs_beta',
       'gamma', 'abs_gamma'],
      dtype='object')

In [157]:
df_genes.sort_values('abs_velocity',  ascending=False)

Unnamed: 0,genes,recon,unspliced_velocity,abs_unspliced_velocity,velocity,abs_velocity,alpha,abs_alpha,beta,abs_beta,gamma,abs_gamma
273,Pyy,0.088589,-0.013530,0.013530,-1.743633,1.743633,-0.045921,0.045921,-0.059900,0.059900,0.376129,0.376129
776,Malat1,0.286095,-1.042672,1.042672,-0.688868,0.688868,-0.026811,0.026811,0.012477,0.012477,0.499145,0.499145
954,Chgb,0.376650,-0.006716,0.006716,-0.371380,0.371380,-0.019874,0.019874,-0.072513,0.072513,0.258653,0.258653
412,Chga,0.197732,-0.004892,0.004892,-0.261398,0.261398,-0.012434,0.012434,-0.039355,0.039355,0.447275,0.447275
1414,Iapp,0.197994,-0.003386,0.003386,-0.209985,0.209985,-0.003498,0.003498,-0.010790,0.010790,1.131830,1.131830
...,...,...,...,...,...,...,...,...,...,...,...,...
67,Dtl,0.016374,0.000557,0.000557,0.000000,0.000000,0.000557,0.000557,0.000160,0.000160,1.721268,1.721268
789,Ankrd1,0.004981,0.000413,0.000413,0.000000,0.000000,0.000413,0.000413,0.012135,0.012135,1.632482,1.632482
264,Krt23,0.003612,-0.000447,0.000447,0.000000,0.000000,-0.000447,0.000447,-0.000352,0.000352,1.379743,1.379743
728,Rasgrp3,0.028836,0.000902,0.000902,0.000000,0.000000,0.000902,0.000902,0.012177,0.012177,1.340917,1.340917


# PERTURB GPs function