In [None]:
# -*- coding: utf-8 -*-
# /root/autodl-tmp/DeepFRI/03_run_galaxy/002_prepare_npz.ipynb
"""
Jupyter-ready: build DeepFRI NPZ from PDB/mmCIF (with .gz)
- 输入目录: IN_DIR
- 输出目录: OUT_DIR
- 只保存接触图(contact)可省大量空间
- 已存在的 npz 会跳过，支持断点续跑
"""

import os, sys, gzip, json, time, shutil
from pathlib import Path
from multiprocessing import Pool, cpu_count
import numpy as np

# ===================== 配置区（按需修改） =====================
# IN_DIR  = Path("/root/autodl-tmp/DeepFRI/03_run_galaxy/02_pdb_over70")       # 你的PDB源目录
IN_DIR  = Path("/root/autodl-tmp/DeepFRI/03_run_galaxy/01_pdb_galaxy_4deepPF")
OUT_DIR = Path("/root/autodl-tmp/DeepFRI/03_run_galaxy/02_pdb_galaxy_npz_fix")  # 输出NPZ目录
OUT_DIR.mkdir(parents=True, exist_ok=True)

WORKERS        = min(48, max(1, cpu_count() - 1))  # 并行数，别太大
CONTACT_CUTOFF = 8.0                               # 接触图阈值(Å)
MODE           = "contact"                         # "contact"|"dist"|"both"
DIST_DTYPE     = "float32"                         # "float32"|"float16"
MIN_FREE_GB    = 10.0                              # 输出盘至少要有这么多GB，否则停

# ===============================================================
AA3_TO_1 = {
    'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLU':'E','GLN':'Q','GLY':'G',
    'HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','PHE':'F','PRO':'P','SER':'S',
    'THR':'T','TRP':'W','TYR':'Y','VAL':'V'
}
SUFFIXES = {".pdb", ".cif", ".pdb.gz", ".cif.gz"}

# 尝试选择解析后端
try:
    import gemmi
    BACKEND = "gemmi"
except Exception:
    try:
        from Bio.PDB import MMCIFParser, PDBParser, is_aa
        BACKEND = "biopython"
    except Exception:
        BACKEND = None

# ===================== 工具函数 =====================
def norm_acc_from_filename(p: Path) -> str:
    """根据文件名生成ACC，去扩展名，去 __AF/__1ABC_A 这类后缀"""
    stem = p.name
    if stem.endswith(".gz"):
        stem = stem[:-3]
    for suf in (".pdb", ".cif"):
        if stem.endswith(suf):
            stem = stem[:-len(suf)]
            break
    if "__" in stem:
        stem = stem.split("__")[0]
    return stem

def read_text_auto(p: Path) -> str:
    """自动识别.gz文本读取"""
    if str(p).endswith(".gz"):
        with gzip.open(p, "rt", encoding="utf-8", errors="ignore") as f:
            return f.read()
    else:
        return p.read_text(encoding="utf-8", errors="ignore")

def parse_with_gemmi(p: Path):
    """用gemmi解析，取最长的标准氨基酸链，返回 (chain, seq, coords(N,3), res_ids)"""
    doc = gemmi.read_structure(str(p))
    best = None
    for model in doc:
        for chain in model:
            seq, coords, res_ids = [], [], []
            for res in chain:
                if not res.is_polymer():
                    continue
                name3 = res.name.upper().strip()
                if name3 not in AA3_TO_1:
                    continue
                ca = res.find_atom("CA", altloc='?')
                if ca is None:
                    continue
                seq.append(AA3_TO_1[name3])
                pos = ca.pos
                coords.append([pos.x, pos.y, pos.z])
                rid = f"{chain.name}:{res.seqid.num}:{res.seqid.icode or ''}"
                res_ids.append(rid)
            if seq:
                cand = (chain.name, "".join(seq), np.asarray(coords, dtype=np.float32), res_ids)
                if best is None or len(cand[1]) > len(best[1]):
                    best = cand
    return best

