# Genomic Technologies, proof-of-concept pipeline, single-file script.
GWAS has identified many genotype–phenotype associations, but most signals lie in noncoding regions and dense LD blocks, making it difficult to move from a locus signal to a specific causal gene (Visscher et al., 2017). Closing this gap typically requires integrating three evidence layers, statistical fine-mapping to prioritize likely causal variants using probabilistic measures such as PIPs, regulatory genomics to connect variants to gene expression and mechanisms through eQTL mapping and colocalization frameworks like COLOC (Giambartolomei et al., 2014; GTEx Consortium, 2020), and functional perturbation evidence such as genome-scale CRISPR knockout screens to test whether candidate genes have meaningful phenotypic impact (Doench et al., 2016; Meyers et al., 2017). Because translational genomics teams increasingly combine these heterogeneous signals using reproducible workflows and machine learning to support prioritization and calibrated decisions, a key challenge is the scarcity of complete, reproducible end-to-end examples. This work addresses that gap by providing a simulation-based proof of concept that creates a realistic, controllable dataset and implements an end-to-end pipeline, producing standard plots and outputs and serving as a benchmark and a template for substitution with real-world data.

## What this script does.
1) Simulates a single genomic locus with block-wise LD, and common-variant MAF range.
2) Simulates a quantitative trait with polygenic, small per-allele effects.
3) Simulates cis-eQTL summary statistics for genes in the locus, using an allelic fold-change inspired mixture.
4) Runs GWAS (univariate), produces summary statistics.
5) Fine-maps the locus using Wakefield approximate Bayes factors, producing PIPs.
6) Performs COLOC-style colocalization (ABF) between GWAS and per-gene eQTL signals.
7) Simulates CRISPR KO screen gene-level log2 fold-changes and p-values, with an essential-gene left tail.
8) Integrates evidence to gene-level features, trains ML models with robust CV, outputs metrics and figures.
9) Saves publishable figures as PNG+PDF, and tables as CSV.

## Benchmarking anchors used only to guide qualitative ranges and distributions.
- Visscher et al., 2017, Am J Hum Genet, polygenicity and small effect sizes for complex traits. https://pubmed.ncbi.nlm.nih.gov/28686856/
- Park et al., 2010, Nat Genet, effect size distribution and architecture concepts in GWAS. https://pubmed.ncbi.nlm.nih.gov/20562874/
- Park et al., 2011, PNAS, common variants and GWAS discovery power concepts. https://pubmed.ncbi.nlm.nih.gov/22003128/
- GTEx Consortium, 2020, Science, cis-eQTL effect sizes, and allelic fold change framing across tissues. https://pubmed.ncbi.nlm.nih.gov/32913098/
- Yao et al., 2020, Nat Genet, expression-mediated component of heritability (MESC). https://pubmed.ncbi.nlm.nih.gov/32424349/
- Giambartolomei et al., 2014, PLoS Genet, COLOC methodology using summary statistics. https://pubmed.ncbi.nlm.nih.gov/24830394/
- Meyers et al., 2017, Nat Genet, CRISPR screening characteristics, gene effects, CN bias discussion. https://pubmed.ncbi.nlm.nih.gov/29083409/
- Doench et al., 2016, Nat Biotechnol, guide design and pooled screen readout considerations. https://pubmed.ncbi.nlm.nih.gov/26780180/
- Lenoir et al., 2021, Nat Commun, essential gene effects, and screen phenotype distributions. https://pubmed.ncbi.nlm.nih.gov/34764293/

## Notes.
- This is a simulation, it does not reproduce any single dataset, it aims to be realistic in magnitude and shape.
- Designed to run in a notebook or as a script, without external files or internet access.

In [11]:
from __future__ import annotations

import os
import math
import json
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

from scipy import stats
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve, brier_score_loss
from sklearn.calibration import calibration_curve
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier, RandomForestClassifier
from sklearn.inspection import permutation_importance


# Plot defaults

plt.rcParams.update(
    {
        "font.size": 11,
        "axes.labelsize": 11,
        "axes.titlesize": 12,
        "legend.fontsize": 9,
        "axes.linewidth": 1.0,
        "xtick.major.width": 1.0,
        "ytick.major.width": 1.0,
    }
)


# Configuration

