In [None]:
### Import Libraries.

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import anndata
from scipy import sparse
from sklearn.decomposition import TruncatedSVD
import celloracle as co

plt.rcParams['figure.figsize'] = [6, 4.5]
plt.rcParams["savefig.dpi"] = 800

In [None]:
### Set Dorectories.

results_dir = "/Results_Folder/"
os.makedirs(results_dir, exist_ok = True)
data_dir = "/Data_Folder/"

save_folder = os.path.join(results_dir, "Figures")
os.makedirs(save_folder, exist_ok = True)

In [None]:
### Load Data.

adata_path = os.path.join(data_dir, "adata.h5ad")
adata = anndata.read_h5ad(adata_path)
print(f"Cell number: {adata.shape[0]}, Gene number: {adata.shape[1]}")
print("Metadata columns:", list(adata.obs.columns))
print("Dimensional reduction:", list(adata.obsm.keys()))

adata.X = adata.layers["counts_RNA"].copy()

In [None]:
### Load Base GRN and Initialize Oracle.

base_GRN = co.data.load_human_promoter_base_GRN()
oracle = co.Oracle()

In [None]:
### Import Data into CellOracle.

oracle.import_anndata_as_raw_count(
    adata=adata,
    cluster_column_name = "Cluster_Column",
    embedding_name = "X_umap"
)
oracle.import_TF_data(TF_info_matrix = base_GRN)

In [None]:
### Filter Genes for GRN.

tf_genes = set(oracle.all_target_genes_in_TFdict).union(set(oracle.all_regulatory_genes_in_TFdict))
common_genes = tf_genes.intersection(set(adata.var_names))
print(f"Number of common genes: {len(common_genes)}")

adata = adata[:, sorted(common_genes)].copy()
adata.X = adata.layers["counts_RNA"].copy()

In [None]:
### Re-import Filtered Data.

oracle.import_anndata_as_raw_count(
    adata=adata,
    cluster_column_name = "Cluster_Column",
    embedding_name = "X_umap"
)
oracle.import_TF_data(TF_info_matrix = base_GRN)

In [None]:
### PCA for Large Datasets.

X = adata.X.toarray() if sparse.issparse(adata.X) else adata.X
svd = TruncatedSVD(n_components = 100, random_state = 0)
pcs = svd.fit_transform(X)

oracle.pca = svd
oracle.pcs = pcs

In [None]:
### Determine Number of PCs.

plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
n_comps = min(np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))) > 0.002)[0][0], 50)
plt.axvline(n_comps, color = "k")
plt.show()

In [None]:
### KNN Imputation.

n_cell = adata.shape[0]
k = int(0.025 * n_cell)
oracle.knn_imputation(
    n_pca_dims = n_comps,
    k = k,
    balanced = True,
    b_sight = k*8,
    b_maxl = k*4,
    n_jobs = 45
)

In [None]:
# Save Oracle object
oracle.to_hdf5(os.path.join(results_dir, "adata.celloracle.oracle"))

In [None]:
### GRN Construction and Analysis.

links = oracle.get_links(cluster_name_for_GRN_unit = "Cluster_Column", alpha = 10, verbose_level = 10)

In [None]:
# Filter Links and Calculate Network Metrics.

links.filter_links(p = 0.001, weight = "coef_abs", threshold_number = 2000)
links.plot_degree_distributions(plot_model = True)
links.get_network_score()
links.to_hdf5(os.path.join(results_dir, "adata.celloracle.links"))

In [None]:
### Plots.

clusters_to_compare = ["Cluster_1", "Cluster_2"]

links.plot_scores_as_rank(cluster = "Cluster_1", n_gene = 30, save = save_folder)
links.plot_score_comparison_2D("eigenvector_centrality", *clusters_to_compare, percentile = 98, save = save_folder)
links.plot_score_comparison_2D("betweenness_centrality", *clusters_to_compare, percentile = 98, save = save_folder)
links.plot_score_comparison_2D("degree_centrality_all", *clusters_to_compare, percentile = 98, save = save_folder)

In [None]:
### Save Cluster-specific Filtered Links.

for cluster_name, df in links.filtered_links.items():
    file_name = os.path.join(results_dir, f"Filtered_CellOracle_{cluster_name}.xlsx")
    df.to_excel(file_name, index=False)
    print(f"Saved {file_name}")

In [None]:
### In-silico Pertubation.

goi = "TF_1"
sc.pl.draw_graph(oracle.adata, color = [goi, oracle.cluster_column_name], layer = "imputed_count", use_raw = False, cmap = "viridis")

# Set gene expression to 0 for KO simulation
oracle.simulate_shift(perturb_condition = {goi: 0}, n_propagation = 3)
oracle.estimate_transition_prob(n_neighbors = 200, knn_random = True, sampled_fraction = 1)
oracle.calculate_embedding_shift(sigma_corr = 0.05)

In [None]:
### Vector Field Visualization.

fig, axes = plt.subplots(1, 2, figsize = [13, 6])
oracle.plot_quiver(scale = 25, ax = axes[0])
axes[0].set_title(f"Simulated cell identity shift vector: {goi} KO")
oracle.plot_quiver_random(scale = 25, ax = axes[1])
axes[1].set_title("Randomized simulation vector")
plt.show()

In [None]:
### Grid-based Simulation & Mass Filtering

oracle.calculate_p_mass(smooth = 0.8, n_grid = 50, n_neighbors = 200)
oracle.suggest_mass_thresholds(n_suggestion = 12)
oracle.calculate_mass_filter(min_mass = 2.3, plot = True)

In [None]:
### Cluster-based Simulation Visualization.

fig, ax = plt.subplots(figsize = [8, 8])
oracle.plot_cluster_whole(ax = ax, s = 10)
oracle.plot_simulation_flow_on_grid(scale = 10, ax = ax, show_background = False)

In [None]:
### Markov Chain Simulation & Value Distribution

oracle.run_markov_chain_simulation()
oracle.evaluate_and_plot_simulation_value_distribution(n_genes = 50)