Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 90 additions & 21 deletions scripts/patch_input_cif_files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# utility script to put all header information from original PDB entry into our CIF files
import fnmatch
import re
import shutil
from argparse import ArgumentParser
from pathlib import Path

Expand Down Expand Up @@ -43,6 +42,7 @@ def _crawl(current: Path, levels_left: int) -> None:
for entry in current.iterdir():
if entry.is_file():
if fnmatch.fnmatch(entry.name, target_pattern):
logger.info(f"Found matching file: {entry}")
results.append(entry)
elif entry.is_dir() and levels_left > 0:
_crawl(entry, levels_left - 1)
Expand All @@ -57,7 +57,11 @@ def _crawl(current: Path, levels_left: int) -> None:
def parse_args():
parser = ArgumentParser()
parser.add_argument("--input-dir", required=True)
parser.add_argument("--cif-pattern", default="refined.cif")
parser.add_argument(
"--cif-pattern",
default="refined.cif",
help="Pattern used by fnmatch/glob for cif files to patch, default: 'refined.cif'",
)
parser.add_argument(
"--rcsb-pattern",
default="grid_search_results/(.{4})",
Expand All @@ -67,53 +71,100 @@ def parse_args():
parser.add_argument(
"--depth", type=int, default=4, help="Depth to search the directory tree below input-dir"
)
parser.add_argument("--grid-search-input-dir", required=True)
parser.add_argument(
"--input-pdb-pattern", default="{pdb_id}/{pdb_id}_single_001_density_input.cif"
)
args = parser.parse_args()
return args


def main(
input_dir: str | Path,
grid_search_input_dir: str | Path,
target_pattern: str,
rcsb_regex: str = r"grid_search_results/(.{4})",
depth: int = 4,
input_pdb_pattern: str = "{pdb_id}/{pdb_id}_single_001_density_input.cif",
) -> None:
# make sure the cache exists
SAMPLEWORKS_CACHE.mkdir(parents=True, exist_ok=True)

cif_files_to_patch = crawl_dir_by_depth(input_dir, target_pattern, n_levels=depth)
results = joblib.Parallel()(
joblib.delayed(patch_individual_cif_file)(f, rcsb_regex) for f in cif_files_to_patch
joblib.delayed(patch_individual_cif_file)(
f, rcsb_regex, Path(grid_search_input_dir), input_pdb_pattern
)
for f in cif_files_to_patch
)
results = [r for r in results if r]
if results:
logger.warning("The following files could not be patched:")
logger.error("The following errors occurred:")
for r in results:
print(r)


def patch_individual_cif_file(cif_file: Path, rcsb_regex: str):
def patch_individual_cif_file(
cif_file: Path, rcsb_regex: str, reference_dir: Path, input_pdb_pattern: str
) -> str | None: # returns an error message if there was one
cif_path = Path(cif_file)
m = re.search(rcsb_regex, str(cif_path))
rcsb_id = m.group(1) if m else None
if not m:
logger.warning(
msg = (
f"Unable to parse an RCSB structure id: from path {cif_file} with pattern {rcsb_regex}"
)
return cif_file

# write a backup version of the input cif file
shutil.copy(cif_path, cif_path.parent / (cif_path.name + ".bak"))

# fetch only downloads the file if it isn't already present.
rcsb_path = fetch(rcsb_id, format="cif", target_path=str(SAMPLEWORKS_CACHE))

# load the copy, and the new coordinates for it.
logger.warning(msg)
return msg

# Get the offset for residue numbering in the reference structure
try:
reference_path = reference_dir / input_pdb_pattern.format(pdb_id=rcsb_id)
# fetch only downloads the file if it isn't already present.
rcsb_path = fetch(rcsb_id, format="cif", target_path=str(SAMPLEWORKS_CACHE))

reference = load_any(reference_path)
asym_unit = load_any(cif_file)
asym_unit = ensure_atom_array_stack(asym_unit)
except Exception as e:
msg = f"Unable to read and parse either/both of {reference_path}, {cif_file}"
logger.warning(msg)
return msg
Comment on lines +121 to +132
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard the exception path against an unbound reference_path.

If Line 122 fails during input_pdb_pattern.format(...), Line 130 can raise UnboundLocalError, which breaks the intended “return error message” flow.

Proposed fix
-    # Get the offset for residue numbering in the reference structure    
-    try:
-        reference_path = reference_dir / input_pdb_pattern.format(pdb_id=rcsb_id)
+    # Get the offset for residue numbering in the reference structure
+    reference_path: Path | None = None
+    try:
+        reference_path = reference_dir / input_pdb_pattern.format(pdb_id=rcsb_id)
         # fetch only downloads the file if it isn't already present.
         rcsb_path = fetch(rcsb_id, format="cif", target_path=str(SAMPLEWORKS_CACHE))
 
         reference = load_any(reference_path)
         asym_unit = load_any(cif_file)
         asym_unit = ensure_atom_array_stack(asym_unit)
-    except Exception as e:
-        msg = f"Unable to read and parse either/both of {reference_path}, {cif_file}"
+    except Exception as exc:
+        msg = (
+            f"Unable to fetch/read one of rcsb_id={rcsb_id}, "
+            f"reference_path={reference_path}, cif_file={cif_file}: {exc}"
+        )
         logger.warning(msg)
         return msg
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 129-129: Do not catch blind exception: Exception

(BLE001)


[error] 129-129: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/patch_input_cif_files.py` around lines 121 - 132, The except block
can reference reference_path that may be unassigned if
input_pdb_pattern.format(...) raises; initialize or build the path safely before
the try or ensure the except uses a safe fallback. Specifically, declare
reference_path = None (or compute a safe reference_str) before the try that
contains input_pdb_pattern.format(...), then inside the except construct the msg
using that safe variable (or include the caught exception e) instead of directly
interpolating reference_path; adjust references around functions fetch,
load_any, cif_file, and ensure_atom_array_stack accordingly so the error path
never raises UnboundLocalError when logger.warning(msg) or return msg executes.


# get the unique residue numbers for each file
if reference.res_id is None or asym_unit.res_id is None:
msg = f"Residue numbers for {cif_path} and/or {reference_path} are missing."
logger.error(msg)
return msg

# CodeRabbit improved this: we are now not chain agnostic, so we can handle multiple chains
ref_keys = list(dict.fromkeys(zip(reference.chain_id.tolist(), reference.res_id.tolist())))
cif_keys = list(dict.fromkeys(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist())))

# There should be a single, unique mapping between them. If not, something is wrong.
if len(ref_keys) != len(cif_keys):
msg = f"Residue numbers in {cif_path} cannot be mapped to those in {reference_path}"
logger.error(msg)
return msg

# patch the residue numbers to match the original pdb
mapping = {}
for cif_key, ref_key in zip(cif_keys, ref_keys, strict=True):
if cif_key[0] != ref_key[0]:
msg = f"Chain mismatch while remapping residues for {cif_path} vs {reference_path}"
logger.error(msg)
return msg
mapping[cif_key] = ref_key[1]

atom_keys = list(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist()))
asym_unit.res_id = np.array([mapping[k] for k in atom_keys], dtype=asym_unit.res_id.dtype)
Comment thread
marcuscollins marked this conversation as resolved.

# load the actual PDB, we'll copy the new coordinates to it.
template = CIFFile.read(rcsb_path)
asym_unit = load_any(cif_file)
asym_unit = ensure_atom_array_stack(asym_unit)

# remove any atoms with nan coordinates--these seem to come in because we sometimes use parse
# (from AtomWorks) which creates them. Still we'll do this here just in case.
# (from AtomWorks) which creates them. Still, we'll do this here just in case.
# I do the flattening so that we remove all copies of an atom if it is missing in any structure
flat_coords = einx.rearrange("a b c -> b (a c)", asym_unit.coord)
asym_unit = asym_unit[:, ~np.isnan(flat_coords).any(axis=1)] # pyright: ignore

Expand Down Expand Up @@ -142,12 +193,30 @@ def patch_individual_cif_file(cif_file: Path, rcsb_regex: str):
# Make sure the id field is unique to each atom
template.block["atom_site"]["id"] = CIFColumn(np.arange(np.prod(asym_unit.shape)))

# make sure there are "occupancy" and "B_iso_or_equiv" annotations
if "occupancy" not in template.block["atom_site"].keys():
template.block["atom_site"]["occupancy"] = CIFColumn(
[1.0] * len(template.block["atom_site"]["id"])
)
if "B_iso_or_equiv" not in template.block["atom_site"].keys():
template.block["atom_site"]["B_iso_or_equiv"] = CIFColumn(
[20.0] * len(template.block["atom_site"]["id"])
)
Comment thread
marcuscollins marked this conversation as resolved.

template.block.name = cif_path.stem
template.write(cif_file)
logger.info(f"Wrote {cif_file}")
patched_cif_name = cif_path.parent / (cif_path.stem + "-patched.cif")
template.write(patched_cif_name)
logger.info(f"Wrote {patched_cif_name}")
return None


if __name__ == "__main__":
args = parse_args()
main(args.input_dir, args.cif_pattern, args.rcsb_pattern, args.depth)
main(
args.input_dir,
args.grid_search_input_dir,
args.cif_pattern,
args.rcsb_pattern,
args.depth,
args.input_pdb_pattern,
)