# Inspect RESP Solver Failures

Use this notebook to diagnose configurations where the batch RESP aggregation flagged non-convergence. Adjust the solver settings and compare against TeraChem reference charges frame by frame.

In [1]:
from __future__ import annotations

from pathlib import Path
from typing import List, Sequence, Tuple
import tempfile

import matplotlib.pyplot as plt
import numpy as np

from biliresp.resp_parser import ParseRespDotOut
from biliresp.resp.resp import HyperbolicRestraint, fit_resp_charges, plot_loss_history
from biliresp.linearESPcharges import explicit_solution, prepare_linear_system

np.set_printoptions(precision=6, suppress=True)

In [2]:
def parse_pdb_atoms(pdb_path: Path) -> Tuple[List[str], List[str], List[Tuple[float, float, float]]]:
    labels: List[str] = []
    symbols: List[str] = []
    coords: List[Tuple[float, float, float]] = []
    with pdb_path.open() as handle:
        for line in handle:
            if not line.startswith(("ATOM", "HETATM")):
                continue
            labels.append(line[12:16].strip())
            symbol = line[76:78].strip()
            if not symbol:
                atom_field = "".join(ch for ch in line[12:16] if ch.isalpha())
                if not atom_field:
                    raise ValueError(f"Cannot infer element symbol from PDB line: {line.rstrip()}")
                symbol = atom_field
            symbol = symbol[0].upper() + symbol[1:].lower()
            x = float(line[30:38])
            y = float(line[38:46])
            z = float(line[46:54])
            coords.append((x, y, z))
            symbols.append(symbol)
    if not labels:
        raise ValueError(f"No ATOM/HETATM entries found in {pdb_path}")
    return labels, symbols, coords


def write_xyz(symbols: Sequence[str], coords: Sequence[Tuple[float, float, float]], path: Path) -> None:
    with path.open("w") as handle:
        handle.write(f"{len(symbols)}\n")
        handle.write("Generated from PDB for RESP inspection\n")
        for symbol, (x, y, z) in zip(symbols, coords):
            handle.write(f"{symbol:<2s} {x:16.8f} {y:16.8f} {z:16.8f}\n")


def resolve_frame_index(frame_index: int, total: int) -> int:
    if total == 0:
        raise ValueError("No frames available in file")
    idx = frame_index if frame_index >= 0 else total + frame_index
    if idx < 0 or idx >= total:
        raise IndexError(f"frame_index {frame_index} out of range for {total} frames")
    return idx


def load_reference_resp(resp_path: Path, number_of_atoms: int, frame_index: int):
    frames = ParseRespDotOut(resp_path, number_of_atoms).extract_frames()
    if not frames:
        return None, None
    idx = resolve_frame_index(frame_index, len(frames))
    frame = frames[idx]
    resp = np.asarray(frame.resp_charges, dtype=float) if frame.resp_charges else None
    esp = np.asarray(frame.esp_charges, dtype=float) if frame.esp_charges else None
    return resp, esp

In [3]:
INPUT_ROOT = Path("../input/microstates/APP")
RESULTS_ROOT = Path("../output/APP")
MICROSTATE_ROOT = INPUT_ROOT
NPZ_PATH = RESULTS_ROOT / "onestepRESP" / "charges.npz"

bundle = np.load(NPZ_PATH, allow_pickle=True)
configs = bundle["configs"].astype(str)
converged = bundle["converged"]
failed_configs = [cfg for cfg, ok in zip(configs, converged) if not ok]

pdb_candidates = sorted(MICROSTATE_ROOT.glob("*.pdb"))
if not pdb_candidates:
    raise FileNotFoundError(f"No PDB found under {MICROSTATE_ROOT}")
PDB_PATH = pdb_candidates[0]

labels, symbols, coords = parse_pdb_atoms(PDB_PATH)
N_ATOMS = len(labels)

print(f"Microstate root: {MICROSTATE_ROOT}")
print(f"PDB: {PDB_PATH} with {N_ATOMS} atoms")
print(f"Failed configurations ({len(failed_configs)}): {failed_configs}")


Microstate root: ../input/microstates/APP
PDB: ../input/microstates/APP/app.pdb with 77 atoms
Failed configurations (0): []


