# 51 Causal modules Leiden on PPI

**Origin:** `5_1_causal_module_Leiden_ppi.ipynb`  
**Annotated on:** 2025-10-13 06:45

**High-level objective:**  
- Detect causal communities on the PPI using Leiden/CPM or RB; grid-search resolution; tag modules enriched for MR-significant genes.

**Notes:**  
- These comments are language-agnostic and focus on intent, inputs, and outputs.  
- Adjust hard-coded paths if needed; prefer `/results_*` for derived artifacts.

---


**Step 1:** Load network or tabular inputs (PPI/GraphML/TSV).

In [4]:
#below is code for pruning ppi network by top 1% of hubnodes. 

import pandas as pd
import numpy as np
import networkx as nx
from scipy.stats import hypergeom
import community as community_louvain  # Louvain
import os

# Load PPI
ppi_edges = pd.read_csv('/mnt/f/10_osteo_MR/datasets/ppi/ppi_all_nonduplicate.tsv' ) # ,
                       #  sep='\t', header=None)
ppi_edges.columns = ['inx', 'gene_u','gene_v']
G = nx.Graph()
G.add_edges_from(ppi_edges[['gene_u','gene_v']].itertuples(index=False, name=None))

# Load causal genes
mr_df = pd.read_csv('/mnt/f/10_osteo_MR/results_mr_ptrs/PTRS/bulk_crossmodal_meta_beta.tsv', sep='\t')
causal_genes = set(mr_df['gene'])

outdir = "/mnt/f/10_osteo_MR/results_network/"
os.makedirs(outdir, exist_ok=True)

# ==== Step 1. Exclude top 1% hubs ====
N = G.number_of_nodes()
top_k = int(N * 0.01)
deg_sorted = sorted(G.degree(), key=lambda x: x[1], reverse=True)
hub_nodes = {n for n,_ in deg_sorted[:top_k]}

# without pruning
G_pruned = G.copy()


# Keep causal genes present in the pruned graph
causal_genes_in = causal_genes & set(G_pruned.nodes())
total_nodes = G_pruned.number_of_nodes()
total_causal = len(causal_genes_in)

print(f"Pruned graph: |V|={total_nodes:,} |E|={G_pruned.number_of_edges():,}")
print(f"Causal genes in pruned graph: {total_causal:,} (of {len(causal_genes):,})")





Pruned graph: |V|=16,201 |E|=236,930
Causal genes in pruned graph: 810 (of 969)


**Step 2:** Community detection / resolution sweep on the PPI.

In [None]:


# ---------- Viz ----------
def visualize_sig_modules_network(G_nx, g, clustering, comm_df, causal_genes, beta_map,
                                  out_png, out_graphml,
                                  q_alpha=0.05, min_size=8, seed=42,
                                  max_nodes_to_draw=8000):
    import matplotlib.patches as mpatches
    sig_ids = set(comm_df.loc[(comm_df['is_significant']) &
                              (comm_df['q_value'] < q_alpha) &
                              (comm_df['size'] >= min_size), 'community'])
    if not sig_ids:
        print("No significant modules to plot.")
        return
    names = np.array(g.vs['name'])
    keep_vids = np.unique(np.concatenate([np.array(clustering[c], int) for c in sig_ids]))
    keep_genes = set(names[keep_vids])

    H = nx.Graph()
    gene2cid = {}
    for cid in sig_ids:
        for vid in clustering[cid]:
            gene2cid[names[vid]] = cid
    for gene in keep_genes:
        b = beta_map.get(gene, np.nan)
        H.add_node(gene,
                   community=int(gene2cid.get(gene, -1)),
                   is_causal=(gene in causal_genes),
                   beta=b,
                   beta_abs=(abs(b) if not np.isnan(b) else 0.0))
    for u,v in G_nx.edges():
        if u in keep_genes and v in keep_genes:
            H.add_edge(u,v)

    if H.number_of_nodes() > max_nodes_to_draw:
        keep = []
        for cid in sig_ids:
            members = [n for n,d in H.nodes(data=True) if d.get("community")==cid]
            members = sorted(members, key=lambda x: H.degree(x), reverse=True)
            keep += members[:max(50, int(0.2*len(members)))]
        H = H.subgraph(keep).copy()

    pos = nx.spring_layout(H, seed=seed)
    cids = sorted({H.nodes[n]['community'] for n in H.nodes()})
    cmap = plt.get_cmap("Set2"); cid2idx = {c:i for i,c in enumerate(cids)}
    def size_for(n):
        babs = H.nodes[n].get('beta_abs', 0.0); return 60 + 240*min(babs,3.0)/3.0

    causal_nodes = [n for n in H.nodes() if H.nodes[n].get('is_causal', False)]
    noncausal_nodes = [n for n in H.nodes() if not H.nodes[n].get('is_causal', False)]

    fig, ax = plt.subplots(figsize=(12,10), dpi=150)
    nx.draw_networkx_edges(H, pos, ax=ax, width=0.3, alpha=0.35, edge_color="#999")
    nx.draw_networkx_nodes(
        H, pos, nodelist=noncausal_nodes,
        node_size=[size_for(n) for n in noncausal_nodes],
        node_color=[cmap(cid2idx[H.nodes[n]['community']]%20) for n in noncausal_nodes],
        linewidths=0.2, edgecolors="none", alpha=0.95, ax=ax
    )
    nx.draw_networkx_nodes(
        H, pos, nodelist=causal_nodes,
        node_size=[size_for(n) for n in causal_nodes],
        node_color=[cmap(cid2idx[H.nodes[n]['community']]%20) for n in causal_nodes],
        linewidths=1.6, edgecolors="black", alpha=0.98, ax=ax
    )
    handles = [mpatches.Patch(color=cmap(cid2idx[c]%20), label=f"Module {c}") for c in cids[:10]]
    if handles: ax.legend(handles=handles, title="Significant modules", loc="upper right", frameon=True)
    ax.set_title("Significant modules (Leiden/CPM)\nCausal genes bolded")
    ax.set_axis_off(); fig.tight_layout()
    fig.savefig(out_png, dpi=300); plt.close(fig)
    nx.write_graphml(H, out_graphml)
    print(f"Saved: {out_png}\nSaved: {out_graphml}")


