In [1]:
import os, warnings
os.environ["PYTHONWARNINGS"] = "ignore:pkg_resources is deprecated as an API.*:UserWarning:prody.utilities.misctools"
warnings.filterwarnings("ignore", message=r"pkg_resources is deprecated as an API.*", category=UserWarning)

In [4]:
# ===== Fast batch docking in Jupyter: process pool + per-process Vina cache =====
import os, glob, time, pandas as pd, numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
import contextlib, sys
from vina import Vina

# ---- your setup ----
RECEPTOR   = "1H1Q_receptorH.pdbqt"
BOX_TXT    = "1H1Q_receptorH.box.txt"
LIG_GLOB   = "pdbqt_ligands_filtered/*.pdbqt"
OUTDIR     = "docked_out_api"
EXH        = 16
NUM_POSES  = 5
SEED       = 42
PROCESSES  = min(32, max(1, (os.cpu_count() or 4) - 1))   # processes
SCORING    = "vina"                               # or "ad4"
MAX_LIGS   = 10                                   # LIMIT to first N ligands

os.makedirs(OUTDIR, exist_ok=True)

def read_box_txt(path):
    cx=cy=cz=sx=sy=sz=None
    with open(path) as fh:
        for line in fh:
            t=line.strip().replace(","," ").replace("="," ")
            p=t.split(); lo=[w.lower() for w in p]
            if not p: continue
            if "center_x" in lo: cx=float(p[-1])
            elif "center_y" in lo: cy=float(p[-1])
            elif "center_z" in lo: cz=float(p[-1])
            elif "size_x"   in lo: sx=float(p[-1])
            elif "size_y"   in lo: sy=float(p[-1])
            elif "size_z"   in lo: sz=float(p[-1])
            elif lo[0]=="center" and len(p)>=4: cx,cy,cz = map(float, p[-3:])
            elif lo[0]=="size"   and len(p)>=4: sx,sy,sz = map(float, p[-3:])
    assert None not in (cx,cy,cz,sx,sy,sz), f"Bad box file: {path}"
    return (cx,cy,cz), (sx,sy,sz)

CENTER, SIZE = read_box_txt(BOX_TXT)

def get_best_energy(v: Vina, n_poses: int):
    """Version-proof extraction of top docking score."""
    res = v.energies(n_poses=n_poses)
    energies = res[0] if isinstance(res, (list, tuple)) and len(res)>=1 else []
    try:
        return float(energies[0][0]) if len(energies) else float("nan")
    except Exception:
        return float("nan")

# --- per-process cache (lives inside each worker process) ---
_global_vina = None
_global_cfg  = None

def _dock_worker(lig_path, receptor, center, size, scoring, seed, exhaus, nposes, outdir):
    # Limit native thread oversubscription per worker
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["NUMEXPR_NUM_THREADS"] = "1"

    global _global_vina, _global_cfg
    try:
        name = os.path.splitext(os.path.basename(lig_path))[0]
        out_pdbqt = os.path.join(outdir, f"{name}_dock.pdbqt")
        if os.path.exists(out_pdbqt) and os.path.getsize(out_pdbqt) > 0:
            return {"CID": name, "status": "skipped", "out_pdbqt": out_pdbqt, "best_kcal_mol": np.nan, "error": ""}

        # Create/reuse a Vina instance with precomputed maps in THIS process
        cfg = (receptor, center, size, scoring, seed)
        if _global_vina is None or _global_cfg != cfg:
            v = Vina(sf_name=scoring, seed=seed)
            v.set_receptor(receptor)                  # positional = version-safe
            v.compute_vina_maps(center=center, box_size=size)
            _global_vina = v
            _global_cfg  = cfg

        v = _global_vina
        v.set_ligand_from_file(lig_path)

        # (optional) silence Vina’s progress bar per worker
        with open(os.devnull, "w") as _null, contextlib.redirect_stdout(_null), contextlib.redirect_stderr(_null):
            v.dock(exhaustiveness=exhaus, n_poses=nposes)
        v.write_poses(out_pdbqt, n_poses=nposes, overwrite=True)

        best = get_best_energy(v, n_poses=nposes)

        return {"CID": name, "status": "ok", "out_pdbqt": out_pdbqt, "best_kcal_mol": best, "error": ""}

    except Exception as e:
        return {"CID": os.path.basename(lig_path), "status": "error", "out_pdbqt": "", "best_kcal_mol": np.nan,
                "error": repr(e)[:400]}
# --- build ligand list (and actually LIMIT it) ---
ligs_all = sorted(glob.glob(LIG_GLOB))
assert ligs_all, f"No ligands matched: {LIG_GLOB}"
ligs = ligs_all[:MAX_LIGS]   # <- TRUE LIMIT
print(f"Docking {len(ligs)} ligands (of {len(ligs_all)} total). Using {PROCESSES} processes.")

# --- smoke test on first ligand (in-proc, no pool) ---
print("Smoke test on:", os.path.basename(ligs[0]))
print(_dock_worker(ligs[0], RECEPTOR, CENTER, SIZE, SCORING, SEED, EXH, NUM_POSES, OUTDIR))

# --- process pool run ---
rows = []
t0 = time.time()
with ProcessPoolExecutor(max_workers=PROCESSES) as ex:
    futs = [ex.submit(_dock_worker, p, RECEPTOR, CENTER, SIZE, SCORING, SEED, EXH, NUM_POSES, OUTDIR)
            for p in ligs]
    for i, f in enumerate(as_completed(futs), 1):
        res = f.result()
        rows.append(res)
        if i % 5 == 0:
            oks = sum(r["status"].startswith("ok") for r in rows)
            errs = sum(r["status"]=="error" for r in rows)
            print(f"[pool] {i}/{len(ligs)}  OK {oks}  Err {errs}", flush=True)

dt = time.time() - t0
print(f"Done {len(rows)} ligands in {dt:.1f}s  (~{len(rows)/max(dt,1):.2f} lig/s)")

# --- write manifests/scores ---
manifest = pd.DataFrame(rows, columns=["CID","status","out_pdbqt","best_kcal_mol","error"])
manifest.to_csv(os.path.join(OUTDIR, "_manifest_api.csv"), index=False)

scores = (manifest.dropna(subset=["best_kcal_mol"])
                  .sort_values("best_kcal_mol"))
scores.to_csv(os.path.join(OUTDIR, "_scores_api.csv"), index=False)
scores.head(10)


Docking 10 ligands (of 15687 total). Using 32 processes.
Smoke test on: 118704754.pdbqt
{'CID': '118704754', 'status': 'skipped', 'out_pdbqt': 'docked_out_api/118704754_dock.pdbqt', 'best_kcal_mol': nan, 'error': ''}
[pool] 5/10  OK 0  Err 0
[pool] 10/10  OK 0  Err 0
Done 10 ligands in 0.4s  (~10.00 lig/s)


Unnamed: 0,CID,status,out_pdbqt,best_kcal_mol,error
