# OFFTOXv3 — Safety Pharmacology & NHR Prediction Workflow

End-to-end workflow for predicting compound activity against the **24-target safety pharmacology panel** using a 2-class binding scheme:

| Class | Label | Definition |
|-------|-------|------------|
| 1 | **binding** | pChEMBL >= 5.0 (IC50/Ki < 10 uM) |
| 0 | **non_binding** | pChEMBL < 5.0 (>= 10 uM) or confirmed inactive |

**24-Target Panel:**
- **Nuclear Hormone Receptors (14):** ERa, ER_beta, AR, GR, PR, MR, PPARg, PXR, CAR, LXRa, LXRb, FXR, RXRa, VDR
- **Cardiac ion channels (3):** hERG, Cav1.2, Nav1.5
- **CYP enzymes (5):** CYP3A4, CYP2D6, CYP2C9, CYP1A2, CYP2C19
- **Transporters (2):** P-gp, BSEP

**Data filters:** IC50 and Ki values with pChEMBL >= 4.0; at least 30 drug-like non-binders per target.

---

## How to use this notebook

1. **Run cells 1-2** to set up and optionally refresh data from ChEMBL.
2. **Run cells 3-9** to train and evaluate models on the training data.
3. **Cell 10** evaluates on a held-out test set not used in training.
4. **Cell 11** (Predict New Compounds) accepts any CSV with `compound_id`, `smiles`, and `target` columns.
5. **Cell 12** generates the analysis report (Markdown).
6. All visualizations render inline. Outputs are also saved to `outputs/`.

---
## 1. Setup & Configuration

In [None]:
# ── Install dependencies (run once) ──────────────────────────────────
# Uncomment the line below if running for the first time:
# !pip install numpy<2 pandas scikit-learn xgboost lightgbm rdkit-pypi matplotlib seaborn scipy joblib requests

import json
import csv
import time
import warnings
import pickle
import hashlib
import requests
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy import stats

from rdkit import Chem
from rdkit.Chem import Crippen, Descriptors, Lipinski, MolSurf, rdFingerprintGenerator
from rdkit.Chem.Scaffolds import MurckoScaffold

from sklearn.calibration import CalibratedClassifierCV, calibration_curve
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    matthews_corrcoef,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
    classification_report,
)
from sklearn.model_selection import RandomizedSearchCV, RepeatedStratifiedKFold
from sklearn.neighbors import NearestNeighbors
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier

warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid", font_scale=1.1)
%matplotlib inline

# ── Paths ────────────────────────────────────────────────────────────
NOTEBOOK_DIR = Path(".").resolve()
DATA_PATH = NOTEBOOK_DIR / "data" / "safety_targets_bioactivity.csv"
TEST_PATH = NOTEBOOK_DIR / "data" / "test_compounds.csv"
OUTPUT_DIR = NOTEBOOK_DIR / "outputs"
MODEL_DIR  = NOTEBOOK_DIR / "model"
OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)

# ── 24-Target Safety Panel ───────────────────────────────────────────
TARGET_PANEL = {
    # Nuclear Hormone Receptors (14)
    "ERa":     {"chembl_id": "CHEMBL206",  "category": "Nuclear Hormone Receptor"},
    "ER_beta": {"chembl_id": "CHEMBL242",  "category": "Nuclear Hormone Receptor"},
    "AR":      {"chembl_id": "CHEMBL1871", "category": "Nuclear Hormone Receptor"},
    "GR":      {"chembl_id": "CHEMBL2034", "category": "Nuclear Hormone Receptor"},
    "PR":      {"chembl_id": "CHEMBL208",  "category": "Nuclear Hormone Receptor"},
    "MR":      {"chembl_id": "CHEMBL1994", "category": "Nuclear Hormone Receptor"},
    "PPARg":   {"chembl_id": "CHEMBL235",  "category": "Nuclear Hormone Receptor"},
    "PXR":     {"chembl_id": "CHEMBL3401", "category": "Nuclear Hormone Receptor"},
    "CAR":     {"chembl_id": "CHEMBL2248", "category": "Nuclear Hormone Receptor"},
    "LXRa":    {"chembl_id": "CHEMBL5231", "category": "Nuclear Hormone Receptor"},
    "LXRb":    {"chembl_id": "CHEMBL4309", "category": "Nuclear Hormone Receptor"},
    "FXR":     {"chembl_id": "CHEMBL2001", "category": "Nuclear Hormone Receptor"},
    "RXRa":    {"chembl_id": "CHEMBL2061", "category": "Nuclear Hormone Receptor"},
    "VDR":     {"chembl_id": "CHEMBL1977", "category": "Nuclear Hormone Receptor"},
    # Cardiac Safety (3)
    "hERG":    {"chembl_id": "CHEMBL240",  "category": "Cardiac Safety"},
    "Cav1.2":  {"chembl_id": "CHEMBL1940", "category": "Cardiac Safety"},
    "Nav1.5":  {"chembl_id": "CHEMBL1993", "category": "Cardiac Safety"},
    # Hepatotoxicity / CYP (5)
    "CYP3A4":  {"chembl_id": "CHEMBL340",  "category": "Hepatotoxicity"},
    "CYP2D6":  {"chembl_id": "CHEMBL289",  "category": "Hepatotoxicity"},
    "CYP2C9":  {"chembl_id": "CHEMBL3397", "category": "Hepatotoxicity"},
    "CYP1A2":  {"chembl_id": "CHEMBL3356", "category": "Hepatotoxicity"},
    "CYP2C19": {"chembl_id": "CHEMBL3622", "category": "Hepatotoxicity"},
    # Transporters (2)
    "P-gp":    {"chembl_id": "CHEMBL4302", "category": "Transporter"},
    "BSEP":    {"chembl_id": "CHEMBL4105", "category": "Transporter"},
}

# ── Constants (2-class: binding vs non-binding) ──────────────────────
RANDOM_STATE = 42
ACTIVITY_CLASS_MAP = {0: "non_binding", 1: "binding"}
CLASS_COLORS = {0: "#2ecc71", 1: "#e74c3c"}
NUM_CLASSES = 2
CHEMBL_BASE_URL = "https://www.ebi.ac.uk/chembl/api/data"

print(f"Data path  : {DATA_PATH}")
print(f"Test path  : {TEST_PATH}")
print(f"Output dir : {OUTPUT_DIR}")
print(f"Model dir  : {MODEL_DIR}")
print(f"Targets    : {len(TARGET_PANEL)} targets in panel")
print(f"Classes    : {NUM_CLASSES} ({', '.join(ACTIVITY_CLASS_MAP.values())})")
print("Setup complete.")

---
## 2. ChEMBL Data Retrieval (24-Target Panel)

Pull IC50 and Ki bioactivity data from ChEMBL for all 24 targets.

**Filters applied:**
- Activity types: IC50, Ki only
- Upper limit: < 100 uM (pChEMBL >= 4.0) to limit data volume and avoid pagination issues
- At least 30 negatives per target (compounds confirmed inactive at > 100 uM)
- Data quality: exact measurements only (`=` relation), valid SMILES required

**Note:** If the ChEMBL API is unreachable, the notebook falls back to the pre-built CSV at `data/safety_targets_bioactivity.csv`. To regenerate from ChEMBL, set `REFRESH_FROM_CHEMBL = True` below.

In [None]:
# ── ChEMBL data retrieval for 24-target panel ────────────────────────
# Set to True to pull fresh data from ChEMBL (requires internet access).
# Set to False to use the pre-built CSV.
REFRESH_FROM_CHEMBL = False

