In [141]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [142]:
import scanpy as sc
import torch
import lineagevi as linvi
import numpy as np
import pandas as pd

In [143]:
adata_path = '/Users/lgolinelli/git/lineageVI/notebooks/data/outputs/pancreas_2025.08.17_12.43.17/adata_with_velocity.h5ad'
model_path = '/Users/lgolinelli/git/lineageVI/notebooks/data/outputs/pancreas_2025.08.17_12.43.17/vae_velocity_model.pt'

adata = sc.read_h5ad(adata_path)

model = linvi.trainer.LineageVI(
    adata,
)

model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

LineageVI(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Linear(in_features=1805, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
    )
    (mean_layer): Linear(in_features=128, out_features=647, bias=True)
    (logvar_layer): Linear(in_features=128, out_features=647, bias=True)
  )
  (gene_decoder): MaskedLinearDecoder(
    (linear): Linear(in_features=647, out_features=1805, bias=True)
  )
  (velocity_decoder): VelocityDecoder(
    (shared_decoder): Sequential(
      (0): Linear(in_features=647, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
    )
    (gp_velocity_decoder): Linear(in_features=128, out_features=647, bias=True)
    (gene_velocity_decoder): Sequential(
      (0): Linear(in_features=128, out_features=5415, bias=True)
      (1): Softplus(beta=1.0, threshold=20.0)
    )
  )
)

# PERTURB GENES FUNCTION

In [144]:
adata.layers

Layers with keys: Ms, Mu, recon, spliced, unspliced, velocity, velocity_u

In [None]:
cell_idx = 0
gene_idx = 0
spliced = True
unspliced = False
both = False
mu_unperturbed = adata.layers['Mu'][cell_idx, :]
ms_unperturbed = adata.layers['Ms'][cell_idx, :]
print(ms_unperturbed.shape)

mu_unperturbed = mu_unperturbed.reshape(1, mu_unperturbed.shape[0])
ms_unperturbed = ms_unperturbed.reshape(1, ms_unperturbed.shape[0])


mu_perturbed = mu_unperturbed.copy()
ms_perturbed = ms_unperturbed.copy()

perturb_value = 0 

if unspliced:
    mu_perturbed[:,gene_idx]= perturb_value
if spliced:
    ms_perturbed[:,gene_idx]= perturb_value
if both:
    mu_perturbed[:,gene_idx]= perturb_value
    ms_perturbed[:,gene_idx]= perturb_value

mu_ms_unpert = np.concatenate([mu_unperturbed, ms_unperturbed], axis=1)
mu_ms_pert = np.concatenate([mu_perturbed, ms_perturbed], axis=1)

x_unpert = torch.tensor(mu_ms_unpert)
x_pert = torch.tensor(mu_ms_pert)

with torch.no_grad():
    model.first_regime = False
    out_unpert = model.forward(x_unpert)
    out_pert = model.forward(x_pert)

(1805,)


In [146]:
out_unpert.keys()

dict_keys(['recon', 'z', 'mean', 'logvar', 'velocity', 'velocity_gp', 'alpha', 'beta', 'gamma'])

In [147]:
recon_unpert = out_unpert['recon']
mean_unpert = out_unpert['mean']
logvar_unpert = out_unpert['logvar']
gp_velo_unpert = out_unpert['velocity_gp']
velo_concat_unpert = out_unpert['velocity']
velo_u_unpert, velo_unpert = np.split(velo_concat_unpert, 2, axis=1)
alpha_unpert = out_unpert['alpha']
beta_unpert = out_unpert['beta']
gamma_unpert = out_unpert['gamma']

In [148]:
recon_pert = out_pert['recon']
mean_pert = out_pert['mean']
logvar_pert = out_pert['logvar']
gp_velo_pert = out_pert['velocity_gp']
velo_concat_pert = out_pert['velocity']
velo_u_pert, velo_pert = np.split(velo_concat_pert, 2, axis=1)
alpha_pert = out_pert['alpha']
beta_pert = out_pert['beta']
gamma_pert = out_pert['gamma']

In [149]:
recon_unpert.shape, mean_unpert.shape, logvar_unpert.shape, gp_velo_pert.shape

(torch.Size([1, 1805]),
 torch.Size([1, 647]),
 torch.Size([1, 647]),
 torch.Size([1, 647]))

In [150]:
velo_u_pert.shape, velo_pert.shape, alpha_unpert.shape, beta_unpert.shape, gamma_unpert.shape

