# Import libraries and data

In [5]:
import numpy as np
np.int = np.int32
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN
from tmap.tda import mapper, Filter
from tmap.tda.cover import Cover
from tmap.tda.metric import Metric
from tmap.tda.utils import optimize_dbscan_eps

from scipy.spatial.distance import squareform,pdist
import pandas as pd

import networkx as nx

In [7]:
from pathlib import Path

code_dir=Path.cwd()
project_dir=code_dir.parent
input_dir=project_dir/"input"
output_dir=project_dir/"output/tda_sensitivity_2/"
tmp_dir=project_dir/"tmp"

output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
metadata = pd.read_csv(input_dir/"data/metadata_df.csv", index_col=0)
oral_microbiome_genus = pd.read_csv(input_dir/"data/microbiome_genus.csv", index_col=0)

In [13]:
# load taxa abundance data, sample metadata and precomputed distance matrix
X = oral_microbiome_genus
metadata = metadata.loc[metadata.index.isin(X.index)][metadata_variables]
X = X.loc[X.index.isin(metadata.index)]

In [14]:
metadata_categories = [col.split("_")[0] for col in metadata.columns.tolist()]
microbiome_categories = ["genus"] * len(X.columns.to_list())

# Mapper

In [15]:
def transform2node_data(graph, data, mode='mean'):
    map_fun = {'sum': np.sum,
               "mean": np.nanmean}
    if mode not in ["sum", "mean"]:
        raise SyntaxError('Wrong provided parameters.')
    else:
        aggregated_fun = map_fun[mode]

    nodes = graph.nodes
    dv = data.values
    if data is not None:
        node_data = {nid: aggregated_fun(dv[attr['sample'], :], 0)
                     for nid, attr in nodes.items()}
        node_data = pd.DataFrame.from_dict(node_data,
                                           orient='index',
                                           columns=data.columns)
        return node_data

In [None]:
from safepy import safe

sensitivity_parameter_dict = {
    "cover_overlap" : [0.75,0.5,0.99],
    "cover_resolution" : [10,20,40,50],
    "mapper_eps_threshold" : [99,90,85],
    "mapper_lens": [Filter.MDS, Filter.UMAP],
    "safe_distance_thresh": [0.5,0.99],
    "safe_neighborhood_radius":[0.05,0.15],
}

sensitivity_result_dict_er = {}
sensitivity_result_dict_cluster = {}

