# BioAI 
## Translating ultra-high throughput perturbation screening outputs into high-confidence CMD drug target candidates

This notebook simulates multi-modal perturbation screens and performs:
- Multi-modal embedding fusion (Cell Painting + DRUG-seq/Perturb-seq-like signatures)
- Batch correction + replicate QC
- Similarity graph construction
- **Graph Neural Network (GNN)** target scoring on the perturbation graph
- **Calibration** + uncertainty quantification
- **Bayesian evidence integration** into posterior target probabilities
- Automated **target evidence bundle reports** (HTML) for top hits

# Proof-of-Concept: Translating ultra-high throughput perturbation screening outputs into high-confidence drug target candidates for cardiometabolic disease (CMD).

This script simulates a realistic multi-modal perturbation dataset and implements
an end-to-end target discovery pipeline:

1) Simulate large-scale perturbation screens
   - Genetic perturbations (CRISPR KO / CRISPRi-like)
   - Compound perturbations (small molecules)
   - Cell Painting-like morphology embeddings (high-dimensional)
   - DRUG-seq / Perturb-seq-like transcriptomics (gene expression signatures)
   - Batch effects + technical noise + replicate structure

2) Quality control and normalization
   - Replicate consistency metrics
   - Batch correction (simple, practical approximation)
   - Feature scaling

3) Systems biology + phenotype AI integration
   - Learn a joint representation (PCA + multi-modal fusion)
   - Build a perturbation similarity graph
   - Cluster phenotypic neighborhoods
   - Map compound neighborhoods to genetic perturbations (MoA-style linking)

4) Target nomination and scoring
   - Evidence integration across modalities:
       (a) CMD relevance prediction (supervised, with simulated labels)
       (b) Phenotype strength and reproducibility
       (c) Cross-modal concordance (morphology ↔ transcriptomics)
       (d) Compound-genetic agreement (chemical profile matches gene perturbation)
       (e) Network centrality (graph importance in phenotype space)
   - Generate ranked high-confidence target candidates

5) Reporting
   - Produce publishable-quality figures
   - Export candidate tables as CSV
   - Export compound→gene similarity links

## Author: Mark I.R. Petalcorin (application-ready PoC)

In [3]:
from __future__ import annotations

import os
import json
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.calibration import CalibratedClassifierCV, calibration_curve

# Reproducibility

SEED = 7
np.random.seed(SEED)
random.seed(SEED)

# Configuration

@dataclass
class SimConfig:
    n_genes: int = 1200                  # genetic perturbations (targets)
    n_compounds: int = 600               # compound perturbations
    n_controls: int = 80                 # DMSO / non-targeting guides
    n_batches: int = 6
    n_cell_lines: int = 3                # phenotype context variability
    replicates_per_cond: int = 4

    morph_dim: int = 512                 # Cell Painting embedding dimensionality
    tx_dim: int = 256                    # transcriptomic signature embedding dim

    n_cmd_true_targets: int = 60         # number of true CMD-driving genes
    n_cmd_modules: int = 8               # latent pathway modules driving CMD phenotypes

    effect_sparsity: float = 0.18        # fraction of features affected per module
    gene_effect_scale: float = 1.2
    compound_effect_scale: float = 1.0

    batch_effect_scale_morph: float = 0.50
    batch_effect_scale_tx: float = 0.45
    technical_noise_morph: float = 0.65
    technical_noise_tx: float = 0.70

    compound_gene_link_prob: float = 0.20  # compounds that map to a gene/module neighborhood
    off_target_noise: float = 0.30         # compound off-target phenotype complexity

    knn_k: int = 25
    n_clusters: int = 30

    outdir: str = "cmd_target_discovery_poc_outputs"

# Utility functions

def set_plot_defaults() -> None:
    plt.rcParams["figure.dpi"] = 160
    plt.rcParams["savefig.dpi"] = 300
    plt.rcParams["font.size"] = 10


def zscore_by_group(X: np.ndarray, groups: np.ndarray) -> np.ndarray:
    """
    Simple per-batch standardization (lightweight batch correction).
    """
    Xc = X.copy()
    for g in np.unique(groups):
        idx = np.where(groups == g)[0]
        mu = Xc[idx].mean(axis=0, keepdims=True)
        sd = Xc[idx].std(axis=0, keepdims=True) + 1e-8
        Xc[idx] = (Xc[idx] - mu) / sd
    return Xc


def cosine_sim_matrix(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    """
    Cosine similarity matrix between A and B, requires same feature dim.
    """
    if A.shape[1] != B.shape[1]:
        raise ValueError(f"Cosine similarity requires same feature dim. Got {A.shape} vs {B.shape}")
    A_norm = A / (np.linalg.norm(A, axis=1, keepdims=True) + 1e-9)
    B_norm = B / (np.linalg.norm(B, axis=1, keepdims=True) + 1e-9)
    return A_norm @ B_norm.T


def robust_minmax(x: np.ndarray) -> np.ndarray:
    lo, hi = np.percentile(x, 5), np.percentile(x, 95)
    return np.clip((x - lo) / (hi - lo + 1e-9), 0.0, 1.0)

# Simulation: perturbation screen

class PerturbationSimulator:
    def __init__(self, cfg: SimConfig):
        self.cfg = cfg

        self.module_masks_morph = self._make_sparse_masks(cfg.n_cmd_modules, cfg.morph_dim, cfg.effect_sparsity)
        self.module_masks_tx = self._make_sparse_masks(cfg.n_cmd_modules, cfg.tx_dim, cfg.effect_sparsity)

        self.module_effects_morph = self._make_module_effects(cfg.n_cmd_modules, cfg.morph_dim, self.module_masks_morph)
        self.module_effects_tx = self._make_module_effects(cfg.n_cmd_modules, cfg.tx_dim, self.module_masks_tx)

        self.cmd_targets = np.sort(np.random.choice(cfg.n_genes, size=cfg.n_cmd_true_targets, replace=False))
        self.gene_module = np.random.choice(cfg.n_cmd_modules, size=cfg.n_genes, replace=True)
        self.cmd_modules = np.random.choice(cfg.n_cmd_modules, size=max(2, cfg.n_cmd_modules // 3), replace=False)

    @staticmethod
    def _make_sparse_masks(n_modules: int, dim: int, sparsity: float) -> np.ndarray:
        masks = np.zeros((n_modules, dim), dtype=float)
        for m in range(n_modules):
            k = max(5, int(dim * sparsity))
            idx = np.random.choice(dim, size=k, replace=False)
            masks[m, idx] = 1.0
        return masks

    @staticmethod
    def _make_module_effects(n_modules: int, dim: int, masks: np.ndarray) -> np.ndarray:
        effects = np.random.normal(0, 1.0, size=(n_modules, dim))
        effects *= masks
        norms = np.linalg.norm(effects, axis=1, keepdims=True) + 1e-9
        return effects / norms

    def simulate(self) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray]:
        cfg = self.cfg

        gene_ids = [f"GENE_{i:04d}" for i in range(cfg.n_genes)]
        cmpd_ids = [f"CMPD_{i:04d}" for i in range(cfg.n_compounds)]
        ctrl_ids = [f"CTRL_{i:03d}" for i in range(cfg.n_controls)]

        cmpd_linked = np.random.rand(cfg.n_compounds) < cfg.compound_gene_link_prob
        cmpd_module = np.random.choice(cfg.n_cmd_modules, size=cfg.n_compounds, replace=True)

        batch_effect_morph = np.random.normal(0, cfg.batch_effect_scale_morph, size=(cfg.n_batches, cfg.morph_dim))
        batch_effect_tx = np.random.normal(0, cfg.batch_effect_scale_tx, size=(cfg.n_batches, cfg.tx_dim))
        cellline_effect_morph = np.random.normal(0, 0.30, size=(cfg.n_cell_lines, cfg.morph_dim))
        cellline_effect_tx = np.random.normal(0, 0.35, size=(cfg.n_cell_lines, cfg.tx_dim))

        rows: List[dict] = []
        morph_list: List[np.ndarray] = []
        tx_list: List[np.ndarray] = []

        def add_obs(entity_type: str,
                    entity_id: str,
                    module_idx: Optional[int],
                    is_cmd_true: bool,
                    batch: int,
                    cell_line: int,
                    replicate: int,
                    potency: float) -> None:

            base_m = np.random.normal(0, 0.25, size=(cfg.morph_dim,))
            base_t = np.random.normal(0, 0.25, size=(cfg.tx_dim,))

            if module_idx is None:
                eff_m = np.zeros(cfg.morph_dim)
                eff_t = np.zeros(cfg.tx_dim)
            else:
                eff_m = cfg.gene_effect_scale * potency * self.module_effects_morph[module_idx]
                eff_t = cfg.gene_effect_scale * potency * self.module_effects_tx[module_idx]

            if entity_type == "compound":
                mix = np.random.normal(0, cfg.off_target_noise, size=(cfg.n_cmd_modules,))
                mix = mix / (np.linalg.norm(mix) + 1e-9)
                eff_m = cfg.compound_effect_scale * potency * (
                    eff_m + 0.45 * sum(mix[k] * self.module_effects_morph[k] for k in range(cfg.n_cmd_modules))
                )
                eff_t = cfg.compound_effect_scale * potency * (
                    eff_t + 0.45 * sum(mix[k] * self.module_effects_tx[k] for k in range(cfg.n_cmd_modules))
                )

            xm = base_m + eff_m + batch_effect_morph[batch] + cellline_effect_morph[cell_line]
            xt = base_t + eff_t + batch_effect_tx[batch] + cellline_effect_tx[cell_line]

            xm += np.random.normal(0, cfg.technical_noise_morph, size=(cfg.morph_dim,))
            xt += np.random.normal(0, cfg.technical_noise_tx, size=(cfg.tx_dim,))

            morph_list.append(xm.astype(np.float32))
            tx_list.append(xt.astype(np.float32))

            rows.append({
                "entity_type": entity_type,
                "entity_id": entity_id,
                "module": -1 if module_idx is None else int(module_idx),
                "is_cmd_true_target": int(is_cmd_true),
                "batch": int(batch),
                "cell_line": int(cell_line),
                "replicate": int(replicate),
                "potency": float(potency),
            })

        cmd_target_set = set(self.cmd_targets.tolist())

        # Genes
        for gi, gid in enumerate(gene_ids):
            is_cmd = int(gi in cmd_target_set)
            module_idx = int(self.gene_module[gi])

            if is_cmd:
                module_idx = int(np.random.choice(self.cmd_modules))
                potency_base = np.random.uniform(0.9, 1.5)
            else:
                potency_base = np.random.uniform(0.2, 1.0)

            for cl in range(cfg.n_cell_lines):
                for r in range(cfg.replicates_per_cond):
                    batch = np.random.randint(0, cfg.n_batches)
                    potency = max(0.05, potency_base + np.random.normal(0, 0.08))
                    add_obs("gene", gid, module_idx, bool(is_cmd), batch, cl, r, potency)

        # Compounds
        for ci, cid in enumerate(cmpd_ids):
            linked = bool(cmpd_linked[ci])
            module_idx = int(cmpd_module[ci]) if linked else int(np.random.choice(cfg.n_cmd_modules))
            potency_base = np.random.uniform(0.15, 1.25) if linked else np.random.uniform(0.05, 0.85)

            for cl in range(cfg.n_cell_lines):
                for r in range(cfg.replicates_per_cond):
                    batch = np.random.randint(0, cfg.n_batches)
                    potency = max(0.02, potency_base + np.random.normal(0, 0.10))
                    add_obs("compound", cid, module_idx, False, batch, cl, r, potency)

        # Controls
        for ctrl in ctrl_ids:
            for cl in range(cfg.n_cell_lines):
                for r in range(cfg.replicates_per_cond):
                    batch = np.random.randint(0, cfg.n_batches)
                    potency = np.random.uniform(0.0, 0.1)
                    add_obs("control", ctrl, None, False, batch, cl, r, potency)

        meta_df = pd.DataFrame(rows)
        X_morph = np.vstack(morph_list)
        X_tx = np.vstack(tx_list)
        return meta_df, X_morph, X_tx

# Pipeline: target discovery

class CMDTargetDiscoveryPoC:
    def __init__(self, cfg: SimConfig):
        self.cfg = cfg
        set_plot_defaults()

    @staticmethod
    def replicate_consistency(meta: pd.DataFrame, X_morph: np.ndarray, X_tx: np.ndarray) -> pd.DataFrame:
        out = []
        for eid, sub_idx in meta.groupby("entity_id").indices.items():
            idx = np.array(list(sub_idx))
            if len(idx) < 4:
                continue
            Xm = X_morph[idx]
            Xt = X_tx[idx]
            mean_m = Xm.mean(axis=0, keepdims=True)
            mean_t = Xt.mean(axis=0, keepdims=True)
            cm = cosine_sim_matrix(Xm, mean_m).flatten()
            ct = cosine_sim_matrix(Xt, mean_t).flatten()
            out.append({
                "entity_id": eid,
                "entity_type": meta.loc[idx[0], "entity_type"],
                "n_obs": int(len(idx)),
                "morph_repl_corr": float(np.mean(cm)),
                "tx_repl_corr": float(np.mean(ct)),
            })
        return (
            pd.DataFrame(out)
            .sort_values(["entity_type", "morph_repl_corr"], ascending=[True, False])
            .reset_index(drop=True)
        )

    @staticmethod
    def aggregate_profiles(meta: pd.DataFrame, X_morph: np.ndarray, X_tx: np.ndarray) -> pd.DataFrame:
        rows, mp, tp = [], [], []
        for eid, sub in meta.groupby("entity_id"):
            idx = sub.index.values
            mp.append(X_morph[idx].mean(axis=0))
            tp.append(X_tx[idx].mean(axis=0))
            rows.append({
                "entity_id": eid,
                "entity_type": sub["entity_type"].iloc[0],
                "module": int(sub["module"].iloc[0]),
                "is_cmd_true_target": int(sub["is_cmd_true_target"].max()),
                "n_obs": int(len(idx)),
                "potency_mean": float(sub["potency"].mean()),
            })
        prof = pd.DataFrame(rows)
        prof["morph_profile"] = list(mp)
        prof["tx_profile"] = list(tp)
        return prof

    @staticmethod
    def joint_embedding(morph_profiles: pd.Series, tx_profiles: pd.Series) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        Xm = np.vstack(morph_profiles.values)
        Xt = np.vstack(tx_profiles.values)

        Zm = StandardScaler().fit_transform(
            PCA(n_components=min(50, Xm.shape[1]), random_state=SEED).fit_transform(Xm)
        )
        Zt = StandardScaler().fit_transform(
            PCA(n_components=min(40, Xt.shape[1]), random_state=SEED).fit_transform(Xt)
        )

        Z = np.hstack([Zm, Zt]).astype(np.float32)
        return Z, Zm.astype(np.float32), Zt.astype(np.float32)

    @staticmethod
    def build_knn_graph(Z: np.ndarray, k: int) -> Dict[int, List[int]]:
        nbrs = NearestNeighbors(n_neighbors=min(k + 1, len(Z)), metric="cosine").fit(Z)
        _, inds = nbrs.kneighbors(Z)
        return {i: inds[i, 1:].tolist() for i in range(len(Z))}

    @staticmethod
    def cluster(Z: np.ndarray, n_clusters: int) -> np.ndarray:
        km = KMeans(n_clusters=min(n_clusters, len(Z) // 10 + 2), random_state=SEED, n_init="auto")
        return km.fit_predict(Z)

    @staticmethod
    def train_cmd_classifier(X: np.ndarray, y: np.ndarray) -> Tuple[CalibratedClassifierCV, dict]:
        Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25, random_state=SEED, stratify=y)

        base = RandomForestClassifier(
            n_estimators=500,
            min_samples_split=4,
            min_samples_leaf=2,
            random_state=SEED,
            n_jobs=-1,
            class_weight="balanced_subsample"
        ).fit(Xtr, ytr)

        cal = CalibratedClassifierCV(base, method="isotonic", cv=3).fit(Xtr, ytr)

        p_raw = base.predict_proba(Xte)[:, 1]
        p_cal = cal.predict_proba(Xte)[:, 1]

        perf = {
            "roc_auc_raw": float(roc_auc_score(yte, p_raw)),
            "ap_raw": float(average_precision_score(yte, p_raw)),
            "brier_raw": float(brier_score_loss(yte, p_raw)),
            "roc_auc_cal": float(roc_auc_score(yte, p_cal)),
            "ap_cal": float(average_precision_score(yte, p_cal)),
            "brier_cal": float(brier_score_loss(yte, p_cal)),
            "n_test": int(len(yte)),
        }
        return cal, perf

    @staticmethod
    def link_compounds_to_genes(prof: pd.DataFrame, Z: np.ndarray, top_n_links: int = 5) -> pd.DataFrame:
        idx_comp = np.where(prof["entity_type"].values == "compound")[0]
        idx_gene = np.where(prof["entity_type"].values == "gene")[0]

        sim_cg = cosine_sim_matrix(Z[idx_comp], Z[idx_gene])
        links = []
        for i_local, i in enumerate(idx_comp):
            sims = sim_cg[i_local]
            topk = np.argsort(-sims)[:top_n_links]
            for j in topk:
                g_idx = idx_gene[j]
                links.append({
                    "compound_id": prof.loc[i, "entity_id"],
                    "gene_id": prof.loc[g_idx, "entity_id"],
                    "cosine_similarity": float(sims[j]),
                    "compound_cluster": int(prof.loc[i, "cluster"]),
                    "gene_cluster": int(prof.loc[g_idx, "cluster"]),
                })
        return (
            pd.DataFrame(links)
            .sort_values(["compound_id", "cosine_similarity"], ascending=[True, False])
            .reset_index(drop=True)
        )

    def score_targets(
        self,
        prof: pd.DataFrame,
        Z: np.ndarray,
        Zm: np.ndarray,
        Zt: np.ndarray,
        graph: Dict[int, List[int]],
        links_df: pd.DataFrame
    ) -> pd.DataFrame:
        """
        Fix for your error:
        Zm is (N, 50) while Zt is (N, 40), so elementwise multiply fails.
        We compute cross-modal concordance by truncating both modality blocks to the same dimension.
        """
        df = prof.copy()

        # phenotype strength relative to controls
        ctrl_idx = np.where(df["entity_type"].values == "control")[0]
        ctrl_centroid = Z[ctrl_idx].mean(axis=0, keepdims=True)
        df["phenotype_strength"] = np.linalg.norm(Z - ctrl_centroid, axis=1)

        # FIXED CROSS-MODAL CONCORDANCE
        # Normalize each modality PCA block row-wise, then compute cosine similarity per sample.
        ZmN = Zm / (np.linalg.norm(Zm, axis=1, keepdims=True) + 1e-9)
        ZtN = Zt / (np.linalg.norm(Zt, axis=1, keepdims=True) + 1e-9)

        # Truncate to the same dimension to avoid broadcasting error
        d = min(ZmN.shape[1], ZtN.shape[1])  # d = 40
        df["cross_modal_concordance"] = np.sum(ZmN[:, :d] * ZtN[:, :d], axis=1)

        # graph centrality via inbound degree
        inbound = np.zeros(len(df), dtype=int)
        for i, neigh in graph.items():
            for j in neigh:
                inbound[j] += 1
        df["graph_centrality"] = inbound.astype(float)

        # compound support per gene (strong links)
        strong = links_df[links_df["cosine_similarity"] > 0.30].copy()
        support = strong.groupby("gene_id")["cosine_similarity"].sum()
        count = strong.groupby("gene_id").size()

        df["compound_support"] = df["entity_id"].map(support).fillna(0.0)
        df["compound_support_count"] = df["entity_id"].map(count).fillna(0).astype(int)

        genes = df[df["entity_type"] == "gene"].copy()

        # Normalize evidence channels
        genes["cmd_prob_norm"] = robust_minmax(genes["cmd_prob"].values)
        genes["phenotype_strength_norm"] = robust_minmax(genes["phenotype_strength"].values)
        genes["cross_modal_concordance_norm"] = robust_minmax(genes["cross_modal_concordance"].values)
        genes["compound_support_norm"] = robust_minmax(genes["compound_support"].values)
        genes["graph_centrality_norm"] = robust_minmax(genes["graph_centrality"].values)

        # Integrated evidence score
        w = {
            "cmd_prob_norm": 0.38,
            "phenotype_strength_norm": 0.18,
            "cross_modal_concordance_norm": 0.14,
            "compound_support_norm": 0.20,
            "graph_centrality_norm": 0.10,
        }
        genes["integrated_score"] = (
            w["cmd_prob_norm"] * genes["cmd_prob_norm"] +
            w["phenotype_strength_norm"] * genes["phenotype_strength_norm"] +
            w["cross_modal_concordance_norm"] * genes["cross_modal_concordance_norm"] +
            w["compound_support_norm"] * genes["compound_support_norm"] +
            w["graph_centrality_norm"] * genes["graph_centrality_norm"]
        )

        cols = [
            "entity_id", "cluster", "module",
            "integrated_score",
            "cmd_prob",
            "phenotype_strength",
            "cross_modal_concordance",
            "compound_support", "compound_support_count",
            "graph_centrality",
            "is_cmd_true_target",
        ]
        return genes.sort_values("integrated_score", ascending=False).reset_index(drop=True)[cols]

    @staticmethod
    def make_figures(
        outdir: str,
        qc_df: pd.DataFrame,
        prof: pd.DataFrame,
        Z: np.ndarray,
        yte: np.ndarray,
        p_raw: np.ndarray,
        p_cal: np.ndarray,
        ranked: pd.DataFrame
    ) -> None:
        figdir = os.path.join(outdir, "figures")
        os.makedirs(figdir, exist_ok=True)

        # Fig 1: replicate consistency
        plt.figure(figsize=(6.2, 4.2))
        plt.hist(qc_df["morph_repl_corr"], bins=40, alpha=0.8, label="Morph")
        plt.hist(qc_df["tx_repl_corr"], bins=40, alpha=0.8, label="Tx")
        plt.xlabel("Replicate-to-mean cosine similarity")
        plt.ylabel("Count (conditions)")
        plt.title("QC: Replicate consistency across modalities")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(figdir, "fig1_qc_replicate_consistency.png"))
        plt.close()

        # Fig 2: embedding landscape
        p2 = PCA(n_components=2, random_state=SEED).fit_transform(Z)
        plt.figure(figsize=(6.0, 5.0))
        for t in ["control", "compound", "gene"]:
            idx = np.where(prof["entity_type"].values == t)[0]
            plt.scatter(p2[idx, 0], p2[idx, 1], s=10, alpha=0.65, label=t)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title("Joint embedding landscape (genes, compounds, controls)")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(figdir, "fig2_joint_embedding_pca.png"))
        plt.close()

        # Fig 3: calibration curve
        frac_pos_raw, mean_pred_raw = calibration_curve(yte, p_raw, n_bins=10)
        frac_pos_cal, mean_pred_cal = calibration_curve(yte, p_cal, n_bins=10)
        plt.figure(figsize=(5.5, 4.2))
        plt.plot(mean_pred_raw, frac_pos_raw, marker="o", label="Raw RF")
        plt.plot(mean_pred_cal, frac_pos_cal, marker="o", label="Calibrated RF")
        plt.plot([0, 1], [0, 1], linestyle="--", label="Ideal")
        plt.xlabel("Mean predicted probability")
        plt.ylabel("Fraction of positives")
        plt.title("CMD relevance probability calibration")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(figdir, "fig3_calibration_curve.png"))
        plt.close()

        # Fig 4: integrated score separation
        plt.figure(figsize=(6.2, 4.2))
        scores = ranked["integrated_score"].values
        ytrue = ranked["is_cmd_true_target"].values.astype(int)
        plt.hist(scores[ytrue == 0], bins=40, alpha=0.8, label="Other genes")
        plt.hist(scores[ytrue == 1], bins=40, alpha=0.8, label="True CMD genes (simulated)")
        plt.xlabel("Integrated target score")
        plt.ylabel("Count (genes)")
        plt.title("Target prioritization separation (simulation)")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(figdir, "fig4_integrated_score_separation.png"))
        plt.close()

        # Fig 5: top targets
        top_n = min(20, len(ranked))
        top = ranked.head(top_n)
        x = np.arange(len(top))
        plt.figure(figsize=(7.6, 4.2))
        plt.bar(x, top["integrated_score"].values)
        plt.xticks(x, top["entity_id"].values, rotation=70, ha="right")
        plt.ylabel("Integrated score")
        plt.title(f"Top {top_n} nominated gene targets (integrated evidence)")
        plt.tight_layout()
        plt.savefig(os.path.join(figdir, f"fig5_top_{top_n}_targets.png"))
        plt.close()

    def run(self) -> None:
        cfg = self.cfg
        os.makedirs(cfg.outdir, exist_ok=True)

        # 1) simulate dataset
        sim = PerturbationSimulator(cfg)
        meta, X_morph, X_tx = sim.simulate()
        meta.to_csv(os.path.join(cfg.outdir, "meta_raw.csv"), index=False)

        # 2) QC replicate consistency
        qc_df = self.replicate_consistency(meta, X_morph, X_tx)
        qc_df.to_csv(os.path.join(cfg.outdir, "qc_replicate_consistency.csv"), index=False)

        # keep reproducible entities
        keep_entities = qc_df.loc[
            (qc_df["morph_repl_corr"] > 0.10) & (qc_df["tx_repl_corr"] > 0.08),
            "entity_id"
        ].tolist()
        keep_mask = meta["entity_id"].isin(keep_entities).values

        meta_f = meta.loc[keep_mask].reset_index(drop=True)
        X_morph_f = X_morph[keep_mask]
        X_tx_f = X_tx[keep_mask]

        # 3) batch correction + scaling
        batches = meta_f["batch"].values
        X_morph_s = StandardScaler().fit_transform(zscore_by_group(X_morph_f, batches))
        X_tx_s = StandardScaler().fit_transform(zscore_by_group(X_tx_f, batches))

        # 4) aggregate to condition profiles
        prof = self.aggregate_profiles(meta_f, X_morph_s, X_tx_s)

        # 5) joint embedding + graph + clusters
        Z, Zm, Zt = self.joint_embedding(prof["morph_profile"], prof["tx_profile"])
        graph = self.build_knn_graph(Z, k=cfg.knn_k)
        prof["cluster"] = self.cluster(Z, n_clusters=cfg.n_clusters)

        # 6) CMD classifier trained on genes only
        gene_mask = (prof["entity_type"].values == "gene")
        X_gene = Z[gene_mask]
        y_gene = prof.loc[gene_mask, "is_cmd_true_target"].values.astype(int)

        model, perf = self.train_cmd_classifier(X_gene, y_gene)
        with open(os.path.join(cfg.outdir, "model_performance.json"), "w") as f:
            json.dump(perf, f, indent=2)

        # calibration plot data
        Xtr, Xte, ytr, yte = train_test_split(X_gene, y_gene, test_size=0.25, random_state=SEED, stratify=y_gene)
        base_rf = RandomForestClassifier(
            n_estimators=500,
            min_samples_split=4,
            min_samples_leaf=2,
            random_state=SEED,
            n_jobs=-1,
            class_weight="balanced_subsample"
        ).fit(Xtr, ytr)
        p_raw = base_rf.predict_proba(Xte)[:, 1]
        p_cal = model.predict_proba(Xte)[:, 1]

        # predict across all profiles
        prof["cmd_prob"] = model.predict_proba(Z)[:, 1]

        # 7) compound→gene linking
        links_df = self.link_compounds_to_genes(prof, Z, top_n_links=5)
        links_df.to_csv(os.path.join(cfg.outdir, "compound_gene_links.csv"), index=False)

        # 8) integrated scoring (FIXED)
        ranked = self.score_targets(prof, Z, Zm, Zt, graph, links_df)
        ranked.to_csv(os.path.join(cfg.outdir, "ranked_targets.csv"), index=False)

        # 9) figures
        self.make_figures(cfg.outdir, qc_df, prof, Z, yte, p_raw, p_cal, ranked)

        # summary
        print("\n=== Top 20 Target Candidates (Proof-of-Concept) ===")
        cols = [
            "entity_id", "integrated_score", "cmd_prob", "phenotype_strength",
            "cross_modal_concordance", "compound_support", "graph_centrality", "cluster"
        ]
        print(ranked.head(20)[cols].to_string(index=False))
        print(f"\nOutputs written to: {cfg.outdir}/")
        print("Key outputs: ranked_targets.csv, compound_gene_links.csv, qc_replicate_consistency.csv, figures/*.png")

# Main

def main() -> None:
    cfg = SimConfig()
    CMDTargetDiscoveryPoC(cfg).run()

if __name__ == "__main__":
    main()


=== Top 20 Target Candidates (Proof-of-Concept) ===
entity_id  integrated_score  cmd_prob  phenotype_strength  cross_modal_concordance  compound_support  graph_centrality  cluster
GENE_0961          1.000000  1.000000           11.762787                 0.316419          1.290323              54.0       25
GENE_0684          1.000000  0.888889           11.187452                 0.359706          1.367465              52.0       23
GENE_0599          0.992164  0.814815           11.873228                 0.344218          1.079352              54.0       25
GENE_1111          0.958542  1.000000           10.826568                 0.378805          0.971952              64.0       23
GENE_0437          0.956163  0.999702           10.460979                 0.282492          1.153722              47.0       23
GENE_0898          0.952348  0.614815           11.356086                 0.325608          0.877319              41.0       23
GENE_0731          0.945459  0.888889           10.