def parse_with_biopython(p: Path):
    """biopython作为后备解析器"""
    from Bio.PDB import MMCIFParser, PDBParser, is_aa
    if str(p).endswith(".cif") or str(p).endswith(".cif.gz"):
        parser = MMCIFParser(QUIET=True)
    else:
        parser = PDBParser(QUIET=True)
    if str(p).endswith(".gz"):
        from io import StringIO
        txt = read_text_auto(p)
        handle = StringIO(txt)
        structure = parser.get_structure("S", handle)
    else:
        structure = parser.get_structure("S", str(p))
    best = None
    for model in structure:
        for chain in model:
            seq, coords, res_ids = [], [], []
            for res in chain:
                if not is_aa(res, standard=True):
                    continue
                ca = res["CA"] if "CA" in res else None
                if ca is None:
                    continue
                name3 = res.get_resname().upper().strip()
                if name3 not in AA3_TO_1:
                    continue
                seq.append(AA3_TO_1[name3])
                c = ca.get_coord()
                coords.append([float(c[0]), float(c[1]), float(c[2])])
                icode = res.id[2] if len(res.id) > 2 else ""
                resnum = res.id[1]
                rid = f"{chain.id}:{resnum}:{icode or ''}"
                res_ids.append(rid)
            if seq:
                cand = (chain.id, "".join(seq), np.asarray(coords, dtype=np.float32), res_ids)
                if best is None or len(cand[1]) > len(best[1]):
                    best = cand
    return best

def coords_to_dist_and_contacts(coords: np.ndarray, cutoff: float, dist_dtype: str):
    """由Cα坐标得到距离矩阵和接触图"""
    if coords.shape[0] == 0:
        return None, None
    x = coords.astype(np.float32)
    xx = np.sum(x*x, axis=1, keepdims=True)
    dist2 = xx + xx.T - 2.0 * (x @ x.T)
    np.maximum(dist2, 0.0, out=dist2)
    dist = np.sqrt(dist2, dtype=np.float32)
    np.fill_diagonal(dist, 0.0)
    contact_map = (dist <= cutoff).astype(np.uint8)
    np.fill_diagonal(contact_map, 0)
    if dist_dtype == "float16":
        dist = dist.astype(np.float16)
    return dist, contact_map

def disk_free_gb(path: Path) -> float:
    """返回path所在分区的可用GB"""
    usage = shutil.disk_usage(path)
    return usage.free / (1024**3)

def find_all_inputs(root: Path):
    files = []
    for suf in SUFFIXES:
        files.extend(root.rglob(f"*{suf}"))
    return sorted([p for p in files if p.is_file() and not p.name.startswith(".")])

# 为了在Jupyter里也能并行，这里用全局变量注入
GLOBAL_CFG = {}
def init_worker(cfg):
    GLOBAL_CFG.update(cfg)

def process_one(p: Path):
    out_dir   = GLOBAL_CFG["out_dir"]
    cutoff    = GLOBAL_CFG["cutoff"]
    mode      = GLOBAL_CFG["mode"]
    dist_dtype= GLOBAL_CFG["dist_dtype"]
    backend   = GLOBAL_CFG["backend"]
    min_free  = GLOBAL_CFG["min_free_gb"]

    try:
        # 磁盘不足直接停止
        if disk_free_gb(out_dir) < min_free:
            return {"acc": None, "npz_path": None, "n_res": 0, "ok": False, "msg": "low_disk_stop"}

        acc = norm_acc_from_filename(p)
        out_npz = out_dir / f"{acc}.npz"
        if out_npz.exists():
            return {"acc": acc, "npz_path": str(out_npz.resolve()), "n_res": -1, "ok": True, "msg": "exists_skip"}

        if backend == "gemmi":
            parsed = parse_with_gemmi(p)
        else:
            parsed = parse_with_biopython(p)

        if parsed is None:
            return {"acc": acc, "npz_path": None, "n_res": 0, "ok": False, "msg": "no_valid_chain"}

        chain_id, seq, coords, res_ids = parsed
        if len(seq) < 2:
            return {"acc": acc, "npz_path": None, "n_res": len(seq), "ok": False, "msg": "too_short"}

        dist, contact_map = coords_to_dist_and_contacts(coords, cutoff, dist_dtype)

        arrays = {
            "seq": np.array(seq),
            "res_ids": np.array(res_ids, dtype=object),
            "chain_id": np.array(chain_id),
        }
        if mode in ("dist","both"):
            arrays["dist"] = dist
        if mode in ("contact","both"):
            arrays["contact_map"] = contact_map

        tmp_npz = out_npz.with_suffix(".npz.tmp")
        np.savez_compressed(tmp_npz, **arrays)
        os.replace(tmp_npz, out_npz)

        return {"acc": acc, "npz_path": str(out_npz.resolve()), "n_res": len(seq), "ok": True, "msg": "ok"}

    except OSError as e:
        if "No space left on device" in str(e):
            return {"acc": None, "npz_path": None, "n_res": 0, "ok": False, "msg": "enospace"}
        return {"acc": None, "npz_path": None, "n_res": 0, "ok": False, "msg": f"oserror:{e}"}
    except Exception as e:
        return {"acc": None, "npz_path": None, "n_res": 0, "ok": False, "msg": f"error:{e}"}