for string,sensitivity_parameter_list in sensitivity_parameter_dict.items():
    for sensitivity_parameter in sensitivity_parameter_list:
        if string == "cover_resolution":
            resolution = sensitivity_parameter
            overlap = 0.75
            eps_threshold = 95
            mapper_lens = Filter.PCOA
            mapper_distance_metric = "braycurtis"
            safe_distance_thresh = 0.75
            safe_neighborhood_radius = 0.1

        elif string == "cover_overlap":
            resolution = 30
            overlap = sensitivity_parameter
            eps_threshold = 95
            mapper_lens = Filter.PCOA
            mapper_distance_metric = "braycurtis"
            safe_distance_thresh = 0.75
            safe_neighborhood_radius = 0.1
        
        elif string == "mapper_eps_threshold":
            resolution = 30
            overlap = 0.75
            eps_threshold = sensitivity_parameter
            mapper_lens = Filter.PCOA
            mapper_distance_metric = "braycurtis"
            safe_distance_thresh = 0.75
            safe_neighborhood_radius = 0.1

        elif string == "mapper_lens":
            resolution = 30
            overlap = 0.75
            eps_threshold = 95
            mapper_lens = sensitivity_parameter
            mapper_distance_metric = "braycurtis"
            safe_distance_thresh = 0.75
            safe_neighborhood_radius = 0.1

        elif string == "safe_distance_thresh":
            resolution = 30
            overlap = 0.75
            eps_threshold = 95
            mapper_lens = Filter.PCOA
            mapper_distance_metric = "braycurtis"
            safe_distance_thresh = sensitivity_parameter
            safe_neighborhood_radius = 0.1

        elif string == "safe_neighborhood_radius":
            resolution = 30
            overlap = 0.75
            eps_threshold = 95
            mapper_lens = Filter.PCOA
            mapper_distance_metric = "braycurtis"
            safe_distance_thresh = 0.75
            safe_neighborhood_radius = sensitivity_parameter

        ################
        #Mapper
        ################

        # TDA Step1. initiate a Mapper
        tm = mapper.Mapper(verbose=1)

        # TDA Step2. Projection
        dm = squareform(pdist(X,metric=mapper_distance_metric))
        metric = Metric(metric="precomputed")
        lens = [mapper_lens(components=[0, 1], metric=metric, random_state=100)]
        projected_X = tm.filter(dm, lens=lens)

        # Step4. Covering, clustering & mapping
        eps = optimize_dbscan_eps(X, threshold=eps_threshold)
        clusterer = DBSCAN(eps=eps, min_samples=3)
        cover = Cover(projected_data=MinMaxScaler().fit_transform(projected_X), resolution=resolution, overlap=overlap)#resolution=40, overlap=0.75)
        graph = tm.map(data=X, cover=cover, clusterer=clusterer)
        print(graph.info())

        ################
        #SAFE
        ################

        initial_nodepos = {idx:graph.nodePos[idx] for idx in range(graph.nodePos.shape[0])}
        pos = nx.spring_layout(graph, k = 0.2, pos = initial_nodepos, seed=42)

        graph.nodePos = np.array([pos[key] for key in pos.keys()])

        for idx, node in enumerate(graph.nodes):
            graph.nodes[idx]["pos"] = pos[idx].tolist()

        edgelist_3col = nx.to_pandas_edgelist(graph)
        edgelist_3col["dist"] = 1
        edgelist_3col.to_csv(output_dir/f"{string}_{str(sensitivity_parameter).replace('.','p')}_mapper_graph_3col.txt", sep="\t", index=False, header=None)

        metadata_transformed = transform2node_data(graph, metadata, mode="mean")
        metadata_plus_imaging_transformed = transform2node_data(graph, metadata_plus_imaging, mode="mean")
        oral_microbiome_genus_transformed = transform2node_data(graph, oral_microbiome_genus, mode="mean")
        oral_microbiome_phylum_transformed = transform2node_data(graph, oral_microbiome_phylum, mode="mean")
        data_transformed = metadata_transformed.join(oral_microbiome_genus_transformed)
        data_transformed.to_csv(output_dir/f"{string}_{str(sensitivity_parameter).replace('.','p')}_mapper_graph_metadata.txt", sep="\t", index=True)

        sf = safe.SAFE(path_to_safe_data=f"{output_dir}/safe_{string}_{str(sensitivity_parameter).replace('.','p')}/")
        sf.random_seed = 0
        sf.attribute_distance_threshold = safe_distance_thresh
        sf.neighborhood_radius = safe_neighborhood_radius
        sf.load_network(network_file=f"{output_dir}/{string}_{str(sensitivity_parameter).replace('.','p')}_mapper_graph_3col.txt")
        sf.load_attributes(attribute_file=f"{output_dir}/{string}_{str(sensitivity_parameter).replace('.','p')}_mapper_graph_metadata.txt")
        sf.define_neighborhoods()

        num_permutations = 5000
        sf.compute_pvalues(num_permutations=num_permutations)

        network_enrichment_scores = pd.DataFrame(sf.nes, columns=data_transformed.columns)
        network_enrichment_scores_signif = pd.DataFrame(sf.nes_binary, columns=data_transformed.columns)
        network_enrichment_scores_signif_pos = (network_enrichment_scores > 0) & (network_enrichment_scores_signif)
        network_enrichment_scores_signif_neg = (network_enrichment_scores < 0) & (network_enrichment_scores_signif)

        safe_summary = sf.attributes.copy()
        safe_summary.drop("id", axis=1)
        safe_summary.set_index("name", inplace=True)

        safe_summary[f"{string}_{sensitivity_parameter}_enrichment_ratio"] = safe_summary["num_neighborhoods_enriched"] / len(graph.nodes)
        sensitivity_result_dict_er[f"{string}_{sensitivity_parameter}_enrichment_ratio"] = safe_summary[f"{string}_{sensitivity_parameter}_enrichment_ratio"]

        ################
        #Clustering
        ################

        from sklearn.cluster import KMeans
        import numpy as np

        positions = pd.DataFrame(nx.get_node_attributes(graph, "pos")).T
        positions.columns = ["0", "1"]

        clustering_input = positions.copy()

        clustering_input.columns = [str(idx) for idx in list(range(clustering_input.shape[1]))]
        n_clusters = 2

        clustering = KMeans(n_clusters=2, random_state=42).fit(clustering_input)
        positions["cluster"] = clustering.labels_

        import itertools

        node_subject_mapping_idx_dict = {node:list(graph.nodes[idx]["sample"]) for idx,node in enumerate(graph.nodes)}
        node_subject_mapping_dict = {node:list(graph.nodes[idx]["sample_names"]) for idx,node in enumerate(graph.nodes)}

        all_subject_indices = sorted(set(itertools.chain(*node_subject_mapping_dict.values())))
        node_subject_df = pd.DataFrame(0, index = all_subject_indices, columns = list(graph.nodes))

        for node, subjects in node_subject_mapping_dict.items():
            for subject in subjects:
                node_subject_df.loc[subject, node] = 1

        node_subject_df = node_subject_df.loc[metadata.index[metadata.index.isin(node_subject_df.index)]]

        subject_group_df = node_subject_df.T.join(positions["cluster"]).groupby("cluster").sum().T

        def determine_cluster(row):
            if row[0] > 0 and row[1] > 0:
                return -1
            elif row[0] > 0:
                return 0
            elif row[1] > 0:
                return 1
            else:
                return np.nan
            
        subject_group_df["cluster"] = subject_group_df.apply(determine_cluster, axis=1)

        sensitivity_result_dict_cluster[f"{string}_{sensitivity_parameter}_clustering"] = subject_group_df["cluster"]