**Step 3:** Load network or tabular inputs (PPI/GraphML/TSV).

In [56]:


# =========================
# Grid over (alpha, gamma) for A2 weights on FULL PPI
# =========================
import os, json, math, warnings
import numpy as np
import pandas as pd
import networkx as nx
from collections import defaultdict
from scipy.stats import hypergeom, chi2

import igraph as ig
import leidenalg as la
import matplotlib.pyplot as plt

# ------------ Paths ------------
PPI_PATH = "/mnt/f/10_osteo_MR/datasets/ppi/ppi_all_nonduplicate.tsv"
MR_PATH  = "/mnt/f/10_osteo_MR/results_mr_ptrs/PTRS/bulk_crossmodal_meta_beta.tsv"
OUTDIR   = "/mnt/f/10_osteo_MR/results_network/"
os.makedirs(OUTDIR, exist_ok=True)

# ------------ Helpers ------------
def bh_fdr(p):
    p = np.asarray(p, float)
    if p.size == 0:
        return p
    idx = np.argsort(p)
    ranks = np.empty_like(idx)
    ranks[idx] = np.arange(1, p.size + 1)
    q = p * p.size / ranks
    q_sorted = np.minimum.accumulate(q[idx][::-1])[::-1]
    out = np.empty_like(q)
    out[idx] = np.minimum(q_sorted, 1.0)
    return out

def load_ppi_and_mr(ppi_path, mr_path):
    # PPI -> NetworkX
    ppi = pd.read_csv(ppi_path)
    ppi.columns = ['inx','gene_u','gene_v']
    G = nx.Graph()
    G.add_edges_from(ppi[['gene_u','gene_v']].itertuples(index=False, name=None))

    # MR causal list with beta
    mr = pd.read_csv(mr_path, sep='\t')[['gene','meta_beta_common']].dropna()
    mr['gene'] = mr['gene'].astype(str)
    mr_agg = mr.groupby('gene', as_index=False)['meta_beta_common'].mean()
    causal = set(mr_agg['gene'])
    beta_map = dict(zip(mr_agg['gene'], mr_agg['meta_beta_common']))
    return G, causal, beta_map

def nx_to_igraph(G, weights=None):
    nodes = list(G.nodes())
    idx = {n:i for i,n in enumerate(nodes)}
    edges = [(idx[u], idx[v]) for u,v in G.edges()]
    g = ig.Graph(n=len(nodes), edges=edges, directed=False)
    g.vs['name'] = nodes
    if weights is not None:
        g.es['weight'] = list(weights)  # aligned to edges order above
    return g, nodes

# ---------- β-aware weights (A1 / A2) ----------
# A1: w = 1 + alpha*(|βu|+|βv|)
# A2: w = 1 + alpha*(|βu|+|βv|)*(1.0 if same sign else lambda_diff)
def build_weights(G, beta_map, scheme="A1", alpha=0.5, lambda_diff=0.5, beta_clip=3.0,
                  normalize_mean1=True, ensure_baseline=True, wmax=None):
    edges = list(G.edges())

    def b(n):
        x = beta_map.get(n, 0.0)
        if x is None or np.isnan(x):
            x = 0.0
        s = 1 if x > 0 else (-1 if x < 0 else 0)
        return s * min(abs(x), beta_clip)

    w = []
    for u, v in edges:
        bu, bv = b(u), b(v)
        mu, mv = abs(bu), abs(bv)
        same = (bu == 0 and bv == 0) or (bu >= 0 and bv >= 0) or (bu <= 0 and bv <= 0)
        if scheme == "A1":
            w_uv = 1.0 + alpha * (mu + mv)
        elif scheme == "A2":
            w_uv = 1.0 + alpha * (mu + mv) * (1.0 if same else lambda_diff)
        else:
            raise ValueError("scheme must be 'A1' or 'A2'")
        w.append(max(w_uv, 1e-9))

    w = np.array(w, float)
    if normalize_mean1 and w.mean() > 0:
        w = w / w.mean()
    if ensure_baseline:
        # enforce: edges touching any nonzero beta > 1, both-zero edges = 1
        beta_sig = np.array([(abs(beta_map.get(u, 0.0)) > 0) or (abs(beta_map.get(v, 0.0)) > 0)
                             for u, v in edges])
        w[~beta_sig] = 1.0
        w[ beta_sig] = np.maximum(w[beta_sig], 1.0 + 1e-3)
    if wmax is not None:
        w = np.minimum(w, float(wmax))
    return w

# ---------- Enrichment tests ----------
def enrich_partition(g, clustering, causal_set, min_size=5, fdr_alpha=0.05):
    rows = []
    N = g.vcount()
    names = np.array(g.vs['name'])
    # Universe causal count (over graph vertices)
    K = np.isin(names, list(causal_set)).sum()

    for cid, members in enumerate(clustering):
        size = len(members)
        if size < min_size:
            continue
        genes = names[members]
        k = np.isin(genes, list(causal_set)).sum()
        # Right-tail hypergeometric: P[X >= k]
        p = hypergeom.sf(k - 1, N, K, size)
        rows.append(dict(community=cid, size=size, causal_in_comm=k, p_value=p))

    df = pd.DataFrame(rows)
    if df.empty:
        df['q_value'] = []
        df['is_significant'] = []
        return df

    df['q_value'] = bh_fdr(df['p_value'].values)
    df['is_significant'] = (df['q_value'] < fdr_alpha) & (df['size'] >= min_size)
    #df['is_significant'] =  (df['causal_in_comm'] >= 0 ) #  min_size)
    
    return df

