# Parametrized co-occurrence networks from taxonomy
1. Load data
2. Remove high taxa (non-identified sequences)
3. Pivot table from sampling events to taxa.
4. Remove low abundance taxa
5. Rarefy, or normalize
6. Remove replicates
7. Split to groups on chosen factor
8. Calculate associations (Bray-curits dissimilarity, Spearman’s correlation, etc.)
9. False discovery rate correction
10. Build and analyse network per group

## Platform dependent part
- Resolve platform setup
- the difference to local imports should be resolved by setting the VRE packages well

In [None]:
import sys
import os
import logging
import warnings
warnings.filterwarnings('ignore')
from IPython import get_ipython
logger = logging.getLogger(name="Co-occurrence network analysis")

if 'google.colab' in str(get_ipython()):
    print('Setting Google colab, you will need a ngrok account to make the dashboard display over the tunnel. \
    https://ngrok.com/')
    # clone the momics-demos repository to use it to load data
    try:
        os.system('git clone https://github.com/palec87/momics-demos.git')
        logger.info(f"Repository cloned")
    except OSError as e:
        logger.info(f"An error occurred while cloning the repository: {e}")

    sys.path.insert(0,'/content/momics-demos')

    # this step takes time beacause of many dependencies
    os.system('pip install marine-omics')

from momics.utils import (
    memory_load, reconfig_logger,
    init_setup, get_notebook_environment,
)

# Set up logging
reconfig_logger()

# Determine the notebook environment
env = get_notebook_environment()

init_setup()
logger.info(f"Environment: {env}")

INFO | root | Logging.basicConfig completed successfully
INFO | Diversity analysis app | Environment: vscode
INFO | Diversity analysis app | Environment: vscode


## Imports

In [None]:
# This needs to be repeated here for the Pannel dashboard to work, WEIRD
# TODO: report as possible bug
import sys
import os
import io
import itertools

import numpy as np
import pandas as pd
import panel as pn
import networkx as nx
import holoviews as hv
from typing import Dict
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr
from statsmodels.stats.multitest import multipletests
from skbio.stats import subsample_counts
from skbio.diversity import beta_diversity
from momics.taxonomy import (
    pivot_taxonomic_data,
    separate_taxonomy)

from mgo.udal import UDAL

# All low level functions are imported from the momics package
from momics.loader import load_parquets_udal
from momics.metadata import get_metadata_udal, enhance_metadata
from momics.taxonomy import (
    remove_high_taxa,
    pivot_taxonomic_data,
    prevalence_cutoff,
    rarefy_table,
    split_metadata,
    split_taxonomic_data,
    split_taxonomic_data_pivoted,
    compute_bray_curtis,
    # fdr_pvals,
)
from momics.networks import interaction_to_graph, interaction_to_graph_with_pvals

import momics.plotting as pl
from momics.panel_utils import serve_app, close_server
from momics.loader import bytes_to_df

## Loading and setup

In [None]:
DEBUG = True  # enable stdout logging
PLOT_FACE_COLOR = "#e6e6e6"
global full_metadata
udal = UDAL()

In [4]:
if 'google.colab' in str(get_ipython()):
    root_folder = os.path.abspath(os.path.join('/content/momics-demos'))
else:
    root_folder = os.path.abspath(os.path.join('../'))


assets_folder = os.path.join(root_folder, 'assets')

In [5]:
@pn.cache()
def get_data():
    return load_parquets_udal()

# Load and merge metadata
@pn.cache()
def get_full_metadata():
    return get_metadata_udal()

@pn.cache()
def get_valid_samples():
    df_valid = pd.read_csv(
        os.path.join(root_folder, 'data/shipment_b1b2_181.csv')
    )
    return df_valid

# Load metadata
full_metadata = get_full_metadata()

# filter the metadata only for valid 181 samples
valid_samples = get_valid_samples()
full_metadata = enhance_metadata(full_metadata, valid_samples)