def fetch_chembl_activities(target_chembl_id, target_name, activity_types=("IC50", "Ki"),
                            pchembl_min=4.0, max_records=1000):
    """Fetch bioactivity data from ChEMBL REST API for one target.

    Limits to IC50/Ki with pChEMBL >= 4.0 (< 100 uM) to control data volume.
    Handles pagination with a hard cap to avoid runaway queries.
    """
    all_records = []
    for act_type in activity_types:
        offset = 0
        limit = 500  # conservative page size
        while offset < max_records:
            params = {
                "target_chembl_id": target_chembl_id,
                "standard_type": act_type,
                "pchembl_value__gte": pchembl_min,
                "standard_relation": "=",
                "limit": limit,
                "offset": offset,
                "format": "json",
            }
            for attempt in range(4):
                try:
                    resp = requests.get(f"{CHEMBL_BASE_URL}/activity.json",
                                        params=params, timeout=60)
                    resp.raise_for_status()
                    break
                except Exception as e:
                    if attempt < 3:
                        time.sleep(2 ** (attempt + 1))
                    else:
                        print(f"  FAILED {target_name}/{act_type} after 4 retries: {e}")
                        return all_records
            data = resp.json()
            activities = data.get("activities", [])
            if not activities:
                break
            for act in activities:
                smi = act.get("canonical_smiles")
                pval = act.get("pchembl_value")
                if not smi or not pval:
                    continue
                all_records.append({
                    "molecule_chembl_id": act.get("molecule_chembl_id", ""),
                    "canonical_smiles": smi,
                    "standard_type": act.get("standard_type", act_type),
                    "standard_relation": "=",
                    "standard_value": act.get("standard_value", ""),
                    "standard_units": act.get("standard_units", "nM"),
                    "pchembl_value": pval,
                    "activity_comment": act.get("activity_comment", ""),
                    "assay_chembl_id": act.get("assay_chembl_id", ""),
                    "assay_type": act.get("assay_type", ""),
                    "target_chembl_id": target_chembl_id,
                    "target_pref_name": act.get("target_pref_name", ""),
                    "document_chembl_id": act.get("document_chembl_id", ""),
                    "src_id": act.get("src_id", ""),
                    "data_validity_comment": act.get("data_validity_comment", ""),
                    "safety_category": TARGET_PANEL[target_name]["category"],
                    "target_common_name": target_name,
                })
            if len(activities) < limit:
                break
            offset += limit
            time.sleep(0.5)  # rate limiting
        print(f"  {target_name}/{act_type}: {len([r for r in all_records if r['standard_type']==act_type])} records")
    return all_records


def fetch_chembl_inactives(target_chembl_id, target_name, min_inactive=30, max_records=2000):
    """Fetch confirmed inactive compounds (right-censored > at >= 100 uM)."""
    inactive = []
    offset = 0
    limit = 500
    while offset < max_records and len(inactive) < min_inactive * 3:
        params = {
            "target_chembl_id": target_chembl_id,
            "standard_relation": ">",
            "standard_type__in": "IC50,Ki",
            "limit": limit,
            "offset": offset,
            "format": "json",
        }
        for attempt in range(4):
            try:
                resp = requests.get(f"{CHEMBL_BASE_URL}/activity.json",
                                    params=params, timeout=60)
                resp.raise_for_status()
                break
            except Exception:
                if attempt < 3:
                    time.sleep(2 ** (attempt + 1))
                else:
                    return inactive
        data = resp.json()
        activities = data.get("activities", [])
        if not activities:
            break
        for act in activities:
            smi = act.get("canonical_smiles")
            val = act.get("standard_value")
            if not smi:
                continue
            try:
                if float(val) >= 100000:  # >= 100 uM in nM
                    inactive.append({
                        "molecule_chembl_id": act.get("molecule_chembl_id", ""),
                        "canonical_smiles": smi,
                        "standard_type": act.get("standard_type", ""),
                        "standard_relation": ">",
                        "standard_value": val,
                        "standard_units": "nM",
                        "pchembl_value": "",
                        "activity_comment": "Confirmed inactive (> 100 uM)",
                        "assay_chembl_id": act.get("assay_chembl_id", ""),
                        "assay_type": act.get("assay_type", ""),
                        "target_chembl_id": target_chembl_id,
                        "target_pref_name": act.get("target_pref_name", ""),
                        "document_chembl_id": act.get("document_chembl_id", ""),
                        "src_id": act.get("src_id", ""),
                        "data_validity_comment": "",
                        "safety_category": TARGET_PANEL[target_name]["category"],
                        "target_common_name": target_name,
                        "activity_class": "0",
                        "activity_class_label": "non_binding",
                    })
            except (ValueError, TypeError):
                continue
        if len(activities) < limit:
            break
        offset += limit
        time.sleep(0.5)

    # Deduplicate by molecule
    seen = set()
    deduped = []
    for r in inactive:
        mid = r["molecule_chembl_id"]
        if mid not in seen:
            seen.add(mid)
            deduped.append(r)
    return deduped[:min_inactive + 10]


if REFRESH_FROM_CHEMBL:
    print("Pulling data from ChEMBL for 24 targets...")
    all_chembl_rows = []
    for tname, tinfo in TARGET_PANEL.items():
        print(f"\n--- {tname} ({tinfo['chembl_id']}) ---")
        active = fetch_chembl_activities(tinfo["chembl_id"], tname)
        # Assign 2-class activity labels
        for r in active:
            try:
                p = float(r["pchembl_value"])
                r["activity_class"] = "1" if p >= 5.0 else "0"
                r["activity_class_label"] = "binding" if p >= 5.0 else "non_binding"
            except (ValueError, TypeError):
                continue
        all_chembl_rows.extend(active)

        inact = fetch_chembl_inactives(tinfo["chembl_id"], tname, min_inactive=30)
        all_chembl_rows.extend(inact)
        print(f"  Total: {len(active)} active + {len(inact)} inactive")

    # Save
    fieldnames = [
        "molecule_chembl_id", "canonical_smiles", "standard_type",
        "standard_relation", "standard_value", "standard_units",
        "pchembl_value", "activity_comment", "assay_chembl_id",
        "assay_type", "target_chembl_id", "target_pref_name",
        "document_chembl_id", "src_id", "data_validity_comment",
        "safety_category", "target_common_name", "activity_class",
        "activity_class_label",
    ]
    with DATA_PATH.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in all_chembl_rows:
            writer.writerow({k: row.get(k, "") for k in fieldnames})
    print(f"\nSaved {len(all_chembl_rows)} records to {DATA_PATH}")
else:
    print(f"Using pre-built dataset at {DATA_PATH}")
    print("Set REFRESH_FROM_CHEMBL = True to pull fresh data from ChEMBL.")

---
## 3. Data Loading & Exploration

In [None]:
# ── Core data structures ─────────────────────────────────────────────
@dataclass
class SplitData:
    train_idx: np.ndarray
    val_idx: np.ndarray
    test_idx: np.ndarray


def load_and_clean_data(path: Path) -> List[dict]:
    """Load training CSV and assign 2-class activity labels.

    Classes:
        1 - binding:     pChEMBL >= 5.0  (< 10 uM)
        0 - non_binding: pChEMBL < 5.0  (>= 10 uM) or confirmed inactive
    """
    rows = []
    with path.open(newline="", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        for row in reader:
            smi = row.get("canonical_smiles")
            if not smi:
                continue

            raw_class = row.get("activity_class", "")
            label = row.get("activity_class_label", "")
            if raw_class == "0" or label in ("inactive", "non_binding"):
                row["pchembl_value"] = None
                row["activity_class"] = 0
                rows.append(row)
                continue

            if row.get("standard_relation") != "=":
                continue
            if not row.get("pchembl_value"):
                continue
            try:
                pchembl = float(row["pchembl_value"])
            except ValueError:
                continue
            if pchembl < 4.0:
                continue
            row["pchembl_value"] = pchembl
            row["activity_class"] = 1 if pchembl >= 5.0 else 0
            rows.append(row)

    deduped: dict = {}
    for row in rows:
        key = (row.get("molecule_chembl_id"), row.get("target_chembl_id"))
        existing = deduped.get(key)
        if existing is None:
            deduped[key] = row
        else:
            existing_p = existing.get("pchembl_value")
            current_p = row.get("pchembl_value")
            if current_p is not None and (existing_p is None or current_p > existing_p):
                deduped[key] = row
    return list(deduped.values())


# ── Load ──────────────────────────────────────────────────────────────
data = load_and_clean_data(DATA_PATH)
labels_all = np.array([row["activity_class"] for row in data], dtype=int)
targets_all = [row.get("target_common_name", "unknown") for row in data]

print(f"Loaded {len(data)} compound-target records")
print(f"Targets: {sorted(set(targets_all))}")
print(f"\nClass distribution:")
for cls in sorted(ACTIVITY_CLASS_MAP):
    n = int((labels_all == cls).sum())
    print(f"  {cls} ({ACTIVITY_CLASS_MAP[cls]:>12s}): {n:>5d}  ({100*n/len(data):.1f}%)")

In [None]:
# ── Exploratory visualizations ────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Class distribution
class_counts = Counter(labels_all)
bars = axes[0].bar(
    [ACTIVITY_CLASS_MAP[c] for c in sorted(class_counts)],
    [class_counts[c] for c in sorted(class_counts)],
    color=[CLASS_COLORS[c] for c in sorted(class_counts)],
    edgecolor="black",
)
for bar, c in zip(bars, sorted(class_counts)):
    axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 10,
                 str(class_counts[c]), ha="center", fontweight="bold")
