# OFFTOXv3 — Safety Pharmacology & NHR Prediction Workflow

End-to-end workflow for predicting compound activity against **13 safety/toxicology targets** using a 3-class scheme:

| Class | Label | Definition |
|-------|-------|------------|
| 2 | **potent** | pChEMBL >= 5.0 (IC50/Ki < 10 µM) |
| 1 | **less_potent** | 4.0 <= pChEMBL < 5.0 (10–100 µM) |
| 0 | **inactive** | Confirmed inactive (tested >= 10 µM, no activity) |

**Targets covered:**
- Cardiac ion channels: hERG, Cav1.2, Nav1.5
- CYP enzymes: CYP3A4, CYP2D6, CYP1A2
- Nuclear Hormone Receptors: ERa, AR, PR, PPARg, RXRa, PXR, GR

---

## How to use this notebook

1. **Run cells 1–8** to train and evaluate models on the bundled training data.
2. **Cell 9** (Predict New Compounds) accepts any CSV with `compound_id`, `smiles`, and `target` columns.
3. All visualizations render inline. Outputs are also saved to `outputs/`.

---
## 1. Setup & Configuration

In [None]:
# ── Install dependencies (run once) ──────────────────────────────────
# Uncomment the line below if running for the first time:
# !pip install numpy<2 pandas scikit-learn xgboost lightgbm rdkit-pypi matplotlib seaborn scipy joblib

import json
import csv
import time
import warnings
import pickle
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy import stats

from rdkit import Chem
from rdkit.Chem import Crippen, Descriptors, Lipinski, MolSurf, rdFingerprintGenerator
from rdkit.Chem.Scaffolds import MurckoScaffold

from sklearn.calibration import CalibratedClassifierCV, calibration_curve
from sklearn.metrics import (
    average_precision_score,
    confusion_matrix,
    matthews_corrcoef,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
    classification_report,
)
from sklearn.model_selection import RandomizedSearchCV, RepeatedStratifiedKFold
from sklearn.neighbors import NearestNeighbors
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier

warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid", font_scale=1.1)
%matplotlib inline

# ── Paths ────────────────────────────────────────────────────────────
NOTEBOOK_DIR = Path(".").resolve()
DATA_PATH = NOTEBOOK_DIR / "data" / "safety_targets_bioactivity.csv"
OUTPUT_DIR = NOTEBOOK_DIR / "outputs"
MODEL_DIR  = NOTEBOOK_DIR / "model"
OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)

# ── Constants ─────────────────────────────────────────────────────────
RANDOM_STATE = 42
ACTIVITY_CLASS_MAP = {0: "inactive", 1: "less_potent", 2: "potent"}
CLASS_COLORS = {0: "#2ecc71", 1: "#f39c12", 2: "#e74c3c"}
NUM_CLASSES = 3

print(f"Data path  : {DATA_PATH}")
print(f"Output dir : {OUTPUT_DIR}")
print(f"Model dir  : {MODEL_DIR}")
print("Setup complete.")

---
## 2. Data Loading & Exploration

In [None]:
# ── Core data structures ─────────────────────────────────────────────
@dataclass
class SplitData:
    train_idx: np.ndarray
    val_idx: np.ndarray
    test_idx: np.ndarray


def load_and_clean_data(path: Path) -> List[dict]:
    """Load training CSV and assign 3-class activity labels.

    Classes:
        2 - potent:      pChEMBL >= 5.0  (< 10 uM)
        1 - less_potent: 4.0 <= pChEMBL < 5.0  (10-100 uM)
        0 - inactive:    confirmed-inactive compounds
    """
    rows = []
    with path.open(newline="", encoding="utf-8") as fh:
        reader = csv.DictReader(fh)
        for row in reader:
            smi = row.get("canonical_smiles")
            if not smi:
                continue

            raw_class = row.get("activity_class", "")
            if raw_class == "0" or row.get("activity_class_label") == "inactive":
                row["pchembl_value"] = None
                row["activity_class"] = 0
                rows.append(row)
                continue

            if row.get("standard_relation") != "=":
                continue
            if not row.get("pchembl_value"):
                continue
            try:
                pchembl = float(row["pchembl_value"])
            except ValueError:
                continue
            if pchembl < 4.0:
                continue
            row["pchembl_value"] = pchembl
            row["activity_class"] = 2 if pchembl >= 5.0 else 1
            rows.append(row)

    deduped: dict = {}
    for row in rows:
        key = (row.get("molecule_chembl_id"), row.get("target_chembl_id"))
        existing = deduped.get(key)
        if existing is None:
            deduped[key] = row
        else:
            existing_p = existing.get("pchembl_value")
            current_p = row.get("pchembl_value")
            if current_p is not None and (existing_p is None or current_p > existing_p):
                deduped[key] = row
    return list(deduped.values())


# ── Load ──────────────────────────────────────────────────────────────
data = load_and_clean_data(DATA_PATH)
labels_all = np.array([row["activity_class"] for row in data], dtype=int)
targets_all = [row.get("target_common_name", "unknown") for row in data]

print(f"Loaded {len(data)} compound-target records")
print(f"Targets: {sorted(set(targets_all))}")
print(f"\nClass distribution:")
for cls in sorted(ACTIVITY_CLASS_MAP):
    n = int((labels_all == cls).sum())
    print(f"  {cls} ({ACTIVITY_CLASS_MAP[cls]:>12s}): {n:>5d}  ({100*n/len(data):.1f}%)")

