In [None]:
import pandas as pd
import scanpy as sc
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
#%matplotlib inline
import seaborn as sns
import os, sys, shutil, importlib, glob
from tqdm import tqdm_notebook as tqdm
#%config InlineBackend.figure_format = 'retina'
plt.rcParams['figure.figsize'] = (15,7)
plt.rcParams["savefig.dpi"] = 600
from celloracle import motif_analysis as ma
from celloracle.utility import save_as_pickled_object
import celloracle as co
import subprocess
import matplotlib

In [None]:
# 01 Sketched cells for Network Analysis  https://github.com/brianhie/geosketch
adata=sc.read("/path/to/trophoblast.h5ad")  ## load the subset of trophoblast h5ad
from geosketch import gs
N = 5000 
X_dimred=adata.obsm['X_pca']
sketch_index = gs(X_dimred, N, replace=False)
X_sketch = X_dimred[sketch_index]
adata=adata[adata.obs_names[sketch_index],:]

In [None]:
#02.TSS annotation  refer to https://github.com/morris-lab/CellOracle/blob/master/docs/notebooks/01_ATAC-seq_data_processing/option1_scATAC-seq_data_analysis_with_cicero/02_preprocess_peak_data.ipynb
wdir="/path/to/working/directory/"
os.chdir(wdir)
peaks = pd.read_csv("/path/to/cicero/trophoblast_all_peaks.csv", index_col=0) ## load all snATAC peaks called within trophoblasts
cicero_connections = pd.read_csv("/path/to/cicero_conns.csv", index_col=0) ##

peaks=peaks.iloc[:,0].str.replace("-","_").values
ref_genome='hg38'
tss_annotated = ma.get_tss_info(peak_str_list=peaks, ref_genome=ref_genome)
integrated = ma.integrate_tss_peak_with_cicero(tss_peak=tss_annotated,cicero_connections=cicero_connections)
peak = integrated[integrated.coaccess >= 0.8]
peak = peak[["peak_id", "gene_short_name"]].reset_index(drop=True)
peak.to_csv("/path/to/processed_peak_file.csv")

In [None]:
#02 motif_scan
def decompose_chrstr(peak_str):
    *chr_, start, end = peak_str.split("_")
    chr_ = "_".join(chr_)
    return chr_, start, end

from genomepy import Genome

def check_peak_format(peaks_df, ref_genome):
    df = peaks_df.copy()

    n_peaks_before = df.shape[0]

    # Decompose peaks and make df
    decomposed = [decompose_chrstr(peak_str) for peak_str in df["peak_id"]]
    df_decomposed = pd.DataFrame(np.array(decomposed))
    df_decomposed.columns = ["chr", "start", "end"]
    df_decomposed["start"] = df_decomposed["start"].astype(np.int)
    df_decomposed["end"] = df_decomposed["end"].astype(np.int)

    # Load genome data
    genome_data = Genome(ref_genome)
    all_chr_list = list(genome_data.keys())


    # DNA length check
    lengths = np.abs(df_decomposed["end"] - df_decomposed["start"])


    # Filter peaks with invalid chromosome name
    n_threshold = 5
    df = df[(lengths >= n_threshold) & df_decomposed.chr.isin(all_chr_list)]

    # DNA length check
    lengths = np.abs(df_decomposed["end"] - df_decomposed["start"])

    # Data counting
    n_invalid_length = len(lengths[lengths < n_threshold])
    n_peaks_invalid_chr = n_peaks_before - df_decomposed.chr.isin(all_chr_list).sum()
    n_peaks_after = df.shape[0]

    #
    print("Peaks before filtering: ", n_peaks_before)
    print("Peaks with invalid chr_name: ", n_peaks_invalid_chr)
    print("Peaks with invalid length: ", n_invalid_length)
    print("Peaks after filtering: ", n_peaks_after)

    return df


ref_genome = "hg38"
genome_installation = ma.is_genome_installed(ref_genome=ref_genome)
print(ref_genome, "installation: ", genome_installation)
if not genome_installation:
    import genomepy
    genomepy.install_genome(ref_genome, "UCSC")
else:
    print(ref_genome, "is installed.")

peaks = check_peak_format(peaks, ref_genome)
#load motifs
from gimmemotifs.motif import default_motifs
motifs =  default_motifs()

tfi = ma.TFinfo(peak_data_frame=peaks, ref_genome=ref_genome)

tfi.scan(fpr=0.02,motifs=motifs,verbose=True) #long step
tfi.to_hdf5(file_path="celloracle.tfinfo")
tfi.reset_filtering()
tfi.filter_motifs_by_score(threshold=10)
# Do post filtering process. Convert results into several file format.
tfi.make_TFinfo_dataframe_and_dictionary(verbose=True)
df=tfi.to_dataframe()
df.to_parquet("base_GRN_dataframe.parquet")


In [None]:
## 03  Make celloracle object
oracle = co.Oracle()
oracle.import_anndata_as_normalized_count(adata=adata,cluster_column_name="minor_class",
                                          embedding_name="X_umap")
oracle.import_TF_data(TF_info_matrix=df)

TG_to_TF_dictionary = tfi.to_dictionary(dictionary_type="targetgene2TFs")
oracle.addTFinfo_dictionary(TG_to_TF_dictionary)

# Perform PCA
oracle.perform_PCA()

# Select important PCs
plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
plt.axvline(n_comps, c="k")
plt.show()
print(n_comps)
n_comps = min(n_comps, 50)
n_cell = oracle.adata.shape[0]
k = int(0.025*n_cell)
oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,b_maxl=k*4, n_jobs=4)
oracle.to_hdf5("tp.celloracle.oracle")

In [None]:
##04 GRN inference
links = oracle.get_links(cluster_name_for_GRN_unit="minor_class", alpha=10,
                         verbose_level=10)
links.to_hdf5("links.celloracle.links")
links.filter_links(p=0.05, weight="coef_abs", threshold_number=10000)
# Calculate network scores.
links.get_network_score()
# Save Links object.
links.to_hdf5(file_path="filtered_links.celloracle.links")
## export the network for cytoscape visualization
for i,j in links.links_dict.items():
    j.to_csv(os.path.join("/path/to/output/",i+'.tsv'),sep="\t",index=True)