mgf_parquet_dfs = get_data()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  metadata["replicate_info"] = (


In [7]:
# convert all object columns to categorical for metadata
factors = []
for col in full_metadata.columns:
    # check if object dtype
    if full_metadata[col].dtype == 'object':
        # convert to categorical
        full_metadata[col] = full_metadata[col].astype('category')
        factors.append(col)

if not isinstance(full_metadata['season'].dtype, pd.CategoricalDtype):
        raise ValueError(f"Column 'season' is not categorical (object dtype).")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  full_metadata[col] = full_metadata[col].astype('category')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  full_metadata[col] = full_metadata[col].astype('category')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  full_metadata[col] = full_metadata[col].astype('category')
A value is trying to be set 

## Content setup

In [None]:
SPLIT_TAXONOMY = {}
SPEARMAN_TAXA = {}

### Side panel

In [None]:
pn.extension("tabulator", "mathjax", "filedropper")
pn.extension(notifications=True)
ACCENT = "teal"

styles = {
    "box-shadow": "rgba(50, 50, 93, 0.25) 0px 6px 12px -2px, rgba(0, 0, 0, 0.3) 0px 3px 7px -3px",
    "border-radius": "4px",
    "padding": "10px",
}

image = pn.pane.JPG(os.path.join(assets_folder, "figs/logo_gecco.jpeg"),
                    width=100, height=100)

md_prepare_table = pn.pane.Markdown(
"""
**Button triggers:**
- Removes high taxa
- Removes low prevalence taxa
- Rarefies/normalizes the table
- Removes replicates
- split to group by factor
"""
)

tables = {
    "LSU": mgf_parquet_dfs['lsu'].copy(),
    "SSU": mgf_parquet_dfs['ssu'].copy(),
}

del mgf_parquet_dfs

select_table = pn.widgets.Select(
    name="Select table",
    options= list(tables.keys()),
    description="Select a table for network analysis",
)

select_factor = pn.widgets.Select(
        name="Select factor",
        options=factors,
)

select_high_taxon = pn.widgets.Select(
    name="Select high taxa to remove",
    options=['None', 'phylum', 'class', 'order', 'family', 'genus'],
    value='phylum',
    description="Taxa identified at this level or higher will be removed from the analysis",
)

low_prevalence_cutoff = pn.widgets.FloatInput(
    name='Low prevalence cutoff [%]',
    value=10, step=1, start=0, end=100,
    description="Percentage of samples in which the taxon must be present not to be removed.",
)

button_prepare_table = pn.widgets.Button(
    name="Process taxonomy",
    button_type="primary",
    width=200,
)

### Bindings

In [None]:
def preprocess_taxonomy(table, factor, high_taxon, prevalence_cutoff_value):
    """
    Preprocess the taxonomy data.
    """
    global SPLIT_TAXONOMY
    SPLIT_TAXONOMY.clear()
    df_filt = tables[table]
    logger.info("Preprocessing taxonomy...")
    if high_taxon != 'None':
        bef = df_filt.shape[0]
        df_filt = remove_high_taxa(df_filt, tax_level=high_taxon)
        aft = df_filt.shape[0]
        logger.info(f"Removed {bef - aft} high taxa at level: {high_taxon}")
        pn.state.notifications.success(
            f"Removed {bef - aft} high taxa at level: {high_taxon}",
            duration=3000
        )
    
    # pivot table
    df_filt = pivot_taxonomic_data(df_filt, normalize=None, rarefy_depth=None)

    # low prevalence cutoff
    df_filt = prevalence_cutoff(df_filt, percent=prevalence_cutoff_value, skip_columns=2)

    # rarefy table
    df_rarefied = df_filt.copy()
    df_rarefied.iloc[:, 2:] = rarefy_table(df_filt.iloc[:, 2:], depth=None, axis=1)

    # process metadata
    metadata = full_metadata.copy()
    filtered_metadata = metadata.drop_duplicates(subset='replicate_info', keep='first')

    groups = split_metadata(
        filtered_metadata,
        factor,
    )
    # remove groups which have less than 2 members (bad for your statistics :)
    for groups_key in list(groups.keys()):
        print(f"{groups_key}: {len(groups[groups_key])} samples")
        if len(groups[groups_key]) < 3:
            del groups[groups_key]
            print(f"Warning: {groups_key} has less than 3 samples, therefore removed.")

    SPLIT_TAXONOMY = split_taxonomic_data_pivoted(
        df_rarefied,
        groups
    )


button_prepare_table.on_click(
    lambda event: preprocess_taxonomy(
        select_table.value,
        select_factor.value,
        select_high_taxon.value,
        low_prevalence_cutoff.value
    )
)

### Association tab

In [None]:
md_associations = pn.pane.Markdown(
"""
**General hints:**
- Calculate associations between the selected factor and the taxonomic data.
- Perform FDR correction on the p-values.
"""
)

hist_fdr = pn.pane.HoloViews(
    height=600,
    width=1000,
    name="Associations visualization",
    )

viz_tab = pn.Column(
    hist_fdr,
)

pval_cutoff = pn.widgets.FloatInput(
    name='P-value cutoff',
    value=0.05, step=0.01, start=0, end=1,
    description="P-value cutoff to identify significant associations.",
)

histogram_plot = pn.pane.HoloViews(
    height=500,
    name="Histogram",
)

fdr_plot = pn.pane.HoloViews(
    height=500,
    name="FDR Plot",
)

button_associations = pn.widgets.Button(
    name="Calculate associations",
    button_type="primary",
    width=200,
)

### Methods

In [None]:
# this has to be updated in the momics-methods
def fdr_pvals(p_spearman_df: pd.DataFrame, pval_cutoff: float) -> pd.DataFrame:
    """
    Apply FDR correction to the p-values DataFrame using Benjamini/Hochberg (non-negative)
    method. This function extracts the upper triangle of the p-values DataFrame.

    Args:
        p_spearman_df (pd.DataFrame): DataFrame containing p-values.
        pval_cutoff (float): P-value cutoff for FDR correction.

    Returns:
        pd.DataFrame: DataFrame with FDR corrected p-values.
    """
    # Extract upper triangle p-values
    pval_array = (
        p_spearman_df.where(np.triu(np.ones(p_spearman_df.shape), k=1).astype(bool))
        .stack()
        .values
    )

    # Apply FDR correction
    _rejected, pvals_corrected, _, _ = multipletests(
        pval_array, alpha=pval_cutoff, method="fdr_bh"
    )

    # Map corrected p-values back to a DataFrame
    pvals_fdr = p_spearman_df.copy()
    pvals_fdr.values[np.triu_indices_from(p_spearman_df, k=1)] = pvals_corrected
    pvals_fdr.values[np.tril_indices_from(p_spearman_df, k=0)] = (
        np.nan
    )  # Optional: keep only upper triangle
    return pvals_fdr

In [None]:
def values_below_diagonal_series(df: pd.DataFrame) -> pd.Series:
    """
    Extract values under the main diagonal from a square DataFrame
    and return them as a flattened pandas Series.
    """
    idx = np.tril_indices_from(df, k=-1)
    return pd.Series(df.values[idx])


def plot_associations(pval_cutoff: float):
    hists = []
    for factor, d in SPEARMAN_TAXA.items():
        values = pd.Series(d['correlation'].values.flatten())
        hist = values.hvplot.hist(
            bins=50,
            alpha=0.5,
            label=factor,
            xlabel="Correlation",
            ylabel="Frequency",
            title="Histogram of Correlation Values"
        )
        hists.append(hist)
    # Overlay all histograms
    histograms = hists[0]
    for h in hists[1:]:
        histograms *= h
    histogram_plot.object = histograms.opts(
        show_legend=True,
        legend_position='top_left',
    )

    fdrs = []
    for factor, d in SPEARMAN_TAXA.items():
        df_pvals = pd.DataFrame({
            'raw_pval': SPEARMAN_TAXA[factor]['p_vals'].values.flatten()[::10],  # downsample for better visibility and speed
            'fdr_pval': SPEARMAN_TAXA[factor]['p_vals_fdr'].values.flatten()[::10],  # downsample for better visibility and speed
        })

        fdr_scatter = df_pvals.hvplot.scatter(
            x='raw_pval',
            y='fdr_pval',
            alpha=0.5,
            label=factor,
            xlabel="Raw p-value",
            ylabel="FDR-corrected p-value",
        )
        fdrs.append(fdr_scatter)
    
    # Overlay all scatter plots
    fdr_scatter = fdrs[0]
    for f in fdrs[1:]:
        fdr_scatter *= f
    # Add horizontal and vertical lines at pval_cutoff
    hline = hv.HLine(pval_cutoff).opts(color='black', line_dash='dashed', line_width=2)
    vline = hv.VLine(pval_cutoff).opts(color='gray', line_dash='dashed', line_width=2)
    
    # Overlay the lines on the scatter plot
    fdr_plot.object = (fdr_scatter * hline * vline).opts(
        show_legend=True,
        legend_position='bottom_right',
    )


def calculate_associations(pval_cutoff):
    global SPEARMAN_TAXA
    global SPLIT_TAXONOMY
    for factor, df in SPLIT_TAXONOMY.items():
        corr, p_spearman = spearmanr(df.iloc[:, 2:].T)
        assert corr.shape == p_spearman.shape, "Spearman correlation and p-values must have the same shape."
        corr_df = pd.DataFrame(
            corr,
            index=df['ncbi_tax_id'],
            columns=df['ncbi_tax_id']
        )
        p_spearman_df = pd.DataFrame(
            p_spearman,
            index=df['ncbi_tax_id'],
            columns=df['ncbi_tax_id']
        )
        d = {
            'correlation': corr_df,
            'p_vals': p_spearman_df
        }
        SPEARMAN_TAXA[factor] = d
        assert SPEARMAN_TAXA[factor]['correlation'].shape == SPEARMAN_TAXA[factor]['p_vals'].shape, "Spearman correlation and p-values must have the same shape."

        # FDR correction
        pvals_fdr = fdr_pvals(SPEARMAN_TAXA[factor]['p_vals'], pval_cutoff=pval_cutoff)
        SPEARMAN_TAXA[factor]['p_vals_fdr'] = pvals_fdr

    # plot associations
    logger.info("Plotting associations...")
    plot_associations(pval_cutoff)


button_associations.on_click(
    lambda event: calculate_associations(
        pval_cutoff.value
    )
)

In [None]:
association_tab = pn.Column(
    md_associations,
    pval_cutoff,
    button_associations,
    pn.Row(
        histogram_plot,
        fdr_plot,
    ),
    scroll=True,
    sizing_mode="stretch_both",
)

### Network tab

In [None]:
md_network = pn.pane.Markdown(
"""
**General hints:**
- Same p-value will be used as in previous section.
- You select the thresholds of the positive and negative associations, respectively.
- To add the edge to the network, the absolute value of the correlation must be above/below the threshold and p-value needs to be lower than the cutoff.
"""
)

pos_corr_cutoff = pn.widgets.FloatInput(
    name='Positive Correlation cutoff',
    value=0.75, step=0.05, start=0, end=1,
    description="Significant Positive Correlation Cutoff",
)

neg_corr_cutoff = pn.widgets.FloatInput(
    name='Negative Correlation cutoff',
    value=-0.70, step=0.05, start=-1, end=0,
    description="Significant Negative Correlation Cutoff",
)

button_network = pn.widgets.Button(
    name="Create and evaluate network",
    button_type="primary",
    width=200,
)

overall_network_df = pn.widgets.Tabulator()
jaccard_pos = pn.widgets.Tabulator()
jaccard_neg = pn.widgets.Tabulator()

network_plot = pn.pane.Matplotlib(
    name="Network plot",
    height=600,
    width=1000,
    # sizing_mode="stretch_both",
)

this is complicated because it is pairwise

In [None]:
def pairwise_jaccard_lower_triangle(network_results, edge_type='edges_pos'):
    """
    Calculate pairwise Jaccard similarity for the lower triangle of all group comparisons.
    Returns a DataFrame with columns: group1, group2, jaccard_similarity.
    """
    # Extract all group names
    groups = list(network_results.keys())
    results = []

    # define empty DataFrame with groups as index and columns
    pivoted = pd.DataFrame(index=groups, columns=groups)
    # Iterate over all unique pairs (lower triangle, i < j)
    for g1, g2 in itertools.combinations(groups, 2):
        edges1 = set(network_results[g1][edge_type])
        edges2 = set(network_results[g2][edge_type])
        intersection = edges1 & edges2
        union = edges1 | edges2
        jaccard = len(intersection) / len(union) if len(union) > 0 else float('nan')
        results.append({'group1': g1, 'group2': g2, 'jaccard_similarity': jaccard})
        pivoted.loc[g1, g2] = jaccard

    return pivoted

In [None]:
def calculate_network(pos_corr_cutoff, neg_corr_cutoff, pval_cutoff):
    global SPLIT_TAXONOMY
    global SPEARMAN_TAXA
    global full_metadata

    network_results = {}
    for factor, dict_df in SPEARMAN_TAXA.items():
        logger.info(f"Calculating network for factor: {factor}")
        pn.state.notifications.info(f"Calculating network for factor: {factor}")
        nodes, edges_pos, edges_neg = interaction_to_graph_with_pvals(
            dict_df['correlation'],
            dict_df['p_vals_fdr'],
            pos_cutoff=pos_corr_cutoff,
            neg_cutoff=neg_corr_cutoff,
            p_val_cutoff=pval_cutoff)
        logger.info(f"Number of nodes: {len(nodes)}")
        logger.info(f"Number of positive edges: {len(edges_pos)}")
        logger.info(f"Number of negative edges: {len(edges_neg)}")
        pn.state.notifications.info(
            f"Number of nodes: {len(nodes)}, "
            f"Number of positive edges: {len(edges_pos)}, "
            f"Number of negative edges: {len(edges_neg)}"
        )
        G = nx.Graph(
            mode = factor,
        )

        G.add_nodes_from(nodes)
        G.add_edges_from(edges_pos, color='green')
        G.add_edges_from(edges_neg, color='red')

        network_results[factor] = {
            "graph": G,
            "nodes": nodes,
            "edges_pos": edges_pos,
            "edges_neg": edges_neg
        }

        degree_centrality = nx.degree_centrality(G)

        network_results[factor]['degree_centrality'] = sorted(degree_centrality.items(),
                                                            key=lambda x: x[1],
                                                            reverse=True)[:10]
        
        betweenness = nx.betweenness_centrality(G)
        network_results[factor]['top_betweenness'] = sorted(betweenness.items(),
                                                        key=lambda x: x[1],
                                                        reverse=True)[:10]
        network_results[factor]['bottom_betweenness'] = sorted(betweenness.items(),
                                                            key=lambda x: x[1])[:10]
        network_results[factor]['total_nodes'] = G.number_of_nodes()
        network_results[factor]['total_edges'] = G.number_of_edges()

    DF = pd.DataFrame(columns=[select_factor.value, 'centrality', 'top_betweenness', 'bottom_betweenness', 'total_nodes', 'total_edges'])
    factors = []
    for factor, dict_results in network_results.items():
        DF = pd.concat([DF, pd.DataFrame([{
            select_factor.value: factor,
            'centrality': dict_results['degree_centrality'],
            'top_betweenness': dict_results['top_betweenness'],
            'bottom_betweenness': dict_results['bottom_betweenness'],
            'total_nodes': dict_results['total_nodes'],
            'total_edges': dict_results['total_edges']
        }])], ignore_index=True)
        factors.append(factor)

    overall_network_df.value = DF

    # Calculate Jaccard similarity for the networks
    jaccard_pos.value = pairwise_jaccard_lower_triangle(network_results, edge_type='edges_pos')
    jaccard_neg.value = pairwise_jaccard_lower_triangle(network_results, edge_type='edges_neg')

    # plot the networks
    fig, axes = plt.subplots(1, len(factors), figsize=(6*len(factors), 6))

    for ax, factor in zip(axes, factors):
        G = network_results[factor]['graph']
        colors = nx.get_edge_attributes(G, 'color')
        pos = nx.spring_layout(G, k=0.2, iterations=50, seed=42)
        nx.draw_networkx_nodes(G, pos, ax=ax, alpha=0.2, node_color='grey', node_size=15)
        nx.draw_networkx_edges(G, pos, ax=ax, alpha=0.2, edge_color=list(colors.values()))
        ax.set_title(factor)
        ax.axis('off')

    plt.tight_layout()
    plt.close(fig)
    network_plot.object = fig

button_network.on_click(
    lambda event: calculate_network(
        pos_corr_cutoff.value,
        neg_corr_cutoff.value,
        pval_cutoff.value
    )
)

In [None]:
network_tab = pn.Column(
    md_network,
    pn.Row(
        pos_corr_cutoff,
        neg_corr_cutoff,
        button_network,
    ),
    overall_network_df,
    pn.Row("## Jaccard similarity (positive)", jaccard_pos),
    pn.Row("## Jaccard similarity (negative)", jaccard_neg),
    network_plot,
    scroll=True,
    sizing_mode="stretch_both",
)

In [None]:
tabs = pn.Tabs(
    ("Associations", association_tab),
    ("Network analysis", network_tab),
    dynamic=True,
    styles=styles,
    sizing_mode="stretch_both",
    margin=10,
)

## APP setup

In [None]:
pn.extension("tabulator", "mathjax")

def app():
    template = pn.template.FastListTemplate(
        title="Co-occurrence network taxonomic analysis",
        sidebar=[image,
                md_prepare_table, 
                pn.layout.Divider(margin=(-20, 0, 0, 0)),
                pn.Column(
                    select_table,
                    select_factor,
                    select_high_taxon,
                    low_prevalence_cutoff,
                ),
                pn.layout.Divider(margin=(-20, 0, 0, 0)),
                button_prepare_table,
                ],
        main=[pn.Column(
            tabs,
            scroll=True,
        )],
        main_layout=None,
        accent=ACCENT,
    )
    return template

template = app()


if 'google.colab' in str(get_ipython()):  
    s = serve_app(template, env=env, name="Co-occurrence networks")
else:
    template.servable()

### Uncomment this if running if running ngrok tunnel which you want to quit

In [None]:
# only use for the ngrok tunnel in GColab
# close_server(s, env=env)