# Load data

In [3]:
import anndata as ad
import pickle
from pathlib import Path


# --- paths ---
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
DATA_PROCESSED_DIR = BASE / "data" / "processed_data"
DATA_PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

H5AD_OUT   = DATA_PROCESSED_DIR / "host_bulk_fractional_counts_ensembl.h5ad"
DICT_OUT   = DATA_PROCESSED_DIR / "ensembl_mappings.pkl"


# load AnnData
adata_ens = ad.read_h5ad(H5AD_OUT)

# load dictionaries
with open(DICT_OUT, "rb") as f:
    maps = pickle.load(f)

ens2entrez = maps["ens2entrez"]
ens2uniprot = maps["ens2uniprot"]

print("TP53 Entrez ID:", ens2entrez.get("ENSG00000141510"))
print("TP53 UniProt:", ens2uniprot.get("ENSG00000141510"))

TP53 Entrez ID: None
TP53 UniProt: None


In [6]:
# filter out None/NaN/empty
ens2entrez  = {k: v for k, v in ens2entrez.items() if v is not None and str(v) != "nan"}
ens2uniprot = {k: v for k, v in ens2uniprot.items() if v not in [None, [], "nan"]}

# print first 5 entries
print(dict(list(ens2entrez.items())[:10]))
print(dict(list(ens2uniprot.items())[:10]))

{'ENSG00000278267.1': '102466751', 'ENSG00000284332.1': '100302278', 'ENSG00000237613.2': '645520', 'ENSG00000268020.3': '79504', 'ENSG00000240361.2': '403263', 'ENSG00000186092.6': '79501', 'ENSG00000233750.3': '100420257', 'ENSG00000222623.1': '124906683', 'ENSG00000273874.1': '102465909', 'ENSG00000228463.10': '728481'}
{'ENSG00000186092.6': 'Q8NH21', 'ENSG00000284733.1': 'Q6IEY1', 'ENSG00000284662.1': 'Q6IEY1', 'ENSG00000187634.12': 'Q96NU1', 'ENSG00000188976.11': 'Q9Y3T9', 'ENSG00000187961.14': 'Q6TDP4', 'ENSG00000187583.11': 'Q494U1', 'ENSG00000187642.9': 'Q5SV97', 'ENSG00000188290.11': 'Q9HCC6', 'ENSG00000187608.10': 'P05161'}


In [4]:
adata_ens

AnnData object with n_obs × n_vars = 12 × 67016
    obs: 'condition', 'tissue', 'replicate'
    var: 'ensembl_core', 'gene_symbol'
    layers: 'counts'

In [5]:
adata_ens.var_names

Index(['ENSG00000223972.5', 'ENSG00000227232.5', 'ENSG00000278267.1',
       'ENSG00000243485.5', 'ENSG00000284332.1', 'ENSG00000237613.2',
       'ENSG00000268020.3', 'ENSG00000240361.2', 'ENSG00000186092.6',
       'ENSG00000238009.6',
       ...
       'ENSG00000273739.1', 'ENSG00000276700.1', 'ENSG00000276312.1',
       'ENSG00000275757.1', 'ENSG00000278573.1', 'ENSG00000276017.1',
       'ENSG00000278817.1', 'ENSG00000277196.4', 'ENSG00000278625.1',
       'ENSG00000277374.1'],
      dtype='object', length=67016)

# Perform GSEA

## Reactome

In [8]:
# --- Reactome GSEA using gseapy (symbols) + NES heatmap ---
from pathlib import Path
import re, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from scipy.spatial.distance import squareform

# Paths
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
DE_DIR   = ANALYSIS_DIR / "DE"
GSEA_DIR = ANALYSIS_DIR / "GSEA" / "reactome"
FIG_DIR  = ANALYSIS_DIR / "figures"
GSEA_DIR.mkdir(parents=True, exist_ok=True)

# Contrasts present from your Python-only DE
contrasts = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]

def strip_ver(x: str) -> str:
    return re.sub(r"\.\d+$", "", x)

# ---------- Build Ensembl(core) -> SYMBOL map from adata_ens.var ----------
# prefer 'gene_symbol', else 'gene_symbols'
sym_col = "gene_symbol" if "gene_symbol" in adata_ens.var.columns else (
          "gene_symbols" if "gene_symbols" in adata_ens.var.columns else None)
if sym_col is None:
    raise ValueError("No 'gene_symbol' or 'gene_symbols' in adata_ens.var")

if "ensembl_core" in adata_ens.var.columns:
    core_ids = adata_ens.var["ensembl_core"].astype(str).values
else:
    core_ids = pd.Index(adata_ens.var_names.astype(str)).map(strip_ver).values

ens2sym = pd.Series(adata_ens.var[sym_col].astype(str).values, index=core_ids)

# ---------- Build preranked lists (t-stat) per contrast ----------
rank_files = {}
for c in contrasts:
    f = DE_DIR / f"python_voomlite_{c}.csv"
    if not f.exists():
        print(f"[WARN] missing DE file: {f} — skipping")
        continue
    df = pd.read_csv(f)
    df["gene"] = df["gene"].astype(str)
    df["gene_core"] = df["gene"].map(strip_ver)
    # map to symbol
    df["symbol"] = df["gene_core"].map(ens2sym)

    # ranking metric: t-stat preferred; fallback to signed -log10(p)
    if "t" in df.columns and np.isfinite(df["t"]).any():
        df["score"] = df["t"].astype(float)
    else:
        df["score"] = np.sign(df["logFC"].astype(float)) * (-np.log10(np.clip(df["pval"].astype(float), 1e-300, 1.0)))

    rnk = df.loc[df["symbol"].notna() & np.isfinite(df["score"]), ["symbol","score"]].copy()
    rnk = rnk.sort_values("score", key=lambda x: x.abs(), ascending=False).drop_duplicates("symbol")
    rnk = rnk.sort_values("score", ascending=False)

    out_rnk = GSEA_DIR / f"{c}.rnk"
    rnk.to_csv(out_rnk, sep="\t", header=False, index=False)
    rank_files[c] = out_rnk
    print(f"[OK] {c}: {len(rnk):,} ranked symbols -> {out_rnk}")

# ---------- Run gseapy prerank on Reactome ----------
import gseapy as gp

all_res = {}
for c, rnk_path in rank_files.items():
    outdir = GSEA_DIR / c
    outdir.mkdir(parents=True, exist_ok=True)
    prer = gp.prerank(
        rnk=str(rnk_path),
        gene_sets="Reactome_2022",   # Enrichr Reactome (symbols)
        processes=4,
        permutation_num=1000,
        min_size=10, max_size=1000,
        seed=42,
        outdir=str(outdir),
        format="png"                 # enrichment plots saved per pathway
    )
    res2d = prer.res2d.copy()
    res2d.rename(columns={"fdr":"FDR_q"}, inplace=True)
    res2d.to_csv(outdir / "gsea_reactome_results.csv", index=False)
    all_res[c] = res2d
    print(f"[OK] Reactome GSEA saved -> {outdir}")

# ---------- Build NES matrix & plot clustered heatmap (pathways x contrasts) ----------
# keep pathways significant in ≥1 contrast (FDR_q < 0.05)
keep = set()
for c, df in all_res.items():
    if not df.empty:
        keep.update(df.loc[df["FDR_q"] < 0.05, "Term"].tolist())

if not keep:
    print("[WARN] No Reactome pathways at FDR<0.05. Lower threshold or check inputs.")
else:
    pathways = sorted(keep)
    NES = pd.DataFrame(index=pathways, columns=contrasts, dtype=float)
    FDR = pd.DataFrame(index=pathways, columns=contrasts, dtype=float)
    for c in contrasts:
        if c in all_res and not all_res[c].empty:
            df = all_res[c].set_index("Term")
            NES.loc[pathways, c] = df.reindex(pathways)["NES"].values
            FDR.loc[pathways, c] = df.reindex(pathways)["FDR_q"].values

    # replace inf/nan
    NES = NES.astype(float).fillna(0.0).clip(-3.5, 3.5)  # clamp for color scale stability

    # cluster rows (pathways) and columns (contrasts) by correlation distance
    def corr_linkage(mat, axis=0, method="average"):
        # axis=0 -> rows; axis=1 -> cols
        X = mat if axis==0 else mat.T
        C = np.corrcoef(X)
        D = np.clip(1 - C, 0, 2)
        return linkage(squareform(D, checks=False), method=method)

    row_link = corr_linkage(NES.values, axis=0, method="average")
    col_link = corr_linkage(NES.values, axis=1, method="average")
    row_ord  = leaves_list(row_link)
    col_ord  = leaves_list(col_link)

    NES_ord = NES.values[row_ord][:, col_ord]
    row_lbl = [NES.index[i] for i in row_ord]
    col_lbl = [NES.columns[i] for i in col_ord]

    # plot
    plt.figure(figsize=(max(8, 0.45*len(col_lbl)+3), max(8, 0.22*len(row_lbl)+3)))
    # top dendrogram
    ax_top = plt.axes([0.28, 0.90, 0.62, 0.08])
    dendrogram(col_link, ax=ax_top, no_labels=True); ax_top.axis("off")
    # left dendrogram
    ax_left = plt.axes([0.07, 0.20, 0.20, 0.70])
    dendrogram(row_link, ax=ax_left, orientation="right", no_labels=True); ax_left.axis("off")
    # heatmap
    ax = plt.axes([0.28, 0.20, 0.62, 0.70])
    im = ax.imshow(NES_ord, aspect="auto", interpolation="nearest")
    ax.set_xticks(range(len(col_lbl))); ax.set_xticklabels(col_lbl, rotation=45, ha="right", fontsize=9)
    ax.set_yticks(range(len(row_lbl))); ax.set_yticklabels(row_lbl, fontsize=8)
    ax.set_title("Reactome GSEA (NES, FDR<0.05 pathways)")
    # colorbar
    cax = plt.axes([0.91, 0.20, 0.02, 0.70])
    cb = plt.colorbar(im, cax=cax); cb.set_label("NES")

    out_png = FIG_DIR / "reactome_gsea_NES_heatmap.png"
    plt.savefig(out_png, dpi=160, bbox_inches="tight"); plt.close()
    NES.to_csv(GSEA_DIR / "reactome_NES_matrix.csv")
    FDR.to_csv(GSEA_DIR / "reactome_FDR_matrix.csv")
    print(f"[OK] NES heatmap -> {out_png}")
    print(f"[OK] NES/FDR matrices -> {GSEA_DIR}")


[OK] Bt_vs_Mock: 60,966 ranked symbols -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Bt_vs_Mock.rnk
[OK] Cd_vs_Mock: 60,966 ranked symbols -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Cd_vs_Mock.rnk
[OK] Co_vs_Mock: 60,966 ranked symbols -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Co_vs_Mock.rnk
[OK] Bt_vs_Cd: 60,966 ranked symbols -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Bt_vs_Cd.rnk
[OK] Bt_vs_Co: 60,966 ranked symbols -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Bt_vs_Co.rnk
[OK] Cd_vs_Co: 60,966 ranked symbols -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Cd_vs_Co.rnk


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.
  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Reactome GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Bt_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Reactome GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Cd_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Reactome GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Co_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Reactome GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Bt_vs_Cd


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Reactome GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Bt_vs_Co
[OK] Reactome GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/reactome/Cd_vs_Co


KeyError: 'FDR_q'

## KEGG

In [11]:
# --- OPTIONAL: KEGG (Entrez) or HALLMARK (symbols) GSEA setup ---
import re, numpy as np, pandas as pd, gseapy as gp
from pathlib import Path

BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
DE_DIR   = BASE / "analysis" / "DE"
GSEA_DIR = BASE / "analysis" / "GSEA"
(GSEA_DIR / "kegg").mkdir(parents=True, exist_ok=True)
(GSEA_DIR / "hallmark").mkdir(parents=True, exist_ok=True)

def strip_ver(x: str) -> str:
    return re.sub(r"\.\d+$", "", x)

# Clean your dictionaries (you already did this upstream)
# ens2entrez: keys are Ensembl WITH version (e.g., ENSG... .13) -> Entrez (string)
# ens2uniprot: (not used here but handy if a library requires UniProt)
def run_collection(collection_name: str, id_type: str, out_subdir: str):
    # id_type: "entrez" or "symbol"
    outbase = GSEA_DIR / out_subdir
    rank_files = {}
    for c in ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]:
        f = DE_DIR / f"python_voomlite_{c}.csv"
        if not f.exists(): 
            print(f"[WARN] missing {f}"); 
            continue
        df = pd.read_csv(f)
        df["gene"] = df["gene"].astype(str)

        if id_type == "entrez":
            # map Ensembl(with version) -> Entrez directly via your dict
            # drop genes without Entrez
            df["id"] = df["gene"].map(ens2entrez).astype(str)
        elif id_type == "symbol":
            # use symbols from adata (as in the Reactome block)
            if "ensembl_core" in adata_ens.var.columns:
                core_ids = adata_ens.var["ensembl_core"].astype(str).values
            else:
                core_ids = pd.Index(adata_ens.var_names.astype(str)).map(strip_ver).values
            sym_col = "gene_symbol" if "gene_symbol" in adata_ens.var.columns else (
                      "gene_symbols" if "gene_symbols" in adata_ens.var.columns else None)
            assert sym_col is not None, "No symbols found in adata_ens.var"
            ens2sym = pd.Series(adata_ens.var[sym_col].astype(str).values, index=core_ids)
            df["id"] = df["gene"].map(strip_ver).map(ens2sym)
        else:
            raise ValueError("id_type must be 'entrez' or 'symbol'")

        # rank: t-stat preferred
        score = df["t"] if "t" in df.columns else np.sign(df["logFC"]) * (-np.log10(np.clip(df["pval"], 1e-300, 1.0)))
        rnk = pd.DataFrame({"id": df["id"], "score": pd.to_numeric(score, errors="coerce")})
        rnk = rnk.dropna().astype({"id": str, "score": float})
        rnk = rnk[~rnk["id"].isin(["", "nan", "None"])]
        rnk = rnk.sort_values("score", key=lambda x: x.abs(), ascending=False).drop_duplicates("id")
        rnk = rnk.sort_values("score", ascending=False)

        out_rnk = outbase / f"{c}.rnk"
        rnk.to_csv(out_rnk, sep="\t", header=False, index=False)
        rank_files[c] = out_rnk
        print(f"[OK] {collection_name} RNK for {c}: {len(rnk)} ids -> {out_rnk}")

    # Run gseapy.prerank
    for c, rnk in rank_files.items():
        outdir = outbase / c; outdir.mkdir(parents=True, exist_ok=True)
        prer = gp.prerank(
            rnk=str(rnk),
            gene_sets=collection_name,   # e.g., "KEGG_2021_Human" or "MSigDB_Hallmark_2020"
            processes=4,
            permutation_num=1000,
            min_size=10, max_size=1000,
            seed=42,
            outdir=str(outdir),
            format="png"
        )
        res = prer.res2d.copy()
        res.rename(columns={"fdr":"FDR_q"}, inplace=True)
        res.to_csv(outdir / f"gsea_{out_subdir}_results.csv", index=False)
        print(f"[OK] {collection_name} for {c} -> {outdir}")

# Example calls:
# KEGG via Entrez:
# run_collection(collection_name="KEGG_2021_Human", id_type="entrez",  out_subdir="kegg")

# HALLMARK via symbols:
# run_collection(collection_name="MSigDB_Hallmark_2020", id_type="symbol", out_subdir="hallmark")


In [12]:
# KEGG via Enrichr (expects GENE SYMBOLS)
run_collection(
    collection_name="KEGG_2021_Human",   # Enrichr KEGG library
    id_type="symbol",                    # <-- use symbols (not Entrez)
    out_subdir="kegg"                    # results to analysis/GSEA/kegg/
)


[OK] KEGG_2021_Human RNK for Bt_vs_Mock: 60966 ids -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Bt_vs_Mock.rnk
[OK] KEGG_2021_Human RNK for Cd_vs_Mock: 60966 ids -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Cd_vs_Mock.rnk
[OK] KEGG_2021_Human RNK for Co_vs_Mock: 60966 ids -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Co_vs_Mock.rnk
[OK] KEGG_2021_Human RNK for Bt_vs_Cd: 60966 ids -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Bt_vs_Cd.rnk
[OK] KEGG_2021_Human RNK for Bt_vs_Co: 60966 ids -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Bt_vs_Co.rnk


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] KEGG_2021_Human RNK for Cd_vs_Co: 60966 ids -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Cd_vs_Co.rnk


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] KEGG_2021_Human for Bt_vs_Mock -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Bt_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] KEGG_2021_Human for Cd_vs_Mock -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Cd_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] KEGG_2021_Human for Co_vs_Mock -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Co_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] KEGG_2021_Human for Bt_vs_Cd -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Bt_vs_Cd


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] KEGG_2021_Human for Bt_vs_Co -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Bt_vs_Co
[OK] KEGG_2021_Human for Cd_vs_Co -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/kegg/Cd_vs_Co


## Hallmark

In [14]:
# --- HALLMARK GSEA (symbols) + NES heatmap ---
from pathlib import Path
import re, numpy as np, pandas as pd, matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from scipy.spatial.distance import squareform

# Paths
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
DE_DIR   = ANALYSIS_DIR / "DE"
GSEA_DIR = ANALYSIS_DIR / "GSEA" / "hallmark"
FIG_DIR  = ANALYSIS_DIR / "figures"
GSEA_DIR.mkdir(parents=True, exist_ok=True)

# Your contrasts from the Python-only DE
contrasts = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]

def strip_ver(x: str) -> str:
    return re.sub(r"\.\d+$", "", x)

# ---------- Build Ensembl(core) -> SYMBOL map from adata_ens.var ----------
# Prefer 'gene_symbol', else 'gene_symbols'
sym_col = "gene_symbol" if "gene_symbol" in adata_ens.var.columns else (
          "gene_symbols" if "gene_symbols" in adata_ens.var.columns else None)
if sym_col is None:
    raise ValueError("No 'gene_symbol' or 'gene_symbols' in adata_ens.var")

# Use precomputed core IDs if present, else strip version from var_names
if "ensembl_core" in adata_ens.var.columns:
    core_ids = adata_ens.var["ensembl_core"].astype(str).values
else:
    core_ids = pd.Index(adata_ens.var_names.astype(str)).map(strip_ver).values

ens2sym = pd.Series(adata_ens.var[sym_col].astype(str).values, index=core_ids)

# ---------- Build preranked lists (t-stat) per contrast ----------
ranked = {}  # contrast -> DataFrame with columns ['symbol','score']
for c in contrasts:
    f = DE_DIR / f"python_voomlite_{c}.csv"
    if not f.exists():
        print(f"[WARN] missing DE file: {f} — skipping")
        continue

    df = pd.read_csv(f)
    df["gene"] = df["gene"].astype(str)
    df["gene_core"] = df["gene"].map(strip_ver)
    df["symbol"] = df["gene_core"].map(ens2sym)

    # rank by t-stat (preferred), else signed -log10(p)
    if "t" in df.columns and np.isfinite(df["t"]).any():
        score = pd.to_numeric(df["t"], errors="coerce")
    else:
        score = np.sign(pd.to_numeric(df["logFC"], errors="coerce")) * (
            -np.log10(np.clip(pd.to_numeric(df["pval"], errors="coerce"), 1e-300, 1.0))
        )

    rnk = pd.DataFrame({"symbol": df["symbol"], "score": score})
    # drop NA/blank, resolve duplicate symbols by max |score|
    rnk = rnk.dropna()
    rnk = rnk[rnk["symbol"].astype(str).str.len() > 0]
    rnk = rnk.sort_values("score", key=lambda x: x.abs(), ascending=False).drop_duplicates("symbol")
    rnk = rnk.sort_values("score", ascending=False)
    ranked[c] = rnk
    print(f"[OK] {c}: {len(rnk):,} ranked symbols")

# ---------- Run gseapy prerank on Hallmark ----------
import gseapy as gp

all_res = {}
for c, rnk in ranked.items():
    outdir = GSEA_DIR / c
    outdir.mkdir(parents=True, exist_ok=True)

    # Option A (default): Enrichr Hallmark collection (symbol-based)
    gene_sets = "MSigDB_Hallmark_2020"
    # Option B (licensed MSigDB GMT): e.g., "/path/to/h.all.v2023.2.Hs.symbols.gmt"
    # gene_sets = "/path/to/h.all.*.symbols.gmt"

    prer = gp.prerank(
        rnk=rnk,                    # pass DataFrame directly
        gene_sets=gene_sets,
        processes=4,
        permutation_num=1000,
        min_size=10, max_size=1000,
        seed=42,
        outdir=str(outdir),
        format="png"               # enrichment plots saved per term
    )
    res2d = prer.res2d.copy()
    res2d.rename(columns={"fdr":"FDR_q"}, inplace=True)
    res2d.to_csv(outdir / "gsea_hallmark_results.csv", index=False)
    all_res[c] = res2d
    print(f"[OK] Hallmark GSEA saved -> {outdir}")

# ---------- NES matrix & clustered heatmap (terms x contrasts) ----------
# keep terms significant in ≥1 contrast (FDR_q < 0.05)
keep = set()
for c, df in all_res.items():
    if not df.empty:
        keep.update(df.loc[df["FDR_q"] < 0.05, "Term"].tolist())

if not keep:
    print("[WARN] No Hallmark terms at FDR<0.05. Consider raising permutations or threshold.")
else:
    terms = sorted(keep)
    NES = pd.DataFrame(index=terms, columns=contrasts, dtype=float)
    FDR = pd.DataFrame(index=terms, columns=contrasts, dtype=float)
    for c in contrasts:
        if c in all_res and not all_res[c].empty:
            df = all_res[c].set_index("Term")
            NES.loc[terms, c] = df.reindex(terms)["NES"].values
            FDR.loc[terms, c] = df.reindex(terms)["FDR_q"].values

    NES = NES.astype(float).fillna(0.0).clip(-3.5, 3.5)

    # cluster rows & columns by correlation distance
    def corr_linkage(mat, axis=0, method="average"):
        X = mat if axis == 0 else mat.T
        C = np.corrcoef(X)
        D = np.clip(1 - C, 0, 2)
        return linkage(squareform(D, checks=False), method=method)

    row_link = corr_linkage(NES.values, axis=0, method="average")
    col_link = corr_linkage(NES.values, axis=1, method="average")
    row_ord  = leaves_list(row_link)
    col_ord  = leaves_list(col_link)

    NES_ord = NES.values[row_ord][:, col_ord]
    row_lbl = [NES.index[i] for i in row_ord]
    col_lbl = [NES.columns[i] for i in col_ord]

    # plot
    plt.figure(figsize=(max(8, 0.45*len(col_lbl)+3), max(8, 0.22*len(row_lbl)+3)))
    ax_top = plt.axes([0.28, 0.90, 0.62, 0.08]); dendrogram(col_link, ax=ax_top, no_labels=True); ax_top.axis("off")
    ax_left = plt.axes([0.07, 0.20, 0.20, 0.70]); dendrogram(row_link, ax=ax_left, orientation="right", no_labels=True); ax_left.axis("off")
    ax = plt.axes([0.28, 0.20, 0.62, 0.70])
    im = ax.imshow(NES_ord, aspect="auto", interpolation="nearest")
    ax.set_xticks(range(len(col_lbl))); ax.set_xticklabels(col_lbl, rotation=45, ha="right", fontsize=9)
    ax.set_yticks(range(len(row_lbl))); ax.set_yticklabels(row_lbl, fontsize=9)
    ax.set_title("HALLMARK GSEA (NES, FDR<0.05)")
    cax = plt.axes([0.91, 0.20, 0.02, 0.70]); cb = plt.colorbar(im, cax=cax); cb.set_label("NES")

    out_png = FIG_DIR / "hallmark_gsea_NES_heatmap.png"
    plt.savefig(out_png, dpi=160, bbox_inches="tight"); plt.close()
    NES.to_csv(GSEA_DIR / "hallmark_NES_matrix.csv")
    FDR.to_csv(GSEA_DIR / "hallmark_FDR_matrix.csv")
    print(f"[OK] NES heatmap -> {out_png}")
    print(f"[OK] NES/FDR matrices -> {GSEA_DIR}")


[OK] Bt_vs_Mock: 60,966 ranked symbols
[OK] Cd_vs_Mock: 60,966 ranked symbols
[OK] Co_vs_Mock: 60,966 ranked symbols
[OK] Bt_vs_Cd: 60,966 ranked symbols
[OK] Bt_vs_Co: 60,966 ranked symbols
[OK] Cd_vs_Co: 60,966 ranked symbols


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.
  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Hallmark GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/hallmark/Bt_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Hallmark GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/hallmark/Cd_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Hallmark GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/hallmark/Co_vs_Mock


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Hallmark GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/hallmark/Bt_vs_Cd


  prer = gp.prerank(
The order of those genes will be arbitrary, which may produce unexpected results.


[OK] Hallmark GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/hallmark/Bt_vs_Co
[OK] Hallmark GSEA saved -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/hallmark/Cd_vs_Co


KeyError: 'FDR_q'

## Meta anaylsis 

In [17]:
from pathlib import Path
import pandas as pd
import numpy as np
import re

# ------------------ Config ------------------
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
GSEA_BASE = ANALYSIS_DIR / "GSEA"
OUT_DIR = GSEA_BASE / "summary"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Your contrasts (adapt if you want to focus on vs-Mock only)
contrasts = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]
collections = {
    "reactome": {
        "dir": GSEA_BASE / "reactome",
        "filename": "gsea_reactome_results.csv"
    },
    "kegg": {
        "dir": GSEA_BASE / "kegg",
        "filename": "gsea_kegg_results.csv"
    },
    "hallmark": {
        "dir": GSEA_BASE / "hallmark",
        "filename": "gsea_hallmark_results.csv"
    },
}

FDR_THRESH = 0.9
TOP_N = 10   # top N per contrast x collection for the overview

# ------------------ Helpers ------------------
def _std_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Robustly standardize GSEA result columns across gseapy versions/collections."""
    # Lowercase for matching
    lower = {c.lower(): c for c in df.columns}
    # Find likely columns
    term_col = lower.get("term", None)
    if term_col is None:
        # sometimes 'name' or 'pathway' occurs
        for cand in ["name", "pathway", "gs"]:
            if cand in lower:
                term_col = lower[cand]; break
    nes_col = lower.get("nes", None)
    if nes_col is None:
        for cand in ["normalized enrichment score", "normalized_enrichment_score"]:
            if cand in lower:
                nes_col = lower[cand]; break
    fdr_col = lower.get("fdr_q", None)
    if fdr_col is None:
        # gseapy often uses 'fdr'
        fdr_col = lower.get("fdr", None)
    p_col = lower.get("pval", None)
    if p_col is None:
        for cand in ["p-value", "pvalue", "p", "nominal p-value", "pval_nom"]:
            if cand in lower:
                p_col = lower[cand]; break
    lead_col = None
    # various names: 'lead_genes', 'leading_edge', 'ledge_genes', 'genes'
    for cand in ["lead_genes", "leading_edge", "ledge_genes", "genes", "hits"]:
        if cand in lower:
            lead_col = lower[cand]; break

    out = pd.DataFrame()
    if term_col is not None: out["Term"] = df[term_col].astype(str)
    if nes_col is not None:  out["NES"] = pd.to_numeric(df[nes_col], errors="coerce")
    if fdr_col is not None:  out["FDR_q"] = pd.to_numeric(df[fdr_col], errors="coerce")
    if p_col is not None:    out["pval"] = pd.to_numeric(df[p_col], errors="coerce")
    if lead_col is not None: out["lead_genes"] = df[lead_col].astype(str)
    # Keep anything else you want by merging later if needed
    return out

def _direction_label(nes):
    return "Up (first group)" if nes > 0 else "Down (second group)"

def _short_contrast_side(contrast: str):
    """Return 'first' vs 'second' names (e.g., 'Bt','Mock')."""
    if "_vs_" in contrast:
        a, b = contrast.split("_vs_", 1)
    else:
        parts = contrast.split("vs")
        a, b = parts[0], parts[1]
    return a, b

# ------------------ Load & unify ------------------
all_rows = []
missing = []

for coll, info in collections.items():
    for c in contrasts:
        f = info["dir"] / c / info["filename"]
        if not f.exists():
            missing.append(str(f))
            continue
        raw = pd.read_csv(f)
        df = _std_cols(raw)
        if df.empty or "Term" not in df.columns or "NES" not in df.columns:
            continue
        df["collection"] = coll
        df["contrast"] = c
        a, b = _short_contrast_side(c)
        df["group_first"] = a
        df["group_second"] = b
        # fallback if FDR_q missing: try to derive from pval (not ideal)
        if "FDR_q" not in df.columns or df["FDR_q"].isna().all():
            if "pval" in df.columns:
                from statsmodels.stats.multitest import multipletests
                df["FDR_q"] = multipletests(df["pval"].fillna(1.0), method="fdr_bh")[1]
            else:
                df["FDR_q"] = np.nan
        all_rows.append(df)

if not all_rows:
    raise RuntimeError("No GSEA result files found. Missing: \n" + "\n".join(missing[:10]))

gsea_all = pd.concat(all_rows, ignore_index=True)
# Clean terms (optional: shorten long Reactome names)
gsea_all["Term_clean"] = (gsea_all["Term"]
                          .str.replace("^REACTOME_", "", regex=True)
                          .str.replace("^HALLMARK_", "", regex=True)
                          .str.replace("_", " "))

# Significance filter
sig = gsea_all[(gsea_all["FDR_q"] < FDR_THRESH) & gsea_all["NES"].notna()].copy()
sig["Direction"] = np.where(sig["NES"] > 0, "Up (first group)", "Down (second group)")

# ------------------ Top N per contrast x collection ------------------
def top_by_contrast(df_sig, top_n=TOP_N):
    out = []
    for (c, coll), sub in df_sig.groupby(["contrast","collection"]):
        if sub.empty: continue
        # sort by FDR, then by |NES| desc, then by pval
        sub = sub.sort_values(["FDR_q", "NES", "pval"], ascending=[True, False, True])
        out.append(sub.head(top_n))
    return pd.concat(out, ignore_index=True) if out else pd.DataFrame()

topN = top_by_contrast(sig, TOP_N)
topN_cols = ["collection","contrast","group_first","group_second","Term_clean","NES","FDR_q","Direction","lead_genes"]
topN = topN[[c for c in topN_cols if c in topN.columns]]
topN.rename(columns={"Term_clean":"Term"}, inplace=True)
topN.to_csv(OUT_DIR / "GSEA_top_terms_by_contrast.csv", index=False)

# ------------------ Recurring pathways across contrasts ------------------
recurring = (sig.groupby(["collection","Term"])
                .agg(n_sig=("NES","size"),
                     mean_NES=("NES","mean"),
                     min_FDR=("FDR_q","min"))
                .reset_index()
                .sort_values(["n_sig","min_FDR"], ascending=[False, True]))
recurring.to_csv(OUT_DIR / "GSEA_recurring_terms.csv", index=False)

# ------------------ Tidy long (all significant) ------------------
sig_out = sig.copy()
sig_out = sig_out[["collection","contrast","group_first","group_second","Term_clean","NES","FDR_q","Direction","lead_genes"]]
sig_out.rename(columns={"Term_clean":"Term"}, inplace=True)
sig_out.to_csv(OUT_DIR / "GSEA_overview_long.csv", index=False)

# ------------------ Quick console summary (vs Mock only) ------------------
def brief_summary(df, contrast):
    sub = df[(df["contrast"]==contrast)]
    if sub.empty:
        print(f"\n{contrast}: no significant pathways at FDR<{FDR_THRESH}.")
        return
    a, b = _short_contrast_side(contrast)
    up = (sub[sub["NES"]>0]
          .sort_values(["FDR_q","NES"], ascending=[True,False])
          .groupby("collection")
          .head(5))
    down = (sub[sub["NES"]<0]
            .sort_values(["FDR_q","NES"], ascending=[True,True])  # most negative NES first
            .groupby("collection")
            .head(5))
    print(f"\n=== {contrast} (Up = enriched in {a}, Down = enriched in {b}) ===")
    if not up.empty:
        print("\nTop UP (by collection):")
        for _, r in up.iterrows():
            print(f"  [{r['collection']}] {r['Term']}  NES={r['NES']:.2f}  FDR={r['FDR_q']:.3g}")
    if not down.empty:
        print("\nTop DOWN (by collection):")
        for _, r in down.iterrows():
            print(f"  [{r['collection']}] {r['Term']}  NES={r['NES']:.2f}  FDR={r['FDR_q']:.3g}")

for c in ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock"]:
    brief_summary(sig_out, c)

print("\n[OK] Wrote:")
print(" -", OUT_DIR / "GSEA_top_terms_by_contrast.csv")
print(" -", OUT_DIR / "GSEA_recurring_terms.csv")
print(" -", OUT_DIR / "GSEA_overview_long.csv")



Bt_vs_Mock: no significant pathways at FDR<0.9.

Cd_vs_Mock: no significant pathways at FDR<0.9.

Co_vs_Mock: no significant pathways at FDR<0.9.

[OK] Wrote:
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_top_terms_by_contrast.csv
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_recurring_terms.csv
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_overview_long.csv


In [21]:
from pathlib import Path
from typing import Optional  # <-- fix for older Python
import pandas as pd
import numpy as np
import re
from statsmodels.stats.multitest import multipletests

# ------------------ Config ------------------
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
GSEA_BASE = ANALYSIS_DIR / "GSEA"
OUT_DIR = GSEA_BASE / "summary"
OUT_DIR.mkdir(parents=True, exist_ok=True)

contrasts = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]
collections = {
    "reactome": {"dir": GSEA_BASE / "reactome", "fname": "gsea_reactome_results.csv"},
    "kegg":     {"dir": GSEA_BASE / "kegg",     "fname": "gsea_kegg_results.csv"},
    "hallmark": {"dir": GSEA_BASE / "hallmark", "fname": "gsea_hallmark_results.csv"},
}
ALT_REPORT = "gseapy.gene_set.prerank.report.csv"  # fallback

FDR_THRESH = 0.09
TOP_N      = 10

# ------------------ Robust standardization ------------------
def read_gsea_csv(path: Path) -> pd.DataFrame:
    # auto-detect delimiter (tab/comma)
    return pd.read_csv(path, sep=None, engine="python")

def _pick_col(cols_lower, patterns):
    for c_low, c_orig in cols_lower.items():
        for pat in patterns:
            if re.search(pat, c_low, flags=re.I):
                return c_orig
    return None

def _std_cols(df: pd.DataFrame) -> pd.DataFrame:
    cl = {c.lower(): c for c in df.columns}
    term = _pick_col(cl, [r"^term$", r"^name$", r"^pathway$", r"^gs$", r"description"])
    nes  = _pick_col(cl, [r"^nes$", r"normalized[ _-]?enrichment"])
    fdr  = _pick_col(cl, [r"^fdr(_?q)?$", r"^fdr[ _-]?q[ _-]?val(ue)?$", r"^q[ _-]?val"])
    pvl  = _pick_col(cl, [r"^pval$", r"^p[-_ ]?value$", r"^p$", r"^nom.*p.*val"])
    led  = _pick_col(cl, [r"^lead.*gene", r"leading.*edge", r"ledge.*gene", r"^genes$", r"^hits$"])

    out = pd.DataFrame()
    if term is not None: out["Term"] = df[term].astype(str)
    if nes  is not None: out["NES"]  = pd.to_numeric(df[nes], errors="coerce")
    if fdr  is not None: out["FDR_q"]= pd.to_numeric(df[fdr], errors="coerce")
    if pvl  is not None: out["pval"] = pd.to_numeric(df[pvl], errors="coerce")
    if led  is not None: out["lead_genes"] = df[led].astype(str)
    return out

def first_existing(*paths: Path) -> Optional[Path]:  # <-- fixed hint
    for p in paths:
        if p.exists():
            return p
    return None

# ------------------ Load all collections ------------------
all_rows, missing = [], []

for coll, info in collections.items():
    for c in contrasts:
        base = info["dir"] / c
        main = base / info["fname"]
        alt  = base / ALT_REPORT
        usef = first_existing(main, alt)
        if usef is None:
            missing.append(str(main))
            continue
        raw = read_gsea_csv(usef)
        df  = _std_cols(raw)
        if df.empty or "Term" not in df.columns or "NES" not in df.columns:
            continue
        # ensure FDR_q
        if "FDR_q" not in df.columns or df["FDR_q"].isna().all():
            if "pval" in df.columns:
                df["FDR_q"] = multipletests(df["pval"].fillna(1.0), method="fdr_bh")[1]
            else:
                df["FDR_q"] = np.nan

        df["collection"]   = coll
        df["contrast"]     = c
        a, b = c.split("_vs_", 1) if "_vs_" in c else c.split("vs", 1)
        df["group_first"]  = a
        df["group_second"] = b
        all_rows.append(df)

if not all_rows:
    raise RuntimeError("No GSEA result files found.\nExamples missing:\n" + "\n".join(missing[:10]))

gsea = pd.concat(all_rows, ignore_index=True)

# Normalize term names a bit
gsea["Term"] = (gsea["Term"]
                .str.replace(r"^REACTOME_", "", regex=True)
                .str.replace(r"^HALLMARK_", "", regex=True)
                .str.replace("_", " "))

# ------------------ Filter + simple summaries ------------------
sig = gsea[(gsea["NES"].notna()) & (gsea["FDR_q"].notna()) & (gsea["FDR_q"] < FDR_THRESH)].copy()
sig["Direction"] = np.where(sig["NES"] > 0, "Up (first group)", "Down (second group)")

def top_by_contrast(df_sig, top_n=TOP_N):
    out = []
    for (c, coll), sub in df_sig.groupby(["contrast","collection"]):
        sub = sub.sort_values(["FDR_q", "NES"], ascending=[True, False])
        out.append(sub.head(top_n))
    return pd.concat(out, ignore_index=True) if out else pd.DataFrame()

topN = top_by_contrast(sig, TOP_N)
topN_cols = ["collection","contrast","group_first","group_second","Term","NES","FDR_q","Direction","lead_genes"]
topN = topN[[c for c in topN_cols if c in topN.columns]]
topN.to_csv(OUT_DIR / "GSEA_top_terms_by_contrast.csv", index=False)

recurring = (sig.groupby(["collection","Term"])
               .agg(n_sig=("NES","size"),
                    mean_NES=("NES","mean"),
                    min_FDR=("FDR_q","min"))
               .reset_index()
               .sort_values(["n_sig","min_FDR"], ascending=[False, True]))
recurring.to_csv(OUT_DIR / "GSEA_recurring_terms.csv", index=False)

sig.to_csv(OUT_DIR / "GSEA_overview_long.csv", index=False)

# Quick console glance (vs Mock only)
for c in ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock"]:
    sub = sig[sig["contrast"]==c].sort_values(["FDR_q","NES"], ascending=[True,False])
    a, b = c.split("_vs_", 1)
    if sub.empty:
        print(f"{c}: no pathways at FDR<{FDR_THRESH}.")
    else:
        up = sub[sub["NES"]>0].groupby("collection").head(3)
        down = sub[sub["NES"]<0].sort_values("NES").groupby("collection").head(3)
        print(f"\n=== {c} (Up={a}, Down={b}) ===")
        if not up.empty:
            print(" Top UP:")
            for _, r in up.iterrows():
                print(f"  [{r['collection']}] {r['Term']}  NES={r['NES']:.2f}  FDR={r['FDR_q']:.3g}")
        if not down.empty:
            print(" Top DOWN:")
            for _, r in down.iterrows():
                print(f"  [{r['collection']}] {r['Term']}  NES={r['NES']:.2f}  FDR={r['FDR_q']:.3g}")

print("\n[OK] Wrote:")
print(" -", OUT_DIR / "GSEA_top_terms_by_contrast.csv")
print(" -", OUT_DIR / "GSEA_recurring_terms.csv")
print(" -", OUT_DIR / "GSEA_overview_long.csv")



=== Bt_vs_Mock (Up=Bt, Down=Mock) ===
 Top UP:
  [reactome] prerank  NES=2.31  FDR=0
  [reactome] prerank  NES=2.29  FDR=0
  [reactome] prerank  NES=2.28  FDR=0
  [kegg] prerank  NES=2.21  FDR=0
  [kegg] prerank  NES=2.12  FDR=0
  [kegg] prerank  NES=2.05  FDR=0
  [hallmark] prerank  NES=1.87  FDR=0.002
  [hallmark] prerank  NES=1.65  FDR=0.011
  [hallmark] prerank  NES=1.51  FDR=0.0266

=== Cd_vs_Mock (Up=Cd, Down=Mock) ===
 Top UP:
  [hallmark] prerank  NES=2.19  FDR=0
  [kegg] prerank  NES=2.08  FDR=0
  [kegg] prerank  NES=2.02  FDR=0
  [hallmark] prerank  NES=1.99  FDR=0
  [kegg] prerank  NES=1.98  FDR=0
  [hallmark] prerank  NES=1.93  FDR=0
  [reactome] prerank  NES=2.05  FDR=0.000548
  [reactome] prerank  NES=2.06  FDR=0.000822
  [reactome] prerank  NES=2.04  FDR=0.000822
 Top DOWN:
  [kegg] prerank  NES=-1.78  FDR=0.0736

=== Co_vs_Mock (Up=Co, Down=Mock) ===
 Top UP:
  [kegg] prerank  NES=1.74  FDR=0.0514
  [hallmark] prerank  NES=1.45  FDR=0.0669
  [kegg] prerank  NES=1.67  F

In [24]:
from pathlib import Path
from typing import Optional
import pandas as pd, numpy as np, re, json
from statsmodels.stats.multitest import multipletests

# ------------------ Config ------------------
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
GSEA_BASE = ANALYSIS_DIR / "GSEA"
OUT_DIR = GSEA_BASE / "summary"
OUT_DIR.mkdir(parents=True, exist_ok=True)

contrasts = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]
collections = {
    "reactome": {"dir": GSEA_BASE / "reactome", "fname": "gsea_reactome_results.csv"},
    "kegg":     {"dir": GSEA_BASE / "kegg",     "fname": "gsea_kegg_results.csv"},
    "hallmark": {"dir": GSEA_BASE / "hallmark", "fname": "gsea_hallmark_results.csv"},
}
ALT_REPORT = "gseapy.gene_set.prerank.report.csv"  # fallback in each contrast folder

FDR_THRESH = 0.05
TOP_N = 10  # top per contrast x collection for overview

# ------------------ Helpers ------------------
def read_gsea_csv(path: Path) -> pd.DataFrame:
    return pd.read_csv(path, sep=None, engine="python")

def _pick_col(cols_lower, patterns):
    # prefer patterns (e.g., 'term') over whatever comes first in file (e.g., 'name')
    for pat in patterns:
        for c_low, c_orig in cols_lower.items():
            if re.search(pat, c_low, flags=re.I):
                return c_orig
    return None

def standardize_cols(df: pd.DataFrame) -> pd.DataFrame:
    cl = {c.lower(): c for c in df.columns}
    term = _pick_col(cl, [r"^term$", r"description", r"^pathway$", r"^gs$", r"^name$"])  # 'term' first!
    nes  = _pick_col(cl, [r"^nes$", r"normalized[ _-]?enrichment"])
    fdr  = _pick_col(cl, [r"^fdr(_?q)?$", r"fdr[ _-]?q[ _-]?val(ue)?", r"^q[ _-]?val"])
    pvl  = _pick_col(cl, [r"^pval$", r"^p[-_ ]?value$", r"^p$", r"^nom.*p.*val"])
    # leading edge sometimes appears under different headers
    led  = _pick_col(cl, [r"^lead.*gene", r"leading.*edge", r"^genes$", r"^hits$"])

    out = pd.DataFrame()
    if term is not None: out["Term"] = df[term].astype(str)
    if nes  is not None:  out["NES"]  = pd.to_numeric(df[nes], errors="coerce")
    if fdr  is not None:  out["FDR_q"]= pd.to_numeric(df[fdr], errors="coerce")
    if pvl  is not None:  out["pval"] = pd.to_numeric(df[pvl], errors="coerce")
    if led  is not None:  out["lead_genes"] = df[led].astype(str)
    return out

def first_existing(*paths: Path) -> Optional[Path]:
    for p in paths:
        if p.exists(): return p
    return None

# Parse GMT membership (term -> list of genes)
def load_gmt_membership(gmt_path: Path) -> dict[str, list[str]]:
    mem = {}
    if not gmt_path.exists(): return mem
    with gmt_path.open() as fh:
        for line in fh:
            parts = line.rstrip("\n").split("\t")
            if len(parts) >= 3:
                term = parts[0]
                genes = [g for g in parts[2:] if g]
                mem[term] = genes
    return mem

# Split a "lead_genes" cell into a list of gene symbols
_SPLIT = re.compile(r"[;,/\s]+")
def parse_leading_edge_cell(x: str):
    if not isinstance(x, str) or not x.strip(): return []
    toks = [t for t in _SPLIT.split(x) if t]
    # strip obvious non-gene tokens (e.g., 'tags=', 'list=' strings)
    toks = [t for t in toks if re.match(r"^[A-Za-z0-9_.\-]+$", t)]
    return toks

def groups_of(contrast: str):
    a, b = contrast.split("_vs_", 1) if "_vs_" in contrast else contrast.split("vs", 1)
    return a, b

# ------------------ Load all results + attach membership ------------------
rows = []
memberships = {}  # (collection, contrast) -> dict(term -> members)

for coll, info in collections.items():
    for c in contrasts:
        base = info["dir"] / c
        usef = first_existing(base / info["fname"], base / ALT_REPORT)
        if usef is None:
            continue
        raw = read_gsea_csv(usef)
        df  = standardize_cols(raw)
        if df.empty or "Term" not in df.columns or "NES" not in df.columns:
            continue

        # ensure FDR_q
        if "FDR_q" not in df.columns or df["FDR_q"].isna().all():
            if "pval" in df.columns:
                df["FDR_q"] = multipletests(df["pval"].fillna(1.0), method="fdr_bh")[1]
            else:
                df["FDR_q"] = np.nan

        # tidy names a little for display
        df["Term_display"] = (df["Term"]
            .str.replace(r"^REACTOME_", "", regex=True)
            .str.replace(r"^HALLMARK_", "", regex=True)
            .str.replace("_", " "))

        df["collection"] = coll
        df["contrast"]   = c
        a, b = groups_of(c)
        df["group_first"] = a
        df["group_second"]= b

        # attach full membership from the local GMT (if present)
        gmt = base / "gene_sets.gmt"
        mem = load_gmt_membership(gmt)
        memberships[(coll, c)] = mem

        # try to normalize membership term keys to match Term_display if needed
        # Build a fast lookup by a relaxed key
        def keyify(s: str) -> str:
            return re.sub(r"[\s_]+", " ", s.strip()).lower()
        mem_by_key = {keyify(k): v for k, v in mem.items()}

        df["members"] = df["Term_display"].map(lambda t: mem_by_key.get(keyify(t), []))
        df["set_size"] = df["members"].map(len)

        rows.append(df)

if not rows:
    raise RuntimeError("No GSEA result files parsed. Check file names or folders.")

gsea = pd.concat(rows, ignore_index=True)

# ------------------ Filter, top tables with names + genes ------------------
sig = gsea[(gsea["NES"].notna()) & (gsea["FDR_q"].notna()) & (gsea["FDR_q"] < FDR_THRESH)].copy()
sig["Direction"] = np.where(sig["NES"] > 0, "Up (first group)", "Down (second group)")

# if lead_genes missing, leave as empty list
sig["lead_genes_list"] = sig.get("lead_genes", pd.Series([""]*len(sig))).apply(parse_leading_edge_cell)
sig["lead_n"] = sig["lead_genes_list"].map(len)
sig["lead_preview"] = sig["lead_genes_list"].apply(lambda xs: ", ".join(xs[:15]) + (" ..." if len(xs)>15 else ""))

# Top N per contrast x collection (with Term names + lead genes)
def top_by_contrast(df_sig, top_n=TOP_N):
    out = []
    for (c, coll), sub in df_sig.groupby(["contrast","collection"]):
        sub = sub.sort_values(["FDR_q","NES"], ascending=[True, False]).head(top_n).copy()
        out.append(sub)
    return pd.concat(out, ignore_index=True) if out else pd.DataFrame()

topN = top_by_contrast(sig, TOP_N)
topN_out = topN[[
    "collection","contrast","group_first","group_second",
    "Term_display","NES","FDR_q","Direction","set_size","lead_n","lead_preview"
]].rename(columns={"Term_display":"Term"})
topN_out.to_csv(OUT_DIR / "GSEA_top_terms_with_genes.csv", index=False)

# Also dump full membership per top row (separate wide file to avoid huge CSV above)
mem_rows = []
for _, r in topN.iterrows():
    mem_rows.append({
        "collection": r["collection"],
        "contrast": r["contrast"],
        "Term": r["Term_display"],
        "set_size": r["set_size"],
        "members": ";".join(r["members"]),
        "leading_edge": ";".join(r["lead_genes_list"]) if isinstance(r["lead_genes_list"], list) else ""
    })
pd.DataFrame(mem_rows).to_csv(OUT_DIR / "GSEA_top_terms_membership.csv", index=False)

# ------------------ Build driver gene bags & overlaps (vs Mock only) ------------------
def driver_bag(df_sig: pd.DataFrame, contrast: str, direction: str) -> set[str]:
    a, b = groups_of(contrast)
    sub = df_sig[(df_sig["contrast"]==contrast) & (df_sig["NES"].notna())]
    sub = sub[sub["NES"] > 0] if direction=="up" else sub[sub["NES"] < 0]
    # union of leading-edge genes across all significant terms (any collection)
    genes = set()
    for xs in sub["lead_genes_list"]:
        if isinstance(xs, list):
            genes.update(xs)
    return genes

vs_mock = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock"]

bags_up   = {c: driver_bag(sig, c, "up")   for c in vs_mock}
bags_down = {c: driver_bag(sig, c, "down") for c in vs_mock}

# Save bags
(OUT_DIR / "driver_bags").mkdir(exist_ok=True)
for c in vs_mock:
    pd.Series(sorted(bags_up[c])).to_csv(OUT_DIR / "driver_bags" / f"drivers_UP_{c}.txt", index=False, header=False)
    pd.Series(sorted(bags_down[c])).to_csv(OUT_DIR / "driver_bags" / f"drivers_DOWN_{c}.txt", index=False, header=False)

# Overlap/unique summary
def overlap_summary(bags: dict[str, set[str]], label: str):
    cs = list(bags.keys())
    A, B, C = cs
    a, b, c = bags[A], bags[B], bags[C]
    summary = {
        f"unique_{A}": sorted(a - b - c),
        f"unique_{B}": sorted(b - a - c),
        f"unique_{C}": sorted(c - a - b),
        f"shared_{A}_{B}": sorted((a & b) - c),
        f"shared_{A}_{C}": sorted((a & c) - b),
        f"shared_{B}_{C}": sorted((b & c) - a),
        f"shared_all": sorted(a & b & c),
    }
    # counts table
    counts = pd.DataFrame({
        "set": list(summary.keys()),
        "n_genes": [len(v) for v in summary.values()]
    })
    # write
    prefix = OUT_DIR / f"drivers_overlap_{label}"
    counts.to_csv(prefix + "_counts.csv", index=False)
    # detailed lists
    with open(prefix + "_lists.json", "w") as fh:
        json.dump(summary, fh, indent=2)
    # also explode to CSV
    long = []
    for k, genes in summary.items():
        for g in genes:
            long.append({"group": k, "gene": g})
    pd.DataFrame(long).to_csv(prefix + "_lists.csv", index=False)
    print(f"[OK] Overlap ({label}) -> {prefix+'_counts.csv'} & {prefix+'_lists.*'}")# --- replace your overlap_summary with this ---
    
def overlap_summary(bags: dict, label: str):
    cs = list(bags.keys())
    if len(cs) != 3:
        raise ValueError("Expected exactly 3 contrasts in bags (Bt_vs_Mock, Cd_vs_Mock, Co_vs_Mock).")
    A, B, C = cs
    a, b, c = bags[A], bags[B], bags[C]

    summary = {
        f"unique_{A}": sorted(a - b - c),
        f"unique_{B}": sorted(b - a - c),
        f"unique_{C}": sorted(c - a - b),
        f"shared_{A}_{B}": sorted((a & b) - c),
        f"shared_{A}_{C}": sorted((a & c) - b),
        f"shared_{B}_{C}": sorted((b & c) - a),
        f"shared_all": sorted(a & b & c),
    }

    # counts table
    counts = pd.DataFrame({
        "set": list(summary.keys()),
        "n_genes": [len(v) for v in summary.values()]
    })

    # build file paths correctly (Path, not string concatenation)
    counts_path = OUT_DIR / f"drivers_overlap_{label}_counts.csv"
    lists_json_path = OUT_DIR / f"drivers_overlap_{label}_lists.json"
    lists_csv_path  = OUT_DIR / f"drivers_overlap_{label}_lists.csv"

    counts.to_csv(counts_path, index=False)

    # detailed lists
    with open(lists_json_path, "w") as fh:
        json.dump(summary, fh, indent=2)

    long = []
    for k, genes in summary.items():
        for g in genes:
            long.append({"group": k, "gene": g})
    pd.DataFrame(long).to_csv(lists_csv_path, index=False)

    print(f"[OK] Overlap ({label}) -> {counts_path} & {lists_csv_path}")


overlap_summary(bags_up,   label="UP")
overlap_summary(bags_down, label="DOWN")

print("\n[OK] Wrote:")
print(" -", OUT_DIR / "GSEA_top_terms_with_genes.csv")
print(" -", OUT_DIR / "GSEA_top_terms_membership.csv")
print(" -", OUT_DIR / "driver_bags/ (UP/DOWN txt lists)")
print(" -", OUT_DIR / "drivers_overlap_UP_counts.csv and _lists.(csv|json)")
print(" -", OUT_DIR / "drivers_overlap_DOWN_counts.csv and _lists.(csv|json)")


[OK] Overlap (UP) -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/drivers_overlap_UP_counts.csv & /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/drivers_overlap_UP_lists.csv
[OK] Overlap (DOWN) -> /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/drivers_overlap_DOWN_counts.csv & /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/drivers_overlap_DOWN_lists.csv

[OK] Wrote:
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_top_terms_with_genes.csv
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_top_terms_membership.csv
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/driver_bags/ (UP/DOWN txt lists)
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/drivers_overlap_UP_counts.csv and _lists.(csv|json)
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/drivers_overlap_DOWN_counts

In [25]:
# ==== GSEA x DEG overlap summary ====
from pathlib import Path
import pandas as pd, numpy as np, re

# ---------- config ----------
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
DE_DIR  = ANALYSIS_DIR / "DE"
GSEA_DIR = ANALYSIS_DIR / "GSEA"
OUT_DIR  = GSEA_DIR / "summary"
TAB_DIR  = OUT_DIR / "overlap_tables"
OUT_DIR.mkdir(parents=True, exist_ok=True)
TAB_DIR.mkdir(parents=True, exist_ok=True)

contrasts = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]
collections = {
    "reactome": {"subdir": "reactome", "fname": "gsea_reactome_results.csv"},
    "kegg":     {"subdir": "kegg",     "fname": "gsea_kegg_results.csv"},
    "hallmark": {"subdir": "hallmark", "fname": "gsea_hallmark_results.csv"},
}
PATHWAY_FDR = 0.05           # pathway significance filter
TOP_PATHWAYS = 10            # number of pathways per contrast x collection to report
TOP_GENES = 10               # number of genes to show for each pathway
DE_FDR = 0.05                # DEG significance when picking driving genes

# ---------- helpers ----------
def read_any_csv(p: Path) -> pd.DataFrame:
    return pd.read_csv(p, sep=None, engine="python")

def pick_col(cols_lower, patterns):
    for pat in patterns:  # prefer specific patterns first
        for cl, co in cols_lower.items():
            if re.search(pat, cl, flags=re.I):
                return co
    return None

def std_gsea(df: pd.DataFrame) -> pd.DataFrame:
    cl = {c.lower(): c for c in df.columns}
    term = pick_col(cl, [r"^term$", r"description", r"^pathway$", r"^gs$", r"^name$"])
    nes  = pick_col(cl, [r"^nes$", r"normalized[ _-]?enrichment"])
    fdr  = pick_col(cl, [r"^fdr(_?q)?$", r"fdr[ _-]?q[ _-]?val(ue)?", r"^q[ _-]?val"])
    pvl  = pick_col(cl, [r"^pval$", r"^p[-_ ]?value$", r"^p$", r"^nom.*p.*val"])
    led  = pick_col(cl, [r"^lead.*gene", r"leading.*edge", r"ledge.*gene", r"^genes$", r"^hits$"])
    out = pd.DataFrame()
    if term is not None: out["Term"] = df[term].astype(str)
    if nes  is not None: out["NES"]  = pd.to_numeric(df[nes], errors="coerce")
    if fdr  is not None: out["FDR_q"]= pd.to_numeric(df[fdr], errors="coerce")
    if pvl  is not None: out["pval"] = pd.to_numeric(df[pvl], errors="coerce")
    if led  is not None: out["lead_genes"] = df[led].astype(str)
    return out

SPLIT = re.compile(r"[;,\|/\s]+")
def parse_leading_edge_cell(x: str):
    if not isinstance(x, str) or not x.strip():
        return []
    toks = [t for t in SPLIT.split(x) if t]
    return [t for t in toks if re.match(r"^[A-Za-z0-9_.\-]+$", t)]

def load_gmt(path: Path):
    mem = {}
    if not path.exists(): return mem
    with path.open() as fh:
        for line in fh:
            parts = line.rstrip("\n").split("\t")
            if len(parts) >= 3:
                mem[parts[0]] = [g for g in parts[2:] if g]
    return mem

def keyify(s: str) -> str:
    return re.sub(r"[\s_]+", " ", s.strip()).lower()

def members_from_gmt(mem: dict, term_raw: str, term_disp: str):
    # exact raw
    if term_raw in mem: return mem[term_raw]
    # rebuild typical prefixed forms from display
    disp_up = term_disp.replace(" ", "_").upper()
    for pref in ("REACTOME_", "HALLMARK_"):
        cand = pref + disp_up
        if cand in mem: return mem[cand]
    # relaxed key match
    bykey = {keyify(k): v for k, v in mem.items()}
    for cand in (term_raw, term_disp):
        v = bykey.get(keyify(cand))
        if v: return v
    return []

def clean_term(s: str) -> str:
    return (s.replace("REACTOME_", "")
             .replace("HALLMARK_", "")
             .replace("_", " "))

def safe_filename(s: str) -> str:
    s = re.sub(r"[^\w\-]+", "_", s)
    return re.sub(r"_+", "_", s).strip("_")

# ---------- main ----------
overview_rows = []
long_rows = []

for contrast in contrasts:
    # load DEG table
    de_path = DE_DIR / f"python_voomlite_{contrast}.csv"
    if not de_path.exists():
        print(f"[WARN] DEG missing for {contrast}: {de_path}")
        continue
    deg = pd.read_csv(de_path)
    # standardize columns
    assert {"gene_symbol","logFC","padj"}.issubset(deg.columns), f"DE file missing cols in {de_path}"
    deg["gene_symbol"] = deg["gene_symbol"].astype(str)
    deg["is_sig"] = (pd.to_numeric(deg["padj"], errors="coerce") < DE_FDR)

    for coll, info in collections.items():
        base = GSEA_DIR / info["subdir"] / contrast
        main = base / info["fname"]
        alt  = base / "gseapy.gene_set.prerank.report.csv"
        gsea_file = main if main.exists() else (alt if alt.exists() else None)
        if gsea_file is None:
            print(f"[WARN] GSEA missing for {coll} {contrast}")
            continue

        gdf_raw = read_any_csv(gsea_file)
        gdf = std_gsea(gdf_raw)
        if gdf.empty or "Term" not in gdf.columns or "NES" not in gdf.columns:
            print(f"[WARN] Unrecognized GSEA format for {coll} {contrast}: {gsea_file}")
            continue

        # ensure FDR_q
        if "FDR_q" not in gdf.columns or gdf["FDR_q"].isna().all():
            from statsmodels.stats.multitest import multipletests
            if "pval" in gdf.columns:
                gdf["FDR_q"] = multipletests(gdf["pval"].fillna(1.0), method="fdr_bh")[1]
            else:
                gdf["FDR_q"] = np.nan

        gdf["Term_display"] = gdf["Term"].astype(str).map(clean_term)
        gdf["lead_list"] = gdf.get("lead_genes", pd.Series([""]*len(gdf))).apply(parse_leading_edge_cell)

        # fallback to GMT membership if no leading-edge available
        gmt_path = base / "gene_sets.gmt"
        gmt_mem = load_gmt(gmt_path)

        # filter & pick top pathways
        sel = gdf[(pd.to_numeric(gdf["FDR_q"], errors="coerce") < PATHWAY_FDR) & gdf["NES"].notna()]
        sel = sel.sort_values(["FDR_q","NES"], ascending=[True, False]).head(TOP_PATHWAYS).copy()
        if sel.empty:
            continue

        for _, r in sel.iterrows():
            term_raw  = str(r["Term"])
            term_disp = str(r["Term_display"])
            nes  = float(r["NES"])
            qval = float(r["FDR_q"])

            lead = r["lead_list"]
            if not lead:  # fallback to entire set membership, then pick DE drivers
                lead = members_from_gmt(gmt_mem, term_raw, term_disp)

            # match to DEG by gene_symbol
            if not lead:
                matched = pd.DataFrame(columns=["gene_symbol","logFC","padj","is_sig"])
            else:
                m = pd.DataFrame({"gene_symbol": list(dict.fromkeys(lead))})
                matched = m.merge(deg[["gene_symbol","logFC","padj","is_sig"]],
                                  on="gene_symbol", how="left")

            # choose top drivers among leading-edge: prefer significant, then by padj then |logFC|
            matched["padj"] = pd.to_numeric(matched["padj"], errors="coerce")
            matched["logFC"] = pd.to_numeric(matched["logFC"], errors="coerce")
            matched["rank_sig"] = (~matched["is_sig"]).astype(int)  # 0 for sig first
            matched = matched.sort_values(
                ["rank_sig","padj","logFC"],
                ascending=[True, True, False],
                na_position="last"
            )
            top_genes = matched.head(TOP_GENES).copy()

            # summary row for overview
            gene_str = "; ".join([
                f"{g} (log2FC={lfc:+.2f}, q={q:.3g})"
                for g, lfc, q in zip(
                    top_genes["gene_symbol"].fillna("NA"),
                    top_genes["logFC"].fillna(np.nan),
                    top_genes["padj"].fillna(np.nan)
                )
            ])
            overview_rows.append({
                "contrast": contrast,
                "collection": coll,
                "pathway": term_disp,
                "NES": nes,
                "FDR_q": qval,
                "n_leading_edge": int(len(lead)),
                "n_leading_edge_sig": int(matched["is_sig"].fillna(False).sum()),
                "top_genes": gene_str
            })

            # detailed table per pathway
            det = matched.loc[:, ["gene_symbol","logFC","padj","is_sig"]].copy()
            det.insert(0, "collection", coll)
            det.insert(1, "contrast", contrast)
            det.insert(2, "pathway", term_disp)
            # write per-pathway file
            fname = f"{safe_filename(contrast)}__{safe_filename(coll)}__{safe_filename(term_disp)}.csv"
            det.to_csv(TAB_DIR / fname, index=False)
            # collect long (optional)
            for _, rr in det.iterrows():
                long_rows.append({
                    "contrast": contrast,
                    "collection": coll,
                    "pathway": term_disp,
                    "gene_symbol": rr["gene_symbol"],
                    "logFC": rr["logFC"],
                    "padj": rr["padj"],
                    "is_sig": rr["is_sig"]
                })

# write outputs
overview_df = pd.DataFrame(overview_rows).sort_values(
    ["contrast","collection","FDR_q","NES"],
    ascending=[True, True, True, False]
)
overview_path = OUT_DIR / "GSEA_DEG_overlap_top_pathways_and_genes.csv"
overview_df.to_csv(overview_path, index=False)

if long_rows:
    long_df = pd.DataFrame(long_rows)
    long_df.to_csv(OUT_DIR / "GSEA_DEG_overlap_long_per_pathway_gene.csv", index=False)

print("[OK] Wrote:")
print(" -", overview_path)
print(" -", TAB_DIR, " (per-pathway detailed gene tables)")
if long_rows:
    print(" -", OUT_DIR / "GSEA_DEG_overlap_long_per_pathway_gene.csv")


[OK] Wrote:
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_DEG_overlap_top_pathways_and_genes.csv
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/overlap_tables  (per-pathway detailed gene tables)
 - /storage/users/job37yv/Projects/Franziska_faber/analysis/GSEA/summary/GSEA_DEG_overlap_long_per_pathway_gene.csv


In [None]:
# ==== Compact GSEA x DEG per-condition summaries ====
from pathlib import Path
from typing import Optional, Dict, List
import pandas as pd, numpy as np, re
from statsmodels.stats.multitest import multipletests

# ------------------ Config ------------------
BASE = Path("/storage/users/job37yv/Projects/Franziska_faber")
ANALYSIS_DIR = BASE / "analysis"
DE_DIR   = ANALYSIS_DIR / "DE"
GSEA_DIR = ANALYSIS_DIR / "GSEA"
OUT_DIR  = GSEA_DIR / "summary" / "compact"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# contrasts to summarize
CONTRASTS = ["Bt_vs_Mock","Cd_vs_Mock","Co_vs_Mock","Bt_vs_Cd","Bt_vs_Co","Cd_vs_Co"]
COLLECTIONS = {
    "reactome": {"subdir": "reactome", "fname": "gsea_reactome_results.csv"},
    "kegg":     {"subdir": "kegg",     "fname": "gsea_kegg_results.csv"},
    "hallmark": {"subdir": "hallmark", "fname": "gsea_hallmark_results.csv"},
}
ALT_REPORT = "gseapy.gene_set.prerank.report.csv"   # fallback inside each contrast folder

# knobs
PATHWAY_FDR   = 0.09    # significance for pathways
TOP_PATHWAYS  = 5       # how many UP/DOWN pathways to show (overall, across collections)
TOP_GENES     = 10      # how many UP/DOWN driver genes to show
DE_FDR        = 0.05    # significance tag for DEG
WEIGHT_FLOOR  = 1e-300  # prevent log(0)

# ------------------ Helpers ------------------
def read_any_csv(p: Path) -> pd.DataFrame:
    return pd.read_csv(p, sep=None, engine="python")

def _pick_col(cols_lower, patterns):
    for pat in patterns:
        for cl, co in cols_lower.items():
            if re.search(pat, cl, flags=re.I):
                return co
    return None

def std_gsea(df: pd.DataFrame) -> pd.DataFrame:
    cl = {c.lower(): c for c in df.columns}
    term = _pick_col(cl, [r"^term$", r"description", r"^pathway$", r"^gs$", r"^name$"])
    nes  = _pick_col(cl, [r"^nes$", r"normalized[ _-]?enrichment"])
    fdr  = _pick_col(cl, [r"^fdr(_?q)?$", r"fdr[ _-]?q[ _-]?val(ue)?", r"^q[ _-]?val"])
    pvl  = _pick_col(cl, [r"^pval$", r"^p[-_ ]?value$", r"^p$", r"^nom.*p.*val"])
    led  = _pick_col(cl, [r"^lead.*gene", r"leading.*edge", r"ledge.*gene", r"^genes$", r"^hits$"])

    out = pd.DataFrame()
    if term is not None: out["Term"] = df[term].astype(str)
    if nes  is not None: out["NES"]  = pd.to_numeric(df[nes], errors="coerce")
    if fdr  is not None: out["FDR_q"]= pd.to_numeric(df[fdr], errors="coerce")
    if pvl  is not None: out["pval"] = pd.to_numeric(df[pvl], errors="coerce")
    if led  is not None: out["lead_genes"] = df[led].astype(str)
    return out

def load_gmt(path: Path) -> Dict[str, List[str]]:
    mem: Dict[str, List[str]] = {}
    if not path.exists(): return mem
    with path.open() as fh:
        for line in fh:
            parts = line.rstrip("\n").split("\t")
            if len(parts) >= 3:
                mem[parts[0]] = [g for g in parts[2:] if g]
    return mem

def keyify(s: str) -> str:
    return re.sub(r"[\s_]+", " ", s.strip()).lower()

def members_from_gmt(mem: Dict[str, List[str]], term_raw: str, term_disp: str):
    # exact raw
    if term_raw in mem: return mem[term_raw]
    # rebuild prefixed forms from display name
    disp_up = term_disp.replace(" ", "_").upper()
    for pref in ("REACTOME_", "HALLMARK_"):
        cand = pref + disp_up
        if cand in mem: return mem[cand]
    # relaxed key match
    bykey = {keyify(k): v for k, v in mem.items()}
    for cand in (term_raw, term_disp):
        v = bykey.get(keyify(cand))
        if v: return v
    return []

_SPLIT = re.compile(r"[;,\|/\s]+")
def parse_leading_edge_cell(x: str):
    if not isinstance(x, str) or not x.strip(): return []
    toks = [t for t in _SPLIT.split(x) if t]
    return [t for t in toks if re.match(r"^[A-Za-z0-9_.\-]+$", t)]

def clean_term(s: str) -> str:
    return (s.replace("REACTOME_", "")
             .replace("HALLMARK_", "")
             .replace("_", " "))

def safe_name(s: str) -> str:
    s = re.sub(r"[^\w\-]+", "_", s)
    return re.sub(r"_+", "_", s).strip("_")

def groups_of(contrast: str):
    return contrast.split("_vs_", 1) if "_vs_" in contrast else contrast.split("vs", 1)

# per-gene weight from a pathway row (for ranking driver genes)
def pathway_weight(nes: float, fdr: float) -> float:
    return abs(float(nes)) * (-np.log10(max(float(fdr), WEIGHT_FLOOR)))

# ------------------ Run per contrast ------------------
overview_all = []  # one line per pathway kept (for an all-conditions sheet)

for contrast in CONTRASTS:
    # 1) DEG
    de_path = DE_DIR / f"python_voomlite_{contrast}.csv"
    if not de_path.exists():
        print(f"[WARN] DEG missing for {contrast}: {de_path}")
        continue
    deg = pd.read_csv(de_path)
    if not {"gene_symbol","logFC","padj"}.issubset(deg.columns):
        print(f"[WARN] DEG columns missing in {de_path}, skipping.")
        continue
    deg["gene_symbol"] = deg["gene_symbol"].astype(str)
    deg["padj"] = pd.to_numeric(deg["padj"], errors="coerce")
    deg["logFC"] = pd.to_numeric(deg["logFC"], errors="coerce")
    deg["is_sig"] = deg["padj"] < DE_FDR

    # 2) GSEA (collect all collections)
    gsea_rows = []
    members_cache = {}  # (collection) -> membership dict
    for coll, info in COLLECTIONS.items():
        base = GSEA_DIR / info["subdir"] / contrast
        main = base / info["fname"]
        alt  = base / ALT_REPORT
        usef = main if main.exists() else (alt if alt.exists() else None)
        if usef is None:
            continue

        gdf_raw = read_any_csv(usef)
        gdf = std_gsea(gdf_raw)
        if gdf.empty or "Term" not in gdf.columns or "NES" not in gdf.columns:
            continue

        # ensure FDR_q
        if "FDR_q" not in gdf.columns or gdf["FDR_q"].isna().all():
            if "pval" in gdf.columns:
                gdf["FDR_q"] = multipletests(gdf["pval"].fillna(1.0), method="fdr_bh")[1]
            else:
                gdf["FDR_q"] = np.nan

        gdf["collection"]  = coll
        gdf["Term_raw"]    = gdf["Term"].astype(str)
        gdf["Term"]        = gdf["Term_raw"].map(clean_term)
        gdf["lead_list"]   = gdf.get("lead_genes", pd.Series([""]*len(gdf))).apply(parse_leading_edge_cell)

        gsea_rows.append(gdf)

        # cache membership dict
        if coll not in members_cache:
            members_cache[coll] = load_gmt(base / "gene_sets.gmt")

    if not gsea_rows:
        print(f"[WARN] No GSEA rows for {contrast}")
        continue
    gsea_all = pd.concat(gsea_rows, ignore_index=True)

    # 3) filter significant pathways & pick compact top UP/DOWN overall
    sig = gsea_all[(gsea_all["NES"].notna()) & (gsea_all["FDR_q"].notna()) & (gsea_all["FDR_q"] < PATHWAY_FDR)].copy()
    if sig.empty:
        print(f"[INFO] No pathways at FDR<{PATHWAY_FDR} for {contrast}. Skipping summary.")
        continue

    sig["Direction"] = np.where(sig["NES"]>0, "UP", "DOWN")
    # rank: FDR asc, then |NES| desc
    sig = sig.sort_values(["FDR_q","NES"], ascending=[True, False])

    top_up   = sig[sig["NES"]>0].head(TOP_PATHWAYS).copy()
    top_down = sig[sig["NES"]<0].sort_values(["FDR_q","NES"], ascending=[True, True]).head(TOP_PATHWAYS).copy()

    # 4) build driver genes per direction from TOP pathways only (keeps it compact)
    def drivers_from(top_df: pd.DataFrame) -> pd.DataFrame:
        if top_df.empty: 
            return pd.DataFrame(columns=["gene_symbol","driver_score","log2FC","q"])
        # assemble gene->weight sum
        weights = {}
        for _, r in top_df.iterrows():
            term = r["Term"]
            coll = r["collection"]
            nes  = float(r["NES"]); q = float(r["FDR_q"])
            w    = pathway_weight(nes, q)

            lead = r["lead_list"]
            if not lead:  # fallback to full membership from GMT
                mem = members_cache.get(coll, {})
                lead = members_from_gmt(mem, r["Term_raw"], term)

            for g in set(lead):
                weights[g] = weights.get(g, 0.0) + w

        if not weights:
            return pd.DataFrame(columns=["gene_symbol","driver_score","log2FC","q"])

        drv = (pd.DataFrame(list(weights.items()), columns=["gene_symbol","driver_score"])
                 .sort_values("driver_score", ascending=False))

        # join DEG stats
        drv = drv.merge(deg[["gene_symbol","logFC","padj"]], on="gene_symbol", how="left")
        drv.rename(columns={"logFC":"log2FC","padj":"q"}, inplace=True)
        # order by driver_score, then q, then |log2FC|
        drv["q"] = pd.to_numeric(drv["q"], errors="coerce")
        drv["log2FC"] = pd.to_numeric(drv["log2FC"], errors="coerce")
        drv = drv.sort_values(["driver_score","q","log2FC"], ascending=[False, True, False])
        return drv

    drv_up   = drivers_from(top_up).head(TOP_GENES)
    drv_down = drivers_from(top_down).head(TOP_GENES)

    # 5) write tiny CSVs + markdown one-pager
    a, b = groups_of(contrast)
    pre = OUT_DIR / f"{safe_name(contrast)}"

    # pathways tables
    pu = top_up.loc[:, ["collection","Term","NES","FDR_q"]].copy()
    pdn = top_down.loc[:, ["collection","Term","NES","FDR_q"]].copy()
    pu.to_csv(pre.with_suffix(".top_pathways_UP.csv"), index=False)
    pdn.to_csv(pre.with_suffix(".top_pathways_DOWN.csv"), index=False)

    # driver genes tables
    drv_up.to_csv(pre.with_suffix(".top_genes_UP.csv"), index=False)
    drv_down.to_csv(pre.with_suffix(".top_genes_DOWN.csv"), index=False)

    # overview (also collect for a global sheet)
    for df, direction in [(top_up,"UP"), (top_down,"DOWN")]:
        for _, r in df.iterrows():
            overview_all.append({
                "contrast": contrast,
                "direction": direction,
                "collection": r["collection"],
                "pathway": r["Term"],
                "NES": float(r["NES"]),
                "FDR_q": float(r["FDR_q"]),
            })

    # markdown one-pager
    def fmt_pathways(df, dirlabel):
        if df.empty: return f"_No significant pathways {dirlabel.lower()}._\n"
        lines = []
        for _, r in df.iterrows():
            lines.append(f"- [{r['collection']}] **{r['Term']}** (NES {r['NES']:.2f}, q {r['FDR_q']:.3g})")
        return "\n".join(lines) + "\n"

    def fmt_genes(df, dirlabel):
        if df.empty: return f"_No driver genes {dirlabel.lower()}._\n"
        parts = []
        for _, r in df.iterrows():
            g = r["gene_symbol"]; lfc = r["log2FC"]; q = r["q"]
            parts.append(f"**{g}** (log2FC {lfc:+.2f}, q {q:.3g})")
        return ", ".join(parts) + "\n"

    md = []
    md.append(f"# {contrast}: compact summary\n")
    md.append(f"_Up = enriched in **{a}**; Down = enriched in **{b}**._\n")
    md.append(f"## Top pathways — UP in {a}\n")
    md.append(fmt_pathways(pu, "UP"))
    md.append(f"## Top driver genes — UP in {a}\n")
    md.append(fmt_genes(drv_up, "UP"))
    md.append(f"## Top pathways — DOWN in {b}\n")
    md.append(fmt_pathways(pdn, "DOWN"))
    md.append(f"## Top driver genes — DOWN in {b}\n")
    md.append(fmt_genes(drv_down, "DOWN"))

    (pre.with_suffix(".summary.md")).write_text("\n".join(md), encoding="utf-8")

    print(f"[OK] {contrast} ->")
    print("   ", pre.with_suffix(".summary.md").name,
          pre.with_suffix(".top_pathways_UP.csv").name,
          pre.with_suffix(".top_pathways_DOWN.csv").name,
          pre.with_suffix(".top_genes_UP.csv").name,
          pre.with_suffix(".top_genes_DOWN.csv").name)

# Global compact sheet across all contrasts
if overview_all:
    pd.DataFrame(overview_all).sort_values(
        ["contrast","direction","FDR_q","NES"],
        ascending=[True, True, True, False]
    ).to_csv(OUT_DIR / "compact_pathway_overview_all.csv", index=False)

print("\n[OK] Wrote compact summaries to:", OUT_DIR)


[OK] Bt_vs_Mock ->
    Bt_vs_Mock.summary.md Bt_vs_Mock.top_pathways_UP.csv Bt_vs_Mock.top_pathways_DOWN.csv Bt_vs_Mock.top_genes_UP.csv Bt_vs_Mock.top_genes_DOWN.csv
[OK] Cd_vs_Mock ->
    Cd_vs_Mock.summary.md Cd_vs_Mock.top_pathways_UP.csv Cd_vs_Mock.top_pathways_DOWN.csv Cd_vs_Mock.top_genes_UP.csv Cd_vs_Mock.top_genes_DOWN.csv
[OK] Co_vs_Mock ->
    Co_vs_Mock.summary.md Co_vs_Mock.top_pathways_UP.csv Co_vs_Mock.top_pathways_DOWN.csv Co_vs_Mock.top_genes_UP.csv Co_vs_Mock.top_genes_DOWN.csv
[OK] Bt_vs_Cd ->
    Bt_vs_Cd.summary.md Bt_vs_Cd.top_pathways_UP.csv Bt_vs_Cd.top_pathways_DOWN.csv Bt_vs_Cd.top_genes_UP.csv Bt_vs_Cd.top_genes_DOWN.csv