In [4]:
def inspect_configuration(
    stem: str,
    *,
    frame_index: int = -1,
    grid_frame: int = 0,
    solver_tol: float = 1e-11,
    maxiter: int = 100,
    a: float = 5.0e-4,
    b: float = 1.0e-3,
    q0: float = 0.0,
    restrain_all_atoms: bool = True,
    show_loss: bool = True,
):
    resp_path = MICROSTATE_ROOT / "terachem" / "respout" / f"{stem}.resp.out"
    esp_path = MICROSTATE_ROOT / "terachem" / "espxyz" / f"{stem}.esp.xyz"
    if not resp_path.exists():
        raise FileNotFoundError(f"Missing resp output: {resp_path}")
    if not esp_path.exists():
        raise FileNotFoundError(f"Missing esp grid: {esp_path}")

    restraint = HyperbolicRestraint(a=a, b=b, q0=q0)

    with tempfile.TemporaryDirectory() as tmpdir:
        geometry_path = Path(tmpdir) / "geometry.xyz"
        write_xyz(symbols, coords, geometry_path)
        try:
            result = fit_resp_charges(
                resp_path,
                esp_path,
                geometry_path,
                N_ATOMS,
                frame_index=frame_index,
                grid_frame_index=grid_frame,
                restraint=restraint,
                solver_tol=solver_tol,
                maxiter=maxiter,
                restrain_all_atoms=restrain_all_atoms,
            )
            print("RESP solver converged.")
        except RuntimeError as exc:
            print(f"RESP solver failed: {exc}")
            return None

    A, V, total_charge, _ = prepare_linear_system(
        resp_path,
        esp_path,
        N_ATOMS,
        frame_index=frame_index,
        grid_frame_index=grid_frame,
    )
    linear_solver = explicit_solution()
    linear_result = linear_solver.fit(A, V, total_charge)

    ref_resp, ref_esp = load_reference_resp(resp_path, N_ATOMS, frame_index)

    print(f"RMSE: {result['rmse']:.6e} | RRMS: {result['rrms']:.6e}")
    print(f"Sum q: {result['sum_q']:.6f} | Target: {result['target_total_charge']:.6f}")
    print(f"λ (KKT): {result['lagrange_multiplier']:.6e}")
    print(
        "First 5 RESP charges:",
        np.array2string(result["charges"][:5], precision=6, separator=", "),
    )
    print(
        "First 5 initial (linear) charges:",
        np.array2string(result["initial_charges"][:5], precision=6, separator=", "),
    )
    if ref_resp is not None:
        diff = result["charges"] - ref_resp
        print(f"Max |Δ| to TeraChem RESP: {np.max(np.abs(diff)):.3e}")
        print(
            "First 5 TeraChem RESP charges:",
            np.array2string(ref_resp[:5], precision=6, separator=", "),
        )
    else:
        print("Terachem RESP charges not available in resp.out.")
    if ref_esp is not None:
        diff_lin = linear_result["q"] - ref_esp
        print(f"Max |Δ| linear vs TeraChem ESP: {np.max(np.abs(diff_lin)):.3e}")

    history = result.get("loss_history", [])
    if show_loss and history:
        ax = plot_loss_history(history, show=False)
        ax.figure.suptitle(f"Loss history for {stem}")
        plt.show()

    return {
        "charges": result["charges"],
        "initial_charges": result["initial_charges"],
        "linear_charges": linear_result["q"],
        "reference_resp": ref_resp,
        "reference_esp": ref_esp,
        "loss_history": history,
    }

In [5]:
failed_configs = [
    'conf0154',
    'conf0380',
    'conf0709',
    'conf0846',
    'conf1033',
    'conf1072',
    'conf1127',
    'conf1243',
    'conf1295',
    'conf1305',
    'conf1480',
    'conf1511',
    'conf1554',
    'conf1678',
    'conf1818',
    'conf2173',
    'conf2229'
]

In [6]:
import io
from contextlib import redirect_stdout

results = {}

for fail in failed_configs:
    buf = io.StringIO()
    with redirect_stdout(buf):
        result = inspect_configuration(
            fail,
            solver_tol=1e-9,
            maxiter=300,
            show_loss=False,  # avoid a flood of plots
        )
    status = "converged" if result is not None else "failed"
    print(f"{fail}: {status}")
    results[fail] = result

conf0154: converged


conf0380: converged


conf0709: converged


conf0846: converged


conf1033: converged


conf1072: converged


conf1127: converged


conf1243: converged


conf1295: converged


conf1305: converged


conf1480: converged


conf1511: converged


conf1554: converged


conf1678: converged


conf1818: converged


conf2173: converged


conf2229: converged