In [None]:
# ── Exploratory visualizations ────────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Class distribution
class_counts = Counter(labels_all)
bars = axes[0].bar(
    [ACTIVITY_CLASS_MAP[c] for c in sorted(class_counts)],
    [class_counts[c] for c in sorted(class_counts)],
    color=[CLASS_COLORS[c] for c in sorted(class_counts)],
    edgecolor="black",
)
for bar, c in zip(bars, sorted(class_counts)):
    axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 10,
                 str(class_counts[c]), ha="center", fontweight="bold")
axes[0].set_title("Class Distribution")
axes[0].set_ylabel("Count")

# 2. Per-target breakdown
target_class_df = pd.DataFrame({"target": targets_all, "class": labels_all})
target_order = sorted(set(targets_all))
class_by_target = target_class_df.groupby(["target", "class"]).size().unstack(fill_value=0)
class_by_target = class_by_target.reindex(columns=[0, 1, 2], fill_value=0)
class_by_target.columns = [ACTIVITY_CLASS_MAP[c] for c in class_by_target.columns]
class_by_target.loc[target_order].plot.barh(
    stacked=True, ax=axes[1],
    color=[CLASS_COLORS[0], CLASS_COLORS[1], CLASS_COLORS[2]],
    edgecolor="black",
)
axes[1].set_title("Compounds per Target")
axes[1].set_xlabel("Count")
axes[1].legend(title="Class", loc="lower right")

# 3. pChEMBL distribution (active compounds only)
pchembl_vals = [float(row["pchembl_value"]) for row in data if row["pchembl_value"] is not None]
axes[2].hist(pchembl_vals, bins=30, color="#3498db", edgecolor="black", alpha=0.8)
axes[2].axvline(5.0, color="red", ls="--", lw=2, label="Potent threshold (5.0)")
axes[2].axvline(4.0, color="orange", ls="--", lw=2, label="Less-potent threshold (4.0)")
axes[2].set_title("pChEMBL Value Distribution")
axes[2].set_xlabel("pChEMBL")
axes[2].set_ylabel("Count")
axes[2].legend()

fig.tight_layout()
fig.savefig(OUTPUT_DIR / "01_data_exploration.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved: outputs/01_data_exploration.png")

---
## 3. Feature Engineering

In [None]:
def compute_descriptors(smiles: List[str]) -> Tuple[np.ndarray, List[str]]:
    """Compute 10 physicochemical descriptors per molecule."""
    descriptor_functions = {
        "MW": Descriptors.MolWt,
        "LogP": Crippen.MolLogP,
        "HBA": Lipinski.NumHAcceptors,
        "HBD": Lipinski.NumHDonors,
        "TPSA": MolSurf.TPSA,
        "RotatableBonds": Lipinski.NumRotatableBonds,
        "AromaticRings": Lipinski.NumAromaticRings,
        "HeavyAtoms": Lipinski.HeavyAtomCount,
        "FractionCSP3": Lipinski.FractionCSP3,
        "MolMR": Crippen.MolMR,
    }
    rows = []
    for smi in smiles:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            rows.append([np.nan] * len(descriptor_functions))
            continue
        rows.append([func(mol) for func in descriptor_functions.values()])
    return np.array(rows, dtype=float), list(descriptor_functions.keys())


def compute_morgan_fingerprints(smiles: List[str], n_bits: int = 2048) -> np.ndarray:
    """Compute 2048-bit Morgan fingerprints (ECFP4, radius=2)."""
    gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=n_bits)
    fps = []
    for smi in smiles:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            fps.append(np.zeros(n_bits, dtype=int))
            continue
        fps.append(np.array(gen.GetFingerprint(mol)))
    return np.array(fps)


