In [None]:
import numpy as np
import scanpy as sc
import scvelo as scv
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import scipy
from scipy.stats import spearmanr
import sys
import os
import matplotlib
sys.path.append("..")

matplotlib.rcParams['pdf.fonttype'] = 42

In [None]:
dataset_name = 'GCB'
sample_name = 'CtcfWT29'
model_type = 'DynaVelo'
num_hidden = 200
zx_dim = 50
zy_dim = 50
k_z0 = 1
k_t = 1
k_velocity = 10000
k_consistency = 10000
seed = 0

model_suffix = f'{dataset_name}_{sample_name}_{model_type}_num_hidden_{num_hidden}_zxdim_{zx_dim}_zydim_{zy_dim}_k_z0_{str(k_z0)}_k_t_{str(k_t)}_k_velocity_{str(k_velocity)}_k_consistency_{str(k_consistency)}_seed_{seed}'

# load predicted RNA
adata_rna_pred = sc.read_h5ad(f"/media/labuser/STORAGE/sc-multiome/data/MelnickLab_GerminalCenter_BCells/processed_data_rna/predicted/RNA_Matrix_Pred_{model_suffix}.h5ad")

# load predicted ATAC
adata_atac_pred = sc.read_h5ad(f"/media/labuser/STORAGE/sc-multiome/data/MelnickLab_GerminalCenter_BCells/processed_data_atac/MotifMatrix/predicted/Motif_Matrix_Pred_{model_suffix}.h5ad")

# check if the cell names are the same
assert all(adata_rna_pred.obs_names == adata_atac_pred.obs_names)

latent_time_dynavelo = pd.read_csv(f'../files/{sample_name}_latent_time_dynavelo.csv', sep="\t", index_col=False)
adata_rna_pred.obs['latent_time_dynavelo'] = adata_rna_pred.obs['latent_time'].copy()
adata_rna_pred.obs['latent_time'] = latent_time_dynavelo['latent_time'].values

adata_atac_pred.obs['latent_time_dynavelo'] = adata_atac_pred.obs['latent_time'].copy()
adata_atac_pred.obs['latent_time'] = latent_time_dynavelo['latent_time'].values

In [None]:
adata_atac_pred.obs['final.celltype'] = pd.Categorical(adata_atac_pred.obs['final.celltype'], categories=['Centroblast', 'Transitioning', 'Centrocyte', 'Plasmablast', 'Prememory'], ordered=True)
adata_atac_pred.obs['fine.celltype'] = pd.Categorical(adata_atac_pred.obs['fine.celltype'], categories=['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Recycling', 'CC_Rec', 'Plasmablast', 'Prememory'], ordered=True)
adata_atac_pred.uns['final.celltype_colors'] = ['#74c476', '#9ecae1', '#3182bd', '#1be7ff', '#e2a0ff']
adata_atac_pred.uns['fine.celltype_colors'] = ['#deebf7', '#9ecae1', '#3182bd', '#edf8e9', '#bae4b3', '#74c476', '#238b45', '#d94701', '#fd8d3c', '#1be7ff', '#e2a0ff']

adata_rna_pred.obs['final.celltype'] = adata_atac_pred.obs['final.celltype']
adata_rna_pred.obs['fine.celltype'] = adata_atac_pred.obs['fine.celltype']
adata_rna_pred.uns['final.celltype_colors'] = adata_atac_pred.uns['final.celltype_colors']
adata_rna_pred.uns['fine.celltype_colors'] = adata_atac_pred.uns['fine.celltype_colors']

In [None]:
dataset_name = 'GCB'
sample_name = 'Arid1aHetM3'
model_type = 'DynaVelo'
num_hidden = 200
zx_dim = 50
zy_dim = 50
k_z0 = 1
k_t = 1
k_velocity = 10000
k_consistency = 10000
seed = 0

model_suffix = f'{dataset_name}_{sample_name}_{model_type}_num_hidden_{num_hidden}_zxdim_{zx_dim}_zydim_{zy_dim}_k_z0_{str(k_z0)}_k_t_{str(k_t)}_k_velocity_{str(k_velocity)}_k_consistency_{str(k_consistency)}_seed_{seed}'

# load predicted RNA
adata_rna_pred_Arid1aHet = sc.read_h5ad(f"/media/labuser/STORAGE/sc-multiome/data/MelnickLab_GerminalCenter_BCells/processed_data_rna/predicted/RNA_Matrix_Pred_{model_suffix}.h5ad")

# load predicted ATAC
adata_atac_pred_Arid1aHet = sc.read_h5ad(f"/media/labuser/STORAGE/sc-multiome/data/MelnickLab_GerminalCenter_BCells/processed_data_atac/MotifMatrix/predicted/Motif_Matrix_Pred_{model_suffix}.h5ad")

# check if the cell names are the same
assert all(adata_rna_pred_Arid1aHet.obs_names == adata_atac_pred_Arid1aHet.obs_names)

latent_time_dynavelo = pd.read_csv(f'../files/{sample_name}_latent_time_dynavelo.csv', sep="\t", index_col=False)
adata_rna_pred_Arid1aHet.obs['latent_time_dynavelo'] = adata_rna_pred_Arid1aHet.obs['latent_time'].copy()
adata_rna_pred_Arid1aHet.obs['latent_time'] = latent_time_dynavelo['latent_time'].values

adata_atac_pred_Arid1aHet.obs['latent_time_dynavelo'] = adata_atac_pred_Arid1aHet.obs['latent_time'].copy()
adata_atac_pred_Arid1aHet.obs['latent_time'] = latent_time_dynavelo['latent_time'].values

In [None]:
adata_atac_pred_Arid1aHet.obs['final.celltype'] = pd.Categorical(adata_atac_pred_Arid1aHet.obs['final.celltype'], categories=['Centroblast', 'Transitioning', 'Centrocyte', 'Plasmablast', 'Prememory'], ordered=True)
adata_atac_pred_Arid1aHet.obs['fine.celltype'] = pd.Categorical(adata_atac_pred_Arid1aHet.obs['fine.celltype'], categories=['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Recycling', 'CC_Rec', 'Plasmablast', 'Prememory'], ordered=True)
adata_atac_pred_Arid1aHet.uns['final.celltype_colors'] = ['#74c476', '#9ecae1', '#3182bd', '#1be7ff', '#e2a0ff']
adata_atac_pred_Arid1aHet.uns['fine.celltype_colors'] = ['#deebf7', '#9ecae1', '#3182bd', '#edf8e9', '#bae4b3', '#74c476', '#238b45', '#d94701', '#fd8d3c', '#1be7ff', '#e2a0ff']