@dataclass
class Config:
    seed: int = 13

    # Genotypes and LD
    n_individuals: int = 20000
    n_snps: int = 4000
    n_ld_blocks: int = 20
    maf_min: float = 0.05
    maf_max: float = 0.50

    # Trait architecture
    n_causal_snps: int = 20
    beta_sd: float = 0.03
    trait_h2: float = 0.35

    # Genes, eQTL
    n_genes: int = 25
    prop_eqtl_over_2fold: float = 0.22  # GTEx v8 inspired, fraction with large aFC
    eqtl_N: int = 800

    # Expression mediation, qualitative, guided by MESC averages
    expr_mediated_fraction_of_h2: float = 0.11  # used only as a conceptual anchor
    expr_mix_weight: float = 0.25              # mixing coefficient for mediated component

    # COLOC priors, typical defaults for coloc ABF usage
    coloc_p1: float = 1e-4
    coloc_p2: float = 1e-4
    coloc_p12: float = 1e-5

    # Fine-mapping prior variance
    prior_var_finemap: float = 0.04

    # CRISPR screen simulation, essential left tail
    crispr_prop_essential: float = 0.12
    crispr_prop_prolif_suppressor: float = 0.03
    crispr_mean_nonessential: float = 0.0
    crispr_sd_nonessential: float = 0.18
    crispr_mean_essential: float = -0.9
    crispr_sd_essential: float = 0.35
    crispr_mean_psg: float = 0.30
    crispr_sd_psg: float = 0.12

    # Labels for ML, enforce enough positives for stable CV
    n_positive_genes: int = 8

    # Output
    outdir: str = "genomic_technologies_poc_outputs"
    dpi: int = 350


# Utilities

def set_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)


def ensure_outdir(outdir: str) -> None:
    os.makedirs(outdir, exist_ok=True)


def savefig(fig: plt.Figure, path_no_ext: str, dpi: int = 300) -> None:
    fig.tight_layout()
    fig.savefig(path_no_ext + ".png", dpi=dpi, bbox_inches="tight")
    fig.savefig(path_no_ext + ".pdf", dpi=dpi, bbox_inches="tight")
    plt.close(fig)


def z_to_p(z: np.ndarray) -> np.ndarray:
    return 2.0 * stats.norm.sf(np.abs(z))


def safe_log10(x: np.ndarray, eps: float = 1e-300) -> np.ndarray:
    return np.log10(np.maximum(x, eps))


# Simulation, genotypes with LD

def simulate_ld_genotypes(cfg: Config) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Block-wise LD simulation with a latent factor model and probit thresholds to match MAF.
    Genotypes are 0/1/2 minor-allele counts.
    """
    n = cfg.n_individuals
    m = cfg.n_snps
    b = cfg.n_ld_blocks

    pos = np.linspace(1, 1_000_000, m).astype(int)
    snp_ids = [f"rsSIM{idx+1:05d}" for idx in range(m)]

    mafs = np.random.uniform(cfg.maf_min, cfg.maf_max, size=m)

    block_ids = np.repeat(np.arange(b), m // b)
    if len(block_ids) < m:
        block_ids = np.concatenate([block_ids, np.full(m - len(block_ids), b - 1)])

    latent = np.random.normal(0, 1, size=(n, b))

    G = np.zeros((n, m), dtype=np.int8)

    for blk in range(b):
        idx = np.where(block_ids == blk)[0]
        rho = np.random.uniform(0.25, 0.65)
        noise1 = np.random.normal(0, 1, size=(n, len(idx)))
        noise2 = np.random.normal(0, 1, size=(n, len(idx)))

        s1 = rho * latent[:, [blk]] + (1.0 - rho) * noise1
        s2 = rho * latent[:, [blk]] + (1.0 - rho) * noise2

        thr = stats.norm.ppf(mafs[idx])
        a1 = (s1 < thr).astype(np.int8)
        a2 = (s2 < thr).astype(np.int8)
        G[:, idx] = a1 + a2

    geno = pd.DataFrame(G, columns=snp_ids)
    snp_info = pd.DataFrame({"snp": snp_ids, "pos": pos, "maf": mafs, "ld_block": block_ids})
    return geno, snp_info


# Simulation, eQTL summary stats

def simulate_eqtl_effects(cfg: Config, snp_info: pd.DataFrame) -> pd.DataFrame:
    """
    Simulate per-gene lead cis-eQTL signal.
    Use a mixture on log2(aFC) so that a fraction has |log2 aFC| > 1 (two-fold).
    """
    genes = [f"GENE{g+1:02d}" for g in range(cfg.n_genes)]
    lead_snps = np.random.choice(snp_info["snp"].values, size=cfg.n_genes, replace=False)

    is_large = np.random.rand(cfg.n_genes) < cfg.prop_eqtl_over_2fold
    log2_afc = np.where(
        is_large,
        np.random.normal(0.0, 0.75, size=cfg.n_genes),
        np.random.normal(0.0, 0.18, size=cfg.n_genes),
    )

    # Simplified mapping log2(aFC) to standardized beta
    beta = log2_afc * 0.35
    se = np.random.uniform(0.03, 0.07, size=cfg.n_genes)
    z = beta / se
    p = z_to_p(z)

    eqtl = pd.DataFrame(
        {
            "gene": genes,
            "lead_snp": lead_snps,
            "log2_aFC": log2_afc,
            "beta": beta,
            "se": se,
            "z": z,
            "p": p,
            "N": cfg.eqtl_N,
            "large_eqtl_flag": is_large.astype(int),
        }
    )
    return eqtl


# Simulation, trait phenotype

def pick_causal_snps(cfg: Config) -> np.ndarray:
    return np.random.choice(np.arange(cfg.n_snps), size=cfg.n_causal_snps, replace=False)


def simulate_trait_from_genotypes(cfg: Config, geno: pd.DataFrame, causal_idx: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    y = G*beta + e, scale e to achieve target h2.
    """
    G = geno.values.astype(np.float32)
    m = G.shape[1]

    beta = np.zeros(m, dtype=np.float32)
    beta[causal_idx] = np.random.normal(0.0, cfg.beta_sd, size=len(causal_idx)).astype(np.float32)

    g = G @ beta
    var_g = float(np.var(g))
    var_g = max(var_g, 1e-12)

    var_e = var_g * (1.0 - cfg.trait_h2) / max(cfg.trait_h2, 1e-6)
    e = np.random.normal(0.0, math.sqrt(var_e), size=G.shape[0]).astype(np.float32)

    y = g + e
    y = (y - y.mean()) / y.std(ddof=0)
    return y.astype(np.float32), beta


