In [1]:
import pandas

In [4]:
# Jupyter-friendly retrosynthesis tree visualizer & summarizer
# 依赖: pandas matplotlib networkx
# pip install pandas matplotlib networkx

import json
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from typing import Dict, Any, List, Tuple
import textwrap
import argparse
import sys

# ---------- 配置（如需改路径在这里改） ----------
DEFAULT_FILES = {
    "Celecoxib": "Celecoxib.json",
    "Ibuprofen": "Ibuprofen.json",
}
DEFAULT_OUTDIR = "./retrosyn_outputs"
DEFAULT_MAXPLOTS = 3

# ---------- 工具函数 ----------
def load_routes(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if isinstance(data, dict):
        if "routes" in data and isinstance(data["routes"], list):
            data = data["routes"]
        else:
            data = [data]
    if not isinstance(data, list):
        raise ValueError(f"Unexpected JSON structure in {path}")
    cleaned = []
    for r in data:
        if "nodes" in r and "edges" in r:
            cleaned.append(r)
    return cleaned

def build_nx_graph(route: Dict[str, Any]) -> nx.DiGraph:
    G = nx.DiGraph()
    for n in route.get("nodes", []):
        G.add_node(n["id"], **n)
    for e in route.get("edges", []):
        frm, to = e.get("from") or e.get("source"), e.get("to") or e.get("target")
        if frm is None or to is None:
            continue
        G.add_edge(frm, to, **{k: v for k, v in e.items() if k not in ("from","to","source","target")})
    return G

def graph_metrics(route: Dict[str, Any]) -> Dict[str, Any]:
    g = route.get("graph", {})
    return {
        "depth": g.get("depth"),
        "precursor_cost": g.get("precursor_cost"),
        "num_reactions": g.get("num_reactions"),
        "atom_economy": g.get("atom_economy"),
        "avg_score": g.get("avg_score"),
        "avg_plausibility": g.get("avg_plausibility"),
        "min_score": g.get("min_score"),
        "min_plausibility": g.get("min_plausibility"),
        "first_step_score": g.get("first_step_score"),
        "first_step_plausibility": g.get("first_step_plausibility"),
        "cluster_id": g.get("cluster_id"),
    }

def annotate_smiles(smiles: str, max_len: int = 28) -> str:
    if not smiles:
        return ""
    s = smiles.replace("\n", " ")
    return textwrap.shorten(s, width=max_len, placeholder="…")

def find_root(route: Dict[str, Any]) -> str:
    zero_id = "00000000-0000-0000-0000-000000000000"
    ids = {n["id"] for n in route["nodes"]}
    if zero_id in ids:
        return zero_id
    G_tmp = nx.DiGraph()
    for e in route["edges"]:
        frm = e.get("from") or e.get("source")
        to = e.get("to") or e.get("target")
        if frm is not None and to is not None:
            G_tmp.add_edge(frm, to)
    candidates = [n["id"] for n in route["nodes"] if G_tmp.in_degree(n["id"]) == 0 and n.get("type") == "chemical"]
    return candidates[0] if candidates else route["nodes"][0]["id"]

def layered_layout(G: nx.DiGraph, root_id: str) -> Dict[Any, Tuple[float, float]]:
    try:
        order = list(nx.topological_sort(G))
    except nx.NetworkXUnfeasible:
        order = list(G.nodes())
    layers: Dict[int, List[Any]] = {}
    for node in order:
        try:
            dist = nx.shortest_path_length(G, source=root_id, target=node)
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            dist = 0
        layers.setdefault(dist, []).append(node)

    pos = {}
    x_gap, y_gap = 2.2, 1.4
    for layer_idx, nodes in sorted(layers.items()):
        for i, n in enumerate(nodes):
            pos[n] = (layer_idx * x_gap, -i * y_gap)
    return pos

def draw_route(route: Dict[str, Any], title: str, outpath: Path) -> None:
    G = build_nx_graph(route)
    if len(G) == 0:
        return
    root = find_root(route)
    pos = layered_layout(G, root)

    chem_nodes = [n for n, d in G.nodes(data=True) if d.get("type") == "chemical"]
    rxn_nodes  = [n for n, d in G.nodes(data=True) if d.get("type") == "reaction"]

    plt.figure(figsize=(10, 6))
    nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='-|>', arrowsize=12)
    nx.draw_networkx_nodes(G, pos, nodelist=chem_nodes, node_shape='o', node_size=400)
    nx.draw_networkx_nodes(G, pos, nodelist=rxn_nodes, node_shape='s', node_size=400)

    labels = {n: ("Rxn" if G.nodes[n].get("type") != "chemical" else annotate_smiles(G.nodes[n].get("smiles","")))
              for n in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)

    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    outpath.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(outpath, dpi=200, bbox_inches="tight")
    plt.close()

def summarize_routes(name: str, routes: List[Dict[str, Any]]) -> pd.DataFrame:
    rows = []
    for idx, r in enumerate(routes, start=1):
        m = graph_metrics(r)
        n_chem = sum(1 for n in r.get("nodes", []) if n.get("type") == "chemical")
        n_rxn  = sum(1 for n in r.get("nodes", []) if n.get("type") == "reaction")
        rows.append({
            "target": name,
            "route": idx,
            "depth": m.get("depth"),
            "num_reactions": m.get("num_reactions"),
            "chem_nodes": n_chem,
            "rxn_nodes": n_rxn,
            "precursor_cost": m.get("precursor_cost"),
            "atom_economy": m.get("atom_economy"),
            "avg_score": m.get("avg_score"),
            "avg_plausibility": m.get("avg_plausibility"),
            "first_step_score": m.get("first_step_score"),
            "first_step_plausibility": m.get("first_step_plausibility"),
        })
    df = pd.DataFrame(rows)
    if not df.empty:
        df = df.sort_values(["target","depth","precursor_cost","route"], na_position="last").reset_index(drop=True)
    return df