axes[0].set_title("Class Distribution")
axes[0].set_ylabel("Count")

# 2. Per-target breakdown
target_class_df = pd.DataFrame({"target": targets_all, "class": labels_all})
target_order = sorted(set(targets_all))
class_by_target = target_class_df.groupby(["target", "class"]).size().unstack(fill_value=0)
class_by_target = class_by_target.reindex(columns=list(range(NUM_CLASSES)), fill_value=0)
class_by_target.columns = [ACTIVITY_CLASS_MAP[c] for c in class_by_target.columns]
class_by_target.loc[target_order].plot.barh(
    stacked=True, ax=axes[1],
    color=[CLASS_COLORS[c] for c in range(NUM_CLASSES)],
    edgecolor="black",
)
axes[1].set_title("Compounds per Target")
axes[1].set_xlabel("Count")
axes[1].legend(title="Class", loc="lower right")

# 3. pChEMBL distribution (active compounds only)
pchembl_vals = [float(row["pchembl_value"]) for row in data if row["pchembl_value"] is not None]
axes[2].hist(pchembl_vals, bins=30, color="#3498db", edgecolor="black", alpha=0.8)
axes[2].axvline(5.0, color="red", ls="--", lw=2, label="Binding threshold (5.0)")
axes[2].set_title("pChEMBL Value Distribution")
axes[2].set_xlabel("pChEMBL")
axes[2].set_ylabel("Count")
axes[2].legend()

fig.tight_layout()
fig.savefig(OUTPUT_DIR / "01_data_exploration.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved: outputs/01_data_exploration.png")

---
## 4. Feature Engineering

In [None]:
def compute_descriptors(smiles: List[str]) -> Tuple[np.ndarray, List[str]]:
    """Compute 10 physicochemical descriptors per molecule."""
    descriptor_functions = {
        "MW": Descriptors.MolWt,
        "LogP": Crippen.MolLogP,
        "HBA": Lipinski.NumHAcceptors,
        "HBD": Lipinski.NumHDonors,
        "TPSA": MolSurf.TPSA,
        "RotatableBonds": Lipinski.NumRotatableBonds,
        "AromaticRings": Lipinski.NumAromaticRings,
        "HeavyAtoms": Lipinski.HeavyAtomCount,
        "FractionCSP3": Lipinski.FractionCSP3,
        "MolMR": Crippen.MolMR,
    }
    rows = []
    for smi in smiles:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            rows.append([np.nan] * len(descriptor_functions))
            continue
        rows.append([func(mol) for func in descriptor_functions.values()])
    return np.array(rows, dtype=float), list(descriptor_functions.keys())


def compute_morgan_fingerprints(smiles: List[str], n_bits: int = 2048) -> np.ndarray:
    """Compute 2048-bit Morgan fingerprints (ECFP4, radius=2)."""
    gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=n_bits)
    fps = []
    for smi in smiles:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            fps.append(np.zeros(n_bits, dtype=int))
            continue
        fps.append(np.array(gen.GetFingerprint(mol)))
    return np.array(fps)