def build_feature_matrix(
    rows: List[dict],
    selected_columns: Optional[List[str]] = None,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Build combined feature matrix: descriptors + fingerprints + target encoding.

    When selected_columns is None (training), variance filtering is applied and
    the surviving column names are returned.  When selected_columns is provided
    (prediction), the matrix is aligned to those columns.
    """
    smiles = [row["canonical_smiles"] for row in rows]
    targets = [row.get("target_common_name", row.get("target", "")) for row in rows]
    labels = np.array([row.get("activity_class", -1) for row in rows], dtype=int)

    descriptors, desc_names = compute_descriptors(smiles)
    fingerprints = compute_morgan_fingerprints(smiles)
    fp_names = [f"FP_{i}" for i in range(fingerprints.shape[1])]

    target_names = sorted({t for t in targets if t})
    target_map = {name: idx for idx, name in enumerate(target_names)}
    target_matrix = np.zeros((len(rows), len(target_names)), dtype=float)
    for idx, target in enumerate(targets):
        if target in target_map:
            target_matrix[idx, target_map[target]] = 1.0

    feature_matrix = np.concatenate([descriptors, fingerprints, target_matrix], axis=1)
    columns = desc_names + fp_names + [f"target_{n}" for n in target_names]

    if selected_columns is None:
        variances = np.nanvar(feature_matrix, axis=0)
        mask = variances > 0.01
        feature_matrix = np.nan_to_num(feature_matrix[:, mask], nan=0.0)
        selected_columns = [col for col, keep in zip(columns, mask) if keep]
    else:
        col_index = {col: idx for idx, col in enumerate(columns)}
        aligned = np.zeros((len(rows), len(selected_columns)), dtype=float)
        for out_idx, col in enumerate(selected_columns):
            if col in col_index:
                aligned[:, out_idx] = np.nan_to_num(
                    feature_matrix[:, col_index[col]], nan=0.0
                )
        feature_matrix = aligned

    return feature_matrix, labels, selected_columns


# ── Build features ────────────────────────────────────────────────────
print("Computing features (descriptors + 2048-bit Morgan FP + target encoding)...")
t0 = time.time()
features, labels, selected_columns = build_feature_matrix(data)
print(f"  Done in {time.time() - t0:.1f}s")
print(f"  Feature matrix shape: {features.shape}")
print(f"  Features retained after variance filter: {len(selected_columns)}")

---
## 4. Scaffold Split

In [None]:
def scaffold_split(smiles: List[str], y: np.ndarray, random_state: int = 42) -> SplitData:
    """Split data by Murcko scaffold to avoid data leakage (60/20/20)."""
    scaffolds: Dict[str, List[int]] = {}
    for idx, smi in enumerate(smiles):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            scaffold = ""
        else:
            scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol)
        scaffolds.setdefault(scaffold, []).append(idx)

    scaffold_sets = sorted(scaffolds.values(), key=len, reverse=True)
    rng = np.random.default_rng(random_state)
    rng.shuffle(scaffold_sets)

    n_total = len(smiles)
    n_train = int(0.6 * n_total)
    n_val = int(0.2 * n_total)

    train_idx, val_idx, test_idx = [], [], []
    for group in scaffold_sets:
        if len(train_idx) + len(group) <= n_train:
            train_idx.extend(group)
        elif len(val_idx) + len(group) <= n_val:
            val_idx.extend(group)
        else:
            test_idx.extend(group)

    return SplitData(
        train_idx=np.array(train_idx),
        val_idx=np.array(val_idx),
        test_idx=np.array(test_idx),
    )


split = scaffold_split([row["canonical_smiles"] for row in data], labels, RANDOM_STATE)

X_train, y_train = features[split.train_idx], labels[split.train_idx]
X_val, y_val     = features[split.val_idx],   labels[split.val_idx]
X_test, y_test   = features[split.test_idx],  labels[split.test_idx]

print(f"Train : {len(X_train):>5d}  |  Val : {len(X_val):>5d}  |  Test : {len(X_test):>5d}")
for name, y_sub in [("Train", y_train), ("Val", y_val), ("Test", y_test)]:
    counts = Counter(y_sub)
    parts = ", ".join(f"{ACTIVITY_CLASS_MAP[c]}={counts.get(c, 0)}" for c in range(NUM_CLASSES))
    print(f"  {name:>5s}: {parts}")

---
## 5. Model Training & Cross-Validation

In [None]:
def ece_score_fn(y_true, y_prob, n_bins=10):
    """Expected Calibration Error and Maximum Calibration Error."""
    bins = np.linspace(0, 1, n_bins + 1)
    binids = np.digitize(y_prob, bins) - 1
    ece, mce = 0.0, 0.0
    for i in range(n_bins):
        mask = binids == i
        if not np.any(mask):
            continue
        avg_conf = y_prob[mask].mean()
        avg_acc = y_true[mask].mean()
        gap = abs(avg_conf - avg_acc)
        ece += gap * mask.mean()
        mce = max(mce, gap)
    return ece, mce


def get_models(random_state: int) -> Dict[str, Tuple[Pipeline, Dict[str, list]]]:
    """Return three model pipelines with hyperparameter search spaces."""
    return {
        "RandomForest": (
            Pipeline([
                ("scaler", StandardScaler(with_mean=False)),
                ("model", RandomForestClassifier(random_state=random_state, n_jobs=-1)),
            ]),
            {
                "model__n_estimators": [200, 500],
                "model__max_depth": [10, 20, None],
                "model__min_samples_split": [2, 5, 10],
                "model__max_features": ["sqrt", "log2", 0.3],
                "model__class_weight": ["balanced", None],
            },
        ),
        "XGBoost": (
            Pipeline([
                ("scaler", StandardScaler(with_mean=False)),
                ("model", XGBClassifier(
                    random_state=random_state,
                    objective="multi:softprob",
                    num_class=NUM_CLASSES,
                    eval_metric="mlogloss",
                    n_jobs=-1,
                    verbosity=0,
                )),
            ]),
            {
                "model__n_estimators": [200, 500],
                "model__max_depth": [3, 5, 7],
                "model__learning_rate": [0.01, 0.05, 0.1],
                "model__subsample": [0.6, 0.8, 1.0],
                "model__colsample_bytree": [0.6, 0.8, 1.0],
            },
        ),
        "LightGBM": (
            Pipeline([
                ("scaler", StandardScaler(with_mean=False)),
                ("model", LGBMClassifier(
                    random_state=random_state, n_jobs=-1, verbose=-1,
                    objective="multiclass",
                    num_class=NUM_CLASSES,
                )),
            ]),
            {
                "model__n_estimators": [200, 500],
                "model__max_depth": [-1, 5, 10],
                "model__learning_rate": [0.01, 0.05, 0.1],
                "model__num_leaves": [31, 63, 127],
                "model__subsample": [0.6, 0.8, 1.0],
            },
        ),
    }


# ── Train all models ──────────────────────────────────────────────────
models = get_models(RANDOM_STATE)
cv = RepeatedStratifiedKFold(n_splits=3, n_repeats=2, random_state=RANDOM_STATE)

cv_summary = []
best_estimators = {}
calibration_metrics = {}
fold_scores: Dict[str, List[float]] = {}
train_times = {}

for name, (pipeline, param_grid) in models.items():
    print(f"\n{'='*60}")
    print(f"Training: {name}")
    print(f"{'='*60}")

    # Hyperparameter search
    search = RandomizedSearchCV(
        pipeline,
        param_distributions=param_grid,
        n_iter=5,
        scoring="roc_auc_ovr",
        cv=3,
        random_state=RANDOM_STATE,
        n_jobs=1,  # avoid over-subscription with model-level n_jobs=-1
    )
    t0 = time.time()
    search.fit(X_train, y_train)
    train_times[name] = time.time() - t0
    best_estimators[name] = search.best_estimator_
    print(f"  Best params: {search.best_params_}")
    print(f"  Train time : {train_times[name]:.1f}s")

    # Cross-validation evaluation
    scores, pr_scores, mcc_scores = [], [], []
    for train_idx, test_idx in cv.split(X_train, y_train):
        X_tr, X_te = X_train[train_idx], X_train[test_idx]
        y_tr, y_te = y_train[train_idx], y_train[test_idx]
        est = search.best_estimator_
        est.fit(X_tr, y_tr)
        probs = est.predict_proba(X_te)
        preds = est.predict(X_te)
        scores.append(roc_auc_score(y_te, probs, multi_class="ovr", average="macro"))
        pr_per = [average_precision_score((y_te == c).astype(int), probs[:, c])
                  for c in range(NUM_CLASSES) if (y_te == c).sum() > 0]
        pr_scores.append(float(np.mean(pr_per)) if pr_per else 0.0)
        mcc_scores.append(matthews_corrcoef(y_te, preds))

    cv_summary.append({
        "model": name,
        "roc_auc_mean": np.mean(scores), "roc_auc_std": np.std(scores),
        "pr_auc_mean": np.mean(pr_scores), "pr_auc_std": np.std(pr_scores),
        "mcc_mean": np.mean(mcc_scores), "mcc_std": np.std(mcc_scores),
    })
    fold_scores[name] = scores

    # Validation calibration
    val_probs = search.best_estimator_.predict_proba(X_val)
    val_probs_true = val_probs[np.arange(len(y_val)), y_val]
    ece_val, _ = ece_score_fn(np.ones(len(y_val)), val_probs_true)
    calibration_metrics[name] = ece_val

    print(f"  CV ROC-AUC : {np.mean(scores):.4f} +/- {np.std(scores):.4f}")
    print(f"  CV PR-AUC  : {np.mean(pr_scores):.4f} +/- {np.std(pr_scores):.4f}")
    print(f"  CV MCC     : {np.mean(mcc_scores):.4f} +/- {np.std(mcc_scores):.4f}")

print(f"\nTraining complete. {len(models)} models evaluated.")

---
## 6. Model Evaluation & Visualizations

In [None]:
# ── Select best model & refit on train+val ────────────────────────────
cv_summary_sorted = sorted(cv_summary, key=lambda r: r["roc_auc_mean"], reverse=True)
best_model_name = cv_summary_sorted[0]["model"]
best_model = best_estimators[best_model_name]
best_model.fit(np.vstack([X_train, X_val]), np.hstack([y_train, y_val]))

print("Cross-validation summary (sorted by ROC-AUC):")
print("-" * 75)
print(f"{'Model':<15s} {'ROC-AUC':>18s} {'PR-AUC':>18s} {'MCC':>18s}")
print("-" * 75)
for row in cv_summary_sorted:
    print(f"{row['model']:<15s} "
          f"{row['roc_auc_mean']:.4f} +/- {row['roc_auc_std']:.4f}  "
          f"{row['pr_auc_mean']:.4f} +/- {row['pr_auc_std']:.4f}  "
          f"{row['mcc_mean']:.4f} +/- {row['mcc_std']:.4f}")
print("-" * 75)
print(f"Best model: {best_model_name}")

# ── Test-set evaluation ───────────────────────────────────────────────
test_probs = best_model.predict_proba(X_test)
test_preds = best_model.predict(X_test)
test_roc = roc_auc_score(y_test, test_probs, multi_class="ovr", average="macro")
pr_per = [average_precision_score((y_test == c).astype(int), test_probs[:, c])
          for c in range(NUM_CLASSES) if (y_test == c).sum() > 0]
test_pr = float(np.mean(pr_per)) if pr_per else 0.0
test_mcc = matthews_corrcoef(y_test, test_preds)

# Calibration
calibrated = CalibratedClassifierCV(best_model, method="isotonic", cv=3)
calibrated.fit(np.vstack([X_train, X_val]), np.hstack([y_train, y_val]))
cal_probs = calibrated.predict_proba(X_test)
cal_probs_true = cal_probs[np.arange(len(y_test)), y_test]
ece, mce = ece_score_fn(np.ones(len(y_test)), cal_probs_true)

print(f"\nTest Set Metrics ({best_model_name}):")
print(f"  ROC-AUC (macro) : {test_roc:.4f}")
print(f"  PR-AUC  (macro) : {test_pr:.4f}")
print(f"  MCC             : {test_mcc:.4f}")
print(f"  ECE (calibrated): {ece:.4f}")
print(f"  MCE (calibrated): {mce:.4f}")
print(f"\nClassification Report:")
print(classification_report(
    y_test, test_preds,
    target_names=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)],
    digits=3,
))

In [None]:
# ── Figure 2: ROC Curves (per-class) ──────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for cls in range(NUM_CLASSES):
    label = ACTIVITY_CLASS_MAP[cls]
    binary_true = (y_test == cls).astype(int)
    if binary_true.sum() == 0:
        axes[cls].set_title(f"ROC — {label} (no samples)")
        continue
    fpr, tpr, _ = roc_curve(binary_true, test_probs[:, cls])
    auc_val = roc_auc_score(binary_true, test_probs[:, cls])
    axes[cls].plot(fpr, tpr, color=CLASS_COLORS[cls], lw=2,
                   label=f"AUC = {auc_val:.3f}")
    axes[cls].plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5)
    axes[cls].set_xlabel("False Positive Rate")
    axes[cls].set_ylabel("True Positive Rate")
    axes[cls].set_title(f"ROC — {label}")
    axes[cls].legend(loc="lower right", fontsize=12)
    axes[cls].set_xlim([-0.02, 1.02])
    axes[cls].set_ylim([-0.02, 1.02])

fig.suptitle(f"Per-Class ROC Curves ({best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "02_roc_curves.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Figure 3: Precision-Recall Curves ────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for cls in range(NUM_CLASSES):
    label = ACTIVITY_CLASS_MAP[cls]
    binary_true = (y_test == cls).astype(int)
    if binary_true.sum() == 0:
        axes[cls].set_title(f"PR — {label} (no samples)")
        continue
    precision, recall, _ = precision_recall_curve(binary_true, test_probs[:, cls])
    ap = average_precision_score(binary_true, test_probs[:, cls])
    axes[cls].plot(recall, precision, color=CLASS_COLORS[cls], lw=2,
                   label=f"AP = {ap:.3f}")
    baseline = binary_true.mean()
    axes[cls].axhline(baseline, color="gray", ls="--", lw=1, alpha=0.5,
                      label=f"Baseline = {baseline:.3f}")
    axes[cls].set_xlabel("Recall")
    axes[cls].set_ylabel("Precision")
    axes[cls].set_title(f"PR — {label}")
    axes[cls].legend(loc="upper right", fontsize=11)
    axes[cls].set_xlim([-0.02, 1.02])
    axes[cls].set_ylim([-0.02, 1.02])

fig.suptitle(f"Per-Class Precision-Recall Curves ({best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "03_pr_curves.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Figure 4: Confusion Matrix ───────────────────────────────────────
cm = confusion_matrix(y_test, test_preds, labels=list(range(NUM_CLASSES)))
cm_pct = cm.astype(float) / cm.sum(axis=1, keepdims=True) * 100

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=axes[0],
            xticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)],
            yticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)])
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("Actual")
axes[0].set_title("Confusion Matrix (counts)")

# Percentages
sns.heatmap(cm_pct, annot=True, fmt=".1f", cmap="Blues", ax=axes[1],
            xticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)],
            yticklabels=[ACTIVITY_CLASS_MAP[c] for c in range(NUM_CLASSES)])
axes[1].set_xlabel("Predicted")
axes[1].set_ylabel("Actual")
axes[1].set_title("Confusion Matrix (% per row)")

fig.suptitle(f"Test Set Confusion Matrix ({best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "04_confusion_matrix.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Figure 5: Calibration Curves ─────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for cls in range(NUM_CLASSES):
    label = ACTIVITY_CLASS_MAP[cls]
    binary_true = (y_test == cls).astype(int)
    cls_cal_probs = cal_probs[:, cls]
    if binary_true.sum() == 0:
        axes[cls].set_title(f"Calibration — {label} (no samples)")
        continue
    prob_true, prob_pred = calibration_curve(binary_true, cls_cal_probs, n_bins=10)
    axes[cls].plot(prob_pred, prob_true, "o-", color=CLASS_COLORS[cls], lw=2,
                   label=f"{label}")
    axes[cls].plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5, label="Perfectly calibrated")
    axes[cls].set_xlabel("Mean Predicted Probability")
    axes[cls].set_ylabel("Fraction of Positives")
    axes[cls].set_title(f"Calibration — {label}")
    axes[cls].legend(loc="upper left", fontsize=11)
    axes[cls].set_xlim([-0.02, 1.02])
    axes[cls].set_ylim([-0.02, 1.02])

fig.suptitle(f"Calibration Curves (isotonic, {best_model_name})", fontsize=14, y=1.02)
fig.tight_layout()
fig.savefig(OUTPUT_DIR / "05_calibration_curves.png", dpi=150, bbox_inches="tight")
plt.show()
print(f"ECE = {ece:.4f}, MCE = {mce:.4f}")

In [None]:
# ── Figure 6: Feature Importance (top 20) ────────────────────────────
if hasattr(best_model.named_steps["model"], "feature_importances_"):
    importances = best_model.named_steps["model"].feature_importances_
    top_k = 20
    indices = np.argsort(importances)[-top_k:]
    top_features = [selected_columns[i] for i in indices]
    top_values = importances[indices]

    fig, ax = plt.subplots(figsize=(10, 7))
    ax.barh(range(top_k), top_values, color="#3498db", edgecolor="black")
    ax.set_yticks(range(top_k))
    ax.set_yticklabels(top_features)
    ax.set_xlabel("Importance")
    ax.set_title(f"Top {top_k} Feature Importances ({best_model_name})")
    fig.tight_layout()
    fig.savefig(OUTPUT_DIR / "06_feature_importance.png", dpi=150, bbox_inches="tight")
    plt.show()
else:
    print("Feature importances not available for this model type.")

---
## 7. Statistical Comparison & MCDA

In [None]:
# ── Paired t-tests with Bonferroni correction ────────────────────────
model_names = [row["model"] for row in cv_summary_sorted]
stat_rows = []
for i, model_a in enumerate(model_names):
    for model_b in model_names[i + 1:]:
        sa = np.array(fold_scores.get(model_a, []))
        sb = np.array(fold_scores.get(model_b, []))
        if len(sa) == 0 or len(sb) == 0:
            continue
        t_stat, p_val = stats.ttest_rel(sa, sb)
        pooled = np.std(np.concatenate([sa, sb]))
        cohen_d = (sa.mean() - sb.mean()) / pooled if pooled else 0.0
        stat_rows.append({
            "Model A": model_a, "Model B": model_b,
            "t-stat": t_stat, "p-value": p_val, "Cohen's d": cohen_d,
        })

if stat_rows:
    bonferroni = 0.05 / len(stat_rows)
    for r in stat_rows:
        r["Significant"] = "Yes" if r["p-value"] < bonferroni else "No"
        r["Bonferroni alpha"] = bonferroni

stat_df = pd.DataFrame(stat_rows)
print("Statistical Comparison (paired t-test on CV ROC-AUC folds):")
print(f"Bonferroni-corrected alpha = {bonferroni:.4f}")
display(stat_df.style.format({
    "t-stat": "{:.4f}", "p-value": "{:.6f}",
    "Cohen's d": "{:.4f}", "Bonferroni alpha": "{:.4f}",
}).set_caption("Pairwise Model Comparison"))

# ── MCDA Ranking ─────────────────────────────────────────────────────
mcda_rows = []
for row in cv_summary_sorted:
    name = row["model"]
    mcda_rows.append({
        "model": name,
        "roc_auc": row["roc_auc_mean"],
        "pr_auc": row["pr_auc_mean"],
        "calibration": max(0.0, 1 - calibration_metrics.get(name, ece)),
        "robustness": max(0.0, 1 - row["roc_auc_std"]),
        "efficiency": 1.0 / (1.0 + train_times.get(name, 1.0)),
        "interpretability": 1.0 if name in {"RandomForest", "LightGBM", "XGBoost"} else 0.5,
    })

weights = {
    "roc_auc": 0.25, "pr_auc": 0.20, "calibration": 0.20,
    "robustness": 0.15, "efficiency": 0.10, "interpretability": 0.10,
}
for metric in weights:
    vals = [r[metric] for r in mcda_rows]
    mn, mx = min(vals), max(vals)
    for r in mcda_rows:
        r[metric] = (r[metric] - mn) / (mx - mn) if mx > mn else 1.0
for r in mcda_rows:
    r["composite"] = sum(r[m] * w for m, w in weights.items())
mcda_rows = sorted(mcda_rows, key=lambda r: r["composite"], reverse=True)

print("\nMCDA Ranking:")
mcda_df = pd.DataFrame(mcda_rows)
display(mcda_df.style.format("{:.4f}", subset=mcda_df.columns[1:]).set_caption(
    "Multi-Criteria Decision Analysis"
))

---
## 8. Uncertainty Quantification

In [None]:
# ── Conformal Prediction ──────────────────────────────────────────────
def conformal_prediction(probs, y_true, alpha=0.05):
    scores = 1.0 - probs[np.arange(len(y_true)), y_true]
    q = np.quantile(scores, 1 - alpha, method="higher")
    prediction_sets = probs >= (1.0 - q)
    coverage = prediction_sets[np.arange(len(y_true)), y_true].mean()
    return prediction_sets, coverage, q

pred_sets, coverage, q_threshold = conformal_prediction(cal_probs, y_test)
set_sizes = pred_sets.sum(axis=1)

# ── Applicability Domain (k-NN distance) ─────────────────────────────
nn = NearestNeighbors(n_neighbors=5)
nn.fit(X_train)
train_dists = nn.kneighbors(X_train)[0].mean(axis=1)
ad_threshold = np.percentile(train_dists, 95)
test_dists = nn.kneighbors(X_test)[0].mean(axis=1)
ood_rate = (test_dists > ad_threshold).mean()

print(f"Conformal Prediction (alpha=0.05):")
print(f"  Coverage          : {coverage:.4f}  (target: 0.95)")
print(f"  Avg set size      : {set_sizes.mean():.2f}")
print(f"  Quantile threshold: {q_threshold:.4f}")
print(f"\nApplicability Domain:")
print(f"  AD threshold (95th pct): {ad_threshold:.4f}")
print(f"  Out-of-domain rate     : {ood_rate:.2%}")

# ── Figure 7: Uncertainty plots ──────────────────────────────────────
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Set size distribution
unique_sizes, counts = np.unique(set_sizes, return_counts=True)
axes[0].bar(unique_sizes.astype(str), counts, color="#9b59b6", edgecolor="black")
for s, c in zip(unique_sizes, counts):
    axes[0].text(str(s), c + 2, str(c), ha="center", fontweight="bold")
axes[0].set_xlabel("Prediction Set Size")
axes[0].set_ylabel("Count")
axes[0].set_title(f"Conformal Set Sizes (coverage={coverage:.2%})")

# Distance to training set
axes[1].hist(test_dists, bins=30, color="#1abc9c", edgecolor="black", alpha=0.8,
             label="Test compounds")
axes[1].axvline(ad_threshold, color="red", ls="--", lw=2,
                label=f"AD threshold ({ad_threshold:.2f})")
axes[1].set_xlabel("Mean k-NN Distance")
axes[1].set_ylabel("Count")
axes[1].set_title(f"Applicability Domain (OOD={ood_rate:.1%})")
axes[1].legend()

# Confidence vs correctness
max_probs = test_probs.max(axis=1)
correct = (test_preds == y_test)
bins_edge = np.linspace(0, 1, 11)
bin_accs, bin_confs = [], []
for lo, hi in zip(bins_edge[:-1], bins_edge[1:]):
    mask = (max_probs >= lo) & (max_probs < hi)
    if mask.sum() > 0:
        bin_accs.append(correct[mask].mean())
        bin_confs.append(max_probs[mask].mean())
axes[2].plot(bin_confs, bin_accs, "o-", color="#e67e22", lw=2, label="Model")
axes[2].plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5, label="Perfect")
axes[2].set_xlabel("Mean Confidence")
axes[2].set_ylabel("Accuracy")
axes[2].set_title("Reliability Diagram")
axes[2].legend()
axes[2].set_xlim([-0.02, 1.02])
axes[2].set_ylim([-0.02, 1.02])

fig.tight_layout()
fig.savefig(OUTPUT_DIR / "07_uncertainty.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# ── Save model artifacts for reuse ────────────────────────────────────
model_artifacts = {
    "best_model": best_model,
    "calibrated_model": calibrated,
    "selected_columns": selected_columns,
    "activity_class_map": ACTIVITY_CLASS_MAP,
    "num_classes": NUM_CLASSES,
    "ad_threshold": ad_threshold,
    "nn_model": nn,
    "conformal_q": q_threshold,
    "best_model_name": best_model_name,
}
model_path = MODEL_DIR / "safety_model.pkl"
with open(model_path, "wb") as fh:
    pickle.dump(model_artifacts, fh)
print(f"Model saved to: {model_path}")
print(f"  Model type       : {best_model_name}")
print(f"  Feature columns  : {len(selected_columns)}")
print(f"  AD threshold     : {ad_threshold:.4f}")
print(f"  Conformal q      : {q_threshold:.4f}")

# ── Save summary JSON ─────────────────────────────────────────────────
class_counts_dict = {ACTIVITY_CLASS_MAP[c]: int((labels == c).sum()) for c in range(NUM_CLASSES)}
summary = {
    "n_compounds": len(data),
    "targets": sorted(set(targets_all)),
    "class_distribution": class_counts_dict,
    "train_size": len(X_train),
    "val_size": len(X_val),
    "test_size": len(X_test),
    "best_model": best_model_name,
    "test_metrics": {
        "roc_auc_macro": test_roc,
        "pr_auc_macro": test_pr,
        "mcc": test_mcc,
        "ece": ece,
        "mce": mce,
    },
    "conformal_coverage": coverage,
    "avg_prediction_set_size": float(set_sizes.mean()),
    "out_of_domain_rate": ood_rate,
}
with open(OUTPUT_DIR / "workflow_summary.json", "w") as fh:
    json.dump(summary, fh, indent=2)
print(f"\nWorkflow summary saved to: {OUTPUT_DIR / 'workflow_summary.json'}")

---
## 9. Predict New Compounds

Provide a CSV file with these columns:

| Column | Description |
|--------|-------------|
| `compound_id` | Your identifier for the compound |
| `smiles` | SMILES string |
| `target` | One of the trained targets (e.g. `hERG`, `CYP3A4`) |

An example file is provided at `data/example_predictions.csv`.

The output will include:
- `prob_inactive`, `prob_less_potent`, `prob_potent` — predicted class probabilities
- `predicted_class` / `predicted_label` — most likely class
- `conformal_set` — set of plausible classes at 95% confidence
- `in_domain` — whether the compound falls within the model's applicability domain
- `max_confidence` — highest class probability (a rough quality indicator)

In [None]:
# ── SET YOUR INPUT FILE HERE ─────────────────────────────────────────
PREDICTION_CSV = NOTEBOOK_DIR / "data" / "validation_compounds.csv"
# ─────────────────────────────────────────────────────────────────────

def predict_compounds(csv_path: Path, model_path: Path) -> pd.DataFrame:
    """Load a trained model and predict on new compounds from a CSV file.

    Parameters
    ----------
    csv_path : Path
        CSV with columns: compound_id, smiles, target
    model_path : Path
        Path to the saved safety_model.pkl

    Returns
    -------
    pd.DataFrame
        Predictions with probabilities, conformal sets, and AD flags.
    """
    # Load model artifacts
    with open(model_path, "rb") as fh:
        arts = pickle.load(fh)

    model = arts["best_model"]
    sel_cols = arts["selected_columns"]
    class_map = arts["activity_class_map"]
    n_cls = arts["num_classes"]
    ad_thresh = arts["ad_threshold"]
    nn_model = arts["nn_model"]
    conf_q = arts["conformal_q"]

    # Read user CSV
    user_df = pd.read_csv(csv_path)
    required = {"compound_id", "smiles", "target"}
    missing = required - set(user_df.columns)
    if missing:
        raise ValueError(f"CSV is missing required columns: {missing}")

    # Validate SMILES
    valid_mask = []
    for smi in user_df["smiles"]:
        mol = Chem.MolFromSmiles(smi)
        valid_mask.append(mol is not None)
    user_df["valid_smiles"] = valid_mask
    n_invalid = (~user_df["valid_smiles"]).sum()
    if n_invalid > 0:
        print(f"WARNING: {n_invalid} compounds have invalid SMILES and will get NaN predictions.")

    # Preserve extra columns for downstream use
    extra_cols = [c for c in user_df.columns
                  if c not in {"compound_id", "smiles", "target", "valid_smiles"}]

    # Build rows in the format expected by build_feature_matrix
    pred_rows = []
    for _, row in user_df.iterrows():
        pred_rows.append({
            "canonical_smiles": row["smiles"],
            "target_common_name": row["target"],
            "activity_class": -1,  # unknown
        })

    X_pred, _, _ = build_feature_matrix(pred_rows, selected_columns=sel_cols)

    # Predict
    probs = model.predict_proba(X_pred)
    preds = model.predict(X_pred)

    # Applicability domain
    dists = nn_model.kneighbors(X_pred)[0].mean(axis=1)
    in_domain = dists <= ad_thresh

    # Conformal prediction sets
    conformal_sets = probs >= (1.0 - conf_q)

    # Build output DataFrame
    results = user_df[["compound_id", "smiles", "target"]].copy()
    # Copy extra metadata columns (compound_name, known_activity, etc.)
    for col in extra_cols:
        results[col] = user_df[col].values
    for c in range(n_cls):
        results[f"prob_{class_map[c]}"] = probs[:, c]
    results["predicted_class"] = preds
    results["predicted_label"] = [class_map.get(int(p), "unknown") for p in preds]
    results["max_confidence"] = probs.max(axis=1)
    results["conformal_set"] = [
        "{"+", ".join(class_map[c] for c in range(n_cls) if cs[c])+"}"
        for cs in conformal_sets
    ]
    results["in_domain"] = in_domain
    results["knn_distance"] = dists
    results["valid_smiles"] = user_df["valid_smiles"]

    return results


# ── Run predictions ───────────────────────────────────────────────────
if PREDICTION_CSV.exists():
    results = predict_compounds(PREDICTION_CSV, model_path)

    # Save
    out_path = OUTPUT_DIR / "predictions.csv"
    results.to_csv(out_path, index=False)
    print(f"Predictions saved to: {out_path}")
    print(f"Compounds predicted: {len(results)}")
    print(f"In-domain: {results['in_domain'].sum()}/{len(results)}")
    print()
    display(results.style.format({
        "prob_inactive": "{:.4f}",
        "prob_less_potent": "{:.4f}",
        "prob_potent": "{:.4f}",
        "max_confidence": "{:.4f}",
        "knn_distance": "{:.4f}",
    }).set_caption("Predictions"))
else:
    print(f"No prediction file found at: {PREDICTION_CSV}")
    print("Create a CSV with columns: compound_id, smiles, target")
    print(f"and place it at {PREDICTION_CSV}")

In [None]:
# ── Figure 8: Prediction summary visualization ───────────────────────
if PREDICTION_CSV.exists():
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Class distribution of predictions
    pred_counts = results["predicted_label"].value_counts()
    label_order = ["inactive", "less_potent", "potent"]
    pred_counts = pred_counts.reindex(label_order, fill_value=0)
    bars = axes[0].bar(
        pred_counts.index, pred_counts.values,
        color=[CLASS_COLORS[i] for i in range(NUM_CLASSES)],
        edgecolor="black",
    )
    for bar, val in zip(bars, pred_counts.values):
        axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.2,
                     str(val), ha="center", fontweight="bold")
    axes[0].set_title("Predicted Class Distribution")
    axes[0].set_ylabel("Count")

    # Confidence distribution
    axes[1].hist(results["max_confidence"], bins=20, color="#3498db",
                 edgecolor="black", alpha=0.8)
    axes[1].axvline(0.5, color="red", ls="--", lw=1.5, label="50% threshold")
    axes[1].set_xlabel("Max Class Probability")
    axes[1].set_ylabel("Count")
    axes[1].set_title("Prediction Confidence")
    axes[1].legend()

    # Domain coverage
    domain_counts = results["in_domain"].value_counts()
    axes[2].pie(
        [domain_counts.get(True, 0), domain_counts.get(False, 0)],
        labels=["In Domain", "Out of Domain"],
        colors=["#2ecc71", "#e74c3c"],
        autopct="%1.1f%%",
        startangle=90,
        textprops={"fontsize": 12},
    )
    axes[2].set_title("Applicability Domain")

    fig.suptitle("Prediction Summary", fontsize=14, y=1.02)
    fig.tight_layout()
    fig.savefig(OUTPUT_DIR / "08_prediction_summary.png", dpi=150, bbox_inches="tight")
    plt.show()

---
## 10. Load a Saved Model (Standalone Prediction)

If you already trained and saved a model (Section 8), you can skip training entirely
and jump straight to predictions. Just run the cell below with your CSV path.

In [None]:
# ── Standalone prediction from saved model ────────────────────────────
# Set YOUR_CSV below and run this cell.
# No training cells need to be run first.

YOUR_CSV = NOTEBOOK_DIR / "data" / "example_predictions.csv"  # <-- change this
SAVED_MODEL = MODEL_DIR / "safety_model.pkl"

if SAVED_MODEL.exists() and YOUR_CSV.exists():
    standalone_results = predict_compounds(YOUR_CSV, SAVED_MODEL)
    standalone_results.to_csv(OUTPUT_DIR / "standalone_predictions.csv", index=False)
    print(f"Predictions saved to: {OUTPUT_DIR / 'standalone_predictions.csv'}")
    display(standalone_results)
elif not SAVED_MODEL.exists():
    print(f"No saved model found at {SAVED_MODEL}. Run training cells first.")
else:
    print(f"No input CSV found at {YOUR_CSV}. Update the YOUR_CSV variable above.")