def union_enrichment(comm_df, g, clustering, causal_set):
    """
    Aggregate across significant modules:
      - Take the union of vertices from all FDR-significant modules
      - Compute a single hypergeometric P for (#causal in union) vs (union size)
    Returns dict with counts and p-value. If no significant modules, returns zeros and p=1.
    """
    names = np.array(g.vs['name'])
    N = g.vcount()
    K = np.isin(names, list(causal_set)).sum()

    sig = comm_df[comm_df['is_significant']]
    if sig.empty:
        return dict(union_nodes=0, union_causal=0, universe_N=N, universe_K=K,
                    aggregated_p=1.0)

    sig_ids = sig['community'].tolist()
    union_vids = set()
    for cid in sig_ids:
        union_vids.update(clustering[cid])
    union_vids = np.array(sorted(list(union_vids)), dtype=int)

    union_nodes = int(union_vids.size)
    union_genes = names[union_vids]
    union_causal = int(np.isin(union_genes, list(causal_set)).sum())

    # One-shot enrichment of the union
    aggregated_p = float(hypergeom.sf(union_causal - 1, N, K, union_nodes))

    return dict(union_nodes=union_nodes,
                union_causal=union_causal,
                universe_N=N,
                universe_K=K,
                aggregated_p=aggregated_p)

def fishers_method(pvals):
    """
    Fisher's combined probability test for a list of independent p-values.
    Useful for diagnostics; not used for ranking unless you want to.
    """
    pvals = [p for p in pvals if p > 0 and np.isfinite(p)]
    if len(pvals) == 0:
        return 1.0
    stat = -2.0 * np.sum(np.log(pvals))
    df = 2 * len(pvals)
    return float(chi2.sf(stat, df))

def coverage_metrics(comm_df, g, clustering, causal_set):
    """
    Returns:
      num_sig: # significant modules
      covered_causal: # unique causal genes in significant modules
      covered_nodes:  # total nodes (causal + non-causal) in significant modules
    """
    sig = comm_df[comm_df['is_significant']]
    num_sig = int(sig.shape[0])
    if num_sig == 0:
        return 0, 0, 0

    names = np.array(g.vs['name'])
    sig_ids = set(sig['community'].tolist())

    keep_vids = []
    for cid in sig_ids:
        keep_vids.extend(list(clustering[cid]))
    keep_vids = np.unique(np.array(keep_vids, dtype=int))

    covered_nodes = int(keep_vids.size)
    genes = names[keep_vids]
    covered_causal = int(np.isin(genes, list(causal_set)).sum())

    return num_sig, covered_causal, covered_nodes

# ------------ Load data (FULL PPI) ------------
G_full, causal_genes, beta_map = load_ppi_and_mr(PPI_PATH, MR_PATH)
print(f"Full PPI: V={G_full.number_of_nodes()} E={G_full.number_of_edges()}")

# ------------ Grid spec ------------
schemes = ["A2"]  # keep as a list for easy extension
alphas  = np.round(np.arange(0.5, 10.0 + 1e-9, 0.5), 3).tolist()  # 0.5 interval
gammas  = np.round(np.logspace(np.log10(1e-3), np.log10(1e-2), 20), 6).tolist()  # 20 log points

min_comm_size = 5
fdr_alpha = 0.05
seed = 42

rows = []
kept_partitions = {}   # (scheme, alpha, gamma) -> (g, clustering, comm_df)

# ------------ Grid ------------
for scheme in schemes:
    for alpha in alphas:
        # Prebuild weights per alpha
        w = build_weights(G_full, beta_map,
                          scheme=scheme, alpha=alpha,
                          lambda_diff=0.5, beta_clip=3.0,
                          normalize_mean1=True, ensure_baseline=True)
        g, nodes = nx_to_igraph(G_full, weights=w)

        for gamma in gammas:
            try:
                part = la.find_partition(
                    g,
                    la.CPMVertexPartition,
                    weights='weight',
                    resolution_parameter=gamma #,
                    #seed=seed
                )
            except Exception as e:
                warnings.warn(f"Leiden failed: scheme={scheme}, alpha={alpha}, gamma={gamma}: {e}")
                continue

            # Per-module enrichment & FDR
            comm_df = enrich_partition(g, part, causal_genes,
                                       min_size=min_comm_size,
                                       fdr_alpha=fdr_alpha)

            # Coverage over significant modules
            num_sig, covered_causal, covered_nodes = coverage_metrics(comm_df, g, part, causal_genes)

            # Aggregated (union) enrichment p-value across all significant modules
            agg = union_enrichment(comm_df, g, part, causal_genes)
            aggregated_p = agg['aggregated_p']
            union_nodes = agg['union_nodes']
            union_causal = agg['union_causal']

            # Optional: Fisher's method across *module-level* (raw) p-values of significant modules
            # Not used for ranking unless you decide to.
            fish_p = 1.0
            if num_sig > 0:
                fish_p = fishers_method(comm_df.loc[comm_df['is_significant'], 'p_value'].tolist())

            # Useful ratios
            total_causal_in_graph = int(agg['universe_K'])
            coverage_rate = (covered_causal / total_causal_in_graph) if total_causal_in_graph > 0 else 0.0
            precision = (union_causal / union_nodes) if union_nodes > 0 else 0.0

            rows.append(dict(
                scheme=scheme,
                alpha=float(alpha),
                gamma=float(gamma),
                n_modules=len(part),
                num_sig=num_sig,
                covered_causal=covered_causal,    # total causal across significant modules (union)
                covered_nodes=covered_nodes,      # total nodes across significant modules (union)
                union_causal=union_causal,        # same as covered_causal, but explicit naming
                union_nodes=union_nodes,
                coverage_rate=coverage_rate,      # union_causal / total causal in graph
                precision=precision,              # union_causal / union_nodes
                aggregated_p=aggregated_p,        # **PRIMARY selection stat**
                fishers_p=fish_p                  # optional diagnostic
            ))

            # Keep artifacts if you want to inspect later
            kept_partitions[(scheme, float(alpha), float(gamma))] = (g, part, comm_df)

# ------------ Collate & choose best ------------
grid_df = pd.DataFrame(rows).sort_values(['scheme', 'alpha', 'gamma']).reset_index(drop=True)

# Rank primarily by aggregated_p (smaller is better), then by coverage_rate (higher), then precision (higher)
grid_df['rank_key'] = list(zip(grid_df['aggregated_p'],
                               -grid_df['coverage_rate'],
                               -grid_df['precision']))

grid_df = grid_df.sort_values('rank_key', kind='mergesort').reset_index(drop=True)
best_row = grid_df.iloc[0].to_dict()