In [19]:
sensitivity_result_df_er = pd.DataFrame(sensitivity_result_dict_er)
sensitivity_result_df_er.columns = [col.replace(".","p") for col in sensitivity_result_df_er.columns]
sensitivity_result_df_cluster = pd.DataFrame(sensitivity_result_dict_cluster)
sensitivity_result_df_cluster.columns = [col.replace(".","p") for col in sensitivity_result_df_cluster.columns]


In [20]:
original_enrichment_ratio = pd.read_csv(project_dir/"output/tda/metadata_safe_summary.csv", index_col=0)
original_clustering =  pd.read_csv(project_dir/"output/tda/cluster_analysis/subject_clustering.csv", index_col=0)

In [21]:
cluster_robustness_df = original_clustering.copy()
for i in sensitivity_result_df_cluster.keys():
    cluster_robustness_df = cluster_robustness_df.join(sensitivity_result_df_cluster[i], rsuffix=f"_{i}")

In [22]:
from scipy.stats import spearmanr
from sklearn.metrics import adjusted_rand_score

spearman_er_dict = {}
ari_clustering_dict = {}

for i in sensitivity_result_df_er.keys():
    spearman_er_dict[f"sensitivity_{i}_spearman"] = spearmanr(original_enrichment_ratio["enrichment_ratio"], sensitivity_result_df_er[i])[0]


for i in sensitivity_result_df_cluster.keys():
    cluster_robustness_i_df = cluster_robustness_df[["cluster",f"{i}"]].copy()
    cluster_robustness_i_df.dropna(inplace=True)
    ari_clustering_dict[f"sensitivity_{i}_ari"] = adjusted_rand_score(cluster_robustness_i_df["cluster"], cluster_robustness_i_df[f"{i}"])

In [None]:
corr_df = pd.DataFrame(index = sensitivity_result_df_er.columns, columns = ["Enrichment ratio", "Clustering"])
for string in sensitivity_parameter_dict.keys():
    corr_df.loc[f"{string}", "Enrichment ratio"] = sensitivity_result_df_er[f"{string}_enrichment_ratio"]
    sensitivity_result_df_cluster[f"{string}"] = sensitivity_result_df_cluster[f"{string}_{sensitivity_parameter_list[0]}_clustering"]