Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions src/sampleworks/core/rewards/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ class RewardInputs:
Contains all the information needed to call a RewardFunctionProtocol,
extracted from an atom array. This allows the caller to extract inputs
once and pass them to scale() methods without redundant extraction.

The atom array passed to :meth:`from_atom_array` must already be clean:
all coordinates finite and all occupancies positive. Wrappers are
responsible for ensuring this (e.g. replacing NaN coordinates with
noise and setting occupancy to 1.0 for model-operated atoms).
"""

elements: Float[torch.Tensor, "*batch n_atoms"]
b_factors: Float[torch.Tensor, "*batch n_atoms"]
occupancies: Float[torch.Tensor, "*batch n_atoms"]
input_coords: Float[torch.Tensor, "*batch n_atoms 3"]
reward_param_mask: np.ndarray
mask_like: Float[torch.Tensor, "*batch n_atoms"]

@classmethod
def from_atom_array(
Expand All @@ -42,10 +45,15 @@ def from_atom_array(
) -> RewardInputs:
"""Construct RewardInputs from a Biotite AtomArray.

The atom array must contain only valid atoms (finite coordinates,
positive occupancy). Callers are responsible for filtering
beforehand; no masking is applied here.

Parameters
----------
atom_array
Biotite AtomArray or AtomArrayStack containing structure data.
Must have not NaN coordinates and positive occupancy.
Comment thread
k-chrispens marked this conversation as resolved.
ensemble_size
Number of ensemble members (batch dimension).
num_particles
Expand All @@ -58,13 +66,17 @@ def from_atom_array(
RewardInputs
Dataclass containing all inputs needed for reward function computation.
"""
occupancy_mask = atom_array.occupancy > 0
nan_mask = ~np.any(np.isnan(atom_array.coord), axis=-1)
reward_param_mask = occupancy_mask & nan_mask

elements_list = [
ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element[reward_param_mask]
]
# input validation: ensure atom_array has required annotations and valid values
if not hasattr(atom_array, "element"):
raise ValueError("Atom array must have 'element' annotation.")
if not hasattr(atom_array, "b_factor"):
raise ValueError("Atom array must have 'b_factor' annotation.")
if np.any(np.isnan(atom_array.coord)):
raise ValueError("Atom array contains NaN coordinates.")
if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)):
raise ValueError("Atom array contains invalid occupancy values.")
Comment on lines +69 to +77
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n=== to_reward_inputs implementation ===\n'
sed -n '1,220p' src/sampleworks/eval/structure_utils.py

printf '\n=== wrapper/model-atom-array sanitation sites ===\n'
rg -n -C3 'model_atom_array|occupanc|isnan|isfinite|filter_zero_occupancy|coord' \
  src/sampleworks/models src/sampleworks/eval/structure_utils.py

Repository: diff-use/sampleworks

Length of output: 50376


🏁 Script executed:

#!/bin/bash
# Check Protenix wrapper for coordinate sanitization
rg -n -A10 'model_atom_array|model_aa\s*=|coord.*protenix' \
  src/sampleworks/models/protenix/wrapper.py | head -100

# Also check if Protenix structure_processing has any sanitization
fd -e py protenix | xargs rg -l 'nan|isfinite|coord' | head -5

Repository: diff-use/sampleworks

Length of output: 2209


Move the NaN/occupancy validation to the final coordinate path, or require all wrappers to sanitize model_atom_array beforehand.

The validation at lines 69-77 checks atom_array.coord before it's used in rewards, but RewardInputs.from_atom_array() uses this only for atom accounting. The actual reward coordinates come from SampleworksProcessedStructure.input_coords (line 97 of structure_utils.py), which are already clean and pre-aligned to the model space. Protenix's wrapper sets occupancy and b_factor on model_atom_array but does not sanitize its coordinates, so this validation will reject valid inputs where the template has placeholder NaNs but input_coords is finite. Either validate input_coords instead, or document that all wrappers must sanitize model_atom_array.coord before calling to_reward_inputs().

🧰 Tools
🪛 Ruff (0.15.4)

[warning] 71-71: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 73-73: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 75-75: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 77-77: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/sampleworks/core/rewards/protocol.py` around lines 69 - 77, The current
validation in RewardInputs.from_atom_array() checks atom_array.coord (and
occupancy/b_factor) but the reward pipeline uses
SampleworksProcessedStructure.input_coords as the canonical coordinates; move
NaN/occupancy coordinate validation out of RewardInputs.from_atom_array() and
instead perform it where the final coordinate path is assembled (i.e., validate
SampleworksProcessedStructure.input_coords before rewards are computed), or
alternatively require and document that wrappers producing model_atom_array must
sanitize model_atom_array.coord (and occupancy/b_factor) before calling
to_reward_inputs(); update checks to reference input_coords or add an explicit
precondition in the RewardInputs.from_atom_array() docstring mentioning
sanitized model_atom_array.


elements_list = [ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element]

total_batch_size = num_particles * ensemble_size if num_particles > 1 else ensemble_size

Expand All @@ -80,7 +92,7 @@ def from_atom_array(
)
b_factors = einx.rearrange(
"n -> p e n",
torch.Tensor(atom_array.b_factor[reward_param_mask]),
torch.Tensor(atom_array.b_factor),
p=num_particles,
e=ensemble_size,
)
Expand All @@ -89,22 +101,20 @@ def from_atom_array(
"... -> b ...",
coords_t,
b=total_batch_size,
)[..., reward_param_mask, :]
)
else:
elements = einx.rearrange("n -> b n", torch.Tensor(elements_list), b=ensemble_size)
b_factors = einx.rearrange(
"n -> b n",
torch.Tensor(atom_array.b_factor[reward_param_mask]),
torch.Tensor(atom_array.b_factor),
b=ensemble_size,
)
occupancies = torch.ones_like(b_factors) / ensemble_size
input_coords = einx.rearrange(
"... -> e ...",
coords_t,
e=ensemble_size,
)[..., reward_param_mask, :]

mask_like = torch.ones_like(input_coords[..., 0])
)

if isinstance(device, str):
device = torch.device(device)
Expand All @@ -114,8 +124,6 @@ def from_atom_array(
b_factors=b_factors.to(device),
occupancies=occupancies.to(device),
input_coords=input_coords.to(device),
reward_param_mask=reward_param_mask,
mask_like=mask_like.to(device),
)


Expand Down
2 changes: 1 addition & 1 deletion src/sampleworks/core/samplers/edm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def step(
) / (2 * noise_var)

# Euler step: x_{t-1} = x_t + step_scale * dt * delta
# pyright sees dt as float | None, but it will be float if check_context didn't raise
# ty sees dt as float | None, but it will be float if check_context didn't raise
next_state = noisy_state_working_frame_t + self.step_scale * dt * delta # ty: ignore[unsupported-operator]

return SamplerStepOutput(
Expand Down
4 changes: 2 additions & 2 deletions src/sampleworks/eval/generate_synthetic_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _process_single_row(
# TODO: there's probably a more robust way to do this
atom_array = keep_polymer(keep_amino_acids(atom_array)) if strip_ligands else atom_array

altloc_info = detect_altlocs(atom_array) # pyright: ignore[reportArgumentType]
altloc_info = detect_altlocs(atom_array) # ty: ignore[invalid-argument-type]
if row.occ_values:
if occ_mode != "custom":
logger.warning(
Expand Down Expand Up @@ -324,7 +324,7 @@ def _process_single_row(
# Shift coordinates into the grid frame so the saved CIF aligns with
# the CCP4 map. CCP4 format (unlike MRC) cannot encode an arbitrary Cartesian
# origin, so we move the atoms instead. Possible the better way is to resample the map?
atom_array.coord = atom_array.coord - xmap_torch.origin # pyright: ignore[reportOptionalOperand]
atom_array.coord = atom_array.coord - xmap_torch.origin
structure_output_path = structure_path.parent / f"{structure_path.stem}_density_input.cif"
try:
save_structure_to_cif(atom_array, structure_output_path)
Expand Down
2 changes: 1 addition & 1 deletion src/sampleworks/eval/grid_search_eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def scan_grid_search_results(
guidance_weight = float(params["guidance_weight"])
gd_steps = int(params["gd_steps"]) if params["gd_steps"] is not None else None

# Validate parameters to satisfy pyright
# Validate parameters to satisfy ty
if (
protein is None
or occ_a is None
Expand Down
32 changes: 14 additions & 18 deletions src/sampleworks/eval/structure_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sampleworks.utils.atom_reconciler import AtomReconciler
from sampleworks.utils.framework_utils import match_batch


ATOMWORKS_COMPARISON_OPS = ("==", ">", "<", "<=", ">=", " in ")

try:
Expand Down Expand Up @@ -90,19 +91,13 @@ def to_reward_inputs(self, device: torch.device | str = "cpu") -> RewardInputs:
model_indices = self.reconciler.model_indices.to(device=reward_inputs.b_factors.device)
updated_b_factors[..., model_indices] = struct_b_factors

# input_coords stored on the processed structure are always full model atom
# references, apply reward_param_mask for reward calls.
masked_input_coords = self.input_coords[..., reward_inputs.reward_param_mask, :].to(
device=reward_inputs.input_coords.device,
dtype=reward_inputs.input_coords.dtype,
)
updated_mask_like = torch.ones_like(masked_input_coords[..., 0])

return replace(
reward_inputs,
b_factors=updated_b_factors,
input_coords=masked_input_coords,
mask_like=updated_mask_like,
input_coords=self.input_coords.to(
device=reward_inputs.input_coords.device,
dtype=reward_inputs.input_coords.dtype,
),
)


Expand Down Expand Up @@ -174,10 +169,13 @@ def process_structure_to_trajectory_input(
atom_array = _filter_zero_occupancy(atom_array)
atom_array = _add_terminal_oxt_atoms(atom_array, structure.get("chain_info", {}))

# Mask to valid atoms (nonzero occupancy, no NaN coords) in structure atom space
reward_param_mask = atom_array.occupancy > 0
reward_param_mask &= ~np.any(np.isnan(atom_array.coord), axis=-1)
atom_array = atom_array[reward_param_mask]
# Filter to valid atoms (nonzero occupancy, finite coords) in structure atom space.
# The deposited structure may have zero-occupancy or NaN-coordinate atoms from
# unresolved regions or altloc processing; these must be removed before
# building the reconciler against the model's (already-clean) atom array.
valid_atom_mask = atom_array.occupancy > 0
valid_atom_mask &= ~np.any(np.isnan(atom_array.coord), axis=-1)
atom_array = atom_array[valid_atom_mask]

# Build reconciler from model and structure atom arrays.
model_atom_array = (
Expand Down Expand Up @@ -406,13 +404,11 @@ def extract_selection_coordinates(
@overload
def get_asym_unit_from_structure(
structure: dict, atom_array_index: None = None
) -> AtomArrayStack: ...
) -> AtomArrayStack: ...


@overload
def get_asym_unit_from_structure(
structure: dict, atom_array_index: int
) -> AtomArray: ...
def get_asym_unit_from_structure(structure: dict, atom_array_index: int) -> AtomArray: ...


def get_asym_unit_from_structure(
Expand Down
29 changes: 26 additions & 3 deletions src/sampleworks/models/rf3/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:
) # since we're not batching, the loader returns a list of length 1

# (Hydra instantiation of pipeline means it is going to be hard to type check here)
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # type: ignore
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable]
pipeline_output = trainer.fabric.to_device(pipeline_output)

features = trainer._assemble_network_inputs(pipeline_output)
Expand Down Expand Up @@ -317,8 +317,31 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:
# Any excess atoms are trailing entries not represented in atom_to_token_map.
if len(model_aa) > num_atoms:
model_aa = cast(AtomArray, model_aa[:num_atoms])
if not hasattr(model_aa, "occupancy") or model_aa.occupancy is None:
model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32))

# atomworks's add_missing_atoms adds unresolved atoms with
# occupancy=0.0 and NaN coordinates when we get our atom array with
# InferenceInput.from_atom_array. RF3 operates on these atoms (they're
# in atom_to_token_map), so initialize their coordinates with noise and
# set occupancy to 1.0 so they participate in guidance and don't get
# masked out in reward functions.
nan_coord_mask = np.any(np.isnan(model_aa.coord), axis=-1)
if nan_coord_mask.any():
resolved_coords = model_aa.coord[~nan_coord_mask]
centroid = resolved_coords.mean(axis=0) if len(resolved_coords) > 0 else np.zeros(3)
n_nan = int(nan_coord_mask.sum())
noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32)
new_coords = model_aa.coord.copy()
new_coords[nan_coord_mask] = centroid + noise
model_aa.coord = new_coords
logger.info(
f"Initialized {n_nan} unresolved atoms with noise "
f"(had NaN coordinates from add_missing_atoms)"
)

# All atoms in model_aa are operated on by RF3 during diffusion.
# Set occupancy to 1.0 regardless of what atomworks assigned (unresolved
# atoms from add_missing_atoms get occupancy=0.0, but RF3 should use them)
model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32))
if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None:
model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32))

Expand Down
26 changes: 13 additions & 13 deletions src/sampleworks/utils/atom_array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ def select_altloc(
mask = atom_array.altloc_id == altloc_id

if isinstance(atom_array, AtomArrayStack):
return atom_array[:, mask] # pyright: ignore[reportReturnType]
return atom_array[:, mask]
else:
return atom_array[mask] # pyright: ignore[reportReturnType]
return atom_array[mask]


def select_non_hetero(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack:
Expand Down Expand Up @@ -410,9 +410,9 @@ def select_non_hetero(atom_array: AtomArray | AtomArrayStack) -> AtomArray | Ato
mask = ~hetero

if isinstance(atom_array, AtomArrayStack):
return atom_array[:, mask] # pyright: ignore[reportReturnType]
return atom_array[:, mask]
else:
return atom_array[mask] # pyright: ignore[reportReturnType]
return atom_array[mask]


def keep_polymer(
Expand Down Expand Up @@ -449,10 +449,10 @@ def keep_polymer(
# TODO: fix once this is fixed: https://github.com/biotite-dev/biotite/issues/865
if isinstance(atom_array, AtomArrayStack):
polymer_mask = filter_polymer(atom_array[0], pol_type=pol_type)
return atom_array[:, polymer_mask] # pyright: ignore[reportReturnType]
return atom_array[:, polymer_mask]
else:
polymer_mask = filter_polymer(atom_array, pol_type=pol_type)
return atom_array[polymer_mask] # pyright: ignore[reportReturnType]
return atom_array[polymer_mask]


def keep_amino_acids(
Expand Down Expand Up @@ -485,9 +485,9 @@ def keep_amino_acids(
amino_acid_mask = filter_amino_acids(atom_array)

if isinstance(atom_array, AtomArrayStack):
return atom_array[:, amino_acid_mask] # pyright: ignore[reportReturnType]
return atom_array[:, amino_acid_mask]
else:
return atom_array[amino_acid_mask] # pyright: ignore[reportReturnType]
return atom_array[amino_acid_mask]


def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack:
Expand Down Expand Up @@ -521,9 +521,9 @@ def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | Atom
mask = (element != "H") & (element != "D")

if isinstance(atom_array, AtomArrayStack):
return atom_array[:, mask] # pyright: ignore[reportReturnType]
return atom_array[:, mask]
else:
return atom_array[mask] # pyright: ignore[reportReturnType]
return atom_array[mask]


@overload
Expand All @@ -542,7 +542,7 @@ def remove_atoms_with_any_nan_coords(
"""
if not atom_array or not atom_array.shape[-1]:
raise ValueError("Cannot remove atoms from empty AtomArray|Stack")

if isinstance(atom_array, AtomArrayStack):
# Partially flatten the coords so that we remove any atom from all structures if any
# one of them has a NaN at that position.
Expand Down Expand Up @@ -586,9 +586,9 @@ def select_backbone(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomA
mask = np.isin(atom_name, backbone_atoms)

if isinstance(atom_array, AtomArrayStack):
return atom_array[:, mask] # pyright: ignore[reportReturnType]
return atom_array[:, mask]
else:
return atom_array[mask] # pyright: ignore[reportReturnType]
return atom_array[mask]


def make_atom_id(arr: AtomArray | AtomArrayStack) -> np.ndarray:
Expand Down
Loading