In [5]:
#analyze.py — analysis and visualization utilities

import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from textwrap import fill

LIMITS = {"MW": 500, "XLogP": 5, "HBA": 10, "HBD": 5, "TPSA": 140, "RotB": 10}

DESCRIPTORS = [("mass", "MW"), ("logP", "XLogP"), ("H_acceptors", "HBA"), ("H_donors", "HBD"), ("TPSA", "TPSA"), ("rotatable_bonds", "RotB")]


# ---------- Radar chart ----------

def radar_for_row(row: pd.Series, outdir: str, idx: int):
    values, limit_norm = [], []

    for col, label in DESCRIPTORS:
        limit_val = LIMITS[label]
        vmax = limit_val * 1.2
        v = pd.to_numeric(row.get(col), errors="coerce")
        v = 0.0 if pd.isna(v) else min(float(v), vmax) / vmax
        values.append(v)
        limit_norm.append(limit_val / vmax)

    angles = np.linspace(0, 2 * np.pi, len(DESCRIPTORS), endpoint=False)
    angles = np.r_[angles, angles[0]]
    values = np.r_[values, values[0]]
    thr = np.r_[limit_norm, limit_norm[0]]

    fig = plt.figure(figsize=(7, 4))
    ax_radar = fig.add_subplot(1, 2, 1, projection="polar")
    ax_info = fig.add_subplot(1, 2, 2)
    ax_info.axis("off")

    ax_radar.plot(angles, thr, color="gray")
    ax_radar.fill(angles, thr, color="gray", alpha=0.15)
    ax_radar.plot(angles, values, color="C0")
    ax_radar.fill(angles, values, color="C0", alpha=0.1)
    ax_radar.set_xticks(angles[:-1])
    ax_radar.set_xticklabels([lbl for _, lbl in DESCRIPTORS])
    ax_radar.set_ylim(0, 1.0)
    ax_radar.set_yticks(np.linspace(0.2, 1.0, 5))
    ax_radar.set_yticklabels([])

    # --- Info table ---
    title = f"{row.get('name') or row.get('input') or 'compound'} (cid={row.get('cid')})"
    ax_info.set_title(fill(title, width=36), loc="left", pad=8, fontsize=8)

    def fmt(x, nd=2, as_int=False):
        x = pd.to_numeric(x, errors="coerce")
        if pd.isna(x):
            return "—"
        return f"{int(x)}" if as_int else f"{float(x):.{nd}f}"

    table_rows = [
        ("Molecular weight", fmt(row.get("mass"), nd=2)),
        ("LogP", fmt(row.get("logP"), nd=2)),
        ("H_acceptors", fmt(row.get("H_acceptors"), as_int=True)),
        ("H_donors", fmt(row.get("H_donors"), as_int=True)),
        ("TPSA", fmt(row.get("TPSA"), nd=2)),
        ("Rotatable bonds", fmt(row.get("rotatable_bonds"), as_int=True)),
    ]

    tbl = ax_info.table(
        cellText=table_rows,
        colLabels=["Property", "Value"],
        loc="upper left",
        colWidths=[0.64, 0.36],
        edges="closed",
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.0, 1.15)
    for (r, c), cell in tbl.get_celld().items():
        if r == 0:
            cell.set_text_props(fontweight="bold", ha="center")
        else:
            cell.set_text_props(ha=("left" if c == 0 else "right"))

    # Save figure
    fname = f"radar_{idx:03d}_cid{row.get('cid') or 'na'}.png"
    fig.savefig(os.path.join(outdir, fname), dpi=120, bbox_inches="tight")
    plt.close(fig)


# ---------- Scatter plot with threshold line ---------- 

def strip_scatter_static(df: pd.DataFrame, col: str, threshold: float,
                         out_path: str, xlabel: str):

    s = pd.to_numeric(df[col], errors="coerce")
    mask = s.notna()
    if mask.sum() == 0:
        return

    x = s[mask].astype(float).values
    y = (np.random.rand(len(x)) - 0.5) * 0.15
    ids = (
        df.loc[mask, "cid"].astype(str).values
        if "cid" in df.columns else np.array(["NA"] * mask.sum())
    )
    pass_mask = x <= float(threshold)

    fig, ax = plt.subplots()
    fig.subplots_adjust(right=0.80)

    ax.scatter(x[pass_mask], y[pass_mask], alpha=0.85, label=f"≤ {threshold}")
    ax.scatter(x[~pass_mask], y[~pass_mask], alpha=0.95, color="red", label=f"> {threshold}")

    # Add text labels for each point
    for xi, yi, cid in zip(x, y, ids):
        ax.text(xi, yi, cid, fontsize=8, ha="left", va="bottom")

    # X-limits with padding
    left, right = float(np.min(x)), max(float(np.max(x)), float(threshold))
    pad = 0.05 * (right - left if right > left else (abs(right) + 1))
    ax.set_xlim(left - pad, right + pad)
    ax.axvline(float(threshold))
    ax.axvspan(float(threshold), ax.get_xlim()[1], alpha=0.12, color="red", zorder=0)

    ax.set_yticks([])
    ax.set_xlabel(xlabel)
    ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), borderaxespad=0.0)

    fig.savefig(out_path)
    plt.close(fig)


# ---------- Add rule columns and generate plots ----------  

def add_rules_and_plots(df: pd.DataFrame,
                        plot_dir: str,
                        radar_dir: str) -> pd.DataFrame:
    if df.empty:
        return df

    ok = df.copy()

    # Convert numeric columns
    for c in ["mass", "exact_mass", "logP", "TPSA",
              "H_donors", "H_acceptors",
              "rotatable_bonds", "heavy_atoms", "charge"]:
        ok[c] = pd.to_numeric(ok[c], errors="coerce")

    # --- Lipinski rule ---
    def lipinski_row(r):
        issues = []
        if pd.notna(r["mass"]) and r["mass"] > LIMITS["MW"]:
            issues.append("MW>500")
        if pd.notna(r["logP"]) and r["logP"] > LIMITS["XLogP"]:
            issues.append("LogP>5")
        if pd.notna(r["H_donors"]) and r["H_donors"] > LIMITS["HBD"]:
            issues.append("HBD>5")
        if pd.notna(r["H_acceptors"]) and r["H_acceptors"] > LIMITS["HBA"]:
            issues.append("HBA>10")
        return "; ".join(issues) if issues else "ok"

    ok["lipinski_violations"] = ok.apply(lipinski_row, axis=1)
    ok["lipinski_pass"] = ok[["mass", "logP", "H_donors", "H_acceptors"]].notna().all(axis=1) & (
        ok["lipinski_violations"] == "ok"
    )

    # --- Veber rule ---
    ok["veber_pass"] = ok[["TPSA", "rotatable_bonds"]].notna().all(axis=1) & (
        ok["TPSA"] <= LIMITS["TPSA"]
    ) & (ok["rotatable_bonds"] <= LIMITS["RotB"])

    # --- Scatter plots for each property ---
    pairs = [
        ("mass", LIMITS["MW"], "MW (Lipinski)"),
        ("logP", LIMITS["XLogP"], "XLogP (Lipinski)"),
        ("H_acceptors", LIMITS["HBA"], "HBA (Lipinski)"),
        ("H_donors", LIMITS["HBD"], "HBD (Lipinski)"),
        ("TPSA", LIMITS["TPSA"], "TPSA (Veber)"),
        ("rotatable_bonds", LIMITS["RotB"], "Rotatable bonds (Veber)"),
    ]

    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(radar_dir, exist_ok=True)

    for col, thr, label in pairs:
        fname = os.path.join(plot_dir, f"{label.replace(' ', '')}.png")
        strip_scatter_static(ok, col, thr, fname, label)

    # --- Radar charts for all "ok" records ---
    for i, row in ok[ok["status"] == "ok"].reset_index(drop=True).iterrows():
        radar_for_row(row, radar_dir, i)

    return ok