def simulate_expression_mediation(cfg: Config, geno: pd.DataFrame, eqtl: pd.DataFrame, y_base: np.ndarray) -> pd.DataFrame:
    """
    Simulate per-gene expression phenotypes influenced by each gene's lead cis-eQTL.
    Create an expression score and mix into y_base to produce a mediated component.
    """
    n = geno.shape[0]
    genes = eqtl["gene"].tolist()
    Xexpr = np.zeros((n, len(genes)), dtype=np.float32)

    snp_to_col = {s: i for i, s in enumerate(geno.columns)}

    for j, row in eqtl.iterrows():
        snp = row["lead_snp"]
        beta_expr = float(row["beta"])
        g_snp = geno.iloc[:, snp_to_col[snp]].values.astype(np.float32)
        g_snp = (g_snp - g_snp.mean()) / (g_snp.std(ddof=0) + 1e-8)
        noise = np.random.normal(0.0, 1.0, size=n).astype(np.float32)
        Xexpr[:, j] = beta_expr * g_snp + 0.85 * noise

    Xexpr = (Xexpr - Xexpr.mean(axis=0)) / (Xexpr.std(axis=0) + 1e-8)

    w = np.random.normal(0.0, 1.0, size=len(genes)).astype(np.float32)
    expr_score = Xexpr @ w
    expr_score = (expr_score - expr_score.mean()) / (expr_score.std(ddof=0) + 1e-8)

    lam = cfg.expr_mix_weight
    y = (1.0 - lam) * y_base + lam * expr_score
    y = (y - y.mean()) / y.std(ddof=0)

    return pd.DataFrame(Xexpr, columns=genes), y.astype(np.float32)


# GWAS scan, summary statistics

def gwas_scan(geno: pd.DataFrame, y: np.ndarray) -> pd.DataFrame:
    """
    Univariate regression y ~ G + intercept, computed efficiently.
    Returns snp, beta, se, z, p.
    """
    G = geno.values.astype(np.float32)
    y = y.astype(np.float32)

    n = G.shape[0]
    y0 = y - y.mean()
    y0 = y0 / (y0.std(ddof=0) + 1e-8)

    Gc = G - G.mean(axis=0, keepdims=True)
    var_g = (Gc**2).sum(axis=0) / (n - 1)
    cov_gy = (Gc * y0[:, None]).sum(axis=0) / (n - 1)

    beta = cov_gy / (var_g + 1e-12)
    r = cov_gy / (np.sqrt(var_g) + 1e-12)
    se = np.sqrt(np.maximum(1.0 - r**2, 1e-8) / (max(n - 2, 1) * (var_g + 1e-12)))
    z = beta / (se + 1e-12)
    p = z_to_p(z)

    return pd.DataFrame({"snp": geno.columns, "beta": beta.astype(float), "se": se.astype(float), "z": z.astype(float), "p": p.astype(float)})


# Fine-mapping, ABF and PIP

def approximate_bayes_factor(beta: np.ndarray, se: np.ndarray, prior_var: float = 0.04) -> np.ndarray:
    """
    Wakefield approximate Bayes factor, single causal model.
    """
    v = se**2
    z = beta / (se + 1e-12)
    r = prior_var / (prior_var + v)
    abf = np.sqrt(1.0 - r) * np.exp(0.5 * (r * z**2))
    return abf


def compute_pips(sumstats: pd.DataFrame, prior_var: float = 0.04) -> pd.DataFrame:
    abf = approximate_bayes_factor(sumstats["beta"].values, sumstats["se"].values, prior_var=prior_var)
    pip = abf / (abf.sum() + 1e-30)
    out = sumstats.copy()
    out["abf"] = abf
    out["pip"] = pip
    return out


# COLOC-style colocalization