adata_rna_pred_Arid1aHet.obs['final.celltype'] = adata_atac_pred_Arid1aHet.obs['final.celltype']
adata_rna_pred_Arid1aHet.obs['fine.celltype'] = adata_atac_pred_Arid1aHet.obs['fine.celltype']
adata_rna_pred_Arid1aHet.uns['final.celltype_colors'] = adata_atac_pred_Arid1aHet.uns['final.celltype_colors']
adata_rna_pred_Arid1aHet.uns['fine.celltype_colors'] = adata_atac_pred_Arid1aHet.uns['fine.celltype_colors']

In [None]:
dataset_name = 'GCB'
sample_name = 'CtcfHet30'
model_type = 'DynaVelo'
num_hidden = 200
zx_dim = 50
zy_dim = 50
k_z0 = 1
k_t = 1
k_velocity = 10000
k_consistency = 10000
seed = 0

model_suffix = f'{dataset_name}_{sample_name}_{model_type}_num_hidden_{num_hidden}_zxdim_{zx_dim}_zydim_{zy_dim}_k_z0_{str(k_z0)}_k_t_{str(k_t)}_k_velocity_{str(k_velocity)}_k_consistency_{str(k_consistency)}_seed_{seed}'

# load predicted RNA
adata_rna_pred_CtcfHet = sc.read_h5ad(f"/media/labuser/STORAGE/sc-multiome/data/MelnickLab_GerminalCenter_BCells/processed_data_rna/predicted/RNA_Matrix_Pred_{model_suffix}.h5ad")

# load predicted ATAC
adata_atac_pred_CtcfHet = sc.read_h5ad(f"/media/labuser/STORAGE/sc-multiome/data/MelnickLab_GerminalCenter_BCells/processed_data_atac/MotifMatrix/predicted/Motif_Matrix_Pred_{model_suffix}.h5ad")

# check if the cell names are the same
assert all(adata_rna_pred_CtcfHet.obs_names == adata_atac_pred_CtcfHet.obs_names)

latent_time_dynavelo = pd.read_csv(f'../files/{sample_name}_latent_time_dynavelo.csv', sep="\t", index_col=False)
adata_rna_pred_CtcfHet.obs['latent_time_dynavelo'] = adata_rna_pred_CtcfHet.obs['latent_time'].copy()
adata_rna_pred_CtcfHet.obs['latent_time'] = latent_time_dynavelo['latent_time'].values

adata_atac_pred_CtcfHet.obs['latent_time_dynavelo'] = adata_atac_pred_CtcfHet.obs['latent_time'].copy()
adata_atac_pred_CtcfHet.obs['latent_time'] = latent_time_dynavelo['latent_time'].values

In [None]:
adata_atac_pred_CtcfHet.obs['final.celltype'] = pd.Categorical(adata_atac_pred_CtcfHet.obs['final.celltype'], categories=['Centroblast', 'Transitioning', 'Centrocyte', 'Plasmablast', 'Prememory'], ordered=True)
adata_atac_pred_CtcfHet.obs['fine.celltype'] = pd.Categorical(adata_atac_pred_CtcfHet.obs['fine.celltype'], categories=['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Recycling', 'CC_Rec', 'Plasmablast', 'Prememory'], ordered=True)
adata_atac_pred_CtcfHet.uns['final.celltype_colors'] = ['#74c476', '#9ecae1', '#3182bd', '#1be7ff', '#e2a0ff']
adata_atac_pred_CtcfHet.uns['fine.celltype_colors'] = ['#deebf7', '#9ecae1', '#3182bd', '#edf8e9', '#bae4b3', '#74c476', '#238b45', '#d94701', '#fd8d3c', '#1be7ff', '#e2a0ff']

adata_rna_pred_CtcfHet.obs['final.celltype'] = adata_atac_pred_CtcfHet.obs['final.celltype']
adata_rna_pred_CtcfHet.obs['fine.celltype'] = adata_atac_pred_CtcfHet.obs['fine.celltype']
adata_rna_pred_CtcfHet.uns['final.celltype_colors'] = adata_atac_pred_CtcfHet.uns['final.celltype_colors']
adata_rna_pred_CtcfHet.uns['fine.celltype_colors'] = adata_atac_pred_CtcfHet.uns['fine.celltype_colors']

#### Fig. 6A

In [None]:
%matplotlib inline
fig_dir = f'../figures/Fig6'

sample_name = 'CtcfWT29'

perturbed_gene_list = adata_rna_pred.uns['Perturbed_genes']

sc.pp.neighbors(adata_atac_pred, n_neighbors=50, use_rep='X')
sc.tl.umap(adata_atac_pred, min_dist=1, spread=1, random_state=0, n_components=2)

for idx_gene, perturbed_gene in enumerate(perturbed_gene_list):
    print('perturbed_gene: ', perturbed_gene)

    if not os.path.exists(fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}'):
        os.makedirs(fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}')
    perturbed_path = fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}/'

    # norm of delta velocities
    delta_vzx_norm = np.linalg.norm(adata_rna_pred.obsm['delta_vzx'][:,:,idx_gene], axis=1)
    delta_vzy_norm = np.linalg.norm(adata_rna_pred.obsm['delta_vzy'][:,:,idx_gene], axis=1)

    delta_vx_norm = np.linalg.norm(adata_rna_pred.obsm['delta_vx'][:,:,idx_gene], axis=1)
    delta_vy_norm = np.linalg.norm(adata_rna_pred.obsm['delta_vy'][:,:,idx_gene], axis=1)

    # delta_vx_pred
    adata_rna_pred.layers['velocity'] = adata_rna_pred.obsm['delta_vx'][:,:,idx_gene].copy()
    adata_rna_pred.obs['delta_time'] = adata_rna_pred.obsm['delta_latent_time'][:,idx_gene].copy()
    adata_rna_pred.obs['delta_vx_norm'] = delta_vx_norm

    scv.tl.velocity_graph(adata_rna_pred, xkey='Ms', vkey='velocity', gene_subset=adata_rna_pred.var_names, sqrt_transform=False)
    scv.pl.velocity_embedding_stream(adata_rna_pred, color=['fine.celltype'], color_map='viridis', perc=[2,98], wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, save=perturbed_path+f'delta_vx_over_x_obs_stream_{sample_name}.png')

    for tf in perturbed_gene_list:
        vmax = adata_rna_pred[:,tf].layers['velocity'].max()
        scv.pl.velocity_embedding_stream(adata_rna_pred, color=tf, layer='velocity', vmax=vmax, vmin=-vmax, color_map='bwr', wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, show=False, save=perturbed_path+f'delta_vx_over_x_obs_stream_delta_{tf}_{sample_name}.png')
    
    # delta_vy_pred
    adata_atac_pred.layers['Ms'] = adata_atac_pred.X.copy()
    adata_atac_pred.var_names = adata_atac_pred.var['TF'].values
    adata_atac_pred.layers['velocity'] = adata_rna_pred.obsm['delta_vy'][:,:,idx_gene].copy()
    adata_atac_pred.obs['delta_time'] = adata_rna_pred.obsm['delta_latent_time'][:,idx_gene].copy()
    adata_atac_pred.obs['delta_vy_norm'] = delta_vy_norm

    scv.tl.velocity_graph(adata_atac_pred, xkey='Ms', vkey='velocity', gene_subset=adata_atac_pred.var_names, sqrt_transform=False)
    scv.pl.velocity_embedding_stream(adata_atac_pred, color=['fine.celltype'], color_map='viridis', perc=[2,98], wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, save=perturbed_path+f'delta_vy_over_y_obs_stream_{sample_name}.png')

    for tf in np.intersect1d(perturbed_gene_list, adata_atac_pred.var_names):
        vmax = adata_atac_pred[:,tf].layers['velocity'].max()
        scv.pl.velocity_embedding_stream(adata_atac_pred, color=tf, layer='velocity', vmax=vmax, vmin=-vmax, color_map='bwr', wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, show=False, save=perturbed_path+f'delta_vy_over_y_obs_stream_delta_{tf}_{sample_name}.png')


#### Fig. 6B

In [None]:
celltypes = ['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Prememory']

shared_genes, shared_genes_idx_wt, shared_genes_idx_mut = np.intersect1d(adata_rna_pred.var_names, adata_rna_pred_Arid1aHet.var_names, return_indices=True)

arid1a_idx = np.where(adata_rna_pred.uns['Perturbed_genes']=='Arid1a')[0][0]
df_ttest_vx_expr = pd.DataFrame(index=shared_genes, columns=celltypes)
df_ttest_vx_pred = pd.DataFrame(index=shared_genes, columns=celltypes)

for celltype in celltypes:

    # ttest for experimental
    vx_wt =  adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].layers['vx_pred_mean']
    vx_wt = vx_wt[:, shared_genes_idx_wt]
    vx_mut =  adata_rna_pred_Arid1aHet[adata_rna_pred_Arid1aHet.obs['fine.celltype']==celltype].layers['vx_pred_mean']
    vx_mut = vx_mut[:, shared_genes_idx_mut]

    result_experimental = scipy.stats.ttest_ind(vx_mut, vx_wt, 0)
    ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
    pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
    idx_t_keep = np.argsort(pval_experimental)[:100]
    df_ttest_vx_expr.loc[:, celltype] = ttest_experimental

    # ttest for in-silico prediction
    delta_vx = adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].obsm['delta_vx']
    delta_vx = delta_vx[:, :, arid1a_idx]
    delta_vx = delta_vx[:, shared_genes_idx_wt]

    result_pred = scipy.stats.ttest_1samp(delta_vx, 0)
    ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
    pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
    df_ttest_vx_pred.loc[:, celltype] = ttest_pred

df_ttest_vx_prod = df_ttest_vx_expr * df_ttest_vx_pred
df_ttest_vx_prod = df_ttest_vx_prod[df_ttest_vx_prod.min(1)>0]
df_ttest_vx_pred_sub = df_ttest_vx_pred.loc[df_ttest_vx_prod.index]

g = sns.clustermap(
    df_ttest_vx_pred_sub,
    row_cluster=True,
    col_cluster=False,
    cmap='bwr',
    xticklabels=True,
    yticklabels=True,
    center=0,
    #vmax=20,
    #vmin=-20,
    cbar_pos=(0.92, 0.58, 0.05, 0.2),
    figsize=(4, 7),
    rasterized=True
)
g.ax_row_dendrogram.set_visible(False)
g.cax.set_ylabel('t-test / delta RNA velocity (Arid1aKO-WT)')
g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
plt.savefig(fig_dir+f'/ttest_delta_rna_velocity_WT_Arid1a_KO.pdf', bbox_inches="tight")
plt.show()
plt.close()

In [None]:
shared_motifs, shared_motifs_idx_wt, shared_motifs_idx_mut = np.intersect1d(adata_atac_pred.var_names, adata_atac_pred_Arid1aHet.var_names, return_indices=True)

df_ttest_vy_expr = pd.DataFrame(index=shared_motifs, columns=celltypes)
df_ttest_vy_pred = pd.DataFrame(index=shared_motifs, columns=celltypes)

for celltype in celltypes:

    # ttest for experimental
    vy_wt =  adata_atac_pred[adata_atac_pred.obs['fine.celltype']==celltype].layers['vy_pred_mean']
    vy_wt = vy_wt[:, shared_motifs_idx_wt]
    vy_mut =  adata_atac_pred_Arid1aHet[adata_atac_pred_Arid1aHet.obs['fine.celltype']==celltype].layers['vy_pred_mean']
    vy_mut = vy_mut[:, shared_motifs_idx_mut]

    result_experimental = scipy.stats.ttest_ind(vy_mut, vy_wt, 0)
    ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
    pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
    idx_t_keep = np.argsort(pval_experimental)[:100]
    df_ttest_vy_expr.loc[:, celltype] = ttest_experimental

    # ttest for in-silico prediction
    delta_vy = adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].obsm['delta_vy']
    delta_vy = delta_vy[:, :, arid1a_idx]
    delta_vy = delta_vy[:, shared_motifs_idx_wt]

    result_pred = scipy.stats.ttest_1samp(delta_vy, 0)
    ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
    pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
    df_ttest_vy_pred.loc[:, celltype] = ttest_pred

df_ttest_vy_prod = df_ttest_vy_expr * df_ttest_vy_pred
df_ttest_vy_prod = df_ttest_vy_prod[df_ttest_vy_prod.min(1)>0]
df_ttest_vy_pred_sub = df_ttest_vy_pred.loc[df_ttest_vy_prod.index]

g = sns.clustermap(
    df_ttest_vy_pred_sub,
    row_cluster=True,
    col_cluster=False,
    #cmap='PiYG_r',
    cmap='bwr',
    xticklabels=True,
    yticklabels=True,
    center=0,
    #vmax=20,
    #vmin=-20,
    cbar_pos=(0.92, 0.58, 0.05, 0.2),
    figsize=(4, 3),
    rasterized=True
)
g.ax_row_dendrogram.set_visible(False)
g.cax.set_ylabel('t-test / delta motif velocity (Arid1aKO-WT)')
g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
plt.savefig(fig_dir+f'/ttest_delta_motif_velocity_WT_Arid1a_KO.pdf', bbox_inches="tight")
plt.show()
plt.close()

#### Fig. 6C

In [None]:
celltypes = ['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Prememory']

shared_genes, shared_genes_idx_wt, shared_genes_idx_mut = np.intersect1d(adata_rna_pred.var_names, adata_rna_pred_CtcfHet.var_names, return_indices=True)

ctcf_idx = np.where(adata_rna_pred.uns['Perturbed_genes']=='Ctcf')[0][0]
df_ttest_vx_expr = pd.DataFrame(index=shared_genes, columns=celltypes)
df_ttest_vx_pred = pd.DataFrame(index=shared_genes, columns=celltypes)

for celltype in celltypes:

    # ttest for experimental
    vx_wt =  adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].layers['vx_pred_mean']
    vx_wt = vx_wt[:, shared_genes_idx_wt]
    vx_mut =  adata_rna_pred_CtcfHet[adata_rna_pred_CtcfHet.obs['fine.celltype']==celltype].layers['vx_pred_mean']
    vx_mut = vx_mut[:, shared_genes_idx_mut]

    result_experimental = scipy.stats.ttest_ind(vx_mut, vx_wt, 0)
    ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
    pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
    idx_t_keep = np.argsort(pval_experimental)[:100]
    df_ttest_vx_expr.loc[:, celltype] = ttest_experimental

    # ttest for in-silico prediction
    delta_vx = adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].obsm['delta_vx']
    delta_vx = delta_vx[:, :, ctcf_idx]
    delta_vx = delta_vx[:, shared_genes_idx_wt]

    result_pred = scipy.stats.ttest_1samp(delta_vx, 0)
    ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
    pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
    df_ttest_vx_pred.loc[:, celltype] = ttest_pred

df_ttest_vx_prod = df_ttest_vx_expr * df_ttest_vx_pred
df_ttest_vx_prod = df_ttest_vx_prod[df_ttest_vx_prod.min(1)>0]
df_ttest_vx_pred_sub = df_ttest_vx_pred.loc[df_ttest_vx_prod.index]

g = sns.clustermap(
    df_ttest_vx_pred_sub,
    row_cluster=True,
    col_cluster=False,
    cmap='bwr',
    xticklabels=True,
    yticklabels=True,
    center=0,
    #vmax=20,
    #vmin=-20,
    cbar_pos=(0.92, 0.58, 0.05, 0.2),
    figsize=(4, 7),
    rasterized=True
)
g.ax_row_dendrogram.set_visible(False)
g.cax.set_ylabel('t-test / delta RNA velocity (CtcfKO-WT)')
g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
plt.savefig(fig_dir+f'/ttest_delta_rna_velocity_WT_Ctcf_KO.pdf', bbox_inches="tight")
plt.show()
plt.close()

In [None]:
shared_motifs, shared_motifs_idx_wt, shared_motifs_idx_mut = np.intersect1d(adata_atac_pred.var_names, adata_atac_pred_CtcfHet.var_names, return_indices=True)

df_ttest_vy_expr = pd.DataFrame(index=shared_motifs, columns=celltypes)
df_ttest_vy_pred = pd.DataFrame(index=shared_motifs, columns=celltypes)

for celltype in celltypes:

    # ttest for experimental
    vy_wt =  adata_atac_pred[adata_atac_pred.obs['fine.celltype']==celltype].layers['vy_pred_mean']
    vy_wt = vy_wt[:, shared_motifs_idx_wt]
    vy_mut =  adata_atac_pred_CtcfHet[adata_atac_pred_CtcfHet.obs['fine.celltype']==celltype].layers['vy_pred_mean']
    vy_mut = vy_mut[:, shared_motifs_idx_mut]

    result_experimental = scipy.stats.ttest_ind(vy_mut, vy_wt, 0)
    ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
    pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
    idx_t_keep = np.argsort(pval_experimental)[:100]
    df_ttest_vy_expr.loc[:, celltype] = ttest_experimental

    # ttest for in-silico prediction
    delta_vy = adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].obsm['delta_vy']
    delta_vy = delta_vy[:, :, ctcf_idx]
    delta_vy = delta_vy[:, shared_motifs_idx_wt]

    result_pred = scipy.stats.ttest_1samp(delta_vy, 0)
    ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
    pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
    df_ttest_vy_pred.loc[:, celltype] = ttest_pred

df_ttest_vy_prod = df_ttest_vy_expr * df_ttest_vy_pred
df_ttest_vy_prod = df_ttest_vy_prod[df_ttest_vy_prod.min(1)>0]
df_ttest_vy_pred_sub = df_ttest_vy_pred.loc[df_ttest_vy_prod.index]

g = sns.clustermap(
    df_ttest_vy_pred_sub,
    row_cluster=True,
    col_cluster=False,
    #cmap='PiYG_r',
    cmap='bwr',
    xticklabels=True,
    yticklabels=True,
    center=0,
    #vmax=20,
    #vmin=-20,
    cbar_pos=(0.92, 0.58, 0.05, 0.2),
    figsize=(4, 3),
    rasterized=True
)
g.ax_row_dendrogram.set_visible(False)
g.cax.set_ylabel('t-test / delta motif velocity (CtcfKO-WT)')
g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
plt.savefig(fig_dir+f'/ttest_delta_motif_velocity_WT_Ctcf_KO.pdf', bbox_inches="tight")
plt.show()
plt.close()

#### Fig. 6D

In [None]:
celltypes = ['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Prememory']

shared_genes, shared_genes_idx_wt, shared_genes_idx_mut = np.intersect1d(adata_rna_pred.var_names, adata_rna_pred_Arid1aHet.var_names, return_indices=True)
df_correct_preds = pd.DataFrame(data=0, index=adata_rna_pred_Arid1aHet.uns['Perturbed_genes'], columns=['N_correct_genes', 'N_correct_motifs'])

for gene_perturbed in adata_rna_pred_Arid1aHet.uns['Perturbed_genes']:
    print('gene_perturbed: ', gene_perturbed)
    gene_perturbed_idx = np.where(adata_rna_pred_Arid1aHet.uns['Perturbed_genes']==gene_perturbed)[0][0]
    df_ttest_vx_expr = pd.DataFrame(index=shared_genes, columns=celltypes)
    df_ttest_vx_pred = pd.DataFrame(index=shared_genes, columns=celltypes)

    for celltype in celltypes:
        
        # ttest for experimental
        vx_wt =  adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].layers['vx_pred_mean']
        vx_wt = vx_wt[:, shared_genes_idx_wt]
        vx_mut =  adata_rna_pred_Arid1aHet[adata_rna_pred_Arid1aHet.obs['fine.celltype']==celltype].layers['vx_pred_mean']
        vx_mut = vx_mut[:, shared_genes_idx_mut]

        result_experimental = scipy.stats.ttest_ind(vx_wt, vx_mut, 0)
        ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
        pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
        idx_t_keep = np.argsort(pval_experimental)[:100]
        df_ttest_vx_expr.loc[:, celltype] = ttest_experimental

        # ttest for in-silico prediction
        delta_vx = adata_rna_pred_Arid1aHet[adata_rna_pred_Arid1aHet.obs['fine.celltype']==celltype].obsm['delta_vx']
        delta_vx = delta_vx[:, :, gene_perturbed_idx]
        delta_vx = delta_vx[:, shared_genes_idx_mut]

        result_pred = scipy.stats.ttest_1samp(delta_vx, 0)
        ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
        pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
        df_ttest_vx_pred.loc[:, celltype] = ttest_pred

    df_ttest_vx_prod = df_ttest_vx_expr * df_ttest_vx_pred
    df_ttest_vx_prod = df_ttest_vx_prod[df_ttest_vx_prod.min(1)>0]
    df_ttest_vx_pred_sub = df_ttest_vx_pred.loc[df_ttest_vx_prod.index]
    df_correct_preds.loc[gene_perturbed, 'N_correct_genes'] = len(df_ttest_vx_pred_sub)

    g = sns.clustermap(
        df_ttest_vx_pred_sub,
        row_cluster=True,
        col_cluster=False,
        cmap='bwr',
        xticklabels=True,
        yticklabels=True,
        center=0,
        #vmax=20,
        #vmin=-20,
        cbar_pos=(0.92, 0.58, 0.05, 0.2),
        figsize=(4, 7),
        rasterized=True
    )
    g.ax_row_dendrogram.set_visible(False)
    g.cax.set_ylabel(f't-test / Arid1aHet {gene_perturbed} KO vs Arid1aHet')
    g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
    plt.savefig(fig_dir+f'/ttest_rna_velocity_Arid1aHet_{gene_perturbed}_KO_vs_Arid1aHet.pdf', bbox_inches="tight")
    plt.close()

In [None]:
shared_motifs, shared_motifs_idx_wt, shared_motifs_idx_mut = np.intersect1d(adata_atac_pred.var_names, adata_atac_pred_Arid1aHet.var_names, return_indices=True)

for gene_perturbed in adata_rna_pred_Arid1aHet.uns['Perturbed_genes']:
    print('gene_perturbed: ', gene_perturbed)
    gene_perturbed_idx = np.where(adata_rna_pred_Arid1aHet.uns['Perturbed_genes']==gene_perturbed)[0][0]
    df_ttest_vy_expr = pd.DataFrame(index=shared_motifs, columns=celltypes)
    df_ttest_vy_pred = pd.DataFrame(index=shared_motifs, columns=celltypes)

    for celltype in celltypes:

        # ttest for experimental
        vy_wt =  adata_atac_pred[adata_atac_pred.obs['fine.celltype']==celltype].layers['vy_pred_mean']
        vy_wt = vy_wt[:, shared_motifs_idx_wt]
        vy_mut =  adata_atac_pred_Arid1aHet[adata_atac_pred_Arid1aHet.obs['fine.celltype']==celltype].layers['vy_pred_mean']
        vy_mut = vy_mut[:, shared_motifs_idx_mut]

        result_experimental = scipy.stats.ttest_ind(vy_wt, vy_mut, 0)
        ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
        pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
        idx_t_keep = np.argsort(pval_experimental)[:100]
        df_ttest_vy_expr.loc[:, celltype] = ttest_experimental

        # ttest for in-silico prediction
        delta_vy = adata_rna_pred_Arid1aHet[adata_rna_pred_Arid1aHet.obs['fine.celltype']==celltype].obsm['delta_vy']
        delta_vy = delta_vy[:, :, gene_perturbed_idx]
        delta_vy = delta_vy[:, shared_motifs_idx_mut]

        result_pred = scipy.stats.ttest_1samp(delta_vy, 0)
        ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
        pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
        df_ttest_vy_pred.loc[:, celltype] = ttest_pred

    df_ttest_vy_prod = df_ttest_vy_expr * df_ttest_vy_pred
    df_ttest_vy_prod = df_ttest_vy_prod[df_ttest_vy_prod.min(1)>0]
    df_ttest_vy_pred_sub = df_ttest_vy_pred.loc[df_ttest_vy_prod.index]
    df_correct_preds.loc[gene_perturbed, 'N_correct_motifs'] = len(df_ttest_vy_pred_sub)

    if len(df_ttest_vy_pred_sub)>1:
        g = sns.clustermap(
            df_ttest_vy_pred_sub,
            row_cluster=True,
            col_cluster=False,
            #cmap='PiYG_r',
            cmap='bwr',
            xticklabels=True,
            yticklabels=True,
            center=0,
            #vmax=20,
            #vmin=-20,
            cbar_pos=(0.92, 0.58, 0.05, 0.2),
            figsize=(4, 3),
            rasterized=True
        )
        g.ax_row_dendrogram.set_visible(False)
        g.cax.set_ylabel(f't-test / Arid1aHet {gene_perturbed} KO vs Arid1aHet')
        g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
        g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
        plt.savefig(fig_dir+f'/ttest_motif_velocity_Arid1aHet_{gene_perturbed}_KO_vs_Arid1aHet.pdf', bbox_inches="tight")
        plt.close()

In [None]:
fig, ax = plt.subplots()
ax.scatter(df_correct_preds['N_correct_genes'], df_correct_preds['N_correct_motifs'])

min_distance = 5

# Function to calculate distance
def distance(x1, y1, x2, y2):
    return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)

# Store points to check for overlaps
points = []
texts = []

for idx, row in df_correct_preds.iterrows():
    x_pos = row['N_correct_genes']
    y_pos = row['N_correct_motifs']
    overlap = False

    # Check if current point is too close to any of the already annotated points
    for point in points:
        if distance(x_pos, y_pos, point[0], point[1]) < min_distance:
            overlap = True
            break

    if not overlap:
        # If no overlap, annotate this point and add it to the list
        texts.append(ax.text(x_pos, y_pos, idx, fontsize=9, ha='center', va='bottom'))
        points.append((x_pos, y_pos))

plt.xlabel('Number of rescued RNA velocities')
plt.ylabel('Number of rescued motif velocities')
plt.title('In-silico perturbations in Arid1aHet')
plt.savefig(fig_dir+f'/scatterplot_insilico_perturbations_in_Arid1aHet.pdf', bbox_inches="tight")

#### Fig. 6E

In [None]:
celltypes = ['Centroblast', 'Transitioning_CB_CC', 'Centrocyte', 'Transitioning_Sphase', 'CB_Rec_Sphase', 'CB_S_G2M', 'CB_G2M', 'Prememory']

shared_genes, shared_genes_idx_wt, shared_genes_idx_mut = np.intersect1d(adata_rna_pred.var_names, adata_rna_pred_CtcfHet.var_names, return_indices=True)
df_correct_preds = pd.DataFrame(data=0, index=adata_rna_pred_CtcfHet.uns['Perturbed_genes'], columns=['N_correct_genes', 'N_correct_motifs'])

for gene_perturbed in adata_rna_pred_CtcfHet.uns['Perturbed_genes']:
    print('gene_perturbed: ', gene_perturbed)
    gene_perturbed_idx = np.where(adata_rna_pred_CtcfHet.uns['Perturbed_genes']==gene_perturbed)[0][0]
    df_ttest_vx_expr = pd.DataFrame(index=shared_genes, columns=celltypes)
    df_ttest_vx_pred = pd.DataFrame(index=shared_genes, columns=celltypes)

    for celltype in celltypes:
        
        # ttest for experimental
        vx_wt =  adata_rna_pred[adata_rna_pred.obs['fine.celltype']==celltype].layers['vx_pred_mean']
        vx_wt = vx_wt[:, shared_genes_idx_wt]
        vx_mut =  adata_rna_pred_CtcfHet[adata_rna_pred_CtcfHet.obs['fine.celltype']==celltype].layers['vx_pred_mean']
        vx_mut = vx_mut[:, shared_genes_idx_mut]

        result_experimental = scipy.stats.ttest_ind(vx_wt, vx_mut, 0)
        ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
        pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
        idx_t_keep = np.argsort(pval_experimental)[:100]
        df_ttest_vx_expr.loc[:, celltype] = ttest_experimental

        # ttest for in-silico prediction
        delta_vx = adata_rna_pred_CtcfHet[adata_rna_pred_CtcfHet.obs['fine.celltype']==celltype].obsm['delta_vx']
        delta_vx = delta_vx[:, :, gene_perturbed_idx]
        delta_vx = delta_vx[:, shared_genes_idx_mut]

        result_pred = scipy.stats.ttest_1samp(delta_vx, 0)
        ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
        pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
        df_ttest_vx_pred.loc[:, celltype] = ttest_pred

    df_ttest_vx_prod = df_ttest_vx_expr * df_ttest_vx_pred
    df_ttest_vx_prod = df_ttest_vx_prod[df_ttest_vx_prod.min(1)>0]
    df_ttest_vx_pred_sub = df_ttest_vx_pred.loc[df_ttest_vx_prod.index]
    df_correct_preds.loc[gene_perturbed, 'N_correct_genes'] = len(df_ttest_vx_pred_sub)

    g = sns.clustermap(
        df_ttest_vx_pred_sub,
        row_cluster=True,
        col_cluster=False,
        cmap='bwr',
        xticklabels=True,
        yticklabels=True,
        center=0,
        #vmax=20,
        #vmin=-20,
        cbar_pos=(0.92, 0.58, 0.05, 0.2),
        figsize=(4, 7),
        rasterized=True
    )
    g.ax_row_dendrogram.set_visible(False)
    g.cax.set_ylabel(f't-test / CtcfHet {gene_perturbed} KO vs CtcfHet')
    g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
    plt.savefig(fig_dir+f'/ttest_rna_velocity_CtcfHet_{gene_perturbed}_KO_vs_CtcfHet.pdf', bbox_inches="tight")
    plt.close()

In [None]:
shared_motifs, shared_motifs_idx_wt, shared_motifs_idx_mut = np.intersect1d(adata_atac_pred.var_names, adata_atac_pred_CtcfHet.var_names, return_indices=True)

for gene_perturbed in adata_rna_pred_CtcfHet.uns['Perturbed_genes']:
    print('gene_perturbed: ', gene_perturbed)
    gene_perturbed_idx = np.where(adata_rna_pred_CtcfHet.uns['Perturbed_genes']==gene_perturbed)[0][0]
    df_ttest_vy_expr = pd.DataFrame(index=shared_motifs, columns=celltypes)
    df_ttest_vy_pred = pd.DataFrame(index=shared_motifs, columns=celltypes)

    for celltype in celltypes:

        # ttest for experimental
        vy_wt =  adata_atac_pred[adata_atac_pred.obs['fine.celltype']==celltype].layers['vy_pred_mean']
        vy_wt = vy_wt[:, shared_motifs_idx_wt]
        vy_mut =  adata_atac_pred_CtcfHet[adata_atac_pred_CtcfHet.obs['fine.celltype']==celltype].layers['vy_pred_mean']
        vy_mut = vy_mut[:, shared_motifs_idx_mut]

        result_experimental = scipy.stats.ttest_ind(vy_wt, vy_mut, 0)
        ttest_experimental = np.nan_to_num(result_experimental.statistic, nan=0)
        pval_experimental = np.nan_to_num(result_experimental.pvalue, nan=1)
        idx_t_keep = np.argsort(pval_experimental)[:100]
        df_ttest_vy_expr.loc[:, celltype] = ttest_experimental

        # ttest for in-silico prediction
        delta_vy = adata_rna_pred_CtcfHet[adata_rna_pred_CtcfHet.obs['fine.celltype']==celltype].obsm['delta_vy']
        delta_vy = delta_vy[:, :, gene_perturbed_idx]
        delta_vy = delta_vy[:, shared_motifs_idx_mut]

        result_pred = scipy.stats.ttest_1samp(delta_vy, 0)
        ttest_pred = np.nan_to_num(result_pred.statistic, nan=0)
        pval_pred = np.nan_to_num(result_pred.pvalue, nan=1)
        df_ttest_vy_pred.loc[:, celltype] = ttest_pred

    df_ttest_vy_prod = df_ttest_vy_expr * df_ttest_vy_pred
    df_ttest_vy_prod = df_ttest_vy_prod[df_ttest_vy_prod.min(1)>0]
    df_ttest_vy_pred_sub = df_ttest_vy_pred.loc[df_ttest_vy_prod.index]
    df_correct_preds.loc[gene_perturbed, 'N_correct_motifs'] = len(df_ttest_vy_pred_sub)

    if len(df_ttest_vy_pred_sub)>1:
        g = sns.clustermap(
            df_ttest_vy_pred_sub,
            row_cluster=True,
            col_cluster=False,
            cmap='bwr',
            xticklabels=True,
            yticklabels=True,
            center=0,
            #vmax=20,
            #vmin=-20,
            cbar_pos=(0.92, 0.58, 0.05, 0.2),
            figsize=(4, 3),
            rasterized=True
        )
        g.ax_row_dendrogram.set_visible(False)
        g.cax.set_ylabel(f't-test / CtcfHet {gene_perturbed} KO vs CtcfHet')
        g.tick_params(left=True, labelleft=True, right=False, labelright=False, labelrotation=0)
        g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90)
        plt.savefig(fig_dir+f'/ttest_motif_velocity_CtcfHet_{gene_perturbed}_KO_vs_CtcfHet.pdf', bbox_inches="tight")
        plt.close()

In [None]:
fig, ax = plt.subplots()
ax.scatter(df_correct_preds['N_correct_genes'], df_correct_preds['N_correct_motifs'])

min_distance = 5

# Function to calculate distance
def distance(x1, y1, x2, y2):
    return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)

# Store points to check for overlaps
points = []
texts = []

for idx, row in df_correct_preds.iterrows():
    x_pos = row['N_correct_genes']
    y_pos = row['N_correct_motifs']
    overlap = False

    # Check if current point is too close to any of the already annotated points
    for point in points:
        if distance(x_pos, y_pos, point[0], point[1]) < min_distance:
            overlap = True
            break

    if not overlap:
        # If no overlap, annotate this point and add it to the list
        texts.append(ax.text(x_pos, y_pos, idx, fontsize=9, ha='center', va='bottom'))
        points.append((x_pos, y_pos))

plt.xlabel('Number of rescued RNA velocities')
plt.ylabel('Number of rescued motif velocities')
plt.title('In-silico perturbations in CtcfHet')
plt.savefig(fig_dir+f'/scatterplot_insilico_perturbations_in_CtcfHet.pdf', bbox_inches="tight")

In [None]:
sample_name = 'Arid1aHetM3'

perturbed_gene_list = adata_rna_pred_Arid1aHet.uns['Perturbed_genes']

sc.pp.neighbors(adata_atac_pred_Arid1aHet, n_neighbors=50, use_rep='X')
sc.tl.umap(adata_atac_pred_Arid1aHet, min_dist=1, spread=1, random_state=0, n_components=2)

for idx_gene, perturbed_gene in enumerate(perturbed_gene_list):
    print('perturbed_gene: ', perturbed_gene)

    if not os.path.exists(fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}'):
        os.makedirs(fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}')
    perturbed_path = fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}/'

    # norm of delta velocities
    delta_vzx_norm = np.linalg.norm(adata_rna_pred_Arid1aHet.obsm['delta_vzx'][:,:,idx_gene], axis=1)
    delta_vzy_norm = np.linalg.norm(adata_rna_pred_Arid1aHet.obsm['delta_vzy'][:,:,idx_gene], axis=1)

    delta_vx_norm = np.linalg.norm(adata_rna_pred_Arid1aHet.obsm['delta_vx'][:,:,idx_gene], axis=1)
    delta_vy_norm = np.linalg.norm(adata_rna_pred_Arid1aHet.obsm['delta_vy'][:,:,idx_gene], axis=1)

    # delta_vx_pred
    adata_rna_pred_Arid1aHet.layers['velocity'] = adata_rna_pred_Arid1aHet.obsm['delta_vx'][:,:,idx_gene].copy()
    adata_rna_pred_Arid1aHet.obs['delta_time'] = adata_rna_pred_Arid1aHet.obsm['delta_latent_time'][:,idx_gene].copy()
    adata_rna_pred_Arid1aHet.obs['delta_vx_norm'] = delta_vx_norm

    scv.tl.velocity_graph(adata_rna_pred_Arid1aHet, xkey='Ms', vkey='velocity', gene_subset=adata_rna_pred_Arid1aHet.var_names, sqrt_transform=False)
    scv.pl.velocity_embedding_stream(adata_rna_pred_Arid1aHet, color=['fine.celltype'], color_map='viridis', perc=[2,98], wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, save=perturbed_path+f'delta_vx_over_x_obs_stream_{sample_name}.png')

    for tf in perturbed_gene_list:
        vmax = adata_rna_pred_Arid1aHet[:,tf].layers['velocity'].max()
        scv.pl.velocity_embedding_stream(adata_rna_pred_Arid1aHet, color=tf, layer='velocity', vmax=vmax, vmin=-vmax, color_map='bwr', wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, show=False, save=perturbed_path+f'delta_vx_over_x_obs_stream_delta_{tf}_{sample_name}.png')
    
    # delta_vy_pred
    adata_atac_pred_Arid1aHet.layers['Ms'] = adata_atac_pred_Arid1aHet.X.copy()
    adata_atac_pred_Arid1aHet.var_names = adata_atac_pred_Arid1aHet.var['TF'].values
    adata_atac_pred_Arid1aHet.layers['velocity'] = adata_rna_pred_Arid1aHet.obsm['delta_vy'][:,:,idx_gene].copy()
    adata_atac_pred_Arid1aHet.obs['delta_time'] = adata_rna_pred_Arid1aHet.obsm['delta_latent_time'][:,idx_gene].copy()
    adata_atac_pred_Arid1aHet.obs['delta_vy_norm'] = delta_vy_norm

    scv.tl.velocity_graph(adata_atac_pred_Arid1aHet, xkey='Ms', vkey='velocity', gene_subset=adata_atac_pred_Arid1aHet.var_names, sqrt_transform=False)
    scv.pl.velocity_embedding_stream(adata_atac_pred_Arid1aHet, color=['fine.celltype'], color_map='viridis', perc=[2,98], wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, save=perturbed_path+f'delta_vy_over_y_obs_stream_{sample_name}.png')

    for tf in np.intersect1d(perturbed_gene_list, adata_atac_pred_Arid1aHet.var_names):
        vmax = adata_atac_pred_Arid1aHet[:,tf].layers['velocity'].max()
        scv.pl.velocity_embedding_stream(adata_atac_pred_Arid1aHet, color=tf, layer='velocity', vmax=vmax, vmin=-vmax, color_map='bwr', wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, show=False, save=perturbed_path+f'delta_vy_over_y_obs_stream_delta_{tf}_{sample_name}.png')


In [None]:
sample_name = 'CtcfHet30'

perturbed_gene_list = adata_rna_pred_CtcfHet.uns['Perturbed_genes']

sc.pp.neighbors(adata_atac_pred_CtcfHet, n_neighbors=50, use_rep='X')
sc.tl.umap(adata_atac_pred_CtcfHet, min_dist=1, spread=1, random_state=0, n_components=2)

for idx_gene, perturbed_gene in enumerate(perturbed_gene_list):
    print('perturbed_gene: ', perturbed_gene)

    if not os.path.exists(fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}'):
        os.makedirs(fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}')
    perturbed_path = fig_dir+f'/perturbation/{sample_name}/perturbed_gene_{perturbed_gene}/'

    # norm of delta velocities
    delta_vzx_norm = np.linalg.norm(adata_rna_pred_CtcfHet.obsm['delta_vzx'][:,:,idx_gene], axis=1)
    delta_vzy_norm = np.linalg.norm(adata_rna_pred_CtcfHet.obsm['delta_vzy'][:,:,idx_gene], axis=1)

    delta_vx_norm = np.linalg.norm(adata_rna_pred_CtcfHet.obsm['delta_vx'][:,:,idx_gene], axis=1)
    delta_vy_norm = np.linalg.norm(adata_rna_pred_CtcfHet.obsm['delta_vy'][:,:,idx_gene], axis=1)

    # delta_vx_pred
    adata_rna_pred_CtcfHet.layers['velocity'] = adata_rna_pred_CtcfHet.obsm['delta_vx'][:,:,idx_gene].copy()
    adata_rna_pred_CtcfHet.obs['delta_time'] = adata_rna_pred_CtcfHet.obsm['delta_latent_time'][:,idx_gene].copy()
    adata_rna_pred_CtcfHet.obs['delta_vx_norm'] = delta_vx_norm

    scv.tl.velocity_graph(adata_rna_pred_CtcfHet, xkey='Ms', vkey='velocity', gene_subset=adata_rna_pred_CtcfHet.var_names, sqrt_transform=False)
    scv.pl.velocity_embedding_stream(adata_rna_pred_CtcfHet, color=['fine.celltype'], color_map='viridis', perc=[2,98], wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, save=perturbed_path+f'delta_vx_over_x_obs_stream_{sample_name}.png')

    for tf in perturbed_gene_list:
        vmax = adata_rna_pred_CtcfHet[:,tf].layers['velocity'].max()
        scv.pl.velocity_embedding_stream(adata_rna_pred_CtcfHet, color=tf, layer='velocity', vmax=vmax, vmin=-vmax, color_map='bwr', wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, show=False, save=perturbed_path+f'delta_vx_over_x_obs_stream_delta_{tf}_{sample_name}.png')
    
    # delta_vy_pred
    adata_atac_pred_CtcfHet.layers['Ms'] = adata_atac_pred_CtcfHet.X.copy()
    adata_atac_pred_CtcfHet.var_names = adata_atac_pred_CtcfHet.var['TF'].values
    adata_atac_pred_CtcfHet.layers['velocity'] = adata_rna_pred_CtcfHet.obsm['delta_vy'][:,:,idx_gene].copy()
    adata_atac_pred_CtcfHet.obs['delta_time'] = adata_rna_pred_CtcfHet.obsm['delta_latent_time'][:,idx_gene].copy()
    adata_atac_pred_CtcfHet.obs['delta_vy_norm'] = delta_vy_norm

    scv.tl.velocity_graph(adata_atac_pred_CtcfHet, xkey='Ms', vkey='velocity', gene_subset=adata_atac_pred_CtcfHet.var_names, sqrt_transform=False)
    scv.pl.velocity_embedding_stream(adata_atac_pred_CtcfHet, color=['fine.celltype'], color_map='viridis', perc=[2,98], wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, save=perturbed_path+f'delta_vy_over_y_obs_stream_{sample_name}.png')

    for tf in np.intersect1d(perturbed_gene_list, adata_atac_pred_CtcfHet.var_names):
        vmax = adata_atac_pred_CtcfHet[:,tf].layers['velocity'].max()
        scv.pl.velocity_embedding_stream(adata_atac_pred_CtcfHet, color=tf, layer='velocity', vmax=vmax, vmin=-vmax, color_map='bwr', wspace=.4, legend_loc='right margin', legend_fontsize=10, basis='umap', dpi=300, show=False, save=perturbed_path+f'delta_vy_over_y_obs_stream_delta_{tf}_{sample_name}.png')