# ===================== 主运行 =====================
if BACKEND is None:
    print("[ERROR] gemmi/biopython 都不可用，请先安装其一")
else:
    files = find_all_inputs(IN_DIR)
    print(f"[Config] BACKEND={BACKEND}")
    print(f"[Config] IN_DIR={IN_DIR}")
    print(f"[Config] OUT_DIR={OUT_DIR}")
    print(f"[Scan] found candidate structures: {len(files)}")
    print(f"[Parallel] workers={WORKERS}, mode={MODE}, dist_dtype={DIST_DTYPE}, min_free_gb={MIN_FREE_GB}")

    cfg = {
        "out_dir": OUT_DIR,
        "cutoff": CONTACT_CUTOFF,
        "mode": MODE,
        "dist_dtype": DIST_DTYPE,
        "backend": BACKEND,
        "min_free_gb": MIN_FREE_GB,
    }

    t0 = time.time()
    records = []
    stop = False

    with Pool(processes=WORKERS, initializer=init_worker, initargs=(cfg,)) as pool:
        for i, rec in enumerate(pool.imap_unordered(process_one, files, chunksize=16), 1):
            records.append(rec)
            if rec.get("msg") in ("low_disk_stop", "enospace"):
                stop = True
                break
            if i % 200 == 0:
                ok_new = sum(1 for r in records if r.get("ok") and r.get("msg") == "ok")
                exist  = sum(1 for r in records if r.get("ok") and r.get("msg") == "exists_skip")
                fail   = sum(1 for r in records if not r.get("ok"))
                print(f"[Progress] done={i} | ok_new={ok_new} | exists={exist} | fail={fail}")
        if stop:
            pool.terminate()
        pool.join()

    ok_new   = [r for r in records if r.get("ok") and r.get("msg") == "ok"]
    ok_exist = [r for r in records if r.get("ok") and r.get("msg") == "exists_skip"]
    fail     = [r for r in records if not r.get("ok")]
    dt = time.time() - t0

    # 写结构映射
    map_tsv = OUT_DIR / "structure_map_all.tsv"
    with open(map_tsv, "w", encoding="utf-8") as fo:
        fo.write("acc\tnpz_path\n")
        for r in ok_new + ok_exist:
            if r.get("acc") and r.get("npz_path"):
                fo.write(f"{r['acc']}\t{r['npz_path']}\n")

    # 写报告
    rep_json = OUT_DIR / "npz_build_report.json"
    report = {
        "backend": BACKEND,
        "in_dir": str(IN_DIR.resolve()),
        "out_dir": str(OUT_DIR.resolve()),
        "mode": MODE,
        "dist_dtype": DIST_DTYPE,
        "contact_cutoff_A": CONTACT_CUTOFF,
        "n_files": len(files),
        "n_ok_new": len(ok_new),
        "n_ok_exists": len(ok_exist),
        "n_fail": len(fail),
        "time_sec": round(dt, 2),
        "stop_due_to_low_disk": stop,
        "fail_examples": fail[:10],
    }
    with open(rep_json, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2, ensure_ascii=False)

    print(f"[Done] new={len(ok_new)} | exists={len(ok_exist)} | fail={len(fail)} | elapsed={round(dt,1)}s")
    if stop:
        print("[HINT] 触发低磁盘保护：清理空间后重新运行此cell即可，已完成的npz会自动跳过。")
    print(f"[Saved] map -> {map_tsv}")
    print(f"[Saved] report -> {rep_json}")