(torch.Size([1, 1805]),
 torch.Size([1, 1805]),
 torch.Size([1, 1805]),
 torch.Size([1, 1805]),
 torch.Size([1, 1805]))

In [151]:
recon_diff = recon_pert - recon_unpert
mean_diff = mean_pert - mean_unpert
logvar_diff = mean_pert - mean_unpert
gp_velo_diff = gp_velo_pert - gp_velo_unpert
velo_u_diff = velo_u_pert - velo_u_unpert
velo_diff = velo_pert - velo_unpert
alpha_diff = alpha_pert - alpha_unpert
beta_diff = beta_pert - beta_unpert
gamma_diff = gamma_pert = gamma_unpert

to_numpy = lambda x : x.cpu().numpy().reshape(x.size(1))

df_gp = pd.DataFrame({
    'terms' : adata.uns['terms'],
    'mean' : to_numpy(mean_diff),
    'abs_mean' : abs(to_numpy(mean_diff)),
    'logvar' : to_numpy(logvar_diff),
    'abs_logvar' : abs(to_numpy(logvar_diff)),
    'gp_velocity' : to_numpy(gp_velo_diff),
    'abs_gp_velocity' : abs(to_numpy(gp_velo_diff)),
})

df_genes = pd.DataFrame({
    'genes' : adata.var_names,
    'recon' : to_numpy(recon_diff),
    'recon' : abs(to_numpy(recon_diff)),
    'unspliced_velocity' : to_numpy(velo_u_diff),
    'abs_unspliced_velocity' : abs(to_numpy(velo_u_diff)),
    'velocity' : to_numpy(velo_diff),
    'abs_velocity' : abs(to_numpy(velo_diff)),
    'alpha' : to_numpy(alpha_diff),
    'abs_alpha' : abs(to_numpy(alpha_diff)),
    'beta' : to_numpy(beta_diff),
    'abs_beta' : abs(to_numpy(beta_diff)),
    'gamma' : to_numpy(gamma_diff),
    'abs_gamma' : abs(to_numpy(gamma_diff)),
})

#df.sort_values('diff_abs',  ascending=False)

In [152]:
df_gp.sort_values('abs_gp_velocity',  ascending=False)

Unnamed: 0,terms,mean,abs_mean,logvar,abs_logvar,gp_velocity,abs_gp_velocity
261,IL6_IL1_VS_IL6_IL1_IL23_TREATE,0.000028,0.000028,0.000028,0.000028,-0.090058,0.090058
162,TREG_VS_TEFF_IN_IL2RB_KO_DN,0.000013,0.000013,0.000013,0.000013,-0.085409,0.085409
547,PUBERTAL_BREAST_4_5WK_UP,0.000045,0.000045,0.000045,0.000045,-0.082804,0.082804
604,RESPONSE_TO_FORSKOLIN_DN,-0.000015,0.000015,-0.000015,0.000015,0.080494,0.080494
38,FETAL_LIVER_MESOTHELIAL_CELLS,0.000014,0.000014,0.000014,0.000014,-0.074115,0.074115
...,...,...,...,...,...,...,...
425,PBMC_FLUARIX_AGE_50_74YO_CORR_,-0.000010,0.000010,-0.000010,0.000010,0.000369,0.000369
431,PRE_BI_TO_LARGE_PRE_BII_LYMPHO,-0.000002,0.000002,-0.000002,0.000002,0.000349,0.000349
138,DP_VS_CD8POS_THYMOCYTE_UP,0.000021,0.000021,0.000021,0.000021,0.000335,0.000335
577,A375_SOX10_TARGETS,-0.000007,0.000007,-0.000007,0.000007,0.000103,0.000103


In [153]:
df_gp.sort_values('abs_mean',  ascending=False)