# Save outputs
csv_path = os.path.join(OUTDIR, "grid_leiden_A2_union_enrichment.csv")
grid_df.drop(columns=['rank_key']).to_csv(csv_path, index=False)

best_path = os.path.join(OUTDIR, "grid_leiden_A2_best.json")
with open(best_path, "w") as f:
    json.dump(best_row, f, indent=2)

print(f"[OK] Wrote grid results: {csv_path}")
print(f"[OK] Wrote best param set: {best_path}")
print("Best (scheme, alpha, gamma):", (best_row['scheme'], best_row['alpha'], best_row['gamma']))
print("Best aggregated_p =", best_row['aggregated_p'], 
      "| coverage_rate =", best_row['coverage_rate'], 
      "| precision =", best_row['precision'])


Full PPI: V=16201 E=236930
[OK] Wrote grid results: /mnt/f/10_osteo_MR/results_network/grid_leiden_A2_union_enrichment.csv
[OK] Wrote best param set: /mnt/f/10_osteo_MR/results_network/grid_leiden_A2_best.json
Best (scheme, alpha, gamma): ('A2', 8.0, 0.00336)
Best aggregated_p = 2.8157112461777197e-48 | coverage_rate = 0.31851851851851853 | precision = 0.12451737451737452


**Step 4:** Community detection / resolution sweep on the PPI.

In [60]:
import os, json, warnings
import numpy as np
import pandas as pd
import networkx as nx
from scipy.stats import hypergeom

import igraph as ig
import leidenalg as la
import matplotlib.pyplot as plt

# ---------- Paths & best params ----------
PPI_PATH = "/mnt/f/10_osteo_MR/datasets/ppi/ppi_all_nonduplicate.tsv"
MR_PATH  = "/mnt/f/10_osteo_MR/results_mr_ptrs/PTRS/bulk_crossmodal_meta_beta.tsv"
OUTDIR   = "/mnt/f/10_osteo_MR/results_network/"
os.makedirs(OUTDIR, exist_ok=True)

BEST_SCHEME = "A2"
BEST_ALPHA  = 6.0
BEST_GAMMA  = 0.001832981
MIN_COMM_SIZE = 5
FDR_ALPHA = 0.05
SEED = 42

# === Reuse your helpers: bh_fdr, load_ppi_and_mr, nx_to_igraph, build_weights, enrich_partition from previous cell ===
# If not already in scope, uncomment/paste those helpers here.


# ---------- Enrichment tests ----------
def enrich_partition_min(g, clustering, causal_set, min_size=5, fdr_alpha=0.05):
    rows = []
    N = g.vcount()
    names = np.array(g.vs['name'])
    # Universe causal count (over graph vertices)
    K = np.isin(names, list(causal_set)).sum()

    for cid, members in enumerate(clustering):
        size = len(members)
        if size < min_size:
            continue
        genes = names[members]
        k = np.isin(genes, list(causal_set)).sum()
        # Right-tail hypergeometric: P[X >= k]
        p = hypergeom.sf(k - 1, N, K, size)
        rows.append(dict(community=cid, size=size, causal_in_comm=k, p_value=p))

    df = pd.DataFrame(rows)
    if df.empty:
        df['q_value'] = []
        df['is_significant'] = []
        return df

    df['q_value'] = bh_fdr(df['p_value'].values)
    #df['is_significant'] = (df['q_value'] < fdr_alpha) & (df['size'] >= min_size)
    df['is_significant'] =  (df['causal_in_comm'] > 0 ) #  min_size)
    
    return df


# ---------- Utility: collect union of sig modules ----------
def union_vids_from_sig(comm_df, clustering):
    sig_ids = comm_df.loc[comm_df['is_significant'], 'community'].tolist()
    union_vids = set()
    for cid in sig_ids:
        union_vids.update(clustering[cid])
    return np.array(sorted(list(union_vids)), dtype=int)

# ---------- Visual (same signature as you shared; kept intact) ----------
def visualize_sig_modules_network(G_nx, g, clustering, comm_df, causal_genes, beta_map,
                                  out_png, out_graphml,
                                  q_alpha=0.05, min_size=8, seed=42,
                                  max_nodes_to_draw=8000):
    import matplotlib.patches as mpatches
    sig_ids = set(comm_df.loc[(comm_df['is_significant'])    &
                              # (comm_df['q_value'] < q_alpha) &
                               (comm_df['size'] >= min_size), 'community'])
    
    #print( comm_df ) 
    
    if not sig_ids:
        print("No significant modules to plot.")
        return
    
    names = np.array(g.vs['name'])
    keep_vids = np.unique(np.concatenate([np.array(clustering[c], int) for c in sig_ids]))

    keep_genes = set(names[keep_vids])

    H = nx.Graph()
    gene2cid = {}
    for cid in sig_ids:
        for vid in clustering[cid]:
            gene2cid[names[vid]] = cid

    print( 'gene2cid', set( gene2cid.values()  )) 
    
    for gene in keep_genes:
        b = beta_map.get(gene, np.nan)
        H.add_node(gene,
                   community=int(gene2cid.get(gene, -1)),
                   is_causal=(gene in causal_genes),
                   beta=b,
                   beta_abs=(abs(b) if not np.isnan(b) else 0.0))
        #if int(gene2cid.get(gene, -1)) in [ 10, 11 ] :
        #    print( gene, int(gene2cid.get(gene, -1)) )
    
    for u,v in G_nx.edges():
        if u in keep_genes and v in keep_genes:
            H.add_edge(u,v)

    print( H.number_of_nodes ())
    
    if H.number_of_nodes() > max_nodes_to_draw:
        keep = []
        for cid in sig_ids:
            members = [n for n,d in H.nodes(data=True) if d.get("community")==cid]
            members = sorted(members, key=lambda x: H.degree(x), reverse=True)
            
            keep += members[:max(50, int(0.2*len(members)))]
        H = H.subgraph(keep).copy()

    pos = nx.spring_layout(H, seed=seed)
    cids = sorted({H.nodes[n]['community'] for n in H.nodes()})

    cmap = plt.get_cmap("Set3"); cid2idx = {c:i for i,c in enumerate(cids)}
    
    def size_for(n):
        babs = H.nodes[n].get('beta_abs', 0.0); return 60 + 240*min(babs,3.0)/3.0

    causal_nodes = [n for n in H.nodes() if H.nodes[n].get('is_causal', False)]
    noncausal_nodes = [n for n in H.nodes() if not H.nodes[n].get('is_causal', False)]

    fig, ax = plt.subplots(figsize=(12,10), dpi=150)
    nx.draw_networkx_edges(H, pos, ax=ax, width=0.3, alpha=0.35, edge_color="#999")
    nx.draw_networkx_nodes(
        H, pos, nodelist=noncausal_nodes,
        node_size=[size_for(n) for n in noncausal_nodes],
        node_color=[cmap(cid2idx[H.nodes[n]['community']]%20) for n in noncausal_nodes],
        linewidths=0.2, edgecolors="none", alpha=0.95, ax=ax
    )
    nx.draw_networkx_nodes(
        H, pos, nodelist=causal_nodes,
        node_size=[size_for(n) for n in causal_nodes],
        node_color=[cmap(cid2idx[H.nodes[n]['community']]%20) for n in causal_nodes],
        linewidths=1.6, edgecolors="black", alpha=0.98, ax=ax
    )
    handles = [mpatches.Patch(color=cmap(cid2idx[c]%20), label=f"Module {c}") for c in cids[:10]]
    if handles: ax.legend(handles=handles, title="Significant modules", loc="upper right", frameon=True)
    ax.set_title("Significant modules (Leiden/CPM)\nCausal genes bolded")
    ax.set_axis_off(); fig.tight_layout()
    fig.savefig(out_png, dpi=300); plt.close(fig)
    nx.write_graphml(H, out_graphml)
    print(f"Saved: {out_png}\nSaved: {out_graphml}")

