# Import libraries and data

In [1]:
import numpy as np
np.int = np.int32
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN
from tmap.tda import mapper, Filter
from tmap.tda.cover import Cover
from tmap.tda.metric import Metric
from tmap.tda.utils import optimize_dbscan_eps

from scipy.spatial.distance import squareform,pdist
import pandas as pd

import networkx as nx

In [2]:
from pathlib import Path

code_dir=Path.cwd()
project_dir=code_dir.parent
input_dir=project_dir/"input"
output_dir=project_dir/"output/tda_sensitivity_1/"
tmp_dir=project_dir/"tmp"

output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
metadata = pd.read_csv(input_dir/"data/metadata_df.csv", index_col=0)
oral_microbiome_genus = pd.read_csv(input_dir/"data/microbiome_genus.csv", index_col=0)

In [8]:
# load taxa abundance data, sample metadata and precomputed distance matrix
X = oral_microbiome_genus
metadata = metadata.loc[metadata.index.isin(X.index)][metadata_variables]
X = X.loc[X.index.isin(metadata.index)]

In [9]:
metadata_categories = [col.split("_")[0] for col in metadata.columns.tolist()]
microbiome_categories = ["genus"] * len(X.columns.to_list())

# Mapper

In [10]:
def transform2node_data(graph, data, mode='mean'):
    map_fun = {'sum': np.sum,
               "mean": np.nanmean}
    if mode not in ["sum", "mean"]:
        raise SyntaxError('Wrong provided parameters.')
    else:
        aggregated_fun = map_fun[mode]

    nodes = graph.nodes
    dv = data.values
    if data is not None:
        node_data = {nid: aggregated_fun(dv[attr['sample'], :], 0)
                     for nid, attr in nodes.items()}
        node_data = pd.DataFrame.from_dict(node_data,
                                           orient='index',
                                           columns=data.columns)
        return node_data

In [12]:
step = 0.01
robustness_array = np.arange(0.1, 1 + step, step)

In [None]:
from safepy import safe

iterations = 100

robustness_er_result_dict = {}
robustness_clustering_result_dict = {}

resolution = 30
overlap = 0.75
mapper_lens = Filter.PCOA
mapper_distance_metric = "braycurtis"
safe_distance_thresh = 0.75
safe_neighborhood_radius = 0.1