def build_feature_matrix(
    rows: List[dict],
    selected_columns: Optional[List[str]] = None,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Build combined feature matrix: descriptors + fingerprints + target encoding.

    When selected_columns is None (training), variance filtering is applied and
    the surviving column names are returned.  When selected_columns is provided
    (prediction), the matrix is aligned to those columns.
    """
    smiles = [row["canonical_smiles"] for row in rows]
    targets = [row.get("target_common_name", row.get("target", "")) for row in rows]
    labels = np.array([row.get("activity_class", -1) for row in rows], dtype=int)

    descriptors, desc_names = compute_descriptors(smiles)
    fingerprints = compute_morgan_fingerprints(smiles)
    fp_names = [f"FP_{i}" for i in range(fingerprints.shape[1])]

    target_names = sorted({t for t in targets if t})
    target_map = {name: idx for idx, name in enumerate(target_names)}
    target_matrix = np.zeros((len(rows), len(target_names)), dtype=float)
    for idx, target in enumerate(targets):
        if target in target_map:
            target_matrix[idx, target_map[target]] = 1.0

    feature_matrix = np.concatenate([descriptors, fingerprints, target_matrix], axis=1)
    columns = desc_names + fp_names + [f"target_{n}" for n in target_names]

    if selected_columns is None:
        variances = np.nanvar(feature_matrix, axis=0)
        mask = variances > 0.01
        feature_matrix = np.nan_to_num(feature_matrix[:, mask], nan=0.0)
        selected_columns = [col for col, keep in zip(columns, mask) if keep]
    else:
        col_index = {col: idx for idx, col in enumerate(columns)}
        aligned = np.zeros((len(rows), len(selected_columns)), dtype=float)
        for out_idx, col in enumerate(selected_columns):
            if col in col_index:
                aligned[:, out_idx] = np.nan_to_num(
                    feature_matrix[:, col_index[col]], nan=0.0
                )
        feature_matrix = aligned

    return feature_matrix, labels, selected_columns


# ── Build features ────────────────────────────────────────────────────
print("Computing features (descriptors + 2048-bit Morgan FP + target encoding)...")
t0 = time.time()
features, labels, selected_columns = build_feature_matrix(data)
print(f"  Done in {time.time() - t0:.1f}s")
print(f"  Feature matrix shape: {features.shape}")
print(f"  Features retained after variance filter: {len(selected_columns)}")

---
## 5. Scaffold Split

In [None]:
def scaffold_split(smiles: List[str], y: np.ndarray, random_state: int = 42) -> SplitData:
    """Split data by Murcko scaffold to avoid data leakage (60/20/20)."""
    scaffolds: Dict[str, List[int]] = {}
    for idx, smi in enumerate(smiles):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            scaffold = ""
        else:
            scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol)
        scaffolds.setdefault(scaffold, []).append(idx)

    scaffold_sets = sorted(scaffolds.values(), key=len, reverse=True)
    rng = np.random.default_rng(random_state)
    rng.shuffle(scaffold_sets)

    n_total = len(smiles)
    n_train = int(0.6 * n_total)
    n_val = int(0.2 * n_total)

    train_idx, val_idx, test_idx = [], [], []
    for group in scaffold_sets:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(val_idx) + len(group) <= n_val:
            val_idx.extend(group)
        else:
            test_idx.extend(group)

    return SplitData(
        train_idx=np.array(train_idx),
        val_idx=np.array(val_idx),
        test_idx=np.array(test_idx),
    )


split = scaffold_split([row["canonical_smiles"] for row in data], labels, RANDOM_STATE)

X_train, y_train = features[split.train_idx], labels[split.train_idx]
X_val, y_val     = features[split.val_idx],   labels[split.val_idx]
X_test, y_test   = features[split.test_idx],  labels[split.test_idx]

print(f"Train : {len(X_train):>5d}  |  Val : {len(X_val):>5d}  |  Test : {len(X_test):>5d}")
for name, y_sub in [("Train", y_train), ("Val", y_val), ("Test", y_test)]:
    counts = Counter(y_sub)
    parts = ", ".join(f"{ACTIVITY_CLASS_MAP[c]}={counts.get(c, 0)}" for c in range(NUM_CLASSES))
    print(f"  {name:>5s}: {parts}")

---
## 6. Model Training & Cross-Validation

In [None]:
def ece_score_fn(y_true, y_prob, n_bins=10):
    """Expected Calibration Error and Maximum Calibration Error."""
    bins = np.linspace(0, 1, n_bins + 1)
    binids = np.digitize(y_prob, bins) - 1
    ece, mce = 0.0, 0.0
    for i in range(n_bins):
        mask = binids == i
        if not np.any(mask):
            continue
        avg_conf = y_prob[mask].mean()
        avg_acc = y_true[mask].mean()
        gap = abs(avg_conf - avg_acc)
        ece += gap * mask.mean()
        mce = max(mce, gap)
    return ece, mce


def get_models(random_state: int) -> Dict[str, Tuple[Pipeline, Dict[str, list]]]:
    """Return three model pipelines with hyperparameter search spaces."""
    return {
        "RandomForest": (
            Pipeline([
                ("scaler", StandardScaler(with_mean=False)),
                ("model", RandomForestClassifier(random_state=random_state, n_jobs=-1)),
            ]),
            {
                "model__n_estimators": [200, 500],
                "model__max_depth": [10, 20, None],
                "model__min_samples_split": [2, 5, 10],
                "model__max_features": ["sqrt", "log2", 0.3],
                "model__class_weight": ["balanced", None],
            },
        ),
        "XGBoost": (
            Pipeline([
                ("scaler", StandardScaler(with_mean=False)),
                ("model", XGBClassifier(
                    random_state=random_state,
                    objective="binary:logistic",
                    eval_metric="logloss",
                    n_jobs=-1,
                    verbosity=0,
                )),
            ]),
            {
                "model__n_estimators": [200, 500],
                "model__max_depth": [3, 5, 7],
                "model__learning_rate": [0.01, 0.05, 0.1],
                "model__subsample": [0.6, 0.8, 1.0],
                "model__colsample_bytree": [0.6, 0.8, 1.0],
            },
        ),
        "LightGBM": (
            Pipeline([
                ("scaler", StandardScaler(with_mean=False)),
                ("model", LGBMClassifier(
                    random_state=random_state, n_jobs=-1, verbose=-1,
                    objective="binary",
                )),
            ]),
            {
                "model__n_estimators": [200, 500],
                "model__max_depth": [-1, 5, 10],
                "model__learning_rate": [0.01, 0.05, 0.1],
                "model__num_leaves": [31, 63, 127],
                "model__subsample": [0.6, 0.8, 1.0],
            },
        ),
    }


# ── Train all models ──────────────────────────────────────────────────
models = get_models(RANDOM_STATE)
cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=2, random_state=RANDOM_STATE)

cv_summary = []
best_estimators = {}
calibration_metrics = {}
fold_scores: Dict[str, List[float]] = {}
train_times = {}

for name, (pipeline, param_grid) in models.items():
    print(f"\n{'='*60}")
    print(f"Training: {name}")
    print(f"{'='*60}")

    # Hyperparameter search
    search = RandomizedSearchCV(
        pipeline,
        param_distributions=param_grid,
        n_iter=5,
        scoring="roc_auc",
        cv=3,
        random_state=RANDOM_STATE,
        n_jobs=1,  # avoid over-subscription with model-level n_jobs=-1
    )
    t0 = time.time()
    search.fit(X_train, y_train)
    train_times[name] = time.time() - t0
    best_estimators[name] = search.best_estimator_
    print(f"  Best params: {search.best_params_}")
    print(f"  Train time : {train_times[name]:.1f}s")

    # Cross-validation evaluation
    scores, pr_scores, mcc_scores = [], [], []
    for train_idx, test_idx in cv.split(X_train, y_train):
        X_tr, X_te = X_train[train_idx], X_train[test_idx]
        y_tr, y_te = y_train[train_idx], y_train[test_idx]
        est = search.best_estimator_
        est.fit(X_tr, y_tr)
        probs = est.predict_proba(X_te)
        preds = est.predict(X_te)
        scores.append(roc_auc_score(y_te, probs[:, 1]))
        pr_scores.append(average_precision_score(y_te, probs[:, 1]))
        mcc_scores.append(matthews_corrcoef(y_te, preds))

    cv_summary.append({
        "model": name,
        "roc_auc_mean": np.mean(scores), "roc_auc_std": np.std(scores),
        "pr_auc_mean": np.mean(pr_scores), "pr_auc_std": np.std(pr_scores),
        "mcc_mean": np.mean(mcc_scores), "mcc_std": np.std(mcc_scores),
    })
    fold_scores[name] = scores

    # Validation calibration
    val_probs = search.best_estimator_.predict_proba(X_val)
    val_probs_true = val_probs[np.arange(len(y_val)), y_val]
    ece_val, _ = ece_score_fn(np.ones(len(y_val)), val_probs_true)
    calibration_metrics[name] = ece_val

    print(f"  CV ROC-AUC : {np.mean(scores):.4f} +/- {np.std(scores):.4f}")
    print(f"  CV PR-AUC  : {np.mean(pr_scores):.4f} +/- {np.std(pr_scores):.4f}")
    print(f"  CV MCC     : {np.mean(mcc_scores):.4f} +/- {np.std(mcc_scores):.4f}")

print(f"\nTraining complete. {len(models)} models evaluated.")

---
## 7. Model Evaluation & Visualizations

In [None]:
# ── Select best model & refit on train+val ────────────────────────────
cv_summary_sorted = sorted(cv_summary, key=lambda r: r["roc_auc_mean"], reverse=True)
best_model_name = cv_summary_sorted[0]["model"]
best_model = best_estimators[best_model_name]
best_model.fit(np.vstack([X_train, X_val]), np.hstack([y_train, y_val]))

print("Cross-validation summary (sorted by ROC-AUC):")
print("-" * 75)
print(f"{'Model':<15s} {'ROC-AUC':>18s} {'PR-AUC':>18s} {'MCC':>18s}")
print("-" * 75)
for row in cv_summary_sorted:
    print(f"{row['model']:<15s} "
          f"{row['roc_auc_mean']:.4f} +/- {row['roc_auc_std']:.4f}  "
          f"{row['pr_auc_mean']:.4f} +/- {row['pr_auc_std']:.4f}  "
          f"{row['mcc_mean']:.4f} +/- {row['mcc_std']:.4f}")
print("-" * 75)
print(f"Best model: {best_model_name}")

# ── Test-set evaluation ───────────────────────────────────────────────
test_probs = best_model.predict_proba(X_test)
test_preds = best_model.predict(X_test)
test_roc = roc_auc_score(y_test, test_probs[:, 1])
test_pr = average_precision_score(y_test, test_probs[:, 1])
test_mcc = matthews_corrcoef(y_test, test_preds)

# Calibration
calibrated = CalibratedClassifierCV(best_model, method="isotonic", cv=3)
calibrated.fit(np.vstack([X_train, X_val]), np.hstack([y_train, y_val]))
cal_probs = calibrated.predict_proba(X_test)
cal_probs_true = cal_probs[np.arange(len(y_test)), y_test]
ece, mce = ece_score_fn(np.ones(len(y_test)), cal_probs_true)

print(f"\nTest Set Metrics ({best_model_name}):")
print(f"  ROC-AUC         : {test_roc:.4f}")
print(f"  PR-AUC          : {test_pr:.4f}")
print(f"  MCC             : {test_mcc:.4f}")
print(f"  ECE (calibrated): {ece:.4f}")
print(f"  MCE (calibrated): {mce:.4f}")
print(f"\nClassification Report:")
print(classification_report(
    y_test, test_preds,
    target_names=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)],
    digits=3,
))

In [None]:
# ── Figure 2: ROC Curves (per-class) ──────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for cls in range(NUM_CLASSES):
    label = ACTIVITY_CLASS_MAP[cls]
    binary_true = (y_test == cls).astype(int)
    if binary_true.sum() == 0:
        axes[cls].set_title(f"ROC — {label} (no samples)")
        continue
    fpr, tpr, _ = roc_curve(binary_true, test_probs[:, cls])
    auc_val = roc_auc_score(binary_true, test_probs[:, cls])
    axes[cls].plot(fpr, tpr, color=CLASS_COLORS[cls], lw=2,
                   label=f"AUC = {auc_val:.3f}")
    axes[cls].plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5)
    axes[cls].set_xlabel("False Positive Rate")
    axes[cls].set_ylabel("True Positive Rate")
    axes[cls].set_title(f"ROC — {label}")
    axes[cls].legend(loc="lower right", fontsize=12)
    axes[cls].set_xlim([-0.02, 1.02])
    axes[cls].set_ylim([-0.02, 1.02])

fig.suptitle(f"Per-Class ROC Curves ({best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "02_roc_curves.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Figure 3: Precision-Recall Curves ────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for cls in range(NUM_CLASSES):
    label = ACTIVITY_CLASS_MAP[cls]
    binary_true = (y_test == cls).astype(int)
    if binary_true.sum() == 0:
        axes[cls].set_title(f"PR — {label} (no samples)")
        continue
    precision, recall, _ = precision_recall_curve(binary_true, test_probs[:, cls])
    ap = average_precision_score(binary_true, test_probs[:, cls])
    axes[cls].plot(recall, precision, color=CLASS_COLORS[cls], lw=2,
                   label=f"AP = {ap:.3f}")
    baseline = binary_true.mean()
    axes[cls].axhline(baseline, color="gray", ls="--", lw=1, alpha=0.5,
                      label=f"Baseline = {baseline:.3f}")
    axes[cls].set_xlabel("Recall")
    axes[cls].set_ylabel("Precision")
    axes[cls].set_title(f"PR — {label}")
    axes[cls].legend(loc="upper right", fontsize=11)
    axes[cls].set_xlim([-0.02, 1.02])
    axes[cls].set_ylim([-0.02, 1.02])

fig.suptitle(f"Per-Class Precision-Recall Curves ({best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "03_pr_curves.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Figure 4: Confusion Matrix ───────────────────────────────────────
cm = confusion_matrix(y_test, test_preds, labels=list(range(NUM_CLASSES)))
cm_pct = cm.astype(float) / cm.sum(axis=1, keepdims=True) * 100

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=axes[0],
            xticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)],
            yticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)])
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("Actual")
axes[0].set_title("Confusion Matrix (counts)")

# Percentages
sns.heatmap(cm_pct, annot=True, fmt=".1f", cmap="Blues", ax=axes[1],
            xticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)],
            yticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)])
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("Actual")
axes[1].set_title("Confusion Matrix (% per row)")

fig.suptitle(f"Test Set Confusion Matrix ({best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "04_confusion_matrix.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Figure 5: Calibration Curves ─────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for cls in range(NUM_CLASSES):
    label = ACTIVITY_CLASS_MAP[cls]
    binary_true = (y_test == cls).astype(int)
    cls_cal_probs = cal_probs[:, cls]
    if binary_true.sum() == 0:
        axes[cls].set_title(f"Calibration — {label} (no samples)")
        continue
    prob_true, prob_pred = calibration_curve(binary_true, cls_cal_probs, n_bins=10)
    axes[cls].plot(prob_pred, prob_true, "o-", color=CLASS_COLORS[cls], lw=2,
                   label=f"{label}")
    axes[cls].plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5, label="Perfectly calibrated")
    axes[cls].set_xlabel("Mean Predicted Probability")
    axes[cls].set_ylabel("Fraction of Positives")
    axes[cls].set_title(f"Calibration — {label}")
    axes[cls].legend(loc="upper left", fontsize=11)
    axes[cls].set_xlim([-0.02, 1.02])
    axes[cls].set_ylim([-0.02, 1.02])

fig.suptitle(f"Calibration Curves (isotonic, {best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "05_calibration_curves.png", dpi=150, bbox_inches="tight")
plt.show()
print(f"ECE = {ece:.4f}, MCE = {mce:.4f}")

In [None]:
# ── Figure 6: Feature Importance (top 20) ────────────────────────────
if hasattr(best_model.named_steps["model"], "feature_importances_"):
    importances = best_model.named_steps["model"].feature_importances_
    top_k = 20
    indices = np.argsort(importances)[-top_k:]
    top_features = [selected_columns[i] for i in indices]
    top_values = importances[indices]

    fig, ax = plt.subplots(figsize=(10, 7))
    ax.barh(range(top_k), top_values, color="#3498db", edgecolor="black")
    ax.set_yticks(range(top_k))
    ax.set_yticklabels(top_features)
    ax.set_xlabel("Importance")
    ax.set_title(f"Top {top_k} Feature Importances ({best_model_name})")
    fig.tight_layout()
    fig.savefig(OUTPUT_DIR / "06_feature_importance.png", dpi=150, bbox_inches="tight")
    plt.show()
else:
    print("Feature importances not available for this model type.")

---
## 8. Statistical Comparison & MCDA

In [None]:
# ── Paired t-tests with Bonferroni correction ────────────────────────
model_names = [row["model"] for row in cv_summary_sorted]
stat_rows = []
for i, model_a in enumerate(model_names):
    for model_b in model_names[i + 1:]:
        sa = np.array(fold_scores.get(model_a, []))
        sb = np.array(fold_scores.get(model_b, []))
        if len(sa) == 0 or len(sb) == 0:
            continue
        t_stat, p_val = stats.ttest_rel(sa, sb)
        pooled = np.std(np.concatenate([sa, sb]))
        cohen_d = (sa.mean() - sb.mean()) / pooled if pooled else 0.0
        stat_rows.append({
            "Model A": model_a, "Model B": model_b,
            "t-stat": t_stat, "p-value": p_val, "Cohen's d": cohen_d,
        })

if stat_rows:
    bonferroni = 0.05 / len(stat_rows)
    for r in stat_rows:
        r["Significant"] = "Yes" if r["p-value"] < bonferroni else "No"
        r["Bonferroni alpha"] = bonferroni

stat_df = pd.DataFrame(stat_rows)
print("Statistical Comparison (paired t-test on CV ROC-AUC folds):")
print(f"Bonferroni-corrected alpha = {bonferroni:.4f}")
display(stat_df.style.format({
    "t-stat": "{:.4f}", "p-value": "{:.6f}",
    "Cohen's d": "{:.4f}", "Bonferroni alpha": "{:.4f}",
}).set_caption("Pairwise Model Comparison"))

# ── MCDA Ranking ─────────────────────────────────────────────────────
mcda_rows = []
for row in cv_summary_sorted:
    name = row["model"]
    mcda_rows.append({
        "model": name,
        "roc_auc": row["roc_auc_mean"],
        "pr_auc": row["pr_auc_mean"],
        "calibration": max(0.0, 1 - calibration_metrics.get(name, ece)),
        "robustness": max(0.0, 1 - row["roc_auc_std"]),
        "efficiency": 1.0 / (1.0 + train_times.get(name, 1.0)),
        "interpretability": 1.0 if name in {"RandomForest", "LightGBM", "XGBoost"} else 0.5,
    })

weights = {
    "roc_auc": 0.25, "pr_auc": 0.20, "calibration": 0.20,
    "robustness": 0.15, "efficiency": 0.10, "interpretability": 0.10,
}
for metric in weights:
    vals = [r[metric] for r in mcda_rows]
    mn, mx = min(vals), max(vals)
    for r in mcda_rows:
        r[metric] = (r[metric] - mn) / (mx - mn) if mx > mn else 1.0
for r in mcda_rows:
    r["composite"] = sum(r[m] * w for m, w in weights.items())
mcda_rows = sorted(mcda_rows, key=lambda r: r["composite"], reverse=True)

print("\nMCDA Ranking:")
mcda_df = pd.DataFrame(mcda_rows)
display(mcda_df.style.format("{:.4f}", subset=mcda_df.columns[1:]).set_caption(
    "Multi-Criteria Decision Analysis"
))

---
## 9. Uncertainty Quantification

In [None]:
# ── Conformal Prediction ──────────────────────────────────────────────
def conformal_prediction(probs, y_true, alpha=0.05):
    scores = 1.0 - probs[np.arange(len(y_true)), y_true]
    q = np.quantile(scores, 1 - alpha, method="higher")
    prediction_sets = probs >= (1.0 - q)
    coverage = prediction_sets[np.arange(len(y_true)), y_true].mean()
    return prediction_sets, coverage, q

pred_sets, coverage, q_threshold = conformal_prediction(cal_probs, y_test)
set_sizes = pred_sets.sum(axis=1)

# ── Applicability Domain (k-NN distance) ─────────────────────────────
nn = NearestNeighbors(n_neighbors=5)
nn.fit(X_train)
train_dists = nn.kneighbors(X_train)[0].mean(axis=1)
ad_threshold = np.percentile(train_dists, 95)
test_dists = nn.kneighbors(X_test)[0].mean(axis=1)
ood_rate = (test_dists > ad_threshold).mean()

print(f"Conformal Prediction (alpha=0.05):")
print(f"  Coverage          : {coverage:.4f}  (target: 0.95)")
print(f"  Avg set size      : {set_sizes.mean():.2f}")
print(f"  Quantile threshold: {q_threshold:.4f}")
print(f"\nApplicability Domain:")
print(f"  AD threshold (95th pct): {ad_threshold:.4f}")
print(f"  Out-of-domain rate     : {ood_rate:.2%}")

# ── Figure 7: Uncertainty plots ──────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Set size distribution
unique_sizes, counts = np.unique(set_sizes, return_counts=True)
axes[0].bar(unique_sizes.astype(str), counts, color="#9b59b6", edgecolor="black")
for s, c in zip(unique_sizes, counts):
    axes[0].text(str(s), c + 2, str(c), ha="center", fontweight="bold")
axes[0].set_xlabel("Prediction Set Size")
axes[0].set_ylabel("Count")
axes[0].set_title(f"Conformal Set Sizes (coverage={coverage:.2%})")

# Distance to training set
axes[1].hist(test_dists, bins=30, color="#1abc9c", edgecolor="black", alpha=0.8,
             label="Test compounds")
axes[1].axvline(ad_threshold, color="red", ls="--", lw=2,
                label=f"AD threshold ({ad_threshold:.2f})")
axes[1].set_xlabel("Mean k-NN Distance")
axes[1].set_ylabel("Count")
axes[1].set_title(f"Applicability Domain (OOD={ood_rate:.1%})")
axes[1].legend()

# Confidence vs correctness
max_probs = test_probs.max(axis=1)
correct = (test_preds == y_test)
bins_edge = np.linspace(0, 1, 11)
bin_accs, bin_confs = [], []
for lo, hi in zip(bins_edge[:-1], bins_edge[1:]):
    mask = (max_probs >= lo) & (max_probs < hi)
    if mask.sum() > 0:
        bin_accs.append(correct[mask].mean())
        bin_confs.append(max_probs[mask].mean())
axes[2].plot(bin_confs, bin_accs, "o-", color="#e67e22", lw=2, label="Model")
axes[2].plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5, label="Perfect")
axes[2].set_xlabel("Mean Confidence")
axes[2].set_ylabel("Accuracy")
axes[2].set_title("Reliability Diagram")
axes[2].legend()
axes[2].set_xlim([-0.02, 1.02])
axes[2].set_ylim([-0.02, 1.02])

fig.tight_layout()
fig.savefig(OUTPUT_DIR / "07_uncertainty.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Save model artifacts for reuse ────────────────────────────────────
model_artifacts = {
    "best_model": best_model,
    "calibrated_model": calibrated,
    "selected_columns": selected_columns,
    "activity_class_map": ACTIVITY_CLASS_MAP,
    "num_classes": NUM_CLASSES,
    "ad_threshold": ad_threshold,
    "nn_model": nn,
    "conformal_q": q_threshold,
    "best_model_name": best_model_name,
}
model_path = MODEL_DIR / "safety_model.pkl"
with open(model_path, "wb") as fh:
    pickle.dump(model_artifacts, fh)
print(f"Model saved to: {model_path}")
print(f"  Model type       : {best_model_name}")
print(f"  Feature columns  : {len(selected_columns)}")
print(f"  AD threshold     : {ad_threshold:.4f}")
print(f"  Conformal q      : {q_threshold:.4f}")

# ── Save summary JSON ─────────────────────────────────────────────────
class_counts_dict = {ACTIVITY_CLASS_MAP[c]: int((labels_all == c).sum()) for c in range(NUM_CLASSES)}
summary = {
    "n_compounds": len(data),
    "targets": sorted(set(targets_all)),
    "class_distribution": class_counts_dict,
    "train_size": len(X_train),
    "val_size": len(X_val),
    "test_size": len(X_test),
    "best_model": best_model_name,
    "test_metrics": {
        "roc_auc": test_roc,
        "pr_auc": test_pr,
        "mcc": test_mcc,
        "ece": ece,
        "mce": mce,
    },
    "conformal_coverage": coverage,
    "avg_prediction_set_size": float(set_sizes.mean()),
    "out_of_domain_rate": ood_rate,
}
with open(OUTPUT_DIR / "workflow_summary.json", "w") as fh:
    json.dump(summary, fh, indent=2)
print(f"\nWorkflow summary saved to: {OUTPUT_DIR / 'workflow_summary.json'}")

---
## 10. Test Set Evaluation (Held-Out Compounds)

Evaluate the trained model on the held-out test set (`data/test_compounds.csv`).
These compounds were **not** used during training or cross-validation.

In [None]:
# ── Evaluate on held-out test set ─────────────────────────────────────
if TEST_PATH.exists():
    test_df = pd.read_csv(TEST_PATH)
    print(f"Loaded {len(test_df)} test compounds from {TEST_PATH}")
    print(f"Test targets: {sorted(test_df['target'].unique())}")

    # Build feature matrix for test compounds
    test_rows_for_features = []
    for _, row in test_df.iterrows():
        # Map old 3-class labels to 2-class: old class 2 -> 1 (binding), old class 0,1 -> 0 (non_binding)
        raw_class = int(row.get("known_class", -1)) if pd.notna(row.get("known_class")) else -1
        if raw_class >= 0:
            mapped_class = 1 if raw_class == 2 else 0
        else:
            mapped_class = -1
        test_rows_for_features.append({
            "canonical_smiles": row["smiles"],
            "target_common_name": row["target"],
            "activity_class": mapped_class,
        })
    X_test_ext, y_test_ext, _ = build_feature_matrix(test_rows_for_features, selected_columns=selected_columns)

    # Filter to only compounds with known labels
    valid_mask = y_test_ext >= 0
    X_test_ext_valid = X_test_ext[valid_mask]
    y_test_ext_valid = y_test_ext[valid_mask]
    test_df_valid = test_df[valid_mask].copy()

    if len(y_test_ext_valid) > 0:
        # Predict
        ext_probs = best_model.predict_proba(X_test_ext_valid)
        ext_preds = best_model.predict(X_test_ext_valid)

        # Applicability domain
        ext_dists = nn.kneighbors(X_test_ext_valid)[0].mean(axis=1)
        ext_in_domain = ext_dists <= ad_threshold

        # Metrics
        ext_roc = roc_auc_score(y_test_ext_valid, ext_probs[:, 1])
        ext_mcc = matthews_corrcoef(y_test_ext_valid, ext_preds)
        ext_acc = (ext_preds == y_test_ext_valid).mean()

        print(f"\nHeld-out Test Set Metrics ({best_model_name}):")
        print(f"  Compounds evaluated: {len(y_test_ext_valid)}")
        print(f"  ROC-AUC            : {ext_roc:.4f}")
        print(f"  MCC                : {ext_mcc:.4f}")
        print(f"  Accuracy           : {ext_acc:.4f}")
        print(f"  In-domain          : {ext_in_domain.sum()}/{len(ext_in_domain)} ({ext_in_domain.mean():.1%})")

        # Per-target metrics
        test_target_metrics = {}
        for target in sorted(test_df_valid["target"].unique()):
            tmask = test_df_valid["target"].values == target
            if tmask.sum() < 2:
                continue
            t_true = y_test_ext_valid[tmask]
            t_pred = ext_preds[tmask]
            t_acc = (t_pred == t_true).mean()
            t_mcc = matthews_corrcoef(t_true, t_pred) if len(set(t_true)) > 1 else 0.0
            test_target_metrics[target] = {"n": int(tmask.sum()), "acc": t_acc, "mcc": t_mcc}

        print(f"\nPer-target test set performance:")
        print(f"  {'Target':>10s}  {'N':>4s}  {'Acc':>6s}  {'MCC':>6s}")
        print(f"  {'-'*30}")
        for target, m in sorted(test_target_metrics.items()):
            print(f"  {target:>10s}  {m['n']:>4d}  {m['acc']:>.3f}  {m['mcc']:>.3f}")

        # Confusion matrix
        print(f"\nClassification Report (held-out test set):")
        present_classes = sorted(set(y_test_ext_valid) | set(ext_preds))
        target_names_present = [ACTIVITY_CLASS_MAP[c] for c in present_classes]
        print(classification_report(
            y_test_ext_valid, ext_preds,
            labels=present_classes,
            target_names=target_names_present,
            digits=3, zero_division=0,
        ))

        # Build results DataFrame
        test_df_valid = test_df_valid.copy()
        test_df_valid["predicted_class"] = ext_preds
        test_df_valid["predicted_label"] = [ACTIVITY_CLASS_MAP.get(int(p), "?") for p in ext_preds]
        test_df_valid["correct"] = ext_preds == y_test_ext_valid
        test_df_valid["in_domain"] = ext_in_domain
        for c in range(NUM_CLASSES):
            test_df_valid[f"prob_{ACTIVITY_CLASS_MAP[c]}"] = ext_probs[:, c]
        test_df_valid["max_confidence"] = ext_probs.max(axis=1)

        # Save
        test_df_valid.to_csv(OUTPUT_DIR / "test_set_predictions.csv", index=False)
        print(f"\nTest set predictions saved to: {OUTPUT_DIR / 'test_set_predictions.csv'}")
    else:
        print("No valid test compounds with known labels found.")
        ext_roc = ext_mcc = ext_acc = float("nan")
        test_target_metrics = {}
else:
    print(f"No test file found at {TEST_PATH}")
    ext_roc = ext_mcc = ext_acc = float("nan")
    test_target_metrics = {}

In [None]:
# ── Figure 8: Test set results visualization ─────────────────────────
if TEST_PATH.exists() and len(y_test_ext_valid) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Class distribution: actual vs predicted
    actual_counts = Counter(y_test_ext_valid)
    pred_counts_ext = Counter(ext_preds)
    x_labels = [ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)]
    x = np.arange(NUM_CLASSES)
    w = 0.35
    axes[0].bar(x - w/2, [actual_counts.get(c, 0) for c in range(NUM_CLASSES)],
                w, color=[CLASS_COLORS[c] for c in range(NUM_CLASSES)], edgecolor="black",
                alpha=0.7, label="Actual")
    axes[0].bar(x + w/2, [pred_counts_ext.get(c, 0) for c in range(NUM_CLASSES)],
                w, color=[CLASS_COLORS[c] for c in range(NUM_CLASSES)], edgecolor="black",
                alpha=0.4, label="Predicted", hatch="//")
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(x_labels)
    axes[0].set_title("Test Set: Actual vs Predicted")
    axes[0].set_ylabel("Count")
    axes[0].legend()

    # Per-target accuracy
    if test_target_metrics:
        targets_sorted = sorted(test_target_metrics.keys(),
                                key=lambda t: test_target_metrics[t]["acc"], reverse=True)
        accs = [test_target_metrics[t]["acc"] for t in targets_sorted]
        ns = [test_target_metrics[t]["n"] for t in targets_sorted]
        colors = ["#3498db" if a >= 0.5 else "#e74c3c" for a in accs]
        bars = axes[1].barh(range(len(targets_sorted)), accs, color=colors, edgecolor="black")
        axes[1].set_yticks(range(len(targets_sorted)))
        axes[1].set_yticklabels([f"{t} (n={n})" for t, n in zip(targets_sorted, ns)])
        axes[1].axvline(0.5, color="gray", ls="--", lw=1)
        axes[1].set_xlabel("Accuracy")
        axes[1].set_title("Per-Target Test Accuracy")
        axes[1].set_xlim([0, 1.05])

    # Confidence distribution
    axes[2].hist(ext_probs.max(axis=1), bins=20, color="#9b59b6", edgecolor="black", alpha=0.8)
    axes[2].axvline(0.5, color="red", ls="--", lw=1.5, label="50% threshold")
    axes[2].set_xlabel("Max Confidence")
    axes[2].set_ylabel("Count")
    axes[2].set_title("Test Set Confidence")
    axes[2].legend()

    fig.suptitle("Held-Out Test Set Results", fontsize=14, y=1.02)
    fig.tight_layout()
    fig.savefig(OUTPUT_DIR / "08_test_set_results.png", dpi=150, bbox_inches="tight")
    plt.show()
    print("Saved: outputs/08_test_set_results.png")

---
## 11. Predict New Compounds

Provide a CSV file with these columns:

| Column | Description |
|--------|-------------|
| `compound_id` | Your identifier for the compound |
| `smiles` | SMILES string |
| `target` | One of the 24 trained targets (e.g. `hERG`, `CYP3A4`, `ERa`, `P-gp`) |

An example file is provided at `data/example_predictions.csv`.

The output will include:
- `prob_non_binding`, `prob_binding` -- predicted class probabilities
- `predicted_class` / `predicted_label` -- most likely class
- `conformal_set` -- set of plausible classes at 95% confidence
- `in_domain` -- whether the compound falls within the model's applicability domain
- `max_confidence` -- highest class probability (a rough quality indicator)

In [None]:
# ── Standalone prediction from saved model ────────────────────────────
# Set YOUR_CSV below and run this cell.
# No training cells need to be run first.

YOUR_CSV = NOTEBOOK_DIR / "data" / "example_predictions.csv"  # <-- change this
SAVED_MODEL = MODEL_DIR / "safety_model.pkl"

if SAVED_MODEL.exists() and YOUR_CSV.exists():
    standalone_results = predict_compounds(YOUR_CSV, SAVED_MODEL)
    standalone_results.to_csv(OUTPUT_DIR / "standalone_predictions.csv", index=False)
    print(f"Predictions saved to: {OUTPUT_DIR / 'standalone_predictions.csv'}")
    display(standalone_results)
elif not SAVED_MODEL.exists():
    print(f"No saved model found at {SAVED_MODEL}. Run training cells first.")
else:
    print(f"No input CSV found at {YOUR_CSV}. Update the YOUR_CSV variable above.")

---
## 12. Analysis Report (Markdown Output)

Generate a comprehensive analysis report as a Markdown file saved to the outputs directory.

In [None]:
# ── Generate analysis Markdown report ─────────────────────────────────
import datetime

report_lines = []
report_lines.append("# OFFTOXv3 Analysis Report")
report_lines.append(f"\n**Generated:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}")
report_lines.append(f"**Best Model:** {best_model_name}")
report_lines.append(f"**Targets:** {len(TARGET_PANEL)} safety pharmacology targets")
report_lines.append(f"**Classification:** Binary (binding vs non-binding at 10 uM threshold)")
report_lines.append("")

# Dataset summary
report_lines.append("## 1. Dataset Summary")
report_lines.append("")
report_lines.append(f"- **Total training compounds:** {len(data)}")
report_lines.append(f"- **Unique targets:** {len(set(targets_all))}")
report_lines.append(f"- **Train/Val/Test split:** {len(X_train)}/{len(X_val)}/{len(X_test)} (scaffold-based)")
report_lines.append(f"- **Feature dimensions:** {features.shape[1]} (10 descriptors + 2048 Morgan FP + target encoding)")
report_lines.append("")

# Class distribution
report_lines.append("### Class Distribution")
report_lines.append("")
report_lines.append("| Class | Label | Count | Percentage |")
report_lines.append("|-------|-------|------:|----------:|")
for cls in range(NUM_CLASSES):
    n = int((labels_all == cls).sum())
    pct = 100 * n / len(data)
    report_lines.append(f"| {cls} | {ACTIVITY_CLASS_MAP[cls]} | {n} | {pct:.1f}% |")
report_lines.append("")

# Per-target breakdown
report_lines.append("### Per-Target Compound Counts")
report_lines.append("")
report_lines.append("| Target | Category | Total | Binding | Non-Binding |")
report_lines.append("|--------|----------|------:|--------:|------------:|")
target_df = pd.DataFrame({"target": targets_all, "class": labels_all})
for t in sorted(set(targets_all)):
    cat = TARGET_PANEL.get(t, {}).get("category", "?")
    t_data = target_df[target_df["target"] == t]
    total = len(t_data)
    binding = int((t_data["class"] == 1).sum())
    non_binding = int((t_data["class"] == 0).sum())
    report_lines.append(f"| {t} | {cat} | {total} | {binding} | {non_binding} |")
report_lines.append("")

# Cross-validation results
report_lines.append("## 2. Cross-Validation Results")
report_lines.append("")
report_lines.append("| Model | ROC-AUC | PR-AUC | MCC |")
report_lines.append("|-------|--------:|-------:|----:|")
for row in cv_summary_sorted:
    report_lines.append(
        f"| {row['model']} | "
        f"{row['roc_auc_mean']:.4f} +/- {row['roc_auc_std']:.4f} | "
        f"{row['pr_auc_mean']:.4f} +/- {row['pr_auc_std']:.4f} | "
        f"{row['mcc_mean']:.4f} +/- {row['mcc_std']:.4f} |"
    )
report_lines.append("")
report_lines.append(f"**Selected model:** {best_model_name} (highest ROC-AUC)")
report_lines.append("")

# Internal test set metrics
report_lines.append("## 3. Internal Test Set Performance (Scaffold Split)")
report_lines.append("")
report_lines.append(f"| Metric | Value |")
report_lines.append(f"|--------|------:|")
report_lines.append(f"| ROC-AUC | {test_roc:.4f} |")
report_lines.append(f"| PR-AUC | {test_pr:.4f} |")
report_lines.append(f"| MCC | {test_mcc:.4f} |")
report_lines.append(f"| ECE (calibrated) | {ece:.4f} |")
report_lines.append(f"| MCE (calibrated) | {mce:.4f} |")
report_lines.append("")

# Confusion matrix
report_lines.append("### Confusion Matrix")
report_lines.append("")
cm_report = confusion_matrix(y_test, test_preds, labels=list(range(NUM_CLASSES)))
header = "| | " + " | ".join(f"Pred: {ACTIVITY_CLASS_MAP[c]}" for c in range(NUM_CLASSES)) + " |"
report_lines.append(header)
report_lines.append("|---|" + "---:|" * NUM_CLASSES)
for i in range(NUM_CLASSES):
    row_vals = " | ".join(str(cm_report[i, j]) for j in range(NUM_CLASSES))
    report_lines.append(f"| **{ACTIVITY_CLASS_MAP[i]}** | {row_vals} |")
report_lines.append("")

# Uncertainty
report_lines.append("## 4. Uncertainty Quantification")
report_lines.append("")
report_lines.append(f"- **Conformal coverage:** {coverage:.4f} (target: 0.95)")
report_lines.append(f"- **Average prediction set size:** {set_sizes.mean():.2f}")
report_lines.append(f"- **AD threshold (95th pct k-NN):** {ad_threshold:.4f}")
report_lines.append(f"- **Out-of-domain rate:** {ood_rate:.2%}")
report_lines.append("")

# Held-out test set
report_lines.append("## 5. Held-Out Test Set Evaluation")
report_lines.append("")
if not np.isnan(ext_roc):
    report_lines.append(f"- **Test compounds:** {len(y_test_ext_valid)}")
    report_lines.append(f"- **ROC-AUC:** {ext_roc:.4f}")
    report_lines.append(f"- **MCC:** {ext_mcc:.4f}")
    report_lines.append(f"- **Accuracy:** {ext_acc:.4f}")
    report_lines.append("")
    if test_target_metrics:
        report_lines.append("### Per-Target Test Performance")
        report_lines.append("")
        report_lines.append("| Target | N | Accuracy | MCC |")
        report_lines.append("|--------|--:|---------:|----:|")
        for target, m in sorted(test_target_metrics.items()):
            report_lines.append(f"| {target} | {m['n']} | {m['acc']:.3f} | {m['mcc']:.3f} |")
        report_lines.append("")
else:
    report_lines.append("No held-out test set was available for evaluation.")
    report_lines.append("")

# Statistical comparison
report_lines.append("## 6. Statistical Model Comparison")
report_lines.append("")
if stat_rows:
    report_lines.append(f"Bonferroni-corrected alpha = {bonferroni:.4f}")
    report_lines.append("")
    report_lines.append("| Model A | Model B | t-stat | p-value | Cohen's d | Significant |")
    report_lines.append("|---------|---------|-------:|--------:|----------:|:-----------:|")
    for r in stat_rows:
        report_lines.append(
            f"| {r['Model A']} | {r['Model B']} | "
            f"{r['t-stat']:.4f} | {r['p-value']:.6f} | "
            f"{r[\"Cohen's d\"]:.4f} | {r['Significant']} |"
        )
    report_lines.append("")

# MCDA
report_lines.append("## 7. MCDA Ranking")
report_lines.append("")
report_lines.append("| Rank | Model | Composite Score |")
report_lines.append("|-----:|-------|----------------:|")
for i, r in enumerate(mcda_rows, 1):
    report_lines.append(f"| {i} | {r['model']} | {r['composite']:.4f} |")
report_lines.append("")

# Target panel reference
report_lines.append("## 8. Target Panel Reference")
report_lines.append("")
report_lines.append("| # | Target | ChEMBL ID | Category |")
report_lines.append("|--:|--------|-----------|----------|")
for i, (tname, tinfo) in enumerate(TARGET_PANEL.items(), 1):
    report_lines.append(f"| {i} | {tname} | {tinfo['chembl_id']} | {tinfo['category']} |")
report_lines.append("")

# Outputs
report_lines.append("## 9. Output Files")
report_lines.append("")
report_lines.append("| File | Description |")
report_lines.append("|------|-------------|")
report_lines.append("| `01_data_exploration.png` | Class distribution, per-target breakdown, pChEMBL histogram |")
report_lines.append("| `02_roc_curves.png` | Per-class ROC curves |")
report_lines.append("| `03_pr_curves.png` | Per-class Precision-Recall curves |")
report_lines.append("| `04_confusion_matrix.png` | Confusion matrices (counts and percentages) |")
report_lines.append("| `05_calibration_curves.png` | Per-class calibration curves |")
report_lines.append("| `06_feature_importance.png` | Top 20 feature importances |")
report_lines.append("| `07_uncertainty.png` | Conformal sets, AD distances, reliability diagram |")
report_lines.append("| `08_test_set_results.png` | Held-out test set results |")
report_lines.append("| `workflow_summary.json` | Machine-readable summary of all metrics |")
report_lines.append("| `test_set_predictions.csv` | Held-out test set predictions with probabilities |")
report_lines.append("| `analysis_report.md` | This report |")
report_lines.append("")

# Write the report
report_path = OUTPUT_DIR / "analysis_report.md"
with open(report_path, "w") as f:
    f.write("\n".join(report_lines))

print(f"Analysis report saved to: {report_path}")
print(f"Report length: {len(report_lines)} lines")