# 62 Drug prioritization GNN

**Origin:** `6_2_drug_prioritization_GNN_.ipynb`  
**Annotated on:** 2025-10-13 06:45

**High-level objective:**  
- Train GNNs (GCN/GAT) over module graphs to impute/propagate MR betas; export predictions for DRS/TRS.

**Notes:**  
- These comments are language-agnostic and focus on intent, inputs, and outputs.  
- Adjust hard-coded paths if needed; prefer `/results_*` for derived artifacts.

---


**Step 1:** Community detection / resolution sweep on the PPI.

In [10]:
"""
Unit tests for gnn_from_graphml.py
Run with: pytest -q
"""

import os
import tempfile
import numpy as np
import networkx as nx
import torch

from gnn_from_graphml import (
    build_data_from_graphml,
    build_model,
    Trainer,
    set_seed,
)

def _make_tiny_graphml(tmpdir: str, fname: str = "toy.graphml") -> str:
    """
    Create a tiny undirected graph with:
      nodes: A(community=0, beta=0.4), B(0, beta=-0.2), C(1, beta=NaN), D(1, beta=NaN)
      edges: A-B (w=1.0), B-C (w=0.5), C-D (w=1.0)
    """
    G = nx.Graph()
    G.add_node("A", community=0, beta=0.4)
    G.add_node("B", community=0, beta=-0.2)
    G.add_node("C", community=1, beta=float("nan"))
    G.add_node("D", community=1, beta=float("nan"))

    G.add_edge("A", "B", weight=1.0)
    G.add_edge("B", "C", weight=0.5)
    G.add_edge("C", "D", weight=1.0)

    path = os.path.join(tmpdir, fname)
    nx.write_graphml(G, path)
    return path


def test_build_data_from_graphml_total():
    with tempfile.TemporaryDirectory() as td:
        gml = _make_tiny_graphml(td)
        data, idx2gene, gene2idx = build_data_from_graphml(
            gml,
            submodule=None,
            add_degree_features=True,
            use_label_as_feat=True,
            add_community_feature=True,
        )
        # nodes
        assert data.x.shape[0] == 4
        # features: degree(1) + prior(1) + is_label(1) + community one-hot (max id=1 => 2 cols) => 5 total
        assert data.x.shape[1] == 5
        # labels present on A,B only
        assert int(data.has_label.sum().item()) == 2
        # edge_index should exist and be even (undirected + self-loops)
        assert data.edge_index.shape[0] == 2
        assert data.edge_attr.shape[0] == data.edge_index.shape[1]


def test_build_data_from_graphml_submodule_filter():
    with tempfile.TemporaryDirectory() as td:
        gml = _make_tiny_graphml(td)
        # take submodule 1 (C,D)
        data, idx2gene, gene2idx = build_data_from_graphml(
            gml,
            submodule=1,
            add_degree_features=True,
            use_label_as_feat=True,
            add_community_feature=False,
        )
        assert data.x.shape[0] == 2  # only C and D
        # no labels in this submodule (C,D have NaN beta)
        assert int(data.has_label.sum().item()) == 0  # will raise if used in training without labels


def test_trainer_split_and_train_minimal_gcn():
    with tempfile.TemporaryDirectory() as td:
        gml = _make_tiny_graphml(td)
        data, idx2gene, gene2idx = build_data_from_graphml(
            gml,
            submodule=None,
            add_degree_features=True,
            use_label_as_feat=True,
            add_community_feature=False,
        )
        # Ensure at least 2 labeled nodes for split (A,B)
        assert int(data.has_label.sum().item()) >= 2

        # Build minimal GCN
        in_dim = data.x.shape[1]
        model = build_model("GCN", in_dim, {"hidden": 16, "layers": 2, "dropout": 0.1})
        trainer = Trainer(data, model, lr=5e-3, weight_decay=0.0, patience=20)

        set_seed(123)
        train_mask, val_mask = trainer.split_masks(val_size=0.5, seed=123)

        # Masks should be subset of has_label
        assert bool((train_mask & (~data.has_label)).any().item()) is False
        assert bool((val_mask   & (~data.has_label)).any().item()) is False

        # Train briefly; just ensure it runs and produces predictions
        trainer.train(epochs=50, verbose=False)
        pred = trainer.predict_all()
        assert isinstance(pred, torch.Tensor)
        assert pred.numel() == data.x.shape[0]


def test_trainer_with_gat():
    with tempfile.TemporaryDirectory() as td:
        gml = _make_tiny_graphml(td)
        data, idx2gene, gene2idx = build_data_from_graphml(
            gml,
            submodule=None,
            add_degree_features=True,
            use_label_as_feat=True,
            add_community_feature=False,
        )
        in_dim = data.x.shape[1]
        model = build_model("GAT", in_dim, {"hidden": 16, "layers": 2, "heads": 2, "dropout": 0.1, "attn_dropout": 0.0})
        trainer = Trainer(data, model, lr=1e-3, weight_decay=0.0, patience=10)
        trainer.split_masks(val_size=0.5, seed=7)
        trainer.train(epochs=30, verbose=False)
        df = trainer.export_predictions(os.path.join(td, "pred.tsv"))
        assert df.shape[0] == data.x.shape[0]
        assert {"gene", "pred_beta", "is_labeled", "true_beta"}.issubset(df.columns)


**Step 2:** Graph construction or GNN modeling (GCN/GAT).

In [11]:


from gnn_from_graphml import train_from_graphml

trainer, preds = train_from_graphml(
    graphml_path="/mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.graphml",
    outdir="/mnt/f/10_osteo_MR/gnn_from_graphml_runs",
    submodule=None,        # or an int 0..12
    model_type="GAT",
    hidden=128,
    layers=3,
    heads=4,
    dropout=0.3,
    attn_dropout=0.1,
    lr=1e-3,
    weight_decay=5e-4,
    epochs=600,
    patience=80,
    val_size=0.2,
    add_degree_features=True,
    use_label_as_feat=True,
    add_community_feature=False,
)


Epoch    1 | train 0.309268 | val 0.061370
Epoch   25 | train 0.141663 | val 0.052051
Epoch   50 | train 0.086361 | val 0.033164
Epoch   75 | train 0.052538 | val 0.026549
Epoch  100 | train 0.051623 | val 0.024798
Epoch  125 | train 0.038426 | val 0.025476
Epoch  150 | train 0.051707 | val 0.024737
Epoch  175 | train 0.078164 | val 0.024373
Epoch  200 | train 0.043249 | val 0.024210
Epoch  225 | train 0.035632 | val 0.024167
Epoch  250 | train 0.035349 | val 0.024500
Epoch  275 | train 0.033880 | val 0.023448
Epoch  300 | train 0.039681 | val 0.023087
Epoch  325 | train 0.028943 | val 0.022000
Epoch  350 | train 0.023006 | val 0.021761
Epoch  375 | train 0.025155 | val 0.027420
Epoch  400 | train 0.032799 | val 0.027870
Epoch  425 | train 0.024576 | val 0.025586
Epoch  450 | train 0.020037 | val 0.022637
Epoch  475 | train 0.019316 | val 0.024492
Epoch  500 | train 0.025133 | val 0.022586
Early stopping at epoch 520 (best val=0.020520)


**Step 3:** Load network or tabular inputs (PPI/GraphML/TSV).

In [12]:


"""
Run per-submodule training over communities 0..12.
- Trains total-model once (for fallback + comparison).
- Trains each submodule; if too few labels, falls back to total predictions.
- Writes one TSV per submodule and a summary CSV.

Usage (example):
  python run_submodules.py
"""

import os
import json
import numpy as np
import pandas as pd
import networkx as nx
from typing import Dict, Tuple, Optional, List

from gnn_from_graphml import (
    train_from_graphml,
    build_data_from_graphml,
    build_model,
    Trainer,
    set_seed,
)

GRAPHML = "/mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.graphml"
OUTDIR  = "/mnt/f/10_osteo_MR/gnn_from_graphml_runs_submodules"

# ---- Default model/training params (edit if you like) ----
MODEL_TYPE     = "GAT"     # "GAT" or "GCN"
HIDDEN         = 128
LAYERS         = 3
HEADS          = 4
DROPOUT        = 0.3
ATTN_DROPOUT   = 0.1       # GAT only
LR             = 1e-3
WEIGHT_DECAY   = 5e-4
EPOCHS         = 600
PATIENCE       = 80
VAL_SIZE       = 0.2
SEED           = 42

ADD_DEGREE_FEATURES   = True
USE_LABEL_AS_FEAT     = True
ADD_COMMUNITY_FEATURE = False   # per submodule this is not needed

def list_submodules(graphml_path: str, community_attr: str = "community") -> List[int]:
    G = nx.read_graphml(graphml_path)
    c = sorted({int(d.get(community_attr, -1)) for _, d in G.nodes(data=True)})
    # Filter out negatives if present
    return [x for x in c if x >= 0]

def _compute_val_rmse(trainer: Trainer) -> Optional[float]:
    """Compute RMSE on the trainer's val mask (None if not available)."""
    if trainer.val_mask is None:
        return None
    mask = trainer.val_mask.detach().cpu().numpy().astype(bool)
    if mask.sum() == 0:
        return None
    with torch.no_grad():
        pred = trainer.model(trainer.x, trainer.edge_index,
                             edge_weight=trainer.edge_weight,
                             edge_attr=trainer.edge_attr).squeeze(1).detach().cpu().numpy()
    y_true = trainer.y.squeeze(1).detach().cpu().numpy()[mask]
    y_pred = pred[mask]
    return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))

def _load_total_predictions(preds_path: str) -> Dict[str, float]:
    df = pd.read_csv(preds_path, sep="\t")
    return dict(zip(df["gene"].astype(str), df["pred_beta"].astype(float)))

def _write_submodule_from_total(
    graphml_path: str,
    submodule: int,
    total_pred_map: Dict[str, float],
    outdir: str,
    filename_tag: str
) -> pd.DataFrame:
    """Fallback writer: slice the submodule and fill pred_beta from total preds."""
    data, idx2gene, gene2idx = build_data_from_graphml(
        graphml_path,
        submodule=submodule,
        add_degree_features=ADD_DEGREE_FEATURES,
        use_label_as_feat=USE_LABEL_AS_FEAT,
        add_community_feature=False,
    )
    genes = [idx2gene[i] for i in range(len(idx2gene))]
    pred_beta = [total_pred_map.get(g, np.nan) for g in genes]
    true_map = {}
    y_np = data.y.squeeze(1).cpu().numpy()
    has = data.has_label.cpu().numpy()
    for i, ok in enumerate(has):
        if ok:
            true_map[idx2gene[i]] = float(y_np[i])
    df = pd.DataFrame({
        "gene": genes,
        "pred_beta": pred_beta,
        "is_labeled": has.astype(bool),
        "true_beta": [true_map.get(g, np.nan) for g in genes],
    })
    out_path = os.path.join(outdir, f"{filename_tag}_sub{submodule}_beta_predictions.tsv")
    df.to_csv(out_path, sep="\t", index=False)
    return df

def train_all_submodules(
    graphml_path: str,
    outdir: str,
    submodules: Optional[List[int]] = None,
    model_type: str = MODEL_TYPE,
):
    import torch  # local import to avoid hard dep in text context
    os.makedirs(outdir, exist_ok=True)
    set_seed(SEED)

    # 1) Train total model once (for fallback use)
    print("== Training TOTAL causal module ==")
    total_trainer, total_df = train_from_graphml(
        graphml_path=graphml_path,
        outdir=outdir,
        submodule=None,
        model_type=model_type,
        seed=SEED,
        hidden=HIDDEN,
        layers=LAYERS,
        heads=HEADS,
        dropout=DROPOUT,
        attn_dropout=ATTN_DROPOUT,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        epochs=EPOCHS,
        patience=PATIENCE,
        val_size=VAL_SIZE,
        add_degree_features=ADD_DEGREE_FEATURES,
        use_label_as_feat=USE_LABEL_AS_FEAT,
        add_community_feature=False,
    )
    total_pred_path = os.path.join(outdir, f"{model_type}_suball_beta_predictions.tsv")
    total_df.to_csv(total_pred_path, sep="\t", index=False)
    total_pred_map = _load_total_predictions(total_pred_path)

    # 2) Determine submodule ids
    if submodules is None:
        submodules = list_submodules(graphml_path)
    print(f"Submodules to run: {submodules}")

    summary_rows = []
    filename_tag = f"{model_type}"

    # 3) Train each submodule (or fallback)
    for m in submodules:
        print(f"\n== Submodule {m} ==")
        try:
            # Try to build data and ensure at least 2 labeled nodes to split
            data, idx2gene, gene2idx = build_data_from_graphml(
                graphml_path,
                submodule=m,
                add_degree_features=ADD_DEGREE_FEATURES,
                use_label_as_feat=USE_LABEL_AS_FEAT,
                add_community_feature=False,
            )
            n_nodes = data.x.shape[0]
            n_edges = int(data.edge_index.shape[1] // 2)  # approx undirected no self-loops
            n_labeled = int(data.has_label.sum().item())
            print(f"Nodes={n_nodes}  Labeled={n_labeled}")

            if n_labeled < 2:
                print("  Not enough labeled nodes; using TOTAL-model fallback.")
                df = _write_submodule_from_total(
                    graphml_path, m, total_pred_map, outdir, filename_tag
                )
                trained = False
                val_rmse = np.nan
            else:
                in_dim = data.x.shape[1]
                hps = dict(hidden=HIDDEN, layers=LAYERS, dropout=DROPOUT, heads=HEADS, attn_dropout=ATTN_DROPOUT)
                model = build_model(model_type, in_dim, hps)
                trainer = Trainer(data, model, lr=LR, weight_decay=WEIGHT_DECAY, patience=PATIENCE)
                trainer.split_masks(val_size=VAL_SIZE, seed=SEED)
                trainer.train(epochs=EPOCHS, verbose=True)
                val_rmse = _compute_val_rmse(trainer)

                out_path = os.path.join(outdir, f"{filename_tag}_sub{m}_beta_predictions.tsv")
                df = trainer.export_predictions(out_path)
                trained = True

            summary_rows.append({
                "submodule": m,
                "n_nodes": n_nodes,
                "n_labeled": n_labeled,
                "trained": trained,
                "val_rmse": val_rmse,
                "out_path": os.path.join(outdir, f"{filename_tag}_sub{m}_beta_predictions.tsv"),
            })

        except Exception as e:
            print(f"  ERROR on submodule {m}: {e}")
            # As last resort, try to still write fallback
            try:
                df = _write_submodule_from_total(
                    graphml_path, m, total_pred_map, outdir, filename_tag
                )
                summary_rows.append({
                    "submodule": m,
                    "n_nodes": len(df),
                    "n_labeled": int(df["is_labeled"].sum()),
                    "trained": False,
                    "val_rmse": np.nan,
                    "out_path": os.path.join(outdir, f"{filename_tag}_sub{m}_beta_predictions.tsv"),
                })
            except Exception as e2:
                print(f"  Fallback failed for submodule {m}: {e2}")

    # 4) Save summary
    summary_df = pd.DataFrame(summary_rows).sort_values("submodule")
    summary_path = os.path.join(outdir, f"{MODEL_TYPE}_submodules_summary.csv")
    summary_df.to_csv(summary_path, index=False)
    print("\nWrote summary to:", summary_path)
    return summary_df

if __name__ == "__main__":
    summary = train_all_submodules(GRAPHML, OUTDIR, submodules=list(range(0, 13)), model_type=MODEL_TYPE)
    print(summary)


== Training TOTAL causal module ==
Epoch    1 | train 0.309268 | val 0.061352
Epoch   25 | train 0.132905 | val 0.050230
Epoch   50 | train 0.071826 | val 0.036703
Epoch   75 | train 0.051588 | val 0.028961
Epoch  100 | train 0.049184 | val 0.024228
Epoch  125 | train 0.036236 | val 0.025325
Epoch  150 | train 0.049714 | val 0.023745
Epoch  175 | train 0.084519 | val 0.023178
Epoch  200 | train 0.047027 | val 0.023093
Epoch  225 | train 0.037955 | val 0.023044
Epoch  250 | train 0.036443 | val 0.023222
Epoch  275 | train 0.034772 | val 0.022156
Epoch  300 | train 0.035943 | val 0.022067
Epoch  325 | train 0.027847 | val 0.021777
Epoch  350 | train 0.022879 | val 0.021960
Epoch  375 | train 0.024185 | val 0.026455
Early stopping at epoch 399 (best val=0.020910)
Submodules to run: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

== Submodule 0 ==
Nodes=450  Labeled=41
Epoch    1 | train 0.461596 | val 0.103306
Epoch   25 | train 0.174893 | val 0.054353
Epoch   50 | train 0.146266 | val 0.0740

**Step 4:** Community detection / resolution sweep on the PPI.

In [13]:


# === Load expression matrix ===
print("Loading expression matrix...")

with open("/mnt/f/10_osteo_MR/datasets/gse123568/gene_expr_and_labels.pkl", "rb") as f:
    data = pickle.load(f)

gene_expr_norm = data["gene_expr_norm"]
group_label = data["group_labels"]




Loading expression matrix...


**Step 6:** Load network or tabular inputs (PPI/GraphML/TSV).

In [34]:

import os
import pandas as pd
import numpy as np
import networkx as nx
from collections import deque, defaultdict

# -----------------------
# USER CONFIG
# -----------------------
GRAPHML = "/mnt/f/10_osteo_MR/results_network/largest_causal_subnet_A2_a6_g0.001832981/causal_modules.graphml"
SUB_BETA_DIR = "/mnt/f/10_osteo_MR/gnn_from_graphml_runs_submodules"  # where GAT_sub<m>_beta_predictions.tsv live
MODEL_TAG = "GAT"                         # change if you trained with another tag
DRUGMAP = "/mnt/f/10_osteo_MR/result_drug_target/drugmap_drug_gene_by_moa_status.csv"
OUT_BASE = "/mnt/f/10_osteo_MR/result_drug_target/module_sig_drug/ci_ptrs_partial_top10"
os.makedirs(OUT_BASE, exist_ok=True)

# Your top-10 drug IDs and their submodules (order per table provided)
TOP10_IDS = ['DMDYC4J','DM21WBH','DMCVJK9','DMF3DZX','DMN9YOB','DM38N2K','DM42PFT','DMBZMYT','DMBPNKT','DMH5RFU']

TOP10_SUBMODULES = [0, 0, 0, 2, 2, 5, 0, 0, 5, 0]  # from your table rows (in the same order)

TOP10_IDS = ['DMDYC4J',
'DMCVJK9',
'DM21WBH',
'DM42PFT',
'DMF3DZX',
'DMSFWT7',
'DMH5RFU',
'DMN9YOB',
'DMBZMYT',
'DM5DMCH']

TOP10_SUBMODULES = [6,
6,
6,
6,
0,
6,
6,
0,
6,
6]



DECAY_PER_HOP = 0.5       # weight = (DECAY_PER_HOP ** distance), distance in {0,1,2}
MAX_HOPS = 1          # include 0,1,2-hop

# IMPORTANT:
# Expect gene_expr_norm (DataFrame: genes x samples) and group_label (list[str]) to be already in memory.
# If you want to load from file, replace these with your loader.
# Example placeholders below (comment out if already in memory):
# gene_expr_norm = pd.read_parquet("/path/to/gene_expr_norm.parquet")  # genes x samples
# group_label = pd.read_csv("/path/to/group_labels.csv")["group"].tolist()

# -----------------------
# HELPERS
# -----------------------
def load_submodule_graph(graphml_path: str, submodule: int, community_attr: str = "community") -> nx.Graph:
    G = nx.read_graphml(graphml_path)
    keep = [n for n, d in G.nodes(data=True) if int(d.get(community_attr, -1)) == int(submodule)]
    return G.subgraph(keep).copy()

def load_submodule_betas(sub_beta_dir: str, model_tag: str, submodule: int) -> pd.Series:
    """
    Returns pd.Series beta_hat indexed by gene (str).
    Reads: {dir}/{MODEL_TAG}_sub{submodule}_beta_predictions.tsv
    """
    p = os.path.join(sub_beta_dir, f"{model_tag}_sub{submodule}_beta_predictions.tsv")
    if not os.path.exists(p):
        raise FileNotFoundError(f"Missing beta predictions for submodule {submodule}: {p}")
    df = pd.read_csv(p, sep="\t")
    if not {"gene","pred_beta"}.issubset(df.columns):
        raise ValueError(f"Missing columns in {p}")
    return pd.Series(df["pred_beta"].values, index=df["gene"].astype(str))

def load_drug_targets(drugmap_csv: str, drug_ids: list) -> dict:
    """
    Returns {DrugID: set(genes)} for requested IDs, filtered to Gene_clean not null.
    """
    dm = pd.read_csv(drugmap_csv)
    dm = dm.dropna(subset=["DrugID","Gene_clean"])
    dm["DrugID"] = dm["DrugID"].astype(str)
    dm["Gene_clean"] = dm["Gene_clean"].astype(str)
    targets = defaultdict(set)
    for did, g in zip(dm["DrugID"], dm["Gene_clean"]):
        if did in drug_ids:
            targets[did].add(g)
    return dict(targets)

def k_hop_weights(Gm: nx.Graph, seeds: set, max_hops: int = 2, decay: float = 0.5) -> dict:
    """
    Compute per-node distance from seed set within Gm (unweighted graph distance),
    then weight = decay ** dist for dist in [0..max_hops]. For nodes farther than max_hops, weight=0.
    If multiple seeds reach a node, we take the MAX weight (i.e., min distance).
    """
    if not seeds:
        return {}
    # Multi-source BFS
    dist = {n: np.inf for n in Gm.nodes()}
    q = deque()
    for s in seeds:
        if s in Gm:
            dist[s] = 0
            q.append(s)
    while q:
        u = q.popleft()
        if dist[u] >= max_hops:  # we don't need beyond max_hops
            continue
        for v in Gm.neighbors(u):
            if dist[v] > dist[u] + 1:
                dist[v] = dist[u] + 1
                if dist[v] <= max_hops:
                    q.append(v)
    weights = {}
    for n, d in dist.items():
        if d == np.inf or d > max_hops:
            continue
        weights[n] = (decay ** d)
    return weights

def compute_partial_ci_ptrs_for_drug(
    drug_id: str,
    submodule: int,
    G_full_path: str,
    sub_beta_dir: str,
    model_tag: str,
    drug_targets_map: dict,
    gene_expr_norm: pd.DataFrame,
    decay: float = 0.5,
    max_hops: int = 2,
) -> pd.Series:
    """
    Returns per-sample ci-PTRS (Series indexed by sample) for this drug/submodule.
    """
    # Subgraph and betas
    Gm = load_submodule_graph(G_full_path, submodule)
    beta_hat = load_submodule_betas(sub_beta_dir, model_tag, submodule)

    # Targets inside this submodule
    seeds = set(t for t in drug_targets_map.get(drug_id, set()) if t in Gm)
    if not seeds:
        # No nodes -> zero score for all samples (but keep output)
        return pd.Series(0.0, index=gene_expr_norm.columns, name=drug_id)

    # Neighborhood weights
    w = k_hop_weights(Gm, seeds, max_hops=max_hops, decay=decay)
    if not w:
        # Only isolated seeds (0-hop only if in Gm)
        w = {s: 1.0 for s in seeds if s in Gm}

    # Build aligned vectors over contributing genes
    genes = [g for g in w.keys() if g in beta_hat.index and g in gene_expr_norm.index]
    if not genes:
        return pd.Series(0.0, index=gene_expr_norm.columns, name=drug_id)

    b = beta_hat.loc[genes].values            # (n_genes,)
    a = np.array([w[g] for g in genes])       # (n_genes,)
    # Expression matrix slice: genes x samples
    X = gene_expr_norm.loc[genes].values      # (n_genes, n_samples)

    # ci-PTRS per sample = sum_i (b_i * a_i * x_{i,s})
    scores = (b[:, None] * a[:, None] * X).sum(axis=0)
    return pd.Series(scores, index=gene_expr_norm.columns, name=drug_id)

# -----------------------
# MAIN
# -----------------------
def main():
    # 1) Map drug -> submodule
    drug_to_m = {did: TOP10_SUBMODULES[i] for i, did in enumerate(TOP10_IDS)}

    # 2) Load DrugMap targets for these drugs
    drug_targets = load_drug_targets(DRUGMAP, TOP10_IDS)

    # 3) Compute per-drug per-sample ci-PTRS
    all_scores = []
    for did in TOP10_IDS:
        m = drug_to_m[did]
        s = compute_partial_ci_ptrs_for_drug(
            drug_id=did,
            submodule=m,
            G_full_path=GRAPHML,
            sub_beta_dir=SUB_BETA_DIR,
            model_tag=MODEL_TAG,
            drug_targets_map=drug_targets,
            gene_expr_norm=gene_expr_norm,    # must be defined in your session
            decay=DECAY_PER_HOP,
            max_hops=MAX_HOPS,
        )
        all_scores.append(s)

    scores_df = pd.concat(all_scores, axis=1)   # samples x drugs (columns=DrugID)
    scores_out = os.path.join(OUT_BASE, f"partial_ci_ptrs_top10_{MODEL_TAG}.tsv")
    scores_df.to_csv(scores_out, sep="\t")
    print("Wrote per-sample partial ci-PTRS:", scores_out)

    # 4) Add simple group summary if group_label is available
    try:
        groups = pd.Series(group_label, index=scores_df.index, name="group")
        summary = (scores_df.assign(group=groups)
                   .groupby("group").agg(["mean","std","count"]))
        summ_out = os.path.join(OUT_BASE, f"partial_ci_ptrs_top10_{MODEL_TAG}_by_group.tsv")
        summary.to_csv(summ_out, sep="\t")
        print("Wrote group summary:", summ_out)
    except Exception as e:
        print("Group summary skipped (need group_label aligned to columns):", e)

if __name__ == "__main__":
    main()




Wrote per-sample partial ci-PTRS: /mnt/f/10_osteo_MR/result_drug_target/module_sig_drug/ci_ptrs_partial_top10/partial_ci_ptrs_top10_GAT.tsv
Wrote group summary: /mnt/f/10_osteo_MR/result_drug_target/module_sig_drug/ci_ptrs_partial_top10/partial_ci_ptrs_top10_GAT_by_group.tsv


**Step 7:** Load network or tabular inputs (PPI/GraphML/TSV).

In [36]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_heatmaps(scores_df, outdir, tag="partial_ci_ptrs_top10"):

    print( scores_df.shape )
    vals = scores_df.values.flatten()
    vmin, vmax = np.percentile(vals, [5, 95])
    # sns.heatmap(scores_df.T, cmap="RdBu_r", center=0, vmin=vmin, vmax=vmax)
    vmin, vmx = -1, 1

    plt.figure(figsize=(10, 6))
    sns.heatmap(scores_df.T, cmap="RdBu_r", center=0, cbar_kws={'label': 'PTRS'}, 
                linewidth=1, vmin=vmin, vmax=vmax)
    plt.title("Partial ci-PTRS Heatmap (no clustering)")
    plt.ylabel("Drug ID")
    plt.xlabel("Sample")
    plt.tight_layout()
    path1 = os.path.join(outdir, f"{tag}_heatmap_no_clust.pdf")
    plt.savefig(path1, dpi=200)
    plt.close()
    print("Saved:", path1)

   
scores_out = os.path.join(OUT_BASE, f"partial_ci_ptrs_top10_{MODEL_TAG}.tsv")
scores_df = pd.read_csv(scores_out, sep="\t" , index_col = 0 )

plot_heatmaps( scores_df, OUT_BASE )



(40, 10)
Saved: /mnt/f/10_osteo_MR/result_drug_target/module_sig_drug/ci_ptrs_partial_top10/partial_ci_ptrs_top10_heatmap_no_clust.pdf