Unnamed: 0,terms,mean,abs_mean,logvar,abs_logvar,gp_velocity,abs_gp_velocity
541,RB1_TARGETS_UP,-1.201630e-04,1.201630e-04,-1.201630e-04,1.201630e-04,0.033932,0.033932
24,FETAL_CEREBELLUM_VASCULAR_ENDO,1.182556e-04,1.182556e-04,1.182556e-04,1.182556e-04,-0.036567,0.036567
319,LUPUS_VS_HEALTHY_DONOR_BCELL_D,1.144409e-04,1.144409e-04,1.144409e-04,1.144409e-04,-0.038092,0.038092
478,MATURITY_ONSET_DIABETES_OF_THE,-1.144409e-04,1.144409e-04,-1.144409e-04,1.144409e-04,0.022282,0.022282
584,REGULATION_OF_BETA_CELL_DEVELO,-1.096725e-04,1.096725e-04,-1.096725e-04,1.096725e-04,0.043623,0.043623
...,...,...,...,...,...,...,...
617,DN,1.192093e-07,1.192093e-07,1.192093e-07,1.192093e-07,-0.017192,0.017192
137,DP_VS_CD8POS_THYMOCYTE_DN,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.034623,0.034623
102,FOXP3_TARGETS_CLUSTER_T7,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,0.032530,0.032530
504,ERBB2_BREAST_TUMORS_324_DN,0.000000e+00,0.000000e+00,0.000000e+00,0.000000e+00,-0.003120,0.003120


In [154]:
df_genes.columns

Index(['genes', 'recon', 'unspliced_velocity', 'abs_unspliced_velocity',
       'velocity', 'abs_velocity', 'alpha', 'abs_alpha', 'beta', 'abs_beta',
       'gamma', 'abs_gamma'],
      dtype='object')

In [155]:
df_genes.sort_values('abs_velocity',  ascending=False)

Unnamed: 0,genes,recon,unspliced_velocity,abs_unspliced_velocity,velocity,abs_velocity,alpha,abs_alpha,beta,abs_beta,gamma,abs_gamma
273,Pyy,0.167255,0.006774,0.006774,-2.436169,2.436169,-0.027311,0.027311,-0.063034,0.063034,0.380066,0.380066
954,Chgb,1.031860,-0.019316,0.019316,-1.079632,1.079632,-0.028360,0.028360,-0.049844,0.049844,0.249422,0.249422
776,Malat1,0.455452,-0.868658,0.868658,-0.493908,0.493908,0.001890,0.001890,0.010693,0.010693,0.508747,0.508747
790,Rbp4,0.293903,0.000556,0.000556,-0.354007,0.354007,-0.000124,0.000124,-0.010508,0.010508,0.522955,0.522955
1414,Iapp,0.133846,-0.000354,0.000354,-0.305811,0.305811,-0.000911,0.000911,-0.053768,0.053768,1.116701,1.116701
...,...,...,...,...,...,...,...,...,...,...,...,...
1568,Adamts18,0.028830,0.000208,0.000208,0.000000,0.000000,0.000208,0.000208,-0.020704,0.020704,1.309541,1.309541
1164,Clspn,0.000130,0.000233,0.000233,0.000000,0.000000,0.000233,0.000233,0.003508,0.003508,1.584048,1.584048
524,Slc39a2,0.002942,0.000422,0.000422,0.000000,0.000000,0.000422,0.000422,0.006061,0.006061,1.517614,1.517614
67,Dtl,0.034609,-0.000520,0.000520,0.000000,0.000000,-0.000520,0.000520,0.003645,0.003645,1.720535,1.720535


# PERTURB GPs function

In [160]:
cell_idx = 0
gp_idx = 0
mu = adata.layers['Mu'][cell_idx, :]
ms = adata.layers['Ms'][cell_idx, :]

mu_ms = torch.from_numpy(np.concatenate([mu_unperturbed, ms_unperturbed], axis=1))

perturb_value = 0

'''perturb_value = 0 if pert_type == 'ko' else max(adata.layers['Ms'][:,gene_idx])

if unspliced:
    mu_perturbed[:,gene_idx]= perturb_value
if spliced:
    ms_perturbed[:,gene_idx]= perturb_value
if both:
    mu_perturbed[:,gene_idx]= perturb_value
    ms_perturbed[:,gene_idx]= perturb_value

mu_ms_unpert = np.concatenate([mu_unperturbed, ms_unperturbed], axis=1)
mu_ms_pert = np.concatenate([mu_perturbed, ms_perturbed], axis=1)

x_unpert = torch.tensor(mu_ms_unpert)
x_pert = torch.tensor(mu_ms_pert)'''

with torch.no_grad():
    model.first_regime = False
    z_unpert, _, _ = model._forward_encoder(mu_ms)
    z_pert = z_unpert.clone()
    z_pert[:, gp_idx] = perturb_value
    velocity_unpert, velocity_gp_unpert, alpha_unpert, beta_unpert, gamma_unpert = model._forward_velocity_decoder(z_unpert, mu_ms)
    velocity_pert, velocity_gp_pert, alpha_pert, beta_pert, gamma_pert = model._forward_velocity_decoder(z_pert, mu_ms)
    x_dec_unpert = model._forward_gene_decoder(z_unpert)
    x_dec_pert = model._forward_gene_decoder(z_pert)

In [161]:
velo_diff = to_numpy(velocity_pert - velocity_unpert)
velo_gp_diff = to_numpy(velocity_gp_pert - velocity_gp_unpert)
alpha_diff = to_numpy(alpha_pert - alpha_unpert)
beta_diff = to_numpy(beta_pert - beta_unpert)
gamma_diff = to_numpy(gamma_pert - gamma_unpert)
x_dec_diff = to_numpy(x_dec_pert - x_dec_unpert)

In [163]:
velo_diff_u, velo_diff_s = np.split(velo_diff, 2)

In [None]:
genes_df = pd.DataFrame({
        'genes' : adata.var_names,
        'velo_diff_u' : velo_diff_u,
        'abs_velo_diff_u' : np.absolute(velo_diff_u),
        'velo_diff_s' : velo_diff_s,
        'abs_velo_diff_s' : np.absolute(velo_diff_s),
        'x_dec_diff' : x_dec_diff,
        'x_dec_diff_abs' : np.absolute(x_dec_diff),
        'alpha_diff' : alpha_diff,
        'alpha_diff_abs' : np.absolute(alpha_diff),
        'beta_diff' : beta_diff,
        'beta_diff_abs' : np.absolute(beta_diff),
        'gamma_diff' : gamma_diff,
        'gamma_diff_abs' : np.absolute(gamma_diff),
    })

gps_df = pd.DataFrame({
        'gene_programs' : adata.uns['terms'],
        'velo_gp' : velo_gp_diff,
        'abs_velo_gp' : velo_gp_diff,
    })


In [167]:
genes_df.sort_values('x_dec_diff_abs', ascending=False).head(10)

Unnamed: 0,genes,velo_diff_u,abs_velo_diff_u,velo_diff_s,abs_velo_diff_s,x_dec_diff,x_dec_diff_abs,alpha_diff,alpha_diff_abs,beta_diff,beta_diff_abs,gamma_diff,gamma_diff_abs
1166,Marcksl1,-2.379343e-05,2.379343e-05,0.000529,0.000529,-0.027441,0.027441,-2.379343e-05,2.379343e-05,0.000262,0.000262,-0.000127,0.000127
493,Gch1,9.742379e-05,9.742379e-05,0.001593,0.001593,-0.020818,0.020818,7.075071e-05,7.075071e-05,-0.000111,0.000111,-0.000334,0.000334
298,Upp1,-2.723187e-06,2.723187e-06,0.000118,0.000118,-0.011443,0.011443,-2.723187e-06,2.723187e-06,-0.00011,0.00011,-0.000364,0.000364
1406,Usp18,-1.499057e-05,1.499057e-05,7e-06,7e-06,-0.008874,0.008874,-4.529953e-06,4.529953e-06,0.000133,0.000133,4e-06,4e-06
1344,Mxd1,1.132488e-06,1.132488e-06,2.5e-05,2.5e-05,0.005005,0.005005,2.935529e-06,2.935529e-06,2.4e-05,2.4e-05,-0.000249,0.000249
1133,Gem,8.046627e-07,8.046627e-07,-9e-06,9e-06,0.004108,0.004108,-2.533197e-07,2.533197e-07,-3e-06,3e-06,0.000431,0.000431
66,Atf3,-6.60494e-06,6.60494e-06,6e-06,6e-06,0.003002,0.003002,-6.955117e-06,6.955117e-06,-2.7e-05,2.7e-05,-9.1e-05,9.1e-05
164,Rdh5,-2.745539e-06,2.745539e-06,-1.1e-05,1.1e-05,-0.002745,0.002745,-1.089647e-05,1.089647e-05,-0.000156,0.000156,0.000146,0.000146
1737,Gk,-2.233312e-05,2.233312e-05,1.6e-05,1.6e-05,-0.001801,0.001801,-1.928955e-05,1.928955e-05,4.7e-05,4.7e-05,-0.000114,0.000114
610,Litaf,-1.331419e-05,1.331419e-05,-1e-05,1e-05,0.001469,0.001469,-1.28597e-05,1.28597e-05,3.8e-05,3.8e-05,4.2e-05,4.2e-05