def process_targets(files: Dict[str, str], out_dir: Path, max_routes_to_plot: int = 3) -> pd.DataFrame:
    out_dir.mkdir(parents=True, exist_ok=True)
    all_summaries = []
    for name, path in files.items():
        p = Path(path)
        if not p.exists():
            print(f"[WARN] File not found: {p}", file=sys.stderr)
            continue
        routes = load_routes(str(p))
        if not routes:
            print(f"[WARN] No routes parsed: {p}", file=sys.stderr)
            continue

        summary_df = summarize_routes(name, routes)
        all_summaries.append(summary_df)

        to_plot = min(max_routes_to_plot, len(routes))
        for i in range(to_plot):
            img_path = out_dir / f"{name}_route_{i+1}.png"
            draw_route(routes[i], f"{name} — Route {i+1}", img_path)
            print(f"[OK] Saved {img_path}")

    if all_summaries:
        combined = pd.concat(all_summaries, ignore_index=True)
    else:
        combined = pd.DataFrame()
    return combined

# ---------- 两种运行方式 ----------
def cli_entry():
    """命令行入口：在 Jupyter 也能用（忽略未知参数）"""
    parser = argparse.ArgumentParser(description="Retrosynthesis visualizer", add_help=True)
    parser.add_argument("--celecoxib", type=str, default=DEFAULT_FILES["Celecoxib"])
    parser.add_argument("--ibuprofen", type=str, default=DEFAULT_FILES["Ibuprofen"])
    parser.add_argument("--outdir", type=str, default=DEFAULT_OUTDIR)
    parser.add_argument("--max_plots", type=int, default=DEFAULT_MAXPLOTS)
    # 关键：忽略 Notebook 注入的未知参数（如 --f=...）
    args, _unknown = parser.parse_known_args()

    files = {"Celecoxib": args.celecoxib, "Ibuprofen": args.ibuprofen}
    out_dir = Path(args.outdir)
    combined = process_targets(files, out_dir, max_routes_to_plot=args.max_plots)
    if not combined.empty:
        csv_path = out_dir / "route_summaries.csv"
        combined.to_csv(csv_path, index=False)
        from IPython.display import display
        display(combined.head(20))
        print(f"[OK] Wrote {csv_path}")
    else:
        print("[WARN] No summaries produced.")

def notebook_run(
    celecoxib_path: str = DEFAULT_FILES["Celecoxib"],
    ibuprofen_path: str = DEFAULT_FILES["Ibuprofen"],
    outdir: str = DEFAULT_OUTDIR,
    max_plots: int = DEFAULT_MAXPLOTS,
):
    """推荐在 Notebook 直接调用，不使用命令行参数。"""
    files = {"Celecoxib": celecoxib_path, "Ibuprofen": ibuprofen_path}
    out_dir = Path(outdir)
    combined = process_targets(files, out_dir, max_routes_to_plot=max_plots)
    if not combined.empty:
        csv_path = out_dir / "route_summaries.csv"
        combined.to_csv(csv_path, index=False)
        from IPython.display import display
        display(combined)
        print(f"[OK] Wrote {csv_path}")
    else:
        print("[WARN] No summaries produced.")
    return combined

# 在 Notebook 里，直接运行单元就会执行 CLI 入口（但不会被 --f=... 弄崩）
if __name__ == "__main__":
    cli_entry()



[OK] Saved retrosyn_outputs/Celecoxib_route_1.png
[OK] Saved retrosyn_outputs/Celecoxib_route_2.png
[OK] Saved retrosyn_outputs/Celecoxib_route_3.png
[OK] Saved retrosyn_outputs/Ibuprofen_route_1.png
[OK] Saved retrosyn_outputs/Ibuprofen_route_2.png
[OK] Saved retrosyn_outputs/Ibuprofen_route_3.png


Unnamed: 0,target,route,depth,num_reactions,chem_nodes,rxn_nodes,precursor_cost,atom_economy,avg_score,avg_plausibility,first_step_score,first_step_plausibility
0,Celecoxib,2,1,1,3,1,35.7,0.913638,0.259821,0.999989,0.259821,0.999989
1,Celecoxib,3,1,1,3,1,96.0,0.826625,0.001641,0.999985,0.001641,0.999985
2,Celecoxib,1,1,1,3,1,99.0,0.892244,2.3e-05,0.999997,2.3e-05,0.999997
3,Celecoxib,4,1,1,3,1,109.19,0.748693,0.000211,0.999942,0.000211,0.999942
4,Celecoxib,8,2,2,5,2,3.0,0.696746,0.02088,0.99237,0.040049,0.984803
5,Celecoxib,5,2,2,5,2,6.0,0.74279,0.020033,0.992401,0.040049,0.984803
6,Celecoxib,7,2,2,5,2,16.19,0.640547,0.020206,0.9924,0.040049,0.984803
7,Celecoxib,6,2,2,5,2,61.0,0.757559,0.195676,0.992401,0.040049,0.984803
8,Celecoxib,24,3,3,7,3,2.51,0.638384,0.015422,0.994913,0.040049,0.984803
9,Celecoxib,16,3,3,7,3,2.66,0.60402,0.013388,0.994925,0.040049,0.984803


[OK] Wrote retrosyn_outputs/route_summaries.csv
