# 0. Import libraries

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns

In [None]:
import celloracle as co
co.__version__

In [None]:
# visualization settings
%config InlineBackend.figure_format = 'retina'
%matplotlib inline

plt.rcParams['figure.figsize'] = [6, 4.5]
plt.rcParams["savefig.dpi"] = 300

# 1. Load data

In [None]:
# Read preprocessed AnnData object
adata = sc.read_h5ad('/home/jolien/Notebooks/data/preprocessed_data_day14.h5ad')


### Action: Configure folder names, depending on selection of the data

In [None]:
# Create folder for specific GRN inference - for each trial a new folder should be defined/called
current_trial = "day14_analysis"
os.makedirs(current_trial, exist_ok=True)

# Create folder for data
save_folder_data = current_trial+"/data"
os.makedirs(save_folder_data, exist_ok=True)

# Create folder for figures
save_folder_figures = current_trial+"/figures"
os.makedirs(save_folder_figures, exist_ok=True)

In [None]:
print(f"Cell number is :{adata.shape[0]}")
print(f"Gene number is :{adata.shape[1]}")

In [None]:
# Random downsampling into 30K cells if the anndata object include more than 30 K cells.
n_cells_downsample = 30000
if adata.shape[0] > n_cells_downsample:
    # Let's dowmsample into 30K cells
    sc.pp.subsample(adata, n_obs=n_cells_downsample, random_state=123)

In [None]:
print(f"Cell number is :{adata.shape[0]}")

In [None]:
# Load TF info which was made from mouse cell atlas dataset.
base_GRN = co.data.load_human_promoter_base_GRN()

# Check data
base_GRN.head()

# 2. Make Oracle object

In [None]:
# Instantiate Oracle object
oracle = co.Oracle()

In [None]:
# Check data in anndata
print("Metadata columns :", list(adata.obs.columns))
print("Dimensional reduction: ", list(adata.obsm.keys()))

In [None]:
# In this notebook, we use the unscaled mRNA count for the input of Oracle object.
adata.X = adata.layers["raw_count"].copy()

# Instantiate Oracle object.
oracle.import_anndata_as_raw_count(adata=adata,
                                   cluster_column_name="sample_type",
                                   embedding_name="X_draw_graph_fa")

In [None]:
# You can load TF info dataframe with the following code.
oracle.import_TF_data(TF_info_matrix=base_GRN)

# Alternatively, if you saved the informmation as a dictionary, you can use the code below.
# oracle.import_TF_data(TFdict=TFinfo_dictionary)

# 3. KNN imputation

In [None]:
# Perform PCA
oracle.perform_PCA()

# Select important PCs
plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0] # retrieves the point/the amount of components where the second derivative (=rate of change of slope) of the cumulative explained variance becomes smaller than the threshold 
plt.axvline(n_comps, c="k")
plt.show()
print(n_comps)
n_comps = min(n_comps, 50)

In [None]:
n_cell = oracle.adata.shape[0]
print(f"cell number is :{n_cell}")

In [None]:
k = int(0.025*n_cell)
print(f"Auto-selected k is :{k}")

In [None]:
oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,
                      b_maxl=k*4, n_jobs=4)

# 4. Save oracle object

In [None]:
# Save oracle object.
oracle.to_hdf5(os.path.join(save_folder_data,"Persister_cells_day14.celloracle.oracle"))

In [None]:
# Load file.
# oracle = co.load_hdf5(os.path.join(save_folder_data,"Persister_cells_day14.celloracle.oracle"))

# 5. GRN calculation

In [None]:
# Check clustering data
sc.pl.draw_graph(oracle.adata, color="sample_type")

In [None]:
# Calculate GRN for each population in "sample_type" clustering unit.
# This step may take some time.(~30 minutes)
links = oracle.get_links(cluster_name_for_GRN_unit="sample_type", alpha=10,
                         verbose_level=10)

Get specific GRN

In [None]:
# Get clusters
links.links_dict.keys()

In [None]:
# Get GRN for a specific cluster
links.links_dict["Cycling"] # replace the string with one of the keys from the previous output

In [None]:
## Export file

# Set cluster name
cluster = "Cycling"

# Save as csv
#links.links_dict[cluster].to_csv(f"raw_GRN_for_{cluster}.csv")

In [None]:
# Show the contents of pallete - this stores color information which is used when visualizing the clusters. Here we can change both the cluster colors and order.
links.palette

# # Change the order of pallete
# order = ['Cycling','Moderate_cyclers','Non-cycling']
# links.palette = links.palette.loc[order]
# links.palette

In [None]:
# Save Links object.
links.to_hdf5(file_path=os.path.join(save_folder_data,"day14_links.celloracle.links"))

