Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions scripts/patch_output_cif_files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# utility script to put all header information from original PDB entry into our CIF files
import fnmatch
import json
import re
from argparse import ArgumentParser
from pathlib import Path
Expand All @@ -12,6 +13,7 @@
from biotite.structure.io.pdbx import CIFColumn, CIFFile, set_structure
from loguru import logger
from sampleworks.utils.atom_array_utils import remove_atoms_with_any_nan_coords
from sampleworks.utils.cif_utils import add_category_to_cif


SAMPLEWORKS_CACHE = Path("~/.sampleworks/rcsb").expanduser()
Expand Down Expand Up @@ -73,7 +75,11 @@ def parse_args():
)
parser.add_argument("--grid-search-input-dir", required=True)
parser.add_argument(
"--input-pdb-pattern", default="{pdb_id}/{pdb_id}_single_001_density_input.cif"
"--input-pdb-pattern",
default="{pdb_id}/{pdb_id}_single_001_density_input.cif",
help="Pattern used by fnmatch/glob for input pdb files. The complete path of the input "
"pdb must match f'{grid-search-input-dir}/{input-pdb-pattern}'. Defaults to "
"'{pdb_id}/{pdb_id}_single_001_density_input.cif'",
)
args = parser.parse_args()
return args
Expand Down Expand Up @@ -159,9 +165,23 @@ def patch_individual_cif_file(
atom_keys = list(zip(asym_unit.chain_id.tolist(), asym_unit.res_id.tolist()))
asym_unit.res_id = np.array([mapping[k] for k in atom_keys], dtype=asym_unit.res_id.dtype)

# load the actual PDB, we'll copy the new coordinates to it.
# load the actual PDB, we'll copy the new coordinates and metadata into it.
template = CIFFile.read(rcsb_path)

# Write sampleworks trial metadata to the CIF file, if we can find it
cif_data = CIFFile.read(cif_path)
if "sampleworks" in cif_data.block:
template.block["sampleworks"] = cif_data.block["sampleworks"]
elif (metadata_path := cif_path.parent / "job_metadata.json").exists():
with open(metadata_path, "r") as fp:
metadata = json.load(fp)
if metadata is not None:
add_category_to_cif(template, metadata, "sampleworks")
else:
logger.warning(f"Sampleworks metadata file at {metadata_path} is empty")
else:
logger.warning(f"No sampleworks metadata found for {cif_path}")

# remove any atoms with nan coordinates--these seem to come in because we sometimes use parse
# (from AtomWorks) which creates them. Still, we'll do this here just in case.
asym_unit = remove_atoms_with_any_nan_coords(asym_unit)
Expand Down
87 changes: 87 additions & 0 deletions src/sampleworks/utils/cif_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from collections import OrderedDict
from collections.abc import Iterable
from pathlib import Path
from typing import Any

import numpy as np
from atomworks.io.utils.io_utils import load_any
from biotite.structure import AtomArrayStack
from biotite.structure.io.pdbx.cif import CIFCategory, CIFFile
from loguru import logger

from sampleworks.utils.atom_array_utils import (
Expand Down Expand Up @@ -235,3 +237,88 @@ def resolve_mixed_hetatm_atom_altlocs(cif_path: Path | str) -> Path:
save_structure_to_cif(fixed_array, tmp_path)
logger.info(f"Wrote altloc-fixed CIF to temporary file: {tmp_path}")
return tmp_path


def add_category_to_cif(
ciffile: CIFFile,
data: dict[str, Any],
category_name: str,
overwrite: bool = False,
block_name: str | None = None,
) -> None:
"""Add a custom category in-place to a CIFFile.

Parameters
----------
ciffile : CIFFile
The CIF file object to modify.
data : dict[str, Any]
Dictionary with column names as keys and column data as values.
category_name : str
Name of the category to add (e.g., "custom_data").
overwrite : bool, optional
If False and the category already exists, raise RuntimeError. Default is False.
block_name : str | None, optional
Name of the block to add the category to. If None, check that there is only
one block and add to that block. Default is None.

Raises
------
RuntimeError
If category already exists and overwrite is False.
ValueError
If block_name is None but the file has multiple blocks, or if the specified
block_name does not exist.

Examples
--------
>>> from biotite.structure.io.pdbx.cif import CIFFile
>>> ciffile = CIFFile.read("example.cif") # assuming it contains a single block
>>> data = {"id": [1, 2, 3], "value": ["a", "b", "c"]}
>>> add_category_to_cif(ciffile, data, "my_custom_data")
>>> print(ciffile.block["my_custom_data"].serialize())
Comment thread
coderabbitai[bot] marked this conversation as resolved.
loop_
_my_custom_data.id
_my_custom_data.value
1 a
2 b
3 c
>>> data = {"sampleworks_version": "0.4.0", "pdb_id": "1L63"}
>>> add_category_to_cif(ciffile, data, "sampleworks_metadata")
>>> print(ciffile.block["sampleworks_metadata"].serialize())
_sampleworks_metadata.sampleworks_version 0.4.0
_sampleworks_metadata.pdb_id 1L63
"""
# Determine which block to use
if block_name is None:
# CIFFile is a Mapping, so inherits .keys(), which ultimately iterates over blocks
blocks = list(ciffile.keys())
if len(blocks) == 0:
raise ValueError("CIFFile has no blocks. Cannot add category.")
elif len(blocks) > 1:
raise ValueError(
f"CIFFile has multiple blocks: {blocks}. Please specify block_name parameter."
)
block = ciffile[blocks[0]]
else:
if block_name not in ciffile:
raise ValueError(f"Block '{block_name}' not found in CIFFile.")
block = ciffile[block_name]

# Check if a category with name category_name already exists
if category_name in block and not overwrite:
raise RuntimeError(
f"Category '{category_name}' already exists in block with value: {block[category_name]}"
)

# Create and add the category--remove any None values, CIF requires non-null values
category = CIFCategory(
columns={k: _normalize_nulls(v) for k, v in data.items()}, name=category_name
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
block[category_name] = category


def _normalize_nulls(value: Any) -> Any:
if isinstance(value, Iterable) and not isinstance(value, str | bytes):
return ["?" if item is None else item for item in value]
return "?" if value is None else value
41 changes: 22 additions & 19 deletions src/sampleworks/utils/guidance_script_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import argparse
import json
import os
import pickle
Expand Down Expand Up @@ -30,7 +29,7 @@
NoiseSpaceDPSScaler,
NoScalingScaler,
)
from sampleworks.utils.cif_utils import resolve_mixed_hetatm_atom_altlocs
from sampleworks.utils.cif_utils import add_category_to_cif, resolve_mixed_hetatm_atom_altlocs
from sampleworks.utils.guidance_constants import (
GuidanceType,
StructurePredictor,
Expand Down Expand Up @@ -265,7 +264,7 @@ def get_reward_function_and_structure(


def save_everything(
output_dir: str | Path,
args: GuidanceConfig,
losses: list[Any],
refined_structure: dict,
traj_denoised: list[Any],
Expand All @@ -283,8 +282,10 @@ def save_everything(

Parameters
----------
output_dir : str | Path
Directory to write all output files into. Created if it doesn't exist.
args : GuidanceConfig
The arguments for the guidance run. This method directly uses args.output_dir,
and creates that directory if it does not exist. The result of args.as_dict() is
written to a JSON file in the same directory, and inserted into the output CIF file.
losses : list[Any]
Per-step loss values (may contain ``None`` entries for unguided steps).
refined_structure : dict
Expand All @@ -304,7 +305,7 @@ def save_everything(
Optional model-space atom template. When provided (mismatch runs),
this template is used for final structure and trajectory saving.
"""
output_dir = Path(output_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

logger.info("Saving results")
Expand All @@ -327,10 +328,20 @@ def save_everything(
else:
atom_array = base_atom_array

metadata = args.as_dict()

final_structure = CIFFile()
set_structure(final_structure, atom_array)
add_category_to_cif(final_structure, metadata, category_name="sampleworks")
final_structure.write(str(output_dir / "refined.cif"))

# write out the job parameters to a JSON file in the same directory as the refined.cif file
# Even though this is technically duplicated, keep it around as a backup in case metadata
# is lost in some CIF transform.
with open(Path(output_dir) / "job_metadata.json", "w") as fp:
# use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects
json.dump(metadata, fp)

# Two calls to save_trajectory, very similar, but saving different trajectories!
save_trajectory(
scaler_type,
Expand Down Expand Up @@ -363,14 +374,12 @@ def save_everything(
# Methods for running model guidance in separate processes, avoiding reloading of the model.
#####################
# These args are passed from run_grid_search.py via GuidanceConfig.
def run_guidance(
args: GuidanceConfig | argparse.Namespace, guidance_type: str, model_wrapper, device
) -> JobResult:
def run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device) -> JobResult:
"""Wrapper around ``_run_guidance`` to redirect logs and generate a JobResult.

Parameters
----------
args : GuidanceConfig | argparse.Namespace
args : GuidanceConfig
Configuration for the guidance run.
guidance_type : str
Type of guidance/scaler to apply.
Expand Down Expand Up @@ -410,9 +419,7 @@ def run_guidance(


# "guidance_type" is also called "scaler" in many places
def _run_guidance(
args: GuidanceConfig | argparse.Namespace, guidance_type: str, model_wrapper, device
):
def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device):
reward_function, structure = get_reward_function_and_structure(
args.density, # str/path to a map file.
device, # this needs to come from the global context, not the args object.
Expand Down Expand Up @@ -565,7 +572,7 @@ def _run_guidance(
model_atom_array = result.metadata.get("model_atom_array") if result.metadata else None

save_everything(
args.output_dir,
args,
losses,
refined_structure,
traj_denoised,
Expand All @@ -587,7 +594,7 @@ def epoch_seconds(time_to_convert: datetime) -> float:


def get_job_result(
args: GuidanceConfig | argparse.Namespace,
args: GuidanceConfig,
device: torch.device,
started_at: datetime,
ended_at: datetime,
Expand Down Expand Up @@ -641,10 +648,6 @@ def run_guidance_job_queue(job_queue_path: str) -> list[JobResult]:
logger.info(f"Running job {i + 1}/{len(job_queue)}: {job}")

job_result = run_guidance(job, job.guidance_type, model_wrapper, device)
# write out the job parameters to a JSON file in the same directory as the refined.cif file
with open(Path(job_result.output_dir) / "job_metadata.json", "w") as fp:
# use the GuidanceConfig's as_dict() method to avoid serializing PosixPath objects
json.dump(job.as_dict(), fp)

job_results.append(job_result)
torch.cuda.empty_cache() # just in case
Expand Down
13 changes: 12 additions & 1 deletion tests/integration/test_mismatch_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sampleworks.utils.atom_array_utils import make_normalized_atom_id
from sampleworks.utils.atom_reconciler import AtomReconciler
from sampleworks.utils.frame_transforms import apply_forward_transform
from sampleworks.utils.guidance_script_arguments import GuidanceConfig
from sampleworks.utils.guidance_script_utils import save_everything

from tests.mocks import MismatchCase, MismatchCaseWrapper
Expand Down Expand Up @@ -1011,8 +1012,18 @@ def test_save_with_model_template(self, tmp_path: Path):
refined = {"asym_unit": build_test_atom_array(n_atoms=n_struct)}
model_atom_array = build_test_atom_array(n_atoms=n_model, with_occupancy=False)

args = GuidanceConfig(
protein="1l63",
structure=Path("dummy"),
density=Path("dummy"),
model="boltz2",
guidance_type="pure_guidance",
log_path="dummy",
output_dir=str(tmp_path),
)

save_everything(
output_dir=tmp_path,
args,
losses=[0.5, 0.3],
refined_structure=refined,
traj_denoised=[],
Expand Down
Loading
Loading