diff --git a/scripts/patch_input_cif_files.py b/scripts/patch_input_cif_files.py index a6f35de8..3af3c620 100644 --- a/scripts/patch_input_cif_files.py +++ b/scripts/patch_input_cif_files.py @@ -1,7 +1,6 @@ # utility script to put all header information from original PDB entry into our CIF files import fnmatch import re -import shutil from argparse import ArgumentParser from pathlib import Path @@ -43,6 +42,7 @@ def _crawl(current: Path, levels_left: int) -> None: for entry in current.iterdir(): if entry.is_file(): if fnmatch.fnmatch(entry.name, target_pattern): + logger.info(f"Found matching file: {entry}") results.append(entry) elif entry.is_dir() and levels_left > 0: _crawl(entry, levels_left - 1) @@ -57,7 +57,11 @@ def _crawl(current: Path, levels_left: int) -> None: def parse_args(): parser = ArgumentParser() parser.add_argument("--input-dir", required=True) - parser.add_argument("--cif-pattern", default="refined.cif") + parser.add_argument( + "--cif-pattern", + default="refined.cif", + help="Pattern used by fnmatch/glob for cif files to patch, default: 'refined.cif'", + ) parser.add_argument( "--rcsb-pattern", default="grid_search_results/(.{4})", @@ -67,53 +71,100 @@ def parse_args(): parser.add_argument( "--depth", type=int, default=4, help="Depth to search the directory tree below input-dir" ) + parser.add_argument("--grid-search-input-dir", required=True) + parser.add_argument( + "--input-pdb-pattern", default="{pdb_id}/{pdb_id}_single_001_density_input.cif" + ) args = parser.parse_args() return args def main( input_dir: str | Path, + grid_search_input_dir: str | Path, target_pattern: str, rcsb_regex: str = r"grid_search_results/(.{4})", depth: int = 4, + input_pdb_pattern: str = "{pdb_id}/{pdb_id}_single_001_density_input.cif", ) -> None: # make sure the cache exists SAMPLEWORKS_CACHE.mkdir(parents=True, exist_ok=True) cif_files_to_patch = crawl_dir_by_depth(input_dir, target_pattern, n_levels=depth) results = joblib.Parallel()( - joblib.delayed(patch_individual_cif_file)(f, rcsb_regex) for f in cif_files_to_patch + joblib.delayed(patch_individual_cif_file)( + f, rcsb_regex, Path(grid_search_input_dir), input_pdb_pattern + ) + for f in cif_files_to_patch ) results = [r for r in results if r] if results: - logger.warning("The following files could not be patched:") + logger.error("The following errors occurred:") for r in results: print(r) -def patch_individual_cif_file(cif_file: Path, rcsb_regex: str): +def patch_individual_cif_file( + cif_file: Path, rcsb_regex: str, reference_dir: Path, input_pdb_pattern: str +) -> str | None: # returns an error message if there was one cif_path = Path(cif_file) m = re.search(rcsb_regex, str(cif_path)) rcsb_id = m.group(1) if m else None if not m: - logger.warning( + msg = ( f"Unable to parse an RCSB structure id: from path {cif_file} with pattern {rcsb_regex}" ) - return cif_file - - # write a backup version of the input cif file - shutil.copy(cif_path, cif_path.parent / (cif_path.name + ".bak")) - - # fetch only downloads the file if it isn't already present. - rcsb_path = fetch(rcsb_id, format="cif", target_path=str(SAMPLEWORKS_CACHE)) - - # load the copy, and the new coordinates for it. + logger.warning(msg) + return msg + + # Get the offset for residue numbering in the reference structure + try: + reference_path = reference_dir / input_pdb_pattern.format(pdb_id=rcsb_id) + # fetch only downloads the file if it isn't already present. + rcsb_path = fetch(rcsb_id, format="cif", target_path=str(SAMPLEWORKS_CACHE)) + + reference = load_any(reference_path) + asym_unit = load_any(cif_file) + asym_unit = ensure_atom_array_stack(asym_unit) + except Exception as e: + msg = f"Unable to read and parse either/both of {reference_path}, {cif_file}" + logger.warning(msg) + return msg + + # get the unique residue numbers for each file + if reference.res_id is None or asym_unit.res_id is None: + msg = f"Residue numbers for {cif_path} and/or {reference_path} are missing." + logger.error(msg) + return msg + + # CodeRabbit improved this: we are now not chain agnostic, so we can handle multiple chains + ref_keys = list(dict.fromkeys(zip(reference.chain_id.tolist(), reference.res_id.tolist()))) + cif_keys = list(dict.fromkeys(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist()))) + + # There should be a single, unique mapping between them. If not, something is wrong. + if len(ref_keys) != len(cif_keys): + msg = f"Residue numbers in {cif_path} cannot be mapped to those in {reference_path}" + logger.error(msg) + return msg + + # patch the residue numbers to match the original pdb + mapping = {} + for cif_key, ref_key in zip(cif_keys, ref_keys, strict=True): + if cif_key[0] != ref_key[0]: + msg = f"Chain mismatch while remapping residues for {cif_path} vs {reference_path}" + logger.error(msg) + return msg + mapping[cif_key] = ref_key[1] + + atom_keys = list(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist())) + asym_unit.res_id = np.array([mapping[k] for k in atom_keys], dtype=asym_unit.res_id.dtype) + + # load the actual PDB, we'll copy the new coordinates to it. template = CIFFile.read(rcsb_path) - asym_unit = load_any(cif_file) - asym_unit = ensure_atom_array_stack(asym_unit) # remove any atoms with nan coordinates--these seem to come in because we sometimes use parse - # (from AtomWorks) which creates them. Still we'll do this here just in case. + # (from AtomWorks) which creates them. Still, we'll do this here just in case. + # I do the flattening so that we remove all copies of an atom if it is missing in any structure flat_coords = einx.rearrange("a b c -> b (a c)", asym_unit.coord) asym_unit = asym_unit[:, ~np.isnan(flat_coords).any(axis=1)] # pyright: ignore @@ -142,12 +193,30 @@ def patch_individual_cif_file(cif_file: Path, rcsb_regex: str): # Make sure the id field is unique to each atom template.block["atom_site"]["id"] = CIFColumn(np.arange(np.prod(asym_unit.shape))) + # make sure there are "occupancy" and "B_iso_or_equiv" annotations + if "occupancy" not in template.block["atom_site"].keys(): + template.block["atom_site"]["occupancy"] = CIFColumn( + [1.0] * len(template.block["atom_site"]["id"]) + ) + if "B_iso_or_equiv" not in template.block["atom_site"].keys(): + template.block["atom_site"]["B_iso_or_equiv"] = CIFColumn( + [20.0] * len(template.block["atom_site"]["id"]) + ) + template.block.name = cif_path.stem - template.write(cif_file) - logger.info(f"Wrote {cif_file}") + patched_cif_name = cif_path.parent / (cif_path.stem + "-patched.cif") + template.write(patched_cif_name) + logger.info(f"Wrote {patched_cif_name}") return None if __name__ == "__main__": args = parse_args() - main(args.input_dir, args.cif_pattern, args.rcsb_pattern, args.depth) + main( + args.input_dir, + args.grid_search_input_dir, + args.cif_pattern, + args.rcsb_pattern, + args.depth, + args.input_pdb_pattern, + )