for frac in robustness_array:

    i = int(round(frac * 100, 1))

    for j in range(iterations):


        X_resampled = X.sample(frac=frac, replace=False, random_state=i + j)

        ################
        #Mapper
        ################

        # TDA Step1. initiate a Mapper
        tm = mapper.Mapper(verbose=1)

        # TDA Step2. Projection

        dm = squareform(pdist(X_resampled,metric=mapper_distance_metric))
        metric = Metric(metric="precomputed")
        lens = [mapper_lens(components=[0, 1], metric=metric, random_state=100)]
        projected_X = tm.filter(dm, lens=lens)

        # Step4. Covering, clustering & mapping
        eps = optimize_dbscan_eps(X_resampled, threshold=95)
        clusterer = DBSCAN(eps=eps, min_samples=3)
        cover = Cover(projected_data=MinMaxScaler().fit_transform(projected_X), resolution=resolution, overlap=overlap)#resolution=40, overlap=0.75)
        graph = tm.map(data=X_resampled, cover=cover, clusterer=clusterer)
        print(graph.info())

        ################
        #SAFE
        ################

        initial_nodepos = {idx:graph.nodePos[idx] for idx in range(graph.nodePos.shape[0])}
        pos = nx.spring_layout(graph, k = 0.2, pos = initial_nodepos, seed=42)

        graph.nodePos = np.array([pos[key] for key in pos.keys()])

        for idx, node in enumerate(graph.nodes):
            graph.nodes[idx]["pos"] = pos[idx].tolist()

        edgelist_3col = nx.to_pandas_edgelist(graph)
        edgelist_3col["dist"] = 1
        edgelist_3col.to_csv(output_dir/f"robustness_{i}_{j}_mapper_graph_3col.txt", sep="\t", index=False, header=None)

        metadata_transformed = transform2node_data(graph, metadata.loc[X_resampled.index], mode="mean")
        oral_microbiome_genus_transformed = transform2node_data(graph, oral_microbiome_genus.loc[X_resampled.index], mode="mean")
        data_transformed = metadata_transformed.join(oral_microbiome_genus_transformed)
        data_transformed.to_csv(output_dir/f"robustness_{i}_{j}_mapper_graph_metadata.txt", sep="\t", index=True)

        sf = safe.SAFE(path_to_safe_data=f"{output_dir}/safe_robustness_{i}_{j}/")
        sf.attribute_distance_threshold = safe_distance_thresh
        sf.neighborhood_radius = safe_neighborhood_radius
        sf.load_network(network_file=f"{output_dir}/robustness_{i}_{j}_mapper_graph_3col.txt")
        sf.load_attributes(attribute_file=f"{output_dir}/robustness_{i}_{j}_mapper_graph_metadata.txt")
        sf.define_neighborhoods()

        num_permutations = 5000
        sf.compute_pvalues(num_permutations=num_permutations)

        safe_summary = sf.attributes.copy()
        safe_summary.set_index("name", inplace=True)
        robustness_er_result_dict[f"robustness_{i}_{j}_enrichment_ratio"] = safe_summary["num_neighborhoods_enriched"] / len(graph.nodes)

        ################
        #Clustering
        ################

        from sklearn.cluster import KMeans
        import numpy as np

        positions = pd.DataFrame(nx.get_node_attributes(graph, "pos")).T
        positions.columns = ["0", "1"]

        clustering_input = positions.copy()

        clustering_input.columns = [str(idx) for idx in list(range(clustering_input.shape[1]))]
        n_clusters = 2

        clustering = KMeans(n_clusters=2, random_state=42).fit(clustering_input)
        positions["cluster"] = clustering.labels_

        import itertools

        node_subject_mapping_idx_dict = {node:list(graph.nodes[idx]["sample"]) for idx,node in enumerate(graph.nodes)}
        node_subject_mapping_dict = {node:list(graph.nodes[idx]["sample_names"]) for idx,node in enumerate(graph.nodes)}

        all_subject_indices = sorted(set(itertools.chain(*node_subject_mapping_dict.values())))
        node_subject_df = pd.DataFrame(0, index = all_subject_indices, columns = list(graph.nodes))

        for node, subjects in node_subject_mapping_dict.items():
            for subject in subjects:
                node_subject_df.loc[subject, node] = 1

        node_subject_df = node_subject_df.loc[metadata.index[metadata.index.isin(node_subject_df.index)]]

        subject_group_df = node_subject_df.T.join(positions["cluster"]).groupby("cluster").sum().T

        def determine_cluster(row):
            if row[0] > 0 and row[1] > 0:
                return -1
            elif row[0] > 0:
                return 0
            elif row[1] > 0:
                return 1
            else:
                return np.nan
            
        subject_group_df["cluster"] = subject_group_df.apply(determine_cluster, axis=1)

        robustness_clustering_result_dict[f"robustness_{i}_{j}_clustering"] = subject_group_df["cluster"]

In [14]:
# save as pickle
import pickle

with open(output_dir/"robustness_er_result_dict.pkl", "wb") as f:
    pickle.dump(robustness_er_result_dict, f)

with open(output_dir/"robustness_clustering_result_dict.pkl", "wb") as f:
    pickle.dump(robustness_clustering_result_dict, f)

In [29]:
with open(output_dir/"robustness_er_result_dict.pkl", "rb") as f:
    robustness_er_result_dict = pickle.load(f)

with open(output_dir/"robustness_clustering_result_dict.pkl", "rb") as f:
    robustness_clustering_result_dict = pickle.load(f)

In [15]:
original_enrichment_ratio = pd.read_csv(project_dir/"output/tda/metadata_safe_summary.csv", index_col=0)
original_clustering =  pd.read_csv(project_dir/"output/tda/cluster_analysis/subject_clustering.csv", index_col=0)

In [16]:
cluster_robustness_df = original_clustering.copy()
for frac in robustness_array:
    for j in range(iterations):
        i = int(round(frac * 100, 1))
        cluster_robustness_df = cluster_robustness_df.join(robustness_clustering_result_dict[f"robustness_{i}_{j}_clustering"], rsuffix=f"_{i}_{j}")

In [17]:
from scipy.stats import spearmanr
from sklearn.metrics import adjusted_rand_score

spearman_er_dict = {}
ari_clustering_dict = {}

for frac in robustness_array:
    for j in range(iterations):

        i = int(round(frac * 100, 1))
        spearman_er_dict[f"robustness_{i}_{j}_spearman"] = spearmanr(original_enrichment_ratio["enrichment_ratio"], robustness_er_result_dict[f"robustness_{i}_{j}_enrichment_ratio"])[0]
        cluster_robustness_i_df = cluster_robustness_df[["cluster",f"cluster_{i}_{j}"]].copy()
        cluster_robustness_i_df.dropna(inplace=True)
        ari_clustering_dict[f"robustness_{i}_{j}_ari"] = adjusted_rand_score(cluster_robustness_i_df["cluster"], cluster_robustness_i_df[f"cluster_{i}_{j}"])

In [19]:
spearman_er_dict_fin = {"_".join(keys.split("_")[1:3]):values for keys, values in spearman_er_dict.items()}
spearman_er_df = pd.DataFrame.from_dict(spearman_er_dict_fin, orient="index", columns=["spearman"])

ari_clustering_dict_fin = {"_".join(keys.split("_")[1:3]):values for keys, values in ari_clustering_dict.items()}
ari_clustering_df = pd.DataFrame.from_dict(ari_clustering_dict_fin, orient="index", columns=["ari"])

robustness_df = ari_clustering_df.join(spearman_er_df)
robustness_df["percentage"] = [idx[0] for idx in robustness_df.index.str.split("_")]
robustness_df.index = [idx[1] for idx in robustness_df.index.str.split("_")]
robustness_df = pd.melt(robustness_df.reset_index(), id_vars=["index","percentage"], value_vars=["ari", "spearman"], var_name="metric", value_name="value")

In [24]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Load your data into a pandas DataFrame
df = robustness_df.copy()

# Calculate mean and confidence intervals for plotting
grouped = df.groupby(['percentage', 'metric'])['value'].agg(['mean', 'std', 'count'])
grouped['ci'] = 1.96 * grouped['std'] / grouped['count'] ** 0.5  # 95% confidence interval
index_values = grouped.index
new_index = pd.MultiIndex.from_tuples([(int(percentage), metric) for percentage, metric in index_values], names=['percentage', 'metric'])
grouped.index = new_index
grouped.sort_index(inplace=True)

fig = go.Figure()

# Plotting for 'ari'
ari_data = grouped.xs('ari', level='metric')
fig.add_trace(go.Scatter(
    x=ari_data.index,
    y=ari_data['mean'],
    mode='lines',
    name='Adjusted Rand Index (ARI) of clustering',
    line=dict(color='blue')
))
fig.add_trace(go.Scatter(
    x=ari_data.index,
    y=ari_data['mean'] + ari_data['ci'],
    mode='lines',
    name='ARI upper',
    line=dict(width=0),
    showlegend=False
))
fig.add_trace(go.Scatter(
    x=ari_data.index,
    y=ari_data['mean'] - ari_data['ci'],
    mode='lines',
    name='ARI lower',
    line=dict(width=0),
    showlegend=False,
    fill='tonexty',
    fillcolor='rgba(0,0,255,0.2)'
))

# Plotting for 'spearman'
spearman_data = grouped.xs('spearman', level='metric')
fig.add_trace(go.Scatter(
    x=spearman_data.index,
    y=spearman_data['mean'],
    mode='lines',
    name='Spearman correlation of enrichment ratios',
    line=dict(color='red')
))
fig.add_trace(go.Scatter(
    x=spearman_data.index,
    y=spearman_data['mean'] + spearman_data['ci'],
    mode='lines',
    name='Spearman upper',
    line=dict(width=0),
    showlegend=False
))
fig.add_trace(go.Scatter(
    x=spearman_data.index,
    y=spearman_data['mean'] - spearman_data['ci'],
    mode='lines',
    name='Spearman lower',
    line=dict(width=0),
    showlegend=False,
    fill='tonexty',
    fillcolor='rgba(255,0,0,0.2)'
))

# Update layout to show 10 steps from 0 to 100 on the x-axis
fig.update_layout(
    title='',
    xaxis_title='Sample (%)',
    yaxis_title='',
    xaxis=dict(tickmode='array', tickvals=list(range(10, 100, 10))),
    template='plotly_white'
)

#change xticks
fig.update_xaxes(tickvals=[10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
fig.update_yaxes(tickvals=[0.2, 0.4, 0.6, 0.8, 1])

fig.write_html(output_dir/"robustness.html")
fig.write_image(output_dir/"robustness.png", format="png", scale=10)
fig.write_image(output_dir/"robustness.svg", format="svg")

# Show plot
fig.show()