From 4dcba730ffe8dcd291927097682721e96a8151b5 Mon Sep 17 00:00:00 2001 From: Karson Chrispens Date: Thu, 5 Mar 2026 20:58:10 +0000 Subject: [PATCH 1/4] fix(ty): removing remaining pyright ignores --- .../eval/generate_synthetic_density.py | 4 +-- .../eval/grid_search_eval_utils.py | 2 +- src/sampleworks/utils/atom_array_utils.py | 26 +++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/sampleworks/eval/generate_synthetic_density.py b/src/sampleworks/eval/generate_synthetic_density.py index cfb27d26..693db47a 100644 --- a/src/sampleworks/eval/generate_synthetic_density.py +++ b/src/sampleworks/eval/generate_synthetic_density.py @@ -286,7 +286,7 @@ def _process_single_row( # TODO: there's probably a more robust way to do this atom_array = keep_polymer(keep_amino_acids(atom_array)) if strip_ligands else atom_array - altloc_info = detect_altlocs(atom_array) # pyright: ignore[reportArgumentType] + altloc_info = detect_altlocs(atom_array) # ty: ignore[invalid-argument-type] if row.occ_values: if occ_mode != "custom": logger.warning( @@ -324,7 +324,7 @@ def _process_single_row( # Shift coordinates into the grid frame so the saved CIF aligns with # the CCP4 map. CCP4 format (unlike MRC) cannot encode an arbitrary Cartesian # origin, so we move the atoms instead. Possible the better way is to resample the map? - atom_array.coord = atom_array.coord - xmap_torch.origin # pyright: ignore[reportOptionalOperand] + atom_array.coord = atom_array.coord - xmap_torch.origin structure_output_path = structure_path.parent / f"{structure_path.stem}_density_input.cif" try: save_structure_to_cif(atom_array, structure_output_path) diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index 1340b8e0..bcf957e0 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -103,7 +103,7 @@ def scan_grid_search_results( guidance_weight = float(params["guidance_weight"]) gd_steps = int(params["gd_steps"]) if params["gd_steps"] is not None else None - # Validate parameters to satisfy pyright + # Validate parameters to satisfy ty if ( protein is None or occ_a is None diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index dbbb7424..7d87cb94 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -372,9 +372,9 @@ def select_altloc( mask = atom_array.altloc_id == altloc_id if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] def select_non_hetero(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: @@ -410,9 +410,9 @@ def select_non_hetero(atom_array: AtomArray | AtomArrayStack) -> AtomArray | Ato mask = ~hetero if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] def keep_polymer( @@ -449,10 +449,10 @@ def keep_polymer( # TODO: fix once this is fixed: https://github.com/biotite-dev/biotite/issues/865 if isinstance(atom_array, AtomArrayStack): polymer_mask = filter_polymer(atom_array[0], pol_type=pol_type) - return atom_array[:, polymer_mask] # pyright: ignore[reportReturnType] + return atom_array[:, polymer_mask] else: polymer_mask = filter_polymer(atom_array, pol_type=pol_type) - return atom_array[polymer_mask] # pyright: ignore[reportReturnType] + return atom_array[polymer_mask] def keep_amino_acids( @@ -485,9 +485,9 @@ def keep_amino_acids( amino_acid_mask = filter_amino_acids(atom_array) if isinstance(atom_array, AtomArrayStack): - return atom_array[:, amino_acid_mask] # pyright: ignore[reportReturnType] + return atom_array[:, amino_acid_mask] else: - return atom_array[amino_acid_mask] # pyright: ignore[reportReturnType] + return atom_array[amino_acid_mask] def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: @@ -521,9 +521,9 @@ def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | Atom mask = (element != "H") & (element != "D") if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] @overload @@ -542,7 +542,7 @@ def remove_atoms_with_any_nan_coords( """ if not atom_array or not atom_array.shape[-1]: raise ValueError("Cannot remove atoms from empty AtomArray|Stack") - + if isinstance(atom_array, AtomArrayStack): # Partially flatten the coords so that we remove any atom from all structures if any # one of them has a NaN at that position. @@ -586,9 +586,9 @@ def select_backbone(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomA mask = np.isin(atom_name, backbone_atoms) if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] def make_atom_id(arr: AtomArray | AtomArrayStack) -> np.ndarray: From a6213e3a610056409838fe8cb533b2b6e7ce06b3 Mon Sep 17 00:00:00 2001 From: Karson Chrispens Date: Thu, 5 Mar 2026 21:16:53 +0000 Subject: [PATCH 2/4] fix(rf3): fixed bug where nan atoms were still appearing in the model_atom_array, causing indexing errors downstream --- src/sampleworks/models/rf3/wrapper.py | 30 ++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/sampleworks/models/rf3/wrapper.py b/src/sampleworks/models/rf3/wrapper.py index 23e4643a..8635b813 100644 --- a/src/sampleworks/models/rf3/wrapper.py +++ b/src/sampleworks/models/rf3/wrapper.py @@ -289,7 +289,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: ) # since we're not batching, the loader returns a list of length 1 # (Hydra instantiation of pipeline means it is going to be hard to type check here) - pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # type: ignore + pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable] pipeline_output = trainer.fabric.to_device(pipeline_output) features = trainer._assemble_network_inputs(pipeline_output) @@ -317,8 +317,32 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: # Any excess atoms are trailing entries not represented in atom_to_token_map. if len(model_aa) > num_atoms: model_aa = cast(AtomArray, model_aa[:num_atoms]) - if not hasattr(model_aa, "occupancy") or model_aa.occupancy is None: - model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) + + # atomworks's add_missing_atoms adds unresolved atoms with + # occupancy=0.0 and NaN coordinates when we get our atom array with + # InferenceInput.from_atom_array. RF3 operates on these atoms (they're + # in atom_to_token_map), so initialize their coordinates with noise and + # set occupancy to 1.0 so they participate in guidance and don't get + # masked out in reward functions. + nan_coord_mask = np.any(np.isnan(model_aa.coord), axis=-1) + if nan_coord_mask.any(): + n_nan = int(nan_coord_mask.sum()) + resolved_coords = model_aa.coord[~nan_coord_mask] + centroid = resolved_coords.mean(axis=0) if len(resolved_coords) > 0 else np.zeros(3) + rng = np.random.default_rng(42) + noise = rng.normal(scale=1.0, size=(n_nan, 3)).astype(np.float32) + new_coords = model_aa.coord.copy() + new_coords[nan_coord_mask] = centroid + noise + model_aa.coord = new_coords + logger.info( + f"Initialized {n_nan} unresolved atoms with noise " + f"(had NaN coordinates from add_missing_atoms)" + ) + + # All atoms in model_aa are operated on by RF3 during diffusion. + # Set occupancy to 1.0 regardless of what atomworks assigned (unresolved + # atoms from add_missing_atoms get occupancy=0.0, but RF3 should use them) + model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None: model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32)) From eaa92c1c4ca71c9fea778012cc33ae7984990840 Mon Sep 17 00:00:00 2001 From: Karson Chrispens Date: Thu, 5 Mar 2026 21:31:18 +0000 Subject: [PATCH 3/4] fix(reward): remove usage of reward_param_mask, we enforce that the reward must be computable on the whole model output. This is reasonable because the reconciler takes care of mismatches now, so we can now move the responsibility of ensuring the atoms are all valid for the model to the ModelWrapper itself. --- src/sampleworks/core/rewards/protocol.py | 42 +++++++----- src/sampleworks/core/samplers/edm.py | 2 +- src/sampleworks/eval/structure_utils.py | 32 ++++----- .../utils/guidance_script_utils.py | 65 ++++++------------- .../integration/test_mismatch_integration.py | 6 +- .../integration/test_pipeline_integration.py | 59 +++++++++++++++-- tests/utils/test_density_utils.py | 2 +- 7 files changed, 114 insertions(+), 94 deletions(-) diff --git a/src/sampleworks/core/rewards/protocol.py b/src/sampleworks/core/rewards/protocol.py index 9e794eda..26d3c7aa 100644 --- a/src/sampleworks/core/rewards/protocol.py +++ b/src/sampleworks/core/rewards/protocol.py @@ -23,14 +23,17 @@ class RewardInputs: Contains all the information needed to call a RewardFunctionProtocol, extracted from an atom array. This allows the caller to extract inputs once and pass them to scale() methods without redundant extraction. + + The atom array passed to :meth:`from_atom_array` must already be clean: + all coordinates finite and all occupancies positive. Wrappers are + responsible for ensuring this (e.g. replacing NaN coordinates with + noise and setting occupancy to 1.0 for model-operated atoms). """ elements: Float[torch.Tensor, "*batch n_atoms"] b_factors: Float[torch.Tensor, "*batch n_atoms"] occupancies: Float[torch.Tensor, "*batch n_atoms"] input_coords: Float[torch.Tensor, "*batch n_atoms 3"] - reward_param_mask: np.ndarray - mask_like: Float[torch.Tensor, "*batch n_atoms"] @classmethod def from_atom_array( @@ -42,10 +45,15 @@ def from_atom_array( ) -> RewardInputs: """Construct RewardInputs from a Biotite AtomArray. + The atom array must contain only valid atoms (finite coordinates, + positive occupancy). Callers are responsible for filtering + beforehand; no masking is applied here. + Parameters ---------- atom_array Biotite AtomArray or AtomArrayStack containing structure data. + Must have not NaN coordinates and positive occupancy. ensemble_size Number of ensemble members (batch dimension). num_particles @@ -58,13 +66,17 @@ def from_atom_array( RewardInputs Dataclass containing all inputs needed for reward function computation. """ - occupancy_mask = atom_array.occupancy > 0 - nan_mask = ~np.any(np.isnan(atom_array.coord), axis=-1) - reward_param_mask = occupancy_mask & nan_mask - - elements_list = [ - ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element[reward_param_mask] - ] + # input validation: ensure atom_array has required annotations and valid values + if not hasattr(atom_array, "element"): + raise ValueError("Atom array must have 'element' annotation.") + if not hasattr(atom_array, "b_factor"): + raise ValueError("Atom array must have 'b_factor' annotation.") + if np.any(np.isnan(atom_array.coord)): + raise ValueError("Atom array contains NaN coordinates.") + if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)): + raise ValueError("Atom array contains invalid occupancy values.") + + elements_list = [ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element] total_batch_size = num_particles * ensemble_size if num_particles > 1 else ensemble_size @@ -80,7 +92,7 @@ def from_atom_array( ) b_factors = einx.rearrange( "n -> p e n", - torch.Tensor(atom_array.b_factor[reward_param_mask]), + torch.Tensor(atom_array.b_factor), p=num_particles, e=ensemble_size, ) @@ -89,12 +101,12 @@ def from_atom_array( "... -> b ...", coords_t, b=total_batch_size, - )[..., reward_param_mask, :] + ) else: elements = einx.rearrange("n -> b n", torch.Tensor(elements_list), b=ensemble_size) b_factors = einx.rearrange( "n -> b n", - torch.Tensor(atom_array.b_factor[reward_param_mask]), + torch.Tensor(atom_array.b_factor), b=ensemble_size, ) occupancies = torch.ones_like(b_factors) / ensemble_size @@ -102,9 +114,7 @@ def from_atom_array( "... -> e ...", coords_t, e=ensemble_size, - )[..., reward_param_mask, :] - - mask_like = torch.ones_like(input_coords[..., 0]) + ) if isinstance(device, str): device = torch.device(device) @@ -114,8 +124,6 @@ def from_atom_array( b_factors=b_factors.to(device), occupancies=occupancies.to(device), input_coords=input_coords.to(device), - reward_param_mask=reward_param_mask, - mask_like=mask_like.to(device), ) diff --git a/src/sampleworks/core/samplers/edm.py b/src/sampleworks/core/samplers/edm.py index 82b88023..cdf43672 100644 --- a/src/sampleworks/core/samplers/edm.py +++ b/src/sampleworks/core/samplers/edm.py @@ -526,7 +526,7 @@ def step( ) / (2 * noise_var) # Euler step: x_{t-1} = x_t + step_scale * dt * delta - # pyright sees dt as float | None, but it will be float if check_context didn't raise + # ty sees dt as float | None, but it will be float if check_context didn't raise next_state = noisy_state_working_frame_t + self.step_scale * dt * delta # ty: ignore[unsupported-operator] return SamplerStepOutput( diff --git a/src/sampleworks/eval/structure_utils.py b/src/sampleworks/eval/structure_utils.py index b5941fed..18fc1496 100644 --- a/src/sampleworks/eval/structure_utils.py +++ b/src/sampleworks/eval/structure_utils.py @@ -16,6 +16,7 @@ from sampleworks.utils.atom_reconciler import AtomReconciler from sampleworks.utils.framework_utils import match_batch + ATOMWORKS_COMPARISON_OPS = ("==", ">", "<", "<=", ">=", " in ") try: @@ -90,19 +91,13 @@ def to_reward_inputs(self, device: torch.device | str = "cpu") -> RewardInputs: model_indices = self.reconciler.model_indices.to(device=reward_inputs.b_factors.device) updated_b_factors[..., model_indices] = struct_b_factors - # input_coords stored on the processed structure are always full model atom - # references, apply reward_param_mask for reward calls. - masked_input_coords = self.input_coords[..., reward_inputs.reward_param_mask, :].to( - device=reward_inputs.input_coords.device, - dtype=reward_inputs.input_coords.dtype, - ) - updated_mask_like = torch.ones_like(masked_input_coords[..., 0]) - return replace( reward_inputs, b_factors=updated_b_factors, - input_coords=masked_input_coords, - mask_like=updated_mask_like, + input_coords=self.input_coords.to( + device=reward_inputs.input_coords.device, + dtype=reward_inputs.input_coords.dtype, + ), ) @@ -174,10 +169,13 @@ def process_structure_to_trajectory_input( atom_array = _filter_zero_occupancy(atom_array) atom_array = _add_terminal_oxt_atoms(atom_array, structure.get("chain_info", {})) - # Mask to valid atoms (nonzero occupancy, no NaN coords) in structure atom space - reward_param_mask = atom_array.occupancy > 0 - reward_param_mask &= ~np.any(np.isnan(atom_array.coord), axis=-1) - atom_array = atom_array[reward_param_mask] + # Filter to valid atoms (nonzero occupancy, finite coords) in structure atom space. + # The deposited structure may have zero-occupancy or NaN-coordinate atoms from + # unresolved regions or altloc processing; these must be removed before + # building the reconciler against the model's (already-clean) atom array. + valid_atom_mask = atom_array.occupancy > 0 + valid_atom_mask &= ~np.any(np.isnan(atom_array.coord), axis=-1) + atom_array = atom_array[valid_atom_mask] # Build reconciler from model and structure atom arrays. model_atom_array = ( @@ -406,13 +404,11 @@ def extract_selection_coordinates( @overload def get_asym_unit_from_structure( structure: dict, atom_array_index: None = None -) -> AtomArrayStack: ... +) -> AtomArrayStack: ... @overload -def get_asym_unit_from_structure( - structure: dict, atom_array_index: int -) -> AtomArray: ... +def get_asym_unit_from_structure(structure: dict, atom_array_index: int) -> AtomArray: ... def get_asym_unit_from_structure( diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index 71e9c57a..ac1eba68 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -62,18 +62,13 @@ def save_trajectory( trajectory, atom_array, output_dir, - reward_param_mask, subdir_name, save_every=10, ): if scaler_type == GuidanceType.PURE_GUIDANCE: - _save_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every - ) + _save_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every) elif scaler_type == GuidanceType.FK_STEERING: - _save_fk_steering_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every - ) + _save_fk_steering_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every) else: # we shouldn't ever get here, since we can't have run guidance w/o this! raise ValueError(f"Invalid scaler type: {scaler_type}") @@ -81,33 +76,24 @@ def save_trajectory( def _write_coords_into_array( array_copy: AtomArrayStack, coords: np.ndarray, - reward_param_mask: np.ndarray, ) -> None: """**Mutates** ``array_copy.coord`` in-place with trajectory coordinates. - When the trajectory spans all atoms in the array (model trajectories during - a mismatch run, where the model's internal atom count differs from the input structure we are - aligning to), coords are assigned directly to ``.coord``. Otherwise the - ``reward_param_mask`` indexes the correct atom subset. + Coordinates must span all atoms in the array. Wrappers are responsible + for producing model atom arrays with valid coordinates for every atom. """ - n_atoms_array = array_copy.coord.shape[-2] # pyright: ignore[reportOptionalMemberAccess] + n_atoms_array = array_copy.coord.shape[-2] n_atoms_coords = coords.shape[-2] - if n_atoms_coords == n_atoms_array: - array_copy.coord = coords - elif n_atoms_coords == int(reward_param_mask.sum()): - array_copy.coord[:, reward_param_mask] = coords # pyright: ignore[reportOptionalSubscript] - else: + if n_atoms_coords != n_atoms_array: raise ValueError( - f"Trajectory coords ({n_atoms_coords} atoms) match neither " - f"the full atom array ({n_atoms_array}) nor the masked subset " - f"({int(reward_param_mask.sum())})" + f"Trajectory coords ({n_atoms_coords} atoms) don't match " + f"atom array ({n_atoms_array} atoms)" ) + array_copy.coord = coords -def _save_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every -): +def _save_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every): output_dir = Path(output_dir / "trajectory" / subdir_name) output_dir.mkdir(parents=True, exist_ok=True) @@ -125,13 +111,11 @@ def _save_trajectory( continue array_copy = atom_array.copy() array_copy = stack([array_copy] * ensemble_size) - _write_coords_into_array(array_copy, coords.detach().numpy(), reward_param_mask) + _write_coords_into_array(array_copy, coords.detach().numpy()) save_structure(str(output_dir / f"trajectory_{i}.cif"), array_copy) -def _save_fk_steering_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every -): +def _save_fk_steering_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every): output_dir = Path(output_dir / "trajectory" / subdir_name) output_dir.mkdir(parents=True, exist_ok=True) @@ -151,7 +135,7 @@ def _save_fk_steering_trajectory( array_copy = stack([array_copy] * ensemble_size) # we save only the first ensemble out of n_particles, since saving # each particle at every step would clog trajectory saving - _write_coords_into_array(array_copy, coords[0].detach().numpy(), reward_param_mask) + _write_coords_into_array(array_copy, coords[0].detach().numpy()) save_structure(str(output_dir / f"trajectory_{i}.cif"), array_copy) @@ -312,25 +296,16 @@ def save_everything( base_atom_array = ensure_atom_array_stack(refined_structure["asym_unit"])[0] # Use model's internal atom accounting template for mismatch runs when available - atom_array_for_masking: AtomArray = ( + # Wrappers must guarantee model atom arrays have valid coords and occupancy. + atom_array_for_saving: AtomArray = ( model_atom_array if model_atom_array is not None else base_atom_array ) - # Build occupancy mask properly as model atom arrays may lack occupancy annotation - if ( - hasattr(atom_array_for_masking, "occupancy") - and atom_array_for_masking.occupancy is not None - ): - occupancy_mask = atom_array_for_masking.occupancy > 0 - occupancy_mask &= ~np.any(np.isnan(atom_array_for_masking.coord), axis=-1) - else: - occupancy_mask = np.ones(len(atom_array_for_masking), dtype=bool) - if final_state is not None: ensemble_size = final_state.shape[0] - ensemble_array = stack([atom_array_for_masking.copy() for _ in range(ensemble_size)]) - _write_coords_into_array(ensemble_array, final_state.detach().cpu().numpy(), occupancy_mask) + ensemble_array = stack([atom_array_for_saving.copy() for _ in range(ensemble_size)]) + _write_coords_into_array(ensemble_array, final_state.detach().cpu().numpy()) atom_array = ensemble_array else: atom_array = base_atom_array @@ -343,18 +318,16 @@ def save_everything( save_trajectory( scaler_type, traj_denoised, # <--- the difference is here! - atom_array_for_masking, + atom_array_for_saving, output_dir, - occupancy_mask, "denoised", save_every=10, ) save_trajectory( scaler_type, traj_next_step, # <--- and here! - atom_array_for_masking, + atom_array_for_saving, output_dir, - occupancy_mask, "next_step", save_every=10, ) diff --git a/tests/integration/test_mismatch_integration.py b/tests/integration/test_mismatch_integration.py index dcf69b03..83f4b3de 100644 --- a/tests/integration/test_mismatch_integration.py +++ b/tests/integration/test_mismatch_integration.py @@ -275,8 +275,6 @@ def _model_space_reward_inputs(n_model: int, batch: int = 1) -> RewardInputs: b_factors=torch.ones(batch, n_model) * 20.0, occupancies=torch.ones(batch, n_model), input_coords=torch.randn(batch, n_model, 3), - reward_param_mask=np.ones(n_model, dtype=bool), - mask_like=torch.ones(batch, n_model), ) @@ -639,7 +637,7 @@ def test_reward_inputs_from_real_structure( assert reward_inputs.elements.shape[-1] == real_pair_case.expected_n_model assert reward_inputs.input_coords.shape[-2] == real_pair_case.expected_n_model - assert reward_inputs.mask_like.shape == reward_inputs.input_coords.shape[:-1] + assert reward_inputs.input_coords.shape[-2] == real_pair_case.expected_n_model atom_b_factors = cast(np.ndarray, processed.atom_array.b_factor) expected_common_b_factors = torch.as_tensor( @@ -980,7 +978,7 @@ def test_reward_inputs_atom_count(self, mismatch_case: MismatchCase): assert reward_inputs.elements.shape[-1] == n_model assert reward_inputs.b_factors.shape[-1] == n_model assert reward_inputs.input_coords.shape[-2] == n_model - assert reward_inputs.mask_like.shape[-1] == n_model + assert reward_inputs.input_coords.shape[-2] == n_model def test_b_factor_override(self, mismatch_case: MismatchCase): """Structure B-factors override model template values on common atoms.""" diff --git a/tests/integration/test_pipeline_integration.py b/tests/integration/test_pipeline_integration.py index 114987b9..ca1f5949 100644 --- a/tests/integration/test_pipeline_integration.py +++ b/tests/integration/test_pipeline_integration.py @@ -13,7 +13,6 @@ Tests are organized from fast (mock-based) to slow (real wrappers). """ -import numpy as np import pytest import torch from sampleworks.core.rewards.protocol import RewardInputs @@ -25,6 +24,7 @@ NoiseSpaceDPSScaler, NoScalingScaler, ) +from sampleworks.eval.structure_utils import process_structure_to_trajectory_input from sampleworks.utils.guidance_constants import ( StepScalers, StructurePredictor, @@ -94,8 +94,6 @@ def create_step_context_with_reward( b_factors=torch.ones(batch_size, num_atoms, device=device) * 20.0, occupancies=torch.ones(batch_size, num_atoms, device=device), input_coords=state.clone(), - reward_param_mask=np.ones(num_atoms, dtype=bool), - mask_like=torch.ones(batch_size, num_atoms, device=device), ) metadata = None @@ -324,8 +322,6 @@ def test_noise_space_dps_gradient_chain_rule(self, device: torch.device): b_factors=torch.ones(batch_size, num_atoms, device=device) * 20.0, occupancies=torch.ones(batch_size, num_atoms, device=device), input_coords=state.detach().clone(), - reward_param_mask=np.ones(num_atoms, dtype=bool), - mask_like=torch.ones(batch_size, num_atoms, device=device), ) context = StepParams( step_index=0, @@ -368,8 +364,6 @@ def test_scaler_scale_returns_correct_shapes( b_factors=torch.ones(batch_size, num_atoms, device=device) * 20.0, occupancies=torch.ones(batch_size, num_atoms, device=device), input_coords=state.detach().clone(), - reward_param_mask=np.ones(num_atoms, dtype=bool), - mask_like=torch.ones(batch_size, num_atoms, device=device), ) context = StepParams( step_index=0, @@ -1075,6 +1069,57 @@ def test_trajectory_scaler_returns_guidance_output( assert len(result.trajectory) == 3 +@pytest.mark.slow +@pytest.mark.parametrize("wrapper_type", get_slow_wrappers(), ids=lambda w: w.value) +@pytest.mark.parametrize("structure_fixture", STRUCTURES, ids=lambda s: s.replace("structure_", "")) +class TestRealWrapperPreprocessing: + """process_structure_to_trajectory_input feeds properly into to_reward_inputs with real wrappers + + Verifies that featurization and preprocessing produce reward_inputs whose + atom dimension matches the state dimension for every wrapper × structure + combination. This catches model-specific NaN coordinate or occupancy + issues (e.g., RF3 on 5I09) that only surface with real featurization. + """ + + def test_reward_inputs_dimensions_match_state( + self, + wrapper_type: StructurePredictor, + structure_fixture: str, + temp_output_dir, + request, + ): + """reward_inputs atom count equals the model state atom count. + + Wrappers must ensure model_atom_array has valid coordinates and + occupancy for all atoms. ``RewardInputs.from_atom_array`` uses + all atoms without masking, so any NaN or zero-occupancy atoms + would produce invalid reward tensors and a dimension mismatch + in the step scalers. + """ + wrapper = request.getfixturevalue(get_fixture_name_for_wrapper_type(wrapper_type)) + structure = request.getfixturevalue(structure_fixture) + device = wrapper.device if hasattr(wrapper, "device") else torch.device("cpu") + + annotated = annotate_structure_for_wrapper_type(wrapper_type, structure, temp_output_dir) + features = wrapper.featurize(annotated) + state = wrapper.initialize_from_prior(batch_size=1, features=features) + + processed = process_structure_to_trajectory_input( + structure=annotated, + coords_from_prior=state, + features=features, + ensemble_size=1, + ) + reward_inputs = processed.to_reward_inputs(device=device) + + n_state = state.shape[-2] + n_reward = reward_inputs.elements.shape[-1] + assert n_state == n_reward, ( + f"State atom count ({n_state}) != reward atom count ({n_reward}). " + f"wrapper={wrapper_type.value}, structure={structure_fixture}" + ) + + @pytest.mark.slow @pytest.mark.parametrize("wrapper_type", get_slow_wrappers(), ids=lambda w: w.value) class TestRealWrapperNumericalStability: diff --git a/tests/utils/test_density_utils.py b/tests/utils/test_density_utils.py index 0758b8f4..1a274e8a 100644 --- a/tests/utils/test_density_utils.py +++ b/tests/utils/test_density_utils.py @@ -371,7 +371,7 @@ def test_stack_coords_match_per_model( """Coordinates in each batch entry should match the corresponding model.""" coords, _, _, _ = extract_density_inputs_from_atomarray(simple_atom_array_stack, device) for i in range(simple_atom_array_stack.stack_depth()): - expected = torch.from_numpy(simple_atom_array_stack.coord[i].copy()).to( # pyright: ignore[reportOptionalSubscript] + expected = torch.from_numpy(simple_atom_array_stack.coord[i].copy()).to( device, dtype=torch.float32 ) torch.testing.assert_close(coords[i], expected) From 7c854b9d6553ee21864129c03971c407b3207fd3 Mon Sep 17 00:00:00 2001 From: Karson Chrispens Date: Mon, 9 Mar 2026 05:23:34 +0000 Subject: [PATCH 4/4] chore(cr): remove redundant test code and address final cr nits --- src/sampleworks/models/rf3/wrapper.py | 5 +- .../integration/test_mismatch_integration.py | 300 ------------------ .../integration/test_pipeline_integration.py | 12 +- 3 files changed, 12 insertions(+), 305 deletions(-) diff --git a/src/sampleworks/models/rf3/wrapper.py b/src/sampleworks/models/rf3/wrapper.py index 8635b813..048d86b2 100644 --- a/src/sampleworks/models/rf3/wrapper.py +++ b/src/sampleworks/models/rf3/wrapper.py @@ -326,11 +326,10 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: # masked out in reward functions. nan_coord_mask = np.any(np.isnan(model_aa.coord), axis=-1) if nan_coord_mask.any(): - n_nan = int(nan_coord_mask.sum()) resolved_coords = model_aa.coord[~nan_coord_mask] centroid = resolved_coords.mean(axis=0) if len(resolved_coords) > 0 else np.zeros(3) - rng = np.random.default_rng(42) - noise = rng.normal(scale=1.0, size=(n_nan, 3)).astype(np.float32) + n_nan = int(nan_coord_mask.sum()) + noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32) new_coords = model_aa.coord.copy() new_coords[nan_coord_mask] = centroid + noise model_aa.coord = new_coords diff --git a/tests/integration/test_mismatch_integration.py b/tests/integration/test_mismatch_integration.py index 83f4b3de..96ad205c 100644 --- a/tests/integration/test_mismatch_integration.py +++ b/tests/integration/test_mismatch_integration.py @@ -49,34 +49,6 @@ class StructurePreprocessExpectation: expected_hydrogen_atoms: int -@dataclass(frozen=True) -class RealPairExpectation: - """Reconciliation expectations for one real PDB/CIF pair. - - Parameters - ---------- - id - Case identifier used as pytest ID. - pdb_fixture - Fixture name for the full-structure PDB input. - cif_fixture - Fixture name for the density-input CIF model representation. - expected_n_model - Expected model atom count after filtering. - expected_n_struct - Expected structure atom count after filtering. - expected_n_common - Expected common atom count from :class:`sampleworks.utils.atom_reconciler.AtomReconciler`. - """ - - id: str - pdb_fixture: str - cif_fixture: str - expected_n_model: int - expected_n_struct: int - expected_n_common: int - - STRUCTURE_PREPROCESS_EXPECTATIONS: tuple[StructurePreprocessExpectation, ...] = ( StructurePreprocessExpectation( id="1vme_cif", @@ -141,42 +113,6 @@ class RealPairExpectation: ) -REAL_PAIR_EXPECTATIONS: tuple[RealPairExpectation, ...] = ( - RealPairExpectation( - id="2yl0", - pdb_fixture="structure_2yl0", - cif_fixture="structure_2yl0_density", - expected_n_model=955, - expected_n_struct=1727, - expected_n_common=677, - ), - RealPairExpectation( - id="5sop", - pdb_fixture="structure_5sop", - cif_fixture="structure_5sop_density", - expected_n_model=1264, - expected_n_struct=4873, - expected_n_common=1264, - ), - RealPairExpectation( - id="6ni6", - pdb_fixture="structure_6ni6", - cif_fixture="structure_6ni6_density", - expected_n_model=1678, - expected_n_struct=6812, - expected_n_common=1678, - ), - RealPairExpectation( - id="9bn8", - pdb_fixture="structure_9bn8", - cif_fixture="structure_9bn8_density", - expected_n_model=3260, - expected_n_struct=3330, - expected_n_common=3260, - ), -) - - ALL_MISMATCH_CASE_IDS: tuple[str, ...] = ( "identity_no_mismatch", "from_2yl0", @@ -455,44 +391,6 @@ def structure_preprocess_case(request: pytest.FixtureRequest) -> StructurePrepro return request.param -@pytest.fixture(params=REAL_PAIR_EXPECTATIONS, ids=lambda exp: exp.id) -def real_pair_case(request: pytest.FixtureRequest) -> RealPairExpectation: - """Parametrized real pair expectation fixture.""" - return request.param - - -def _get_real_pair_arrays( - request: pytest.FixtureRequest, - pair_case: RealPairExpectation, -) -> tuple[dict[str, Any], AtomArray, AtomArray]: - """Load and filter atom arrays for a real PDB/CIF pair. - - Parameters - ---------- - request - Pytest request object used to resolve fixtures by name. - pair_case - Real pair expectation. - - Returns - ------- - tuple[dict[str, Any], AtomArray, AtomArray] - ``(pdb_structure_dict, pdb_filtered, cif_filtered)``. - """ - pdb_structure = request.getfixturevalue(pair_case.pdb_fixture) - cif_structure = request.getfixturevalue(pair_case.cif_fixture) - - pdb_raw = cast(AtomArray, ensure_atom_array_stack(pdb_structure["asym_unit"])[0]) - pdb_mask = (pdb_raw.occupancy > 0) & ~np.any(np.isnan(pdb_raw.coord), axis=-1) - pdb_filtered = cast(AtomArray, pdb_raw[pdb_mask]) - - cif_raw = cast(AtomArray, ensure_atom_array_stack(cif_structure["asym_unit"])[0]) - cif_mask = (cif_raw.occupancy > 0) & ~np.any(np.isnan(cif_raw.coord), axis=-1) - cif_filtered = cast(AtomArray, cif_raw[cif_mask]) - - return pdb_structure, pdb_filtered, cif_filtered - - class TestRealStructurePreprocessing: """Real-data preprocessing and reconciliation tests.""" @@ -543,204 +441,6 @@ def test_hydrogen_handling( assert hydrogen_count == structure_preprocess_case.expected_hydrogen_atoms - def test_reconciler_from_real_pair( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Reconciler counts on real PDB/CIF pairs match ground truth.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - assert reconciler.n_model == real_pair_case.expected_n_model - assert reconciler.n_struct == real_pair_case.expected_n_struct - assert reconciler.n_common == real_pair_case.expected_n_common - - def test_common_atom_identity_matches( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Common indices map atoms with matching normalized identity.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - model_ids = make_normalized_atom_id(cif_filtered) - struct_ids = make_normalized_atom_id(pdb_filtered) - - np.testing.assert_array_equal( - model_ids[reconciler.model_indices.detach().cpu().numpy()], - struct_ids[reconciler.struct_indices.detach().cpu().numpy()], - ) - - def test_process_structure_to_trajectory_input_with_real_pdb( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Full preprocessing with real PDB input returns consistent model space outputs.""" - pdb_structure, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - - case = MismatchCase( - id=f"real_{real_pair_case.id}", - description=f"Real preprocessing case for {real_pair_case.id}.", - model_atom_array=cif_filtered, - struct_atom_array=pdb_filtered, - expected_n_common=real_pair_case.expected_n_common, - expected_has_mismatch=True, - ) - wrapper = MismatchCaseWrapper(case) - - processed, _ = _preprocess(wrapper, _copy_structure_dict(pdb_structure)) - reconciler = processed.reconciler - - atom_coords = cast(np.ndarray, processed.atom_array.coord) - atom_occupancy = cast(np.ndarray, processed.atom_array.occupancy) - - assert processed.input_coords.shape == (1, real_pair_case.expected_n_model, 3) - assert reconciler.has_mismatch - assert np.isfinite(atom_coords).all() - assert np.all(atom_occupancy > 0) - assert processed.model_atom_array is not None - assert len(processed.model_atom_array) == real_pair_case.expected_n_model - - expected_common_coords = torch.as_tensor( - atom_coords[reconciler.struct_indices.detach().cpu().numpy()], - dtype=processed.input_coords.dtype, - ) - torch.testing.assert_close( - processed.input_coords[0, reconciler.model_indices], - expected_common_coords, - ) - - def test_reward_inputs_from_real_structure( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Reward inputs from real preprocessing have model space counts and B-factor overrides.""" - pdb_structure, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - - case = MismatchCase( - id=f"real_{real_pair_case.id}", - description=f"Real reward input case for {real_pair_case.id}.", - model_atom_array=cif_filtered, - struct_atom_array=pdb_filtered, - expected_n_common=real_pair_case.expected_n_common, - expected_has_mismatch=True, - ) - wrapper = MismatchCaseWrapper(case) - - processed, _ = _preprocess(wrapper, _copy_structure_dict(pdb_structure)) - reconciler = processed.reconciler - reward_inputs = processed.to_reward_inputs(device="cpu") - - assert reward_inputs.elements.shape[-1] == real_pair_case.expected_n_model - assert reward_inputs.input_coords.shape[-2] == real_pair_case.expected_n_model - assert reward_inputs.input_coords.shape[-2] == real_pair_case.expected_n_model - - atom_b_factors = cast(np.ndarray, processed.atom_array.b_factor) - expected_common_b_factors = torch.as_tensor( - atom_b_factors[reconciler.struct_indices.detach().cpu().numpy()], - dtype=reward_inputs.b_factors.dtype, - ) - torch.testing.assert_close( - reward_inputs.b_factors[0, reconciler.model_indices], - expected_common_b_factors, - ) - - def test_process_structure_roundtrip_coordinate_integrity( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Every common struct coordinate appears in the correct model index after preprocessing.""" - pdb_structure, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - - case = MismatchCase( - id=f"real_{real_pair_case.id}", - description=f"Real coordinate integrity case for {real_pair_case.id}.", - model_atom_array=cif_filtered, - struct_atom_array=pdb_filtered, - expected_n_common=real_pair_case.expected_n_common, - expected_has_mismatch=True, - ) - wrapper = MismatchCaseWrapper(case) - - processed, _ = _preprocess(wrapper, _copy_structure_dict(pdb_structure)) - reconciler = processed.reconciler - - atom_coords = cast(np.ndarray, processed.atom_array.coord) - for model_idx, struct_idx in zip( - reconciler.model_indices.tolist(), reconciler.struct_indices.tolist() - ): - torch.testing.assert_close( - processed.input_coords[0, model_idx], - torch.as_tensor( - atom_coords[struct_idx], - dtype=processed.input_coords.dtype, - ), - ) - - -class TestRealStructureAlignmentQuality: - """Alignment tests on real geometry from filtered PDB/CIF pairs.""" - - def test_alignment_on_real_geometry( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Known rigid transforms are recovered on realistic atom geometry.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - struct_coords = torch.as_tensor(pdb_filtered.coord, dtype=torch.float64).unsqueeze(0) - model_template = torch.as_tensor(cif_filtered.coord, dtype=torch.float64).unsqueeze(0) - model_reference = reconciler.struct_to_model(struct_coords, model_template) - - theta = np.deg2rad(45.0) - rotation = torch.tensor( - [ - [np.cos(theta), -np.sin(theta), 0.0], - [np.sin(theta), np.cos(theta), 0.0], - [0.0, 0.0, 1.0], - ], - dtype=torch.float64, - ) - translation = torch.tensor([10.0, -3.0, 5.0], dtype=torch.float64) - - transformed = model_reference @ rotation.T + translation - - aligned, _ = reconciler.align(transformed, model_reference) - - common = reconciler.model_indices - rmsd = _rmsd(aligned[..., common, :], model_reference[..., common, :]) - assert rmsd.item() < 1e-4 - - def test_alignment_with_realistic_perturbation( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Alignment lowers RMSD when realistic Gaussian perturbations are present.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - struct_coords = torch.as_tensor(pdb_filtered.coord, dtype=torch.float32).unsqueeze(0) - model_template = torch.as_tensor(cif_filtered.coord, dtype=torch.float32).unsqueeze(0) - model_reference = reconciler.struct_to_model(struct_coords, model_template) - - torch.manual_seed(0) - perturbed = model_reference + torch.randn_like(model_reference) * 0.5 - - common = reconciler.model_indices - rmsd_before = _rmsd(perturbed[..., common, :], model_reference[..., common, :]) - aligned, _ = reconciler.align(perturbed, model_reference) - rmsd_after = _rmsd(aligned[..., common, :], model_reference[..., common, :]) - - assert rmsd_after.item() < rmsd_before.item() - class TestReconcilerConstruction: """Coordinate identity invariants for reconciler construction.""" diff --git a/tests/integration/test_pipeline_integration.py b/tests/integration/test_pipeline_integration.py index ca1f5949..57d4073c 100644 --- a/tests/integration/test_pipeline_integration.py +++ b/tests/integration/test_pipeline_integration.py @@ -1076,12 +1076,12 @@ class TestRealWrapperPreprocessing: """process_structure_to_trajectory_input feeds properly into to_reward_inputs with real wrappers Verifies that featurization and preprocessing produce reward_inputs whose - atom dimension matches the state dimension for every wrapper × structure + atom dimension matches the state dimension for every wrapper by structure combination. This catches model-specific NaN coordinate or occupancy issues (e.g., RF3 on 5I09) that only surface with real featurization. """ - def test_reward_inputs_dimensions_match_state( + def test_reward_inputs_are_valid( self, wrapper_type: StructurePredictor, structure_fixture: str, @@ -1118,6 +1118,14 @@ def test_reward_inputs_dimensions_match_state( f"State atom count ({n_state}) != reward atom count ({n_reward}). " f"wrapper={wrapper_type.value}, structure={structure_fixture}" ) + assert torch.isfinite(reward_inputs.input_coords).all(), ( + f"Non-finite reward coordinates for " + f"wrapper={wrapper_type.value}, structure={structure_fixture}" + ) + assert torch.all(reward_inputs.occupancies > 0), ( + f"Non-positive reward occupancies for " + f"wrapper={wrapper_type.value}, structure={structure_fixture}" + ) @pytest.mark.slow