In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import sys
import os

# Add the parent directory to the system path
parent_directory = os.path.abspath(os.path.join(os.getcwd(), '../../'))
sys.path.append(parent_directory)

from plotting import plot_velocity_expression_gp, plot_phase_plane_gp, plot_activation_or_velocity_two_gene_programs, gp_phase_plane_no_velocity


In [None]:
pairs_pancreas = [
    ("LUMINAL_EPITHELIAL_CELLS", "DUCTAL_CELLS"),
    ("METABOLISM_OF_CARBOHYDRATES", "GLUCOSE_METABOLISM"),
    ("METABOLISM_OF_LIPIDS_AND_LIPOP", "METABOLISM_OF_CARBOHYDRATES"),
    ("DELTA_CELLS", "BETA_CELLS")
]

pairs_forebrain = [
    ("EMBRYONIC_STEM_CELLS", "NEURAL_STEM-PRECURSOR_CELLS"),
]

pairs_gastrulation_erythroid = [
    ("ERYTHROID-LIKE_AND_ERYTHROID_P", "IRON_UPTAKE_AND_TRANSPORT"),
]

pairs_dentategyrus_lamanno_P5 = [
    ("NEURONS", "OLIGODENDROCYTE_PROGENITOR_CEL"),
    ("ASTROCYTES", "OLIGODENDROCYTE_PROGENITOR_CEL"),
    ("NEURONS", "NEURONAL_SYSTEM"),
    ("NEURONAL_SYSTEM", "PYRAMIDAL_CELLS")
]

pairs_dic = {
    "forebrain" : pairs_forebrain,
    "pancreas" : pairs_pancreas,
    "gastrulation_erythroid" : pairs_gastrulation_erythroid,
    "dentategyrus_lamanno_P5" : pairs_dentategyrus_lamanno_P5
}
datasets = ["forebrain", "pancreas", "gastrulation_erythroid", "dentategyrus_lamanno_P5"]
cell_type_keys = ["Clusters", "clusters", "celltype", "clusters"]
model_names = ["ivelo", "ivelo_filtered"]

In [None]:
for model_name in model_names:
    print(f"processing model: {model_name}")
    for dataset, cell_type_key in zip(datasets, cell_type_keys):
        os.makedirs(f"plots/gp_phase_plane_no_velo/{dataset}/{model_name}/", exist_ok=True)
        print(f"processing dataset: {dataset}")
        adata = sc.read_h5ad(f"../../benchmark/{model_name}/{dataset}/adata_gp.h5ad")
        if model_name in ["ivelo", "ivelo_filtered"]:
            adata_colors = sc.read_h5ad(f"../../benchmark/{model_name}/{dataset}/{model_name}_{dataset}.h5ad")
            adata.uns[f"{cell_type_key}_colors"] = adata_colors.uns[f"{cell_type_key}_colors"].copy()
            del adata_colors

        l = pairs_dic[dataset]

        for gp1, gp2 in l:
            print(f"gp1: {gp1}")
            print(f"gp2: {gp2}")

            flag1 = gp1 in list(adata.var_names)
            flag2 = gp2 in list(adata.var_names)
            if (not flag1) or (not flag2):
                print(f"gp1 present in dataset {dataset} model {model_name}: {flag1}")
                print(f"gp2 present in dataset {dataset} model {model_name}: {flag2}")
                print(f"skipping")
                continue

            """os.makedirs(f"plots/gp_phase_plane_no_velo/{dataset}/{model_name}/", exist_ok=True)
            gp_phase_plane_no_velocity(adata, gp1, gp2, u_scale=.1, s_scale=.1, 
                cell_type_key=cell_type_key, dataset=dataset, 
                K=11, save_path= f"plots/gp_phase_plane_no_velo/{dataset}/{model_name}/{gp1}_{gp2}.png", 
                save_plot=True, scale_expression=1)"""
            
            os.makedirs(f"plots/gp_phase_plane/{dataset}/{model_name}/", exist_ok=True)
            plot_phase_plane_gp(adata, gp1, gp2, u_scale=1, s_scale=1, 
                cell_type_key=cell_type_key, dataset=dataset, 
                K=11, save_path= f"plots/gp_phase_plane/{dataset}/{model_name}/{gp1}_{gp2}.png", 
                save_plot=True, scale_expression=1)



In [None]:
dataset = "pancreas"
cell_type_key = "clusters"
model_name = "ivelo_filtered"
plot_type = ""
gp1, gp2 = pairs_dic[dataset][2]

adata = sc.read_h5ad(f"../../benchmark/{model_name}/{dataset}/adata_gp.h5ad")

In [None]:
gp1, gp2

In [None]:
os.makedirs(f"plots/velo_activation/{dataset}/{model_name}/", exist_ok=True)
if model_name in ["ivelo", "ivelo_filtered"]:
    adata_colors = sc.read_h5ad(f"../../benchmark/{model_name}/{dataset}/{model_name}_{dataset}.h5ad")
    adata.uns[f"{cell_type_key}_colors"] = adata_colors.uns[f"{cell_type_key}_colors"].copy()
    del adata_colors

gp_used = gp1
for use_cell_type_colors in [True, False]:
    plot_velocity_expression_gp(adata, 
                            scale_velocity=20, 
                            shift_expression=3, 
                            gene_name=gp_used, 
                            plot_type=plot_type,  
                            use_cell_type_colors=use_cell_type_colors, 
                            cell_type_key=cell_type_key,  
                            save_path=f"plots/velo_activation/{dataset}/{model_name}/{gp_used}_colors_{use_cell_type_colors}.png", 
                            save_plot=True, 
                            plot_shift=True, 
                            reverse_pseudotime=False) 

In [None]:
os.makedirs(f"plots/two_gp_vs_time/{dataset}/{model_name}/", exist_ok=True)

plot_type = "activation"
for use_cell_type_colors in [True, False]:
    plot_activation_or_velocity_two_gene_programs(adata,
                            gene_name1=gp1,
                            gene_name2=gp2,
                            plot_type=plot_type,
                            use_cell_type_colors=use_cell_type_colors,
                            scale_expr1=1, 
                            shift_expr1=0, 
                            scale_expr2=1, 
                            shift_expr2=0,
                            #legend_loc="upper right",
                            cell_type_key=cell_type_key, 
                            save_path=f"plots/two_gp_vs_time/{dataset}/{model_name}/{gp1}_{gp2}_{plot_type}_colors_{use_cell_type_colors}.png", 
                            save_plot=True,
                            reverse_pseudotime=False)

In [None]:
def print_top_N_gps_per_cell_type(adata, N):
    # Get the cell type labels
    cell_types = adata.obs[cell_type_key]

    # Get gene program names from adata.uns["terms"]
    gene_program_names = np.array(adata.uns["terms"])

    # For each cell type, compute the mean activation of each gene program
    unique_cell_types = cell_types.unique()

    for cell_type in unique_cell_types:
        # Subset data for the current cell type
        cell_type_mask = (cell_types == cell_type)
        z_cell_type = adata.X[cell_type_mask, :]
        
        # Calculate the mean activation across cells for each gene program
        mean_activation = np.mean(z_cell_type, axis=0)
        
        # Find top 5 most activated (positive) and inactivated (negative) gene programs
        top_5_activated_idx = np.argsort(mean_activation)[-N:][::-1]  # Top 5 positive
        top_5_inactivated_idx = np.argsort(mean_activation)[:N]  # Top 5 negative
        
        # Print results for this cell type
        print(f"Top 5 activated gene programs for cell type: {cell_type} in dataset: {dataset}")
        for idx in top_5_activated_idx:
            print(f"{gene_program_names[idx]}: {mean_activation[idx]}")
        
        print(f"Top 5 inactivated gene programs for cell type: {cell_type} in dataset: {dataset}")
        for idx in top_5_inactivated_idx:
            print(f"{gene_program_names[idx]}: {mean_activation[idx]}")
        
        print("-" * 50)
print_top_N_gps_per_cell_type(adata, 10)

In [None]:
adata

In [None]:
def print_top_N_gps_overall(adata, N):
    # Get gene program names from adata.uns["terms"]
    gene_program_names = np.array(adata.uns["terms"])
    
    # Calculate the mean activation across all cells for each gene program
    mean_activation = np.mean(adata.X, axis=0)
    
    # Find top N most activated (positive) and inactivated (negative) gene programs
    top_N_activated_idx = np.argsort(mean_activation)[-N:][::-1]  # Top N positive
    top_N_inactivated_idx = np.argsort(mean_activation)[:N]  # Top N negative
    
    # Print top N activated gene programs
    print(f"Top {N} activated gene programs in the entire dataset:")
    for idx in top_N_activated_idx:
        print(f"{gene_program_names[idx]}: {mean_activation[idx]}")
    
    # Print top N inactivated gene programs
    print(f"Top {N} inactivated gene programs in the entire dataset:")
    for idx in top_N_inactivated_idx:
        print(f"{gene_program_names[idx]}: {mean_activation[idx]}")
    
    print("-" * 50)

print_top_N_gps_overall(adata, 10)