def build_eqtl_sumstats_for_coloc(cfg: Config, snp_info: pd.DataFrame, eqtl: pd.DataFrame) -> Dict[str, pd.DataFrame]:
    """
    For each gene, create eQTL sumstats across the locus SNPs.
    Only the lead SNP has signal, others are near-null (large se).
    """
    snps = snp_info["snp"].values
    out: Dict[str, pd.DataFrame] = {}

    for _, row in eqtl.iterrows():
        gene = row["gene"]
        lead = row["lead_snp"]

        beta = np.zeros(len(snps), dtype=float)
        se = np.full(len(snps), 10.0, dtype=float)

        lead_idx = int(np.where(snps == lead)[0][0])
        beta[lead_idx] = float(row["beta"])
        se[lead_idx] = float(row["se"])

        out[gene] = pd.DataFrame({"snp": snps, "beta": beta, "se": se})

    return out


def coloc_abf(
    gwas: pd.DataFrame,      # columns, snp, beta, se
    eqtl_gene: pd.DataFrame, # columns, snp, beta, se
    p1: float,
    p2: float,
    p12: float,
    prior_var_gwas: float = 0.04,
    prior_var_eqtl: float = 0.04,
) -> Dict[str, float]:
    """
    Simplified COLOC using ABFs (Giambartolomei et al., 2014).
    Returns PP for H0..H4.
    """
    df = gwas.merge(eqtl_gene[["snp", "beta", "se"]], on="snp", how="left", suffixes=("_gwas", "_eqtl"))

    abf1 = approximate_bayes_factor(df["beta_gwas"].values, df["se_gwas"].values, prior_var=prior_var_gwas)

    beta2 = df["beta_eqtl"].fillna(0.0).values
    se2 = df["se_eqtl"].fillna(10.0).values
    abf2 = approximate_bayes_factor(beta2, se2, prior_var=prior_var_eqtl)

    sum_abf1 = abf1.sum()
    sum_abf2 = abf2.sum()
    sum_abf1abf2 = (abf1 * abf2).sum()

    H0 = 1.0
    H1 = p1 * sum_abf1
    H2 = p2 * sum_abf2
    H3 = p1 * p2 * max(sum_abf1 * sum_abf2 - sum_abf1abf2, 0.0)
    H4 = p12 * sum_abf1abf2

    denom = H0 + H1 + H2 + H3 + H4
    return {"PP.H0": H0 / denom, "PP.H1": H1 / denom, "PP.H2": H2 / denom, "PP.H3": H3 / denom, "PP.H4": H4 / denom}


# CRISPR screen simulation

def simulate_crispr_screen(cfg: Config, genes: List[str]) -> pd.DataFrame:
    n = len(genes)
    labels = np.array(["nonessential"] * n, dtype=object)

    idx_ess = np.random.choice(np.arange(n), size=max(1, int(cfg.crispr_prop_essential * n)), replace=False)
    remaining = np.setdiff1d(np.arange(n), idx_ess)
    idx_psg = np.random.choice(remaining, size=max(1, int(cfg.crispr_prop_prolif_suppressor * n)), replace=False)

    labels[idx_ess] = "essential"
    labels[idx_psg] = "prolif_suppressor"

    logfc = np.zeros(n, dtype=np.float32)
    for i in range(n):
        if labels[i] == "essential":
            logfc[i] = np.random.normal(cfg.crispr_mean_essential, cfg.crispr_sd_essential)
        elif labels[i] == "prolif_suppressor":
            logfc[i] = np.random.normal(cfg.crispr_mean_psg, cfg.crispr_sd_psg)
        else:
            logfc[i] = np.random.normal(cfg.crispr_mean_nonessential, cfg.crispr_sd_nonessential)

    z = (logfc - logfc.mean()) / (logfc.std(ddof=0) + 1e-8)
    p = z_to_p(z)

    return pd.DataFrame({"gene": genes, "log2FC": logfc, "z": z, "p": p, "class": labels})


# Gene-level feature assembly

def gene_level_features(
    gwas_sumstats: pd.DataFrame,
    pips: pd.DataFrame,
    eqtl: pd.DataFrame,
    coloc_pp: pd.DataFrame,
    crispr: pd.DataFrame,
) -> pd.DataFrame:
    gwas_map = gwas_sumstats.set_index("snp")
    pip_map = pips.set_index("snp")

    rows = []
    for _, g in eqtl.iterrows():
        gene = g["gene"]
        lead = g["lead_snp"]

        gwas_p = float(gwas_map.loc[lead, "p"])
        gwas_z = float(gwas_map.loc[lead, "z"])
        lead_pip = float(pip_map.loc[lead, "pip"])

        pp4 = float(coloc_pp.loc[coloc_pp["gene"] == gene, "PP.H4"].values[0])
        log2_afc = float(g["log2_aFC"])

        cr = crispr.loc[crispr["gene"] == gene].iloc[0]
        crispr_logfc = float(cr["log2FC"])
        crispr_p = float(cr["p"])

        rows.append(
            {
                "gene": gene,
                "lead_snp": lead,
                "gwas_p_lead": gwas_p,
                "gwas_z_lead": gwas_z,
                "pip_lead": lead_pip,
                "coloc_pp4": pp4,
                "log2_aFC": log2_afc,
                "crispr_log2FC": crispr_logfc,
                "crispr_p": crispr_p,
            }
        )

    df = pd.DataFrame(rows)
    df["gwas_neglog10p_lead"] = -safe_log10(df["gwas_p_lead"].values)
    df["crispr_neglog10p"] = -safe_log10(df["crispr_p"].values)
    return df


# ML training with robust CV

def train_and_evaluate_models(cfg: Config, features: pd.DataFrame, y_label: np.ndarray) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]], Dict[str, object]]:
    X = features.drop(columns=["gene", "lead_snp"]).copy()
    X = X.replace([np.inf, -np.inf], np.nan).fillna(0.0)

    models = {
        "logreg": Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=800, class_weight="balanced"))]),
        "hgb": HistGradientBoostingClassifier(max_depth=3, learning_rate=0.08, max_iter=250, l2_regularization=0.02, random_state=cfg.seed),
        "rf": RandomForestClassifier(n_estimators=500, max_depth=5, min_samples_leaf=2, class_weight="balanced", random_state=cfg.seed),
    }

    n_pos = int(np.sum(y_label == 1))
    n_neg = int(np.sum(y_label == 0))
    minority = min(n_pos, n_neg)
    n_splits = min(5, minority)
    if n_splits < 2:
        raise ValueError(f"Not enough samples per class for CV, positives={n_pos}, negatives={n_neg}.")

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=cfg.seed)

    preds_all: List[pd.DataFrame] = []
    metrics: Dict[str, Dict[str, float]] = {}
    fitted_last: Dict[str, object] = {}

    def to_1d_proba(model, Xte: pd.DataFrame) -> np.ndarray:
        if hasattr(model, "predict_proba"):
            p = np.asarray(model.predict_proba(Xte))
            if p.ndim == 2 and p.shape[1] >= 2:
                p1 = p[:, 1]
            elif p.ndim == 2 and p.shape[1] == 1:
                p1 = p[:, 0]
            else:
                p1 = p
        else:
            s = np.asarray(model.decision_function(Xte)).reshape(-1)
            p1 = 1.0 / (1.0 + np.exp(-s))
        return np.asarray(p1).reshape(-1)

    for name, model in models.items():
        fold_metrics = []
        fold_pred_frames = []

        for fold, (tr, te) in enumerate(skf.split(X, y_label), start=1):
            Xtr, Xte = X.iloc[tr], X.iloc[te]
            ytr, yte = y_label[tr], y_label[te]

            model.fit(Xtr, ytr)
            p = np.clip(to_1d_proba(model, Xte), 0.0, 1.0)

            roc = np.nan
            if len(np.unique(yte)) == 2:
                roc = roc_auc_score(yte, p)
            ap = average_precision_score(yte, p)
            brier = brier_score_loss(yte, p)

            fold_metrics.append({"fold": fold, "roc_auc": roc, "avg_precision": ap, "brier": brier})

            fold_pred_frames.append(
                pd.DataFrame(
                    {
                        "model": [name] * len(te),
                        "fold": [fold] * len(te),
                        "gene": features.iloc[te]["gene"].values,
                        "y_true": yte.astype(int),
                        "y_pred": p.astype(float),
                    }
                )
            )

            fitted_last[name] = model

        preds_all.append(pd.concat(fold_pred_frames, ignore_index=True))
        metrics[name] = {
            "roc_auc_mean": float(np.nanmean([m["roc_auc"] for m in fold_metrics])),
            "avg_precision_mean": float(np.nanmean([m["avg_precision"] for m in fold_metrics])),
            "brier_mean": float(np.nanmean([m["brier"] for m in fold_metrics])),
            "n_splits": float(n_splits),
            "n_pos": float(n_pos),
            "n_neg": float(n_neg),
        }

    return pd.concat(preds_all, ignore_index=True), metrics, fitted_last


# Plotting, robust to duplicate columns

def plot_manhattan(cfg: Config, snp_info: pd.DataFrame, gwas_sumstats: pd.DataFrame) -> plt.Figure:
    s = snp_info[["snp", "pos"]].copy()
    g = gwas_sumstats[["snp", "p"]].copy()
    df = s.merge(g, on="snp", how="inner")
    y = -safe_log10(df["p"].values)

    fig = plt.figure(figsize=(10, 3.6))
    ax = fig.add_subplot(111)
    ax.scatter(df["pos"].values, y, s=8, alpha=0.75)
    ax.set_xlabel("Genomic position (simulated locus)")
    ax.set_ylabel(r"$-\log_{10}(p)$")
    ax.set_title("Locus Manhattan plot (simulated GWAS)")
    return fig


def plot_qq(cfg: Config, gwas_sumstats: pd.DataFrame) -> plt.Figure:
    p = np.clip(gwas_sumstats["p"].values, 1e-300, 1.0)
    p_sorted = np.sort(p)
    exp = -np.log10((np.arange(1, len(p_sorted) + 1) - 0.5) / len(p_sorted))
    obs = -np.log10(p_sorted)

    fig = plt.figure(figsize=(4.3, 4.3))
    ax = fig.add_subplot(111)
    ax.scatter(exp, obs, s=8, alpha=0.75)
    lim = max(exp.max(), obs.max())
    ax.plot([0, lim], [0, lim], linewidth=1)
    ax.set_xlabel("Expected -log10(p)")
    ax.set_ylabel("Observed -log10(p)")
    ax.set_title("QQ plot")
    return fig


def plot_pips(cfg: Config, snp_info: pd.DataFrame, pips: pd.DataFrame) -> plt.Figure:
    s = snp_info[["snp", "pos"]].copy()
    p = pips[["snp", "pip"]].copy()
    df = s.merge(p, on="snp", how="inner")

    fig = plt.figure(figsize=(10, 3.3))
    ax = fig.add_subplot(111)
    ax.scatter(df["pos"].values, df["pip"].values, s=10, alpha=0.8)
    ax.set_xlabel("Genomic position (simulated locus)")
    ax.set_ylabel("Posterior inclusion probability (PIP)")
    ax.set_title("Fine-mapping PIPs (ABF-based)")
    return fig


def plot_coloc_pp4(cfg: Config, coloc_pp: pd.DataFrame) -> plt.Figure:
    df = coloc_pp.sort_values("PP.H4", ascending=False).head(15).copy()

    fig = plt.figure(figsize=(8.6, 4.0))
    ax = fig.add_subplot(111)
    ax.bar(df["gene"], df["PP.H4"].values)
    ax.set_ylabel("Posterior P(H4), shared causal variant")
    ax.set_title("Colocalization evidence, top genes (COLOC ABF)")
    ax.set_xticklabels(df["gene"], rotation=45, ha="right")
    ax.set_ylim(0, 1.0)
    return fig


def plot_crispr(cfg: Config, crispr: pd.DataFrame) -> plt.Figure:
    df = crispr.copy()
    df["neglog10p"] = -safe_log10(df["p"].values)

    fig = plt.figure(figsize=(7.2, 4.0))
    ax = fig.add_subplot(111)
    ax.scatter(df["log2FC"].values, df["neglog10p"].values, s=18, alpha=0.75)
    ax.set_xlabel("CRISPR log2 fold-change")
    ax.set_ylabel(r"$-\log_{10}(p)$")
    ax.set_title("CRISPR screen volcano (simulated)")
    return fig


def plot_model_curves(cfg: Config, preds: pd.DataFrame) -> Dict[str, plt.Figure]:
    figs: Dict[str, plt.Figure] = {}

    # ROC
    fig = plt.figure(figsize=(5.3, 4.6))
    ax = fig.add_subplot(111)
    for model_name in preds["model"].unique():
        sub = preds[preds["model"] == model_name].copy()
        y = sub["y_true"].values
        p = np.clip(sub["y_pred"].values, 0, 1)
        if len(np.unique(y)) < 2:
            continue
        fpr, tpr, _ = roc_curve(y, p)
        auc = roc_auc_score(y, p)
        ax.plot(fpr, tpr, label=f"{model_name}, AUC={auc:.2f}")
    ax.plot([0, 1], [0, 1], linewidth=1)
    ax.set_xlabel("False positive rate")
    ax.set_ylabel("True positive rate")
    ax.set_title("ROC curves")
    ax.legend(frameon=False)
    figs["roc"] = fig

    # PR
    fig = plt.figure(figsize=(5.3, 4.6))
    ax = fig.add_subplot(111)
    for model_name in preds["model"].unique():
        sub = preds[preds["model"] == model_name].copy()
        y = sub["y_true"].values
        p = np.clip(sub["y_pred"].values, 0, 1)
        prec, rec, _ = precision_recall_curve(y, p)
        ap = average_precision_score(y, p)
        ax.plot(rec, prec, label=f"{model_name}, AP={ap:.2f}")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_title("Precision–Recall curves")
    ax.legend(frameon=False)
    figs["pr"] = fig

    # Calibration
    fig = plt.figure(figsize=(5.3, 4.6))
    ax = fig.add_subplot(111)
    for model_name in preds["model"].unique():
        sub = preds[preds["model"] == model_name].copy()
        y = sub["y_true"].values
        p = np.clip(sub["y_pred"].values, 0, 1)
        frac_pos, mean_pred = calibration_curve(y, p, n_bins=6, strategy="quantile")
        ax.plot(mean_pred, frac_pos, marker="o", label=model_name)
    ax.plot([0, 1], [0, 1], linewidth=1)
    ax.set_xlabel("Mean predicted probability")
    ax.set_ylabel("Observed fraction positive")
    ax.set_title("Calibration curves")
    ax.legend(frameon=False)
    figs["calibration"] = fig

    return figs


def plot_permutation_importance(cfg: Config, model, features: pd.DataFrame, y_label: np.ndarray, model_name: str) -> plt.Figure:
    X = features.drop(columns=["gene", "lead_snp"]).replace([np.inf, -np.inf], np.nan).fillna(0.0)

    result = permutation_importance(model, X, y_label, n_repeats=100, random_state=cfg.seed, scoring="roc_auc")
    imp = pd.DataFrame({"feature": X.columns, "importance": result.importances_mean}).sort_values("importance", ascending=False).head(12)

    fig = plt.figure(figsize=(7.6, 4.2))
    ax = fig.add_subplot(111)
    ax.barh(imp["feature"].values[::-1], imp["importance"].values[::-1])
    ax.set_xlabel("Permutation importance (Δ ROC-AUC)")
    ax.set_title(f"Feature importance, {model_name}")
    return fig


# Main

def main() -> None:
    cfg = Config()
    set_seeds(cfg.seed)
    ensure_outdir(cfg.outdir)

    # 1) Genotypes, locus metadata
    geno, snp_info = simulate_ld_genotypes(cfg)

    # 2) eQTLs for genes in the locus
    eqtl = simulate_eqtl_effects(cfg, snp_info)

    # 3) Trait phenotype with polygenic signal
    causal_idx = pick_causal_snps(cfg)
    y_base, beta_true = simulate_trait_from_genotypes(cfg, geno, causal_idx)

    # 4) Expression mediation, simulated expression matrix and updated trait
    expr_df, y = simulate_expression_mediation(cfg, geno, eqtl, y_base)

    # 5) GWAS scan, keep as pure sumstats without annotations
    gwas_sumstats = gwas_scan(geno, y)

    # 6) Fine-map, ABF and PIP
    pips = compute_pips(gwas_sumstats[["snp", "beta", "se", "z", "p"]].copy(), prior_var=cfg.prior_var_finemap)

    # 7) COLOC per gene, using ABFs
    eqtl_sumstats = build_eqtl_sumstats_for_coloc(cfg, snp_info, eqtl)

    coloc_rows = []
    gwas_for_coloc = gwas_sumstats[["snp", "beta", "se"]].copy()
    for gene, eq in eqtl_sumstats.items():
        pp = coloc_abf(
            gwas=gwas_for_coloc,
            eqtl_gene=eq,
            p1=cfg.coloc_p1,
            p2=cfg.coloc_p2,
            p12=cfg.coloc_p12,
            prior_var_gwas=cfg.prior_var_finemap,
            prior_var_eqtl=cfg.prior_var_finemap,
        )
        coloc_rows.append({"gene": gene, **pp})

    coloc_pp = pd.DataFrame(coloc_rows).sort_values("PP.H4", ascending=False)

    # 8) CRISPR screen simulation
    genes = eqtl["gene"].tolist()
    crispr = simulate_crispr_screen(cfg, genes)

    # 9) Gene-level integration
    feats = gene_level_features(gwas_sumstats, pips, eqtl, coloc_pp, crispr)

    # 10) Labels for ML, deterministic top-k to avoid CV failure
    score = (
        1.6 * feats["coloc_pp4"].values
        + 0.8 * feats["pip_lead"].values
        + 0.4 * feats["gwas_neglog10p_lead"].values / (feats["gwas_neglog10p_lead"].max() + 1e-8)
        + 0.25 * np.clip(np.abs(feats["log2_aFC"].values), 0, 2)
        + 0.15 * feats["crispr_neglog10p"].values / (feats["crispr_neglog10p"].max() + 1e-8)
    )
    score = (score - score.min()) / (score.max() - score.min() + 1e-12)

    topk = np.argsort(score)[-cfg.n_positive_genes :]
    y_label = np.zeros(len(feats), dtype=int)
    y_label[topk] = 1
    feats["label_causal_gene"] = y_label

    # 11) ML modelling, explicit feature columns to avoid leakage
    feature_cols = [
        "gene",
        "lead_snp",
        "gwas_p_lead",
        "gwas_z_lead",
        "pip_lead",
        "coloc_pp4",
        "log2_aFC",
        "crispr_log2FC",
        "crispr_p",
        "gwas_neglog10p_lead",
        "crispr_neglog10p",
    ]
    preds, metrics, fitted_last = train_and_evaluate_models(cfg, feats[feature_cols], y_label)

    # 12) Save tables
    snp_info.to_csv(os.path.join(cfg.outdir, "snp_info.csv"), index=False)
    gwas_sumstats.to_csv(os.path.join(cfg.outdir, "gwas_sumstats.csv"), index=False)
    pips.to_csv(os.path.join(cfg.outdir, "finemap_pips.csv"), index=False)
    eqtl.to_csv(os.path.join(cfg.outdir, "eqtl_summary.csv"), index=False)
    coloc_pp.to_csv(os.path.join(cfg.outdir, "coloc_posteriors.csv"), index=False)
    crispr.to_csv(os.path.join(cfg.outdir, "crispr_screen.csv"), index=False)
    feats.to_csv(os.path.join(cfg.outdir, "gene_level_features.csv"), index=False)
    preds.to_csv(os.path.join(cfg.outdir, "ml_predictions_cv.csv"), index=False)
    expr_df.to_csv(os.path.join(cfg.outdir, "simulated_expression_matrix.csv"), index=False)

    with open(os.path.join(cfg.outdir, "ml_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)

    # 13) Figures, publishable PNG + PDF
    fig = plot_manhattan(cfg, snp_info, gwas_sumstats)
    savefig(fig, os.path.join(cfg.outdir, "Fig1_locus_manhattan"), dpi=cfg.dpi)

    fig = plot_qq(cfg, gwas_sumstats)
    savefig(fig, os.path.join(cfg.outdir, "Fig2_qq_plot"), dpi=cfg.dpi)

    fig = plot_pips(cfg, snp_info, pips)
    savefig(fig, os.path.join(cfg.outdir, "Fig3_finemap_pips"), dpi=cfg.dpi)

    fig = plot_coloc_pp4(cfg, coloc_pp)
    savefig(fig, os.path.join(cfg.outdir, "Fig4_coloc_pp4"), dpi=cfg.dpi)

    fig = plot_crispr(cfg, crispr)
    savefig(fig, os.path.join(cfg.outdir, "Fig5_crispr_volcano"), dpi=cfg.dpi)

    figs = plot_model_curves(cfg, preds)
    savefig(figs["roc"], os.path.join(cfg.outdir, "Fig6A_ml_roc"), dpi=cfg.dpi)
    savefig(figs["pr"], os.path.join(cfg.outdir, "Fig6B_ml_pr"), dpi=cfg.dpi)
    savefig(figs["calibration"], os.path.join(cfg.outdir, "Fig6C_ml_calibration"), dpi=cfg.dpi)

    best_model_name = max(metrics.keys(), key=lambda k: metrics[k]["roc_auc_mean"])
    best_model = fitted_last[best_model_name]
    fig = plot_permutation_importance(cfg, best_model, feats[feature_cols], y_label, model_name=best_model_name)
    savefig(fig, os.path.join(cfg.outdir, "Fig7_feature_importance"), dpi=cfg.dpi)

    # 14) Console summary
    print("\nSaved outputs to:", os.path.abspath(cfg.outdir))
    print("\nClass balance, positives:", int(y_label.sum()), "negatives:", int((y_label == 0).sum()))
    print("\nML metrics, mean across CV folds.")
    for k, v in metrics.items():
        print(f"  {k}, ROC-AUC={v['roc_auc_mean']:.3f}, AP={v['avg_precision_mean']:.3f}, Brier={v['brier_mean']:.3f}, folds={int(v['n_splits'])}")

    print("\nTop colocalized genes (PP.H4).")
    print(coloc_pp[["gene", "PP.H4", "PP.H3", "PP.H1", "PP.H2"]].head(10).to_string(index=False))

    print("\nTop fine-mapped SNPs by PIP.")
    top_pip = pips.merge(snp_info[["snp", "pos", "maf"]], on="snp", how="left").sort_values("pip", ascending=False).head(10)[["snp", "pos", "maf", "pip"]]
    print(top_pip.to_string(index=False))


if __name__ == "__main__":
    main()

  ax.set_xticklabels(df["gene"], rotation=45, ha="right")



Saved outputs to: /Users/petalc01/GSK Flow Biomarker Assay + ML Evaluation Demo/gsk_genomic_technologies_poc_outputs

Class balance, positives: 8 negatives: 17

ML metrics, mean across CV folds.
  logreg, ROC-AUC=1.000, AP=1.000, Brier=0.055, folds=5
  hgb, ROC-AUC=0.500, AP=0.320, Brier=0.223, folds=5
  rf, ROC-AUC=1.000, AP=1.000, Brier=0.027, folds=5

Top colocalized genes (PP.H4).
  gene    PP.H4    PP.H3    PP.H1         PP.H2
GENE21 0.066666 0.266544 0.666790 6.292064e-221
GENE20 0.066666 0.266544 0.666790 6.292069e-221
GENE24 0.066665 0.266547 0.666787 6.292138e-221
GENE25 0.066665 0.266548 0.666787 6.292150e-221
GENE02 0.066665 0.266548 0.666787 6.292155e-221
GENE22 0.066665 0.266548 0.666787 6.292156e-221
GENE06 0.066665 0.266549 0.666786 6.292176e-221
GENE23 0.066665 0.266549 0.666786 6.292176e-221
GENE05 0.066665 0.266549 0.666786 6.292179e-221
GENE19 0.066665 0.266550 0.666785 6.292208e-221

Top fine-mapped SNPs by PIP.
       snp    pos      maf           pip
rsSIM03338 8