# ---------- Main refine procedure ----------
def refine_on_best_params():
    # 1) Load full graph + MR
    G_full, causal_genes, beta_map = load_ppi_and_mr(PPI_PATH, MR_PATH)

    # 2) Best-params partition on FULL graph (to reconstruct the sig-union)
    w_full = build_weights(
        G_full, beta_map,
        scheme=BEST_SCHEME, alpha=BEST_ALPHA,
        lambda_diff=0.5, beta_clip=3.0,
        normalize_mean1=True, ensure_baseline=True
    )
    g_full, _ = nx_to_igraph(G_full, weights=w_full)

    base_part = la.find_partition(
        g_full,
        la.CPMVertexPartition,
        weights="weight",
        resolution_parameter=BEST_GAMMA,
        seed=SEED
    )
    base_comm = enrich_partition(
        g_full, base_part, causal_genes,
        min_size=MIN_COMM_SIZE, fdr_alpha=FDR_ALPHA
    )

    # 3) Build subnetwork = union of vertices from FDR-significant modules
    if base_comm.empty or not base_comm['is_significant'].any():
        print("[WARN] No significant modules with the best params; aborting refine.")
        return

    union_vids = union_vids_from_sig(base_comm, base_part)
    names_full = np.array(g_full.vs['name'])
    union_genes = set(names_full[union_vids])

    H_nx = G_full.subgraph(union_genes).copy()
    n_nodes = H_nx.number_of_nodes()
    n_edges = H_nx.number_of_edges()

    # Connectivity check
    components = list(nx.connected_components(H_nx))
    n_components = len(components)
    all_connected = (n_components == 1)
    largest_comp_size = max(len(c) for c in components) if n_components > 0 else 0

    print(f"[Subnetwork] |V|={n_nodes} |E|={n_edges} |components|={n_components} "
          f"|largest|={largest_comp_size} |all_connected|={all_connected}")

    # 4) Run Leiden on subnetwork with default optimiser on CPM partition
    #    (we use same gamma on the subgraph)
    w_sub = build_weights(
        H_nx, beta_map,
        scheme=BEST_SCHEME, alpha=BEST_ALPHA,
        lambda_diff=0.5, beta_clip=3.0,
        normalize_mean1=True, ensure_baseline=True
    )
    h_ig, h_nodes = nx_to_igraph(H_nx, weights=w_sub)

    # Initialise CPM partition and optimise with la.Optimiser()
    part = la.RBConfigurationVertexPartition(
        h_ig, weights="weight" # ,
        #resolution_parameter=BEST_GAMMA
    )

    # part = la.find_partition(h_ig, la.CPMVertexPartition, weights='weight'  )
    
    opt = la.Optimiser()
    # Optimise until convergence
    improved = True
    #while improved:
    improved = opt.optimise_partition( part  , n_iterations=100)


    # 5) Stats for the subnetwork partition
    comm_df = enrich_partition_min(
        h_ig, part, causal_genes,
        min_size=MIN_COMM_SIZE, fdr_alpha=FDR_ALPHA
    ).sort_values(['is_significant','q_value','size'], ascending=[False, True, False]).reset_index(drop=True)


    # Membership table
    membership = np.array(part.membership, dtype=int)
    genes = np.array(h_ig.vs['name'])
    mem_df = pd.DataFrame({'gene': genes, 'community': membership})

    sig_ids = set(comm_df.loc[ comm_df['is_significant'] , 'community'].tolist()) # 
    sig_membership_df = mem_df[mem_df['community'].isin(sig_ids)].copy()


    # 6) Modularity (igraph) and CPM "quality" (stability-like)
    # Modularity is a different objective, but useful as a descriptive metric.
    try:
        modularity = h_ig.modularity(membership, weights=h_ig.es['weight'])
    except Exception as e:
        modularity = np.nan
        warnings.warn(f"igraph modularity failed: {e}")

    # CPM quality (the optimised objective)
    cpm_quality = float(part.quality())

    # 7) Save outputs
    tag = f"subnet_A2_a{int(BEST_ALPHA)}_g{BEST_GAMMA:.9f}"
    subdir = os.path.join(OUTDIR, tag)
    os.makedirs(subdir, exist_ok=True)

    # Connectivity summary
    with open(os.path.join(subdir, "connectivity.json"), "w") as f:
        json.dump({
            "nodes": n_nodes,
            "edges": n_edges,
            "n_components": n_components,
            "largest_component_size": largest_comp_size,
            "all_connected": bool(all_connected)
        }, f, indent=2)

    comm_df_path = os.path.join(subdir, "subnet_partition_stats.csv")
    comm_df.to_csv(comm_df_path, index=False)

    sig_mem_path = os.path.join(subdir, "subnet_sig_membership.tsv")
    sig_membership_df.to_csv(sig_mem_path, sep="\t", index=False)

    all_mem_path = os.path.join(subdir, "subnet_all_membership.tsv")
    mem_df.to_csv(all_mem_path, sep="\t", index=False)

    with open(os.path.join(subdir, "partition_scores.json"), "w") as f:
        json.dump({
            "modularity": modularity,
            "cpm_quality": cpm_quality,
            "gamma": BEST_GAMMA,
            "alpha": BEST_ALPHA
        }, f, indent=2)

    # 8) Visualise significant modules on the subnetwork partition
    png_path = os.path.join(subdir, "subnet_sig_modules.png")
    graphml_path = os.path.join(subdir, "subnet_sig_modules.graphml")
    visualize_sig_modules_network(
        H_nx, h_ig, part, comm_df,
        causal_genes, beta_map,
        out_png=png_path, out_graphml=graphml_path,
        q_alpha=FDR_ALPHA, min_size=MIN_COMM_SIZE, seed=SEED
    )

    print(f"[OK] Community stats  : {comm_df_path}")
    print(f"[OK] Sig membership   : {sig_mem_path}")
    print(f"[OK] All membership   : {all_mem_path}")
    print(f"[OK] Scores           : {os.path.join(subdir, 'partition_scores.json')}")
    print(f"[OK] Viz (PNG/GraphML): {png_path} / {graphml_path}")

# ---- Run ----
refine_on_best_params()


[Subnetwork] |V|=2470 |E|=39835 |components|=1 |largest|=2470 |all_connected|=True
gene2cid {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
2455
Saved: /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_sig_modules.png
Saved: /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_sig_modules.graphml
[OK] Community stats  : /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_partition_stats.csv
[OK] Sig membership   : /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_sig_membership.tsv
[OK] All membership   : /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_all_membership.tsv
[OK] Scores           : /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/partition_scores.json
[OK] Viz (PNG/GraphML): /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_sig_modules.png / /mnt/f/10_osteo_MR/results_network/subnet_A2_a6_g0.001832981/subnet_sig_modules.graphml


**Step 5:** Load network or tabular inputs (PPI/GraphML/TSV).

In [10]:


import os, json, warnings
import numpy as np
import pandas as pd
import networkx as nx
from scipy.stats import hypergeom

import igraph as ig
import leidenalg as la
import matplotlib.pyplot as plt

# ------------ Paths & best params ------------
PPI_PATH = "/mnt/f/10_osteo_MR/datasets/ppi/ppi_all_nonduplicate.tsv"
MR_PATH  = "/mnt/f/10_osteo_MR/results_mr_ptrs/PTRS/bulk_crossmodal_meta_beta.tsv"
OUTDIR   = "/mnt/f/10_osteo_MR/results_network/"
os.makedirs(OUTDIR, exist_ok=True)

BEST_SCHEME = "A2"
BEST_ALPHA  = 6.0
BEST_GAMMA  = 0.001832981
SEED = 42

MIN_COMM_SIZE_FOR_UNION = 5   # used only to FIND the union subnetwork on full graph
FDR_ALPHA_FOR_UNION     = 0.05

# ------------ Helpers ------------
def bh_fdr(p):
    p = np.asarray(p, float)
    if p.size == 0: return p
    idx = np.argsort(p)
    ranks = np.empty_like(idx); ranks[idx] = np.arange(1, p.size+1)
    q = p * p.size / ranks
    q_sorted = np.minimum.accumulate(q[idx][::-1])[::-1]
    out = np.empty_like(q); out[idx] = np.minimum(q_sorted, 1.0)
    return out

def load_ppi_and_mr(ppi_path, mr_path):
    ppi = pd.read_csv(ppi_path)
    ppi.columns = ['inx','gene_u','gene_v']
    G = nx.Graph()
    G.add_edges_from(ppi[['gene_u','gene_v']].itertuples(index=False, name=None))

    mr = pd.read_csv(mr_path, sep='\t')[['gene','meta_beta_common']].dropna()
    mr['gene'] = mr['gene'].astype(str)
    mr_agg = mr.groupby('gene', as_index=False)['meta_beta_common'].mean()
    causal = set(mr_agg['gene'])
    beta_map = dict(zip(mr_agg['gene'], mr_agg['meta_beta_common']))
    return G, causal, beta_map

def nx_to_igraph(G, weights=None):
    nodes = list(G.nodes())
    idx = {n:i for i,n in enumerate(nodes)}
    edges = [(idx[u], idx[v]) for u,v in G.edges()]
    g = ig.Graph(n=len(nodes), edges=edges, directed=False)
    g.vs['name'] = nodes
    if weights is not None:
        g.es['weight'] = list(weights)
    return g, nodes

# β-aware weights
def build_weights(G, beta_map, scheme="A2", alpha=0.5, lambda_diff=0.5, beta_clip=3.0,
                  normalize_mean1=True, ensure_baseline=True, wmax=None):
    edges = list(G.edges())
    def b(n):
        x = beta_map.get(n, 0.0)
        if x is None or np.isnan(x): x = 0.0
        s = 1 if x > 0 else (-1 if x < 0 else 0)
        return s * min(abs(x), beta_clip)

    w = []
    for u,v in edges:
        bu, bv = b(u), b(v)
        mu, mv = abs(bu), abs(bv)
        same = (bu == 0 and bv == 0) or (bu >= 0 and bv >= 0) or (bu <= 0 and bv <= 0)
        if scheme == "A1":
            w_uv = 1.0 + alpha*(mu+mv)
        elif scheme == "A2":
            w_uv = 1.0 + alpha*(mu+mv)*(1.0 if same else lambda_diff)
        else:
            raise ValueError("scheme must be 'A1' or 'A2'")
        w.append(max(w_uv, 1e-9))
    w = np.array(w, float)
    if normalize_mean1 and w.mean() > 0:
        w = w / w.mean()
    if ensure_baseline:
        beta_sig = np.array([(abs(beta_map.get(u,0.0))>0) or (abs(beta_map.get(v,0.0))>0) for u,v in edges])
        w[~beta_sig] = 1.0
        w[ beta_sig] = np.maximum(w[beta_sig], 1.0 + 1e-3)
    if wmax is not None:
        w = np.minimum(w, float(wmax))
    return w

# ONLY used to FIND the union subnetwork on the FULL graph; not used afterward
def enrich_partition_for_union(g, clustering, causal_set, min_size=5, fdr_alpha=0.05):
    rows = []
    N = g.vcount()
    names = np.array(g.vs['name'])
    K = np.isin(names, list(causal_set)).sum()
    for cid, members in enumerate(clustering):
        size = len(members)
        if size < min_size: continue
        genes = names[members]
        k = np.isin(genes, list(causal_set)).sum()
        p = hypergeom.sf(k-1, N, K, size)
        rows.append(dict(community=cid, size=size, causal_in_comm=k, p_value=p))
    df = pd.DataFrame(rows)
    if df.empty:
        df['q_value']=[]; df['is_significant']=[]
        return df
    df['q_value'] = bh_fdr(df['p_value'].values)
    df['is_significant'] = (df['q_value'] < fdr_alpha) & (df['size'] >= min_size)
    return df

def union_vids_from_sig(comm_df, clustering):
    sig_ids = comm_df.loc[comm_df['is_significant'], 'community'].tolist()
    union_vids = set()
    for cid in sig_ids:
        union_vids.update(clustering[cid])
    return np.array(sorted(list(union_vids)), dtype=int)

# ---- Visualization: modules with ≥1 causal gene (no p/q used here) ----
def visualize_causal_modules_network(G_nx, g, membership, causal_genes, beta_map,
                                     out_png, out_graphml,
                                     select_cids, seed=42, max_nodes_to_draw=8000):
    import matplotlib.patches as mpatches

    names = np.array(g.vs['name'])
    cid_by_vid = np.array(membership, dtype=int)
    cid_by_gene = {names[i]: int(cid_by_vid[i]) for i in range(len(names))}

    # Keep only nodes in selected communities
    keep_vids = np.array([i for i,c in enumerate(cid_by_vid) if c in select_cids], dtype=int)
    keep_genes = set(names[keep_vids])

    H = nx.Graph()
    for gene in keep_genes:
        b = beta_map.get(gene, np.nan)
        H.add_node(gene,
                   community=cid_by_gene[gene],
                   is_causal=(gene in causal_genes),
                   beta=b,
                   beta_abs=(abs(b) if not np.isnan(b) else 0.0))
    for u,v in G_nx.edges():
        if u in keep_genes and v in keep_genes:
            H.add_edge(u,v)

    # Downsample for readability if huge
    if H.number_of_nodes() > max_nodes_to_draw:
        keep = []
        for cid in sorted(select_cids):
            members = [n for n,d in H.nodes(data=True) if d.get("community")==cid]
            members = sorted(members, key=lambda x: H.degree(x), reverse=True)
            keep += members[:max(50, int(0.2*len(members)))]
        H = H.subgraph(keep).copy()

    if H.number_of_nodes() == 0:
        print("Nothing to plot (no nodes in selected communities).")
        return

    pos = nx.spring_layout(H, seed=seed)
    cids = sorted({H.nodes[n]['community'] for n in H.nodes()})
    cmap = plt.get_cmap("Set3"); cid2idx = {c:i for i,c in enumerate(cids)}
    def size_for(n):
        babs = H.nodes[n].get('beta_abs', 0.0); return 60 + 240*min(babs,3.0)/3.0

    causal_nodes = [n for n in H.nodes() if H.nodes[n].get('is_causal', False)]
    noncausal_nodes = [n for n in H.nodes() if not H.nodes[n].get('is_causal', False)]

    fig, ax = plt.subplots(figsize=(9,7), dpi=300)
    nx.draw_networkx_edges(H, pos, ax=ax, width=0.3, alpha=0.35, edge_color="#999")
    nx.draw_networkx_nodes(
        H, pos, nodelist=noncausal_nodes,
        node_size=[size_for(n) for n in noncausal_nodes],
        node_color=[cmap(cid2idx[H.nodes[n]['community']]%20) for n in noncausal_nodes],
        linewidths=0.2, edgecolors="none", alpha=0.95, ax=ax
    )
    nx.draw_networkx_nodes(
        H, pos, nodelist=causal_nodes,
        node_size=[size_for(n) for n in causal_nodes],
        node_color=[cmap(cid2idx[H.nodes[n]['community']]%20) for n in causal_nodes],
        linewidths=1.6, edgecolors="black", alpha=0.98, ax=ax
    )
    handles = [mpatches.Patch(color=cmap(cid2idx[c]%20), label=f"Module {c}") for c in cids[:12]]
    if handles: ax.legend(handles=handles, title="Modules (≥5 causal)", loc="upper right", frameon=True)
    ax.set_title("Largest causal subnetwork — modules with ≥5 causal gene")
    ax.set_axis_off(); fig.tight_layout()
    fig.savefig(out_png, dpi=300); plt.close(fig)
    nx.write_graphml(H, out_graphml)
    print(f"Saved: {out_png}\nSaved: {out_graphml}")

# ------------ Main ------------
def optimise_on_largest_causal_subnetwork():
    # (A) Load full, detect union subnetwork ONCE (uses significance *only to find the union*)
    G_full, causal_genes, beta_map = load_ppi_and_mr(PPI_PATH, MR_PATH)

    w_full = build_weights(
        G_full, beta_map, scheme=BEST_SCHEME, alpha=BEST_ALPHA,
        lambda_diff=0.5, beta_clip=3.0, normalize_mean1=True, ensure_baseline=True
    )
    g_full, _ = nx_to_igraph(G_full, weights=w_full)

    # Partition on FULL graph to get the union subnetwork
    base_part = la.find_partition(
        g_full, la.CPMVertexPartition,
        weights="weight", resolution_parameter=BEST_GAMMA, seed=SEED
    )
    base_comm = enrich_partition_for_union(
        g_full, base_part, causal_genes,
        min_size=MIN_COMM_SIZE_FOR_UNION, fdr_alpha=FDR_ALPHA_FOR_UNION
    )

    if base_comm.empty or not base_comm['is_significant'].any():
        raise RuntimeError("No significant modules found on the full graph with the best params.")

    union_vids = union_vids_from_sig(base_comm, base_part)
    names_full = np.array(g_full.vs['name'])
    union_genes = set(names_full[union_vids])

    H_union = G_full.subgraph(union_genes).copy()

    # (B) Choose the LARGEST connected component as the working subnetwork
    components = list(nx.connected_components(H_union))
    largest_comp = max(components, key=len) if components else set()
    H = H_union.subgraph(largest_comp).copy()

    n_nodes, n_edges = H.number_of_nodes(), H.number_of_edges()
    print(f"[Largest causal subnetwork] |V|={n_nodes} |E|={n_edges} |components|=1 (by construction)")

    # (C) Leiden on this component with default Optimiser (no significance tests here)
    w_sub = build_weights(
        H, beta_map, scheme=BEST_SCHEME, alpha=BEST_ALPHA,
        lambda_diff=0.5, beta_clip=3.0, normalize_mean1=True, ensure_baseline=True
    )
    h_ig, h_nodes = nx_to_igraph(H, weights=w_sub)

    # CPM partition + optimiser loop
    # part = la.CPMVertexPartition(h_ig, weights="weight" ) # , resolution_parameter=BEST_GAMMA)
    part = la.RBConfigurationVertexPartition(h_ig, weights="weight" ) # , resolution_parameter=BEST_GAMMA)
    
    opt = la.Optimiser()
    improved = True
    #while improved:
    improved = opt.optimise_partition(part, n_iterations=1000)

    membership = np.array(part.membership, dtype=int)
    genes      = np.array(h_ig.vs['name'])

    # (D) Summaries WITHOUT p/q: just size and causal_in_comm
    df = pd.DataFrame({"gene": genes, "community": membership})
    comm_counts = df.groupby("community")["gene"].count().rename("size").reset_index()
    comm_causal = df.assign(is_causal=df["gene"].isin(causal_genes)) \
                    .groupby("community")["is_causal"].sum().astype(int).rename("causal_in_comm").reset_index()
    comm_df = comm_counts.merge(comm_causal, on="community")
    comm_df = comm_df.sort_values(["causal_in_comm","size"], ascending=[False, False]).reset_index(drop=True)

    # Communities with ≥1 causal gene
    causal_cids = set(comm_df.loc[comm_df["causal_in_comm"] > 0, "community"].tolist())

    # (E) Scores: CPM quality, RBConfiguration quality (stability-like), and modularity
    cpm_quality = float(part.quality())
    try:
        rb_part = la.RBConfigurationVertexPartition(h_ig, resolution_parameter=BEST_GAMMA, initial_membership=membership)
        rb_quality = float(rb_part.quality())
    except Exception as e:
        rb_quality = float("nan")
        warnings.warn(f"RBConfiguration quality failed: {e}")

    try:
        modularity = h_ig.modularity(membership, weights=h_ig.es['weight'])
    except Exception as e:
        modularity = float("nan")
        warnings.warn(f"igraph modularity failed: {e}")

    # (F) Save outputs
    tag = f"largest_causal_subnet_A2_a{int(BEST_ALPHA)}_g{BEST_GAMMA:.9f}"
    subdir = os.path.join(OUTDIR, tag)
    os.makedirs(subdir, exist_ok=True)

    # Connectivity and size
    with open(os.path.join(subdir, "connectivity.json"), "w") as f:
        json.dump({"nodes": n_nodes, "edges": n_edges, "n_components": 1}, f, indent=2)

    # Community stats (no p/q)
    comm_df_path = os.path.join(subdir, "causal_partition_stats.csv")
    comm_df.to_csv(comm_df_path, index=False)

    # Membership tables
    all_mem_path   = os.path.join(subdir, "all_membership.tsv")
    df.to_csv(all_mem_path, sep="\t", index=False)

    causal_mem_path = os.path.join(subdir, "causal_membership.tsv")
    df[df["community"].isin(causal_cids)].to_csv(causal_mem_path, sep="\t", index=False)

    # Scores
    with open(os.path.join(subdir, "partition_scores.json"), "w") as f:
        json.dump({
            "alpha": BEST_ALPHA,
            "gamma": BEST_GAMMA,
            "cpm_quality": cpm_quality,
            "rbconfig_quality": rb_quality,
            "modularity": modularity
        }, f, indent=2)

    # (G) Visualise only modules with ≥1 causal gene
    png_path = os.path.join(subdir, "causal_modules.pdf")
    graphml_path = os.path.join(subdir, "causal_modules.graphml")
    visualize_causal_modules_network(
        H, h_ig, membership, causal_genes, beta_map,
        out_png=png_path, out_graphml=graphml_path,
        select_cids=causal_cids, seed=SEED
    )

    print(f"[OK] Community stats  : {comm_df_path}")
    print(f"[OK] All membership   : {all_mem_path}")
    print(f"[OK] Causal membership: {causal_mem_path}")
    print(f"[OK] Scores           : {os.path.join(subdir, 'partition_scores.json')}")
    print(f"[OK] Viz (PNG/GraphML): {png_path} / {graphml_path}")

# ---- Run ----
optimise_on_largest_causal_subnetwork()


[Largest causal subnetwork] |V|=2470 |E|=39835 |components|=1 (by construction)
Saved: /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.pdf
Saved: /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.graphml
[OK] Community stats  : /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_partition_stats.csv
[OK] All membership   : /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/all_membership.tsv
[OK] Causal membership: /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_membership.tsv
[OK] Scores           : /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/partition_scores.json
[OK] Viz (PNG/GraphML): /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.pdf / /mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.graphml