In [None]:
# Check statistics of interactions - how many edges are significant
for cluster in links.cluster:
    print(cluster)
    # check total amount of edges
    nr_edges_total = len(links.links_dict[cluster])
    print('nr edges in unpreprocessed grns cycling', nr_edges_total)
    # check signigicant edges (when p<0.001)
    mask_pvalue_edges = links.links_dict[cluster]['p']< 0.001                   # mask for significant edges
    nr_edges_significant = len(links.links_dict[cluster][mask_pvalue_edges])    # select siginificant edges and determine the amount
    # print('nr edges in unpreprocessed grns cycling', nr_edges_significant)
    percentage_significant_edges = nr_edges_significant/nr_edges_total*100      # calculate percentage significant edges
    print("{:.1f}% of the edges is significant".format(percentage_significant_edges))

# 6. Network preprocessing

We need to remove the weak edges or insignificant edges before doing network structure analysis.

1. Remove uncertain network edges based on the p-value.
2. Remove weak network edge. In this tutorial, we keep the top 2000 edges ranked by edge strength.

In [None]:
# Filter network edges
links.filter_links(p=0.001, weight="coef_abs", threshold_number=2000)

In [None]:
# Examine the network degree distribution
plt.rcParams["figure.figsize"] = [9, 4.5]
links.plot_degree_distributions(plot_model=True,
                                               save=f"{save_folder_figures}/degree_distribution/",
                                               )

In [None]:
# Calculate network scores.
links.get_network_score()

# The score is stored as a attribute merged_score.
links.merged_score.head()

In [None]:
# Save processed GRNs (Links object).
links.to_hdf5(file_path= os.path.join(save_folder_data, "day14_preprocessed_links.celloracle.links"))

In [None]:
# You can load files with the following command.
links = co.load_hdf5(file_path= os.path.join(save_folder_data, "day14_preprocessed_links.celloracle.links"))

In [None]:
## Export file

# Save as csv
for cluster in links.cluster:
    links.filtered_links[cluster].to_csv(os.path.join(save_folder_data, f"processed_GRN_for_{cluster}.csv"))  


# 7. Network analysis; Network score for each gene

In [None]:
# Check minimum and maximum coef_abs per cluster
for cluster in links.cluster:
    print(cluster)
    print('min coef_abs', np.min(links.filtered_links[cluster]['coef_abs']))
    print('max coef_abs', np.max(links.filtered_links[cluster]['coef_abs']))

In [None]:
# fig, axes = plt.subplots(1,3, figsize=(20,5), sharey=True)

# axes[0].bar(links.filtered_links['Cycling']['source'], links.filtered_links['Cycling']['coef_abs']);
# axes[0].tick_params(rotation=90);

# axes[1].bar(links.filtered_links['Moderate_cyclers']['source'], links.filtered_links['Moderate_cyclers']['coef_abs']);
# axes[1].tick_params(rotation=90);

# axes[2].bar(links.filtered_links['Non-cycling']['source'], links.filtered_links['Non-cycling']['coef_abs']);
# axes[2].tick_params(rotation=90);

In [None]:
# Visualize top n-th genes with high scores.
links.plot_scores_as_rank(cluster="Cycling", n_gene=30)#, save=f"{save_folder_figures}/ranked_score")

In [None]:
# Compare GRN score between two clusters
links.plot_score_comparison_2D(value="eigenvector_centrality",
                               cluster1="Cycling", cluster2="Non-cycling",
                               percentile=98,
                               save=f"{save_folder_figures}/score_comparison")

In [None]:
# Compare GRN score between two clusters
links.plot_score_comparison_2D(value="eigenvector_centrality",
                               cluster1="Cycling", cluster2="Moderate_cyclers",
                               percentile=98,
                               save=f"{save_folder_figures}/score_comparison")

In [None]:
# Compare GRN score between two clusters
links.plot_score_comparison_2D(value="eigenvector_centrality",
                               cluster1="Non-cycling", cluster2="Moderate_cyclers",
                               percentile=98,
                               save=f"{save_folder_figures}/score_comparison")

In [None]:
# Visualize gene network score dynamics
links.plot_score_per_cluster(goi="E2F1", save=f"{save_folder_figures}/network_score_per_gene/")

In [None]:
# Visualize gene network score dynamics
links.plot_score_per_cluster(goi="ATF3", save=f"{save_folder_figures}/network_score_per_gene/")

In [None]:
# Check the filtered network edge
cluster_name = "Cycling"
filtered_links_df = links.filtered_links[cluster_name]
filtered_links_df.head()

# 8. Network analysis; network score distribution

In [None]:
# Visualize the network score distributions to get insight into the global network trends
plt.rcParams["figure.figsize"] = [6, 4.5]

In [None]:
# Plot degree_centrality
plt.subplots_adjust(left=0.15, bottom=0.3)
plt.ylim([0,0.040])
links.plot_score_discributions(values=["degree_centrality_all"],
                               method="boxplot",
                               save=f"{save_folder_figures}",
                              )

In [None]:
# Plot eigenvector_centrality
plt.subplots_adjust(left=0.15, bottom=0.3)
plt.ylim([0, 0.28])
links.plot_score_discributions(values=["eigenvector_centrality"],
                               method="boxplot",
                               save=f"{save_folder_figures}")


In [None]:
plt.subplots_adjust(left=0.15, bottom=0.3)
links.plot_network_entropy_distributions(save=f"{save_folder_figures}")