In [113]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [114]:
import scanpy as sc
import torch
import lineagevi as linvi
import numpy as np
import pandas as pd

In [115]:
adata_path = '/Users/lgolinelli/git/lineageVI/notebooks/data/outputs/pancreas_2025.08.17_12.43.17/adata_with_velocity.h5ad'
model_path = '/Users/lgolinelli/git/lineageVI/notebooks/data/outputs/pancreas_2025.08.17_12.43.17/vae_velocity_model.pt'

adata = sc.read_h5ad(adata_path)

model = linvi.trainer.LineageVI(
    adata,
)

model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval()

LineageVI(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Linear(in_features=1805, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
    )
    (mean_layer): Linear(in_features=128, out_features=647, bias=True)
    (logvar_layer): Linear(in_features=128, out_features=647, bias=True)
  )
  (gene_decoder): MaskedLinearDecoder(
    (linear): Linear(in_features=647, out_features=1805, bias=True)
  )
  (velocity_decoder): VelocityDecoder(
    (shared_decoder): Sequential(
      (0): Linear(in_features=647, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
    )
    (gp_velocity_decoder): Linear(in_features=128, out_features=647, bias=True)
    (gene_velocity_decoder): Sequential(
      (0): Linear(in_features=128, out_features=5415, bias=True)
      (1): Softplus(beta=1.0, threshold=20.0)
    )
  )
)

# PERTURB GENES FUNCTION

In [116]:
adata.layers

Layers with keys: Ms, Mu, recon, spliced, unspliced, velocity, velocity_u

In [134]:
adata.var_names

Index(['Sntg1', 'Snhg6', 'Ncoa2', 'Sbspon', 'Ube2w', 'Mcm3', 'Fam135a',
       'Adgrb3', 'Tmem131', 'Tbc1d8',
       ...
       'Hsd17b10', 'Rragb', 'Map7d2', 'Sh3kbp1', 'Map3k15', 'Rai2', 'Rbbp7',
       'Ap1s2', 'Uty', 'Ddx3y'],
      dtype='object', name='index', length=1805)

In [139]:
genes_to_perturb = ['Sntg1', 'Snhg6']
idxs = np.where(adata.var_names.isin(genes_to_perturb))[0]
idxs

array([0, 1])

In [None]:
import numpy as np
import torch

# allow both int and list
cell_idx = [0, 1]   # could also be 0
gene_idx = idxs   # could also be 0
spliced = True
unspliced = False
both = False

# Always ensure we index properly
mu_unperturbed = adata.layers['Mu'][cell_idx, :]
ms_unperturbed = adata.layers['Ms'][cell_idx, :]

# Convert to 2D consistently
mu_unperturbed = np.atleast_2d(mu_unperturbed)
ms_unperturbed = np.atleast_2d(ms_unperturbed)

print("mu_unperturbed:", mu_unperturbed.shape)
print("ms_unperturbed:", ms_unperturbed.shape)

mu_perturbed = mu_unperturbed.copy()
ms_perturbed = ms_unperturbed.copy()

perturb_value = 0 

if unspliced:
    mu_perturbed[:, gene_idx] = perturb_value
if spliced:
    ms_perturbed[:, gene_idx] = perturb_value
if both:
    mu_perturbed[:, gene_idx] = perturb_value
    ms_perturbed[:, gene_idx] = perturb_value

# Concatenate along feature axis
mu_ms_unpert = np.concatenate([mu_unperturbed, ms_unperturbed], axis=1)
mu_ms_pert = np.concatenate([mu_perturbed, ms_perturbed], axis=1)

# Convert to torch tensors (float type for model)
x_unpert = torch.tensor(mu_ms_unpert, dtype=torch.float32)
x_pert = torch.tensor(mu_ms_pert, dtype=torch.float32)

with torch.no_grad():
    model.first_regime = False
    out_unpert = model.forward(x_unpert)
    out_pert = model.forward(x_pert)


mu_unperturbed: (2, 1805)
ms_unperturbed: (2, 1805)
[[0.00346039 0.2833065 ]
 [0.         0.7443291 ]]


In [118]:
recon_unpert = out_unpert['recon']
mean_unpert = out_unpert['mean']
logvar_unpert = out_unpert['logvar']
gp_velo_unpert = out_unpert['velocity_gp']
velo_concat_unpert = out_unpert['velocity']
velo_u_unpert, velo_unpert = np.split(velo_concat_unpert, 2, axis=1)
alpha_unpert = out_unpert['alpha']
beta_unpert = out_unpert['beta']
gamma_unpert = out_unpert['gamma']

In [119]:
recon_pert = out_pert['recon']
mean_pert = out_pert['mean']
logvar_pert = out_pert['logvar']
gp_velo_pert = out_pert['velocity_gp']
velo_concat_pert = out_pert['velocity']
velo_u_pert, velo_pert = np.split(velo_concat_pert, 2, axis=1)
alpha_pert = out_pert['alpha']
beta_pert = out_pert['beta']
gamma_pert = out_pert['gamma']

In [120]:
recon_unpert.shape, mean_unpert.shape, logvar_unpert.shape, gp_velo_pert.shape

(torch.Size([2, 1805]),
 torch.Size([2, 647]),
 torch.Size([2, 647]),
 torch.Size([2, 647]))

In [121]:
velo_u_pert.shape, velo_pert.shape, alpha_unpert.shape, beta_unpert.shape, gamma_unpert.shape

(torch.Size([2, 1805]),
 torch.Size([2, 1805]),
 torch.Size([2, 1805]),
 torch.Size([2, 1805]),
 torch.Size([2, 1805]))

In [122]:
recon_diff = recon_pert - recon_unpert
mean_diff = mean_pert - mean_unpert
logvar_diff = mean_pert - mean_unpert
gp_velo_diff = gp_velo_pert - gp_velo_unpert
velo_u_diff = velo_u_pert - velo_u_unpert
velo_diff = velo_pert - velo_unpert
alpha_diff = alpha_pert - alpha_unpert
beta_diff = beta_pert - beta_unpert
gamma_diff = gamma_pert = gamma_unpert

if recon_diff.shape[0] > 1:
    to_numpy = lambda x : x.cpu().numpy()

else:
    to_numpy = lambda x : x.cpu().numpy().reshape(x.size(1))

recon_diff = to_numpy(recon_diff).mean(0)
mean_diff = to_numpy(mean_diff).mean(0)
logvar_diff = to_numpy(logvar_diff).mean(0)
gp_velo_diff = to_numpy(gp_velo_diff).mean(0)
velo_u_diff = to_numpy(velo_u_diff).mean(0)
velo_diff = to_numpy(velo_diff).mean(0)
alpha_diff = to_numpy(alpha_diff).mean(0)
beta_diff = to_numpy(beta_diff).mean(0)
gamma_diff = to_numpy(gamma_diff).mean(0)

df_gp = pd.DataFrame({
    'terms' : adata.uns['terms'],
    'mean' : mean_diff,
    'abs_mean' : abs(mean_diff),
    'logvar' : logvar_diff,
    'abs_logvar' : abs(logvar_diff),
    'gp_velocity' : gp_velo_diff,
    'abs_gp_velocity' : abs(gp_velo_diff),
})

df_genes = pd.DataFrame({
    'genes' : adata.var_names,
    'recon' : recon_diff,
    'abs_recon' : abs(recon_diff),
    'unspliced_velocity' : velo_u_diff,
    'abs_unspliced_velocity' : abs(velo_u_diff),
    'velocity' : velo_diff,
    'abs_velocity' : abs(velo_diff),
    'alpha' : alpha_diff,
    'abs_alpha' : abs(alpha_diff),
    'beta' : beta_diff,
    'abs_beta' : abs(beta_diff),
    'gamma' : gamma_diff,
    'abs_gamma' : abs(gamma_diff),
})

#df.sort_values('diff_abs',  ascending=False)

In [123]:
df_gp.sort_values('abs_gp_velocity',  ascending=False)

Unnamed: 0,terms,mean,abs_mean,logvar,abs_logvar,gp_velocity,abs_gp_velocity
39,FETAL_LIVER_STELLATE_CELLS,0.000446,0.000446,0.000446,0.000446,-0.039464,0.039464
367,MEMORY_VS_EXHAUSTED_CD8_TCELL_,0.000460,0.000460,0.000460,0.000460,-0.039395,0.039395
464,GATA2_TARGETS_UP,0.000552,0.000552,0.000552,0.000552,-0.035677,0.035677
259,IL6_IL1_IL23_VS_IL6_IL1_TGFB_T,-0.000282,0.000282,-0.000282,0.000282,0.035189,0.035189
249,UNSTIM_VS_48H_MBOVIS_BCG_STIM_,-0.000300,0.000300,-0.000300,0.000300,0.033176,0.033176
...,...,...,...,...,...,...,...
442,ABNORMALITY_OF_PANCREAS_PHYSIO,0.000803,0.000803,0.000803,0.000803,0.000213,0.000213
306,TCF1_KO_VS_WT_LIN_NEG_CELL_UP,-0.000015,0.000015,-0.000015,0.000015,-0.000205,0.000205
172,CD45RA_NEG_CD4_TCELL_VS_NONSUP,-0.000224,0.000224,-0.000224,0.000224,0.000162,0.000162
35,FETAL_KIDNEY_METANEPHRIC_CELLS,-0.000721,0.000721,-0.000721,0.000721,0.000113,0.000113


In [124]:
df_gp.sort_values('abs_mean',  ascending=False)

Unnamed: 0,terms,mean,abs_mean,logvar,abs_logvar,gp_velocity,abs_gp_velocity
52,FETAL_PANCREAS_ISLET_ENDOCRINE,-2.275586e-03,2.275586e-03,-2.275586e-03,2.275586e-03,-0.001083,0.001083
31,FETAL_INTESTINE_CHROMAFFIN_CEL,-2.221346e-03,2.221346e-03,-2.221346e-03,2.221346e-03,-0.006514,0.006514
319,LUPUS_VS_HEALTHY_DONOR_BCELL_D,-2.140388e-03,2.140388e-03,-2.140388e-03,2.140388e-03,-0.001242,0.001242
584,REGULATION_OF_BETA_CELL_DEVELO,2.080753e-03,2.080753e-03,2.080753e-03,2.080753e-03,-0.025152,0.025152
63,FETAL_STOMACH_NEUROENDOCRINE_C,2.059020e-03,2.059020e-03,2.059020e-03,2.059020e-03,-0.004569,0.004569
...,...,...,...,...,...,...,...
211,WT_VS_HEB_KO_DP_THYMOCYTE_DN,6.914139e-06,6.914139e-06,6.914139e-06,6.914139e-06,0.006400,0.006400
625,COMMD1_TARGETS_GROUP_3_UP,5.960464e-06,5.960464e-06,5.960464e-06,5.960464e-06,-0.001572,0.001572
223,MULTIPOTENT_PROGENITOR_VS_CDC_,-4.649162e-06,4.649162e-06,-4.649162e-06,4.649162e-06,0.000253,0.000253
192,DAY6_VS_DAY10_EFF_CD8_TCELL_UP,-3.576279e-06,3.576279e-06,-3.576279e-06,3.576279e-06,-0.009461,0.009461


In [125]:
df_genes.columns

Index(['genes', 'recon', 'abs_recon', 'unspliced_velocity',
       'abs_unspliced_velocity', 'velocity', 'abs_velocity', 'alpha',
       'abs_alpha', 'beta', 'abs_beta', 'gamma', 'abs_gamma'],
      dtype='object')

In [126]:
df_genes.sort_values('abs_velocity',  ascending=False)

Unnamed: 0,genes,recon,abs_recon,unspliced_velocity,abs_unspliced_velocity,velocity,abs_velocity,alpha,abs_alpha,beta,abs_beta,gamma,abs_gamma
273,Pyy,-0.636447,0.636447,0.005982,0.005982,1.305429e+00,1.305429e+00,0.018894,0.018894,0.025548,0.025548,1.505645,1.505645
977,Gnas,-0.562402,0.562402,-0.030262,0.030262,2.576006e-01,2.576006e-01,0.008976,0.008976,-0.004395,0.004395,0.326213,0.326213
954,Chgb,-0.421679,0.421679,-0.004527,0.004527,2.098818e-01,2.098818e-01,-0.001049,0.001049,0.025636,0.025636,0.712406,0.712406
790,Rbp4,-0.407507,0.407507,0.000622,0.000622,2.073747e-01,2.073747e-01,0.000818,0.000818,0.005355,0.005355,0.927102,0.927102
2,Ncoa2,-0.003013,0.003013,0.008553,0.008553,1.663958e-01,1.663958e-01,0.003738,0.003738,-0.007835,0.007835,1.805570,1.805570
...,...,...,...,...,...,...,...,...,...,...,...,...,...
789,Ankrd1,0.004129,0.004129,0.000062,0.000062,2.783258e-06,2.783258e-06,0.000064,0.000064,0.002435,0.002435,1.988521,1.988521
431,Diras2,0.013120,0.013120,0.000601,0.000601,1.487322e-06,1.487322e-06,0.000668,0.000668,0.007296,0.007296,1.580824,1.580824
398,Dgkb,-0.003569,0.003569,-0.000975,0.000975,9.164214e-07,9.164214e-07,-0.001079,0.001079,-0.000881,0.000881,1.554065,1.554065
1568,Adamts18,-0.005385,0.005385,0.000297,0.000297,0.000000e+00,0.000000e+00,0.000297,0.000297,0.000454,0.000454,1.726389,1.726389


In [127]:
ctype_indices = {}
adata.obs['numerical_idx_linvi'] = np.arange(len(adata))
for cluster, df in adata.obs.groupby('clusters'):
    ctype_indices[cluster] = df['numerical_idx_linvi']

  for cluster, df in adata.obs.groupby('clusters'):


In [128]:
idx = np.where(adata.uns['terms'] == 'YBX1_TARGETS_DN')[0]
idx

array([1])

# PERTURB GPs function

In [None]:
ctypes_to_perturb = 'Beta'
cell_idx = ctype_indices[ctypes_to_perturb]

gps_to_perturb = ['YBX1_TARGETS_DN', 'RESPONSE_TO_LPS_WITH_MECHANICA']
gp_idx = np.where(pd.Series(adata.uns['terms']).isin(gps_to_perturb))[0]

mu = adata.layers['Mu'][cell_idx, :]
ms = adata.layers['Ms'][cell_idx, :]

mu_ms = torch.from_numpy(np.concatenate([mu_unperturbed, ms_unperturbed], axis=1))

perturb_value = 0

with torch.no_grad():
    model.first_regime = False
    z_unpert, _, _ = model._forward_encoder(mu_ms)
    z_pert = z_unpert.clone()
    z_pert[:, gp_idx] = perturb_value
    velocity_unpert, velocity_gp_unpert, alpha_unpert, beta_unpert, gamma_unpert = model._forward_velocity_decoder(z_unpert, mu_ms)
    velocity_pert, velocity_gp_pert, alpha_pert, beta_pert, gamma_pert = model._forward_velocity_decoder(z_pert, mu_ms)
    x_dec_unpert = model._forward_gene_decoder(z_unpert)
    x_dec_pert = model._forward_gene_decoder(z_pert)

In [None]:
velo_diff = to_numpy(velocity_pert - velocity_unpert)
velo_gp_diff = to_numpy(velocity_gp_pert - velocity_gp_unpert)
alpha_diff = to_numpy(alpha_pert - alpha_unpert)
beta_diff = to_numpy(beta_pert - beta_unpert)
gamma_diff = to_numpy(gamma_pert - gamma_unpert)
x_dec_diff = to_numpy(x_dec_pert - x_dec_unpert)

if velo_diff.shape[0] > 1:
    velo_diff = velo_diff.mean(0)
    velo_gp_diff = velo_gp_diff.mean(0)
    alpha_diff = alpha_diff.mean(0)
    beta_diff = beta_diff.mean(0)
    gamma_diff = gamma_diff.mean(0)
    x_dec_diff = x_dec_diff.mean(0)

In [131]:
velo_diff_u, velo_diff_s = np.split(velo_diff, 2)

In [None]:
genes_df = pd.DataFrame({
        'genes' : adata.var_names,
        'velo_diff_u' : velo_diff_u,
        'abs_velo_diff_u' : np.absolute(velo_diff_u),
        'velo_diff_s' : velo_diff_s,
        'abs_velo_diff_s' : np.absolute(velo_diff_s),
        'x_dec_diff' : x_dec_diff,
        'x_dec_diff_abs' : np.absolute(x_dec_diff),
        'alpha_diff' : alpha_diff,
        'alpha_diff_abs' : np.absolute(alpha_diff),
        'beta_diff' : beta_diff,
        'beta_diff_abs' : np.absolute(beta_diff),
        'gamma_diff' : gamma_diff,
        'gamma_diff_abs' : np.absolute(gamma_diff),
    })

gps_df = pd.DataFrame({
        'gene_programs' : adata.uns['terms'],
        'velo_gp' : velo_gp_diff,
        'abs_velo_gp' : velo_gp_diff,
    })

In [133]:
genes_df.sort_values('x_dec_diff_abs', ascending=False).head(10)

Unnamed: 0,genes,velo_diff_u,abs_velo_diff_u,velo_diff_s,abs_velo_diff_s,x_dec_diff,x_dec_diff_abs,alpha_diff,alpha_diff_abs,beta_diff,beta_diff_abs,gamma_diff,gamma_diff_abs
776,Malat1,-0.641241,0.641241,-0.254919,0.254919,-6.637493,6.637493,-0.091522,0.091522,0.006988,0.006988,0.016666,0.016666
875,Meis2,0.031781,0.031781,-0.055517,0.055517,-1.988783,1.988783,-0.000237,0.000237,-0.00484,0.00484,0.001561,0.001561
1166,Marcksl1,-0.000293,0.000293,-0.038032,0.038032,-1.909644,1.909644,-0.000304,0.000304,-0.013202,0.013202,0.008004,0.008004
1038,Tm4sf4,-0.005219,0.005219,-0.025519,0.025519,-1.430311,1.430311,5.4e-05,5.4e-05,0.02608,0.02608,0.007014,0.007014
1743,Xist,0.00044,0.00044,-0.043585,0.043585,-1.329419,1.329419,0.000263,0.000263,-0.007564,0.007564,0.007114,0.007114
1280,Spp1,0.005202,0.005202,-0.074396,0.074396,-1.175714,1.175714,0.004769,0.004769,-0.003334,0.003334,0.004401,0.004401
409,Fos,-0.000141,0.000141,0.033433,0.033433,-1.135473,1.135473,-0.000271,0.000271,-0.019622,0.019622,-0.009959,0.009959
1172,Camk2n1,0.0001,0.0001,0.026553,0.026553,-1.044331,1.044331,9.4e-05,9.4e-05,-0.005752,0.005752,-0.018241,0.018241
787,Rfx3,-0.020719,0.020719,0.01256,0.01256,-1.005717,1.005717,-0.005276,0.005276,0.014796,0.014796,0.007892,0.007892
1739,Pdk3,-0.004553,0.004553,-0.02111,0.02111,-0.922121,0.922121,-0.006719,0.006719,-0.007134,0.007134,0.023481,0.023481
