diff --git a/src/sampleworks/core/rewards/protocol.py b/src/sampleworks/core/rewards/protocol.py index 9e794eda..26d3c7aa 100644 --- a/src/sampleworks/core/rewards/protocol.py +++ b/src/sampleworks/core/rewards/protocol.py @@ -23,14 +23,17 @@ class RewardInputs: Contains all the information needed to call a RewardFunctionProtocol, extracted from an atom array. This allows the caller to extract inputs once and pass them to scale() methods without redundant extraction. + + The atom array passed to :meth:`from_atom_array` must already be clean: + all coordinates finite and all occupancies positive. Wrappers are + responsible for ensuring this (e.g. replacing NaN coordinates with + noise and setting occupancy to 1.0 for model-operated atoms). """ elements: Float[torch.Tensor, "*batch n_atoms"] b_factors: Float[torch.Tensor, "*batch n_atoms"] occupancies: Float[torch.Tensor, "*batch n_atoms"] input_coords: Float[torch.Tensor, "*batch n_atoms 3"] - reward_param_mask: np.ndarray - mask_like: Float[torch.Tensor, "*batch n_atoms"] @classmethod def from_atom_array( @@ -42,10 +45,15 @@ def from_atom_array( ) -> RewardInputs: """Construct RewardInputs from a Biotite AtomArray. + The atom array must contain only valid atoms (finite coordinates, + positive occupancy). Callers are responsible for filtering + beforehand; no masking is applied here. + Parameters ---------- atom_array Biotite AtomArray or AtomArrayStack containing structure data. + Must have not NaN coordinates and positive occupancy. ensemble_size Number of ensemble members (batch dimension). num_particles @@ -58,13 +66,17 @@ def from_atom_array( RewardInputs Dataclass containing all inputs needed for reward function computation. """ - occupancy_mask = atom_array.occupancy > 0 - nan_mask = ~np.any(np.isnan(atom_array.coord), axis=-1) - reward_param_mask = occupancy_mask & nan_mask - - elements_list = [ - ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element[reward_param_mask] - ] + # input validation: ensure atom_array has required annotations and valid values + if not hasattr(atom_array, "element"): + raise ValueError("Atom array must have 'element' annotation.") + if not hasattr(atom_array, "b_factor"): + raise ValueError("Atom array must have 'b_factor' annotation.") + if np.any(np.isnan(atom_array.coord)): + raise ValueError("Atom array contains NaN coordinates.") + if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)): + raise ValueError("Atom array contains invalid occupancy values.") + + elements_list = [ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element] total_batch_size = num_particles * ensemble_size if num_particles > 1 else ensemble_size @@ -80,7 +92,7 @@ def from_atom_array( ) b_factors = einx.rearrange( "n -> p e n", - torch.Tensor(atom_array.b_factor[reward_param_mask]), + torch.Tensor(atom_array.b_factor), p=num_particles, e=ensemble_size, ) @@ -89,12 +101,12 @@ def from_atom_array( "... -> b ...", coords_t, b=total_batch_size, - )[..., reward_param_mask, :] + ) else: elements = einx.rearrange("n -> b n", torch.Tensor(elements_list), b=ensemble_size) b_factors = einx.rearrange( "n -> b n", - torch.Tensor(atom_array.b_factor[reward_param_mask]), + torch.Tensor(atom_array.b_factor), b=ensemble_size, ) occupancies = torch.ones_like(b_factors) / ensemble_size @@ -102,9 +114,7 @@ def from_atom_array( "... -> e ...", coords_t, e=ensemble_size, - )[..., reward_param_mask, :] - - mask_like = torch.ones_like(input_coords[..., 0]) + ) if isinstance(device, str): device = torch.device(device) @@ -114,8 +124,6 @@ def from_atom_array( b_factors=b_factors.to(device), occupancies=occupancies.to(device), input_coords=input_coords.to(device), - reward_param_mask=reward_param_mask, - mask_like=mask_like.to(device), ) diff --git a/src/sampleworks/core/samplers/edm.py b/src/sampleworks/core/samplers/edm.py index 82b88023..cdf43672 100644 --- a/src/sampleworks/core/samplers/edm.py +++ b/src/sampleworks/core/samplers/edm.py @@ -526,7 +526,7 @@ def step( ) / (2 * noise_var) # Euler step: x_{t-1} = x_t + step_scale * dt * delta - # pyright sees dt as float | None, but it will be float if check_context didn't raise + # ty sees dt as float | None, but it will be float if check_context didn't raise next_state = noisy_state_working_frame_t + self.step_scale * dt * delta # ty: ignore[unsupported-operator] return SamplerStepOutput( diff --git a/src/sampleworks/eval/generate_synthetic_density.py b/src/sampleworks/eval/generate_synthetic_density.py index cfb27d26..693db47a 100644 --- a/src/sampleworks/eval/generate_synthetic_density.py +++ b/src/sampleworks/eval/generate_synthetic_density.py @@ -286,7 +286,7 @@ def _process_single_row( # TODO: there's probably a more robust way to do this atom_array = keep_polymer(keep_amino_acids(atom_array)) if strip_ligands else atom_array - altloc_info = detect_altlocs(atom_array) # pyright: ignore[reportArgumentType] + altloc_info = detect_altlocs(atom_array) # ty: ignore[invalid-argument-type] if row.occ_values: if occ_mode != "custom": logger.warning( @@ -324,7 +324,7 @@ def _process_single_row( # Shift coordinates into the grid frame so the saved CIF aligns with # the CCP4 map. CCP4 format (unlike MRC) cannot encode an arbitrary Cartesian # origin, so we move the atoms instead. Possible the better way is to resample the map? - atom_array.coord = atom_array.coord - xmap_torch.origin # pyright: ignore[reportOptionalOperand] + atom_array.coord = atom_array.coord - xmap_torch.origin structure_output_path = structure_path.parent / f"{structure_path.stem}_density_input.cif" try: save_structure_to_cif(atom_array, structure_output_path) diff --git a/src/sampleworks/eval/grid_search_eval_utils.py b/src/sampleworks/eval/grid_search_eval_utils.py index 1340b8e0..bcf957e0 100644 --- a/src/sampleworks/eval/grid_search_eval_utils.py +++ b/src/sampleworks/eval/grid_search_eval_utils.py @@ -103,7 +103,7 @@ def scan_grid_search_results( guidance_weight = float(params["guidance_weight"]) gd_steps = int(params["gd_steps"]) if params["gd_steps"] is not None else None - # Validate parameters to satisfy pyright + # Validate parameters to satisfy ty if ( protein is None or occ_a is None diff --git a/src/sampleworks/eval/structure_utils.py b/src/sampleworks/eval/structure_utils.py index b5941fed..18fc1496 100644 --- a/src/sampleworks/eval/structure_utils.py +++ b/src/sampleworks/eval/structure_utils.py @@ -16,6 +16,7 @@ from sampleworks.utils.atom_reconciler import AtomReconciler from sampleworks.utils.framework_utils import match_batch + ATOMWORKS_COMPARISON_OPS = ("==", ">", "<", "<=", ">=", " in ") try: @@ -90,19 +91,13 @@ def to_reward_inputs(self, device: torch.device | str = "cpu") -> RewardInputs: model_indices = self.reconciler.model_indices.to(device=reward_inputs.b_factors.device) updated_b_factors[..., model_indices] = struct_b_factors - # input_coords stored on the processed structure are always full model atom - # references, apply reward_param_mask for reward calls. - masked_input_coords = self.input_coords[..., reward_inputs.reward_param_mask, :].to( - device=reward_inputs.input_coords.device, - dtype=reward_inputs.input_coords.dtype, - ) - updated_mask_like = torch.ones_like(masked_input_coords[..., 0]) - return replace( reward_inputs, b_factors=updated_b_factors, - input_coords=masked_input_coords, - mask_like=updated_mask_like, + input_coords=self.input_coords.to( + device=reward_inputs.input_coords.device, + dtype=reward_inputs.input_coords.dtype, + ), ) @@ -174,10 +169,13 @@ def process_structure_to_trajectory_input( atom_array = _filter_zero_occupancy(atom_array) atom_array = _add_terminal_oxt_atoms(atom_array, structure.get("chain_info", {})) - # Mask to valid atoms (nonzero occupancy, no NaN coords) in structure atom space - reward_param_mask = atom_array.occupancy > 0 - reward_param_mask &= ~np.any(np.isnan(atom_array.coord), axis=-1) - atom_array = atom_array[reward_param_mask] + # Filter to valid atoms (nonzero occupancy, finite coords) in structure atom space. + # The deposited structure may have zero-occupancy or NaN-coordinate atoms from + # unresolved regions or altloc processing; these must be removed before + # building the reconciler against the model's (already-clean) atom array. + valid_atom_mask = atom_array.occupancy > 0 + valid_atom_mask &= ~np.any(np.isnan(atom_array.coord), axis=-1) + atom_array = atom_array[valid_atom_mask] # Build reconciler from model and structure atom arrays. model_atom_array = ( @@ -406,13 +404,11 @@ def extract_selection_coordinates( @overload def get_asym_unit_from_structure( structure: dict, atom_array_index: None = None -) -> AtomArrayStack: ... +) -> AtomArrayStack: ... @overload -def get_asym_unit_from_structure( - structure: dict, atom_array_index: int -) -> AtomArray: ... +def get_asym_unit_from_structure(structure: dict, atom_array_index: int) -> AtomArray: ... def get_asym_unit_from_structure( diff --git a/src/sampleworks/models/rf3/wrapper.py b/src/sampleworks/models/rf3/wrapper.py index 23e4643a..048d86b2 100644 --- a/src/sampleworks/models/rf3/wrapper.py +++ b/src/sampleworks/models/rf3/wrapper.py @@ -289,7 +289,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: ) # since we're not batching, the loader returns a list of length 1 # (Hydra instantiation of pipeline means it is going to be hard to type check here) - pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # type: ignore + pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable] pipeline_output = trainer.fabric.to_device(pipeline_output) features = trainer._assemble_network_inputs(pipeline_output) @@ -317,8 +317,31 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: # Any excess atoms are trailing entries not represented in atom_to_token_map. if len(model_aa) > num_atoms: model_aa = cast(AtomArray, model_aa[:num_atoms]) - if not hasattr(model_aa, "occupancy") or model_aa.occupancy is None: - model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) + + # atomworks's add_missing_atoms adds unresolved atoms with + # occupancy=0.0 and NaN coordinates when we get our atom array with + # InferenceInput.from_atom_array. RF3 operates on these atoms (they're + # in atom_to_token_map), so initialize their coordinates with noise and + # set occupancy to 1.0 so they participate in guidance and don't get + # masked out in reward functions. + nan_coord_mask = np.any(np.isnan(model_aa.coord), axis=-1) + if nan_coord_mask.any(): + resolved_coords = model_aa.coord[~nan_coord_mask] + centroid = resolved_coords.mean(axis=0) if len(resolved_coords) > 0 else np.zeros(3) + n_nan = int(nan_coord_mask.sum()) + noise = np.random.normal(loc=0.0, scale=1.0, size=(n_nan, 3)).astype(np.float32) + new_coords = model_aa.coord.copy() + new_coords[nan_coord_mask] = centroid + noise + model_aa.coord = new_coords + logger.info( + f"Initialized {n_nan} unresolved atoms with noise " + f"(had NaN coordinates from add_missing_atoms)" + ) + + # All atoms in model_aa are operated on by RF3 during diffusion. + # Set occupancy to 1.0 regardless of what atomworks assigned (unresolved + # atoms from add_missing_atoms get occupancy=0.0, but RF3 should use them) + model_aa.set_annotation("occupancy", np.ones(len(model_aa), dtype=np.float32)) if not hasattr(model_aa, "b_factor") or model_aa.b_factor is None: model_aa.set_annotation("b_factor", np.full(len(model_aa), 20.0, dtype=np.float32)) diff --git a/src/sampleworks/utils/atom_array_utils.py b/src/sampleworks/utils/atom_array_utils.py index dbbb7424..7d87cb94 100644 --- a/src/sampleworks/utils/atom_array_utils.py +++ b/src/sampleworks/utils/atom_array_utils.py @@ -372,9 +372,9 @@ def select_altloc( mask = atom_array.altloc_id == altloc_id if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] def select_non_hetero(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: @@ -410,9 +410,9 @@ def select_non_hetero(atom_array: AtomArray | AtomArrayStack) -> AtomArray | Ato mask = ~hetero if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] def keep_polymer( @@ -449,10 +449,10 @@ def keep_polymer( # TODO: fix once this is fixed: https://github.com/biotite-dev/biotite/issues/865 if isinstance(atom_array, AtomArrayStack): polymer_mask = filter_polymer(atom_array[0], pol_type=pol_type) - return atom_array[:, polymer_mask] # pyright: ignore[reportReturnType] + return atom_array[:, polymer_mask] else: polymer_mask = filter_polymer(atom_array, pol_type=pol_type) - return atom_array[polymer_mask] # pyright: ignore[reportReturnType] + return atom_array[polymer_mask] def keep_amino_acids( @@ -485,9 +485,9 @@ def keep_amino_acids( amino_acid_mask = filter_amino_acids(atom_array) if isinstance(atom_array, AtomArrayStack): - return atom_array[:, amino_acid_mask] # pyright: ignore[reportReturnType] + return atom_array[:, amino_acid_mask] else: - return atom_array[amino_acid_mask] # pyright: ignore[reportReturnType] + return atom_array[amino_acid_mask] def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomArrayStack: @@ -521,9 +521,9 @@ def remove_hydrogens(atom_array: AtomArray | AtomArrayStack) -> AtomArray | Atom mask = (element != "H") & (element != "D") if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] @overload @@ -542,7 +542,7 @@ def remove_atoms_with_any_nan_coords( """ if not atom_array or not atom_array.shape[-1]: raise ValueError("Cannot remove atoms from empty AtomArray|Stack") - + if isinstance(atom_array, AtomArrayStack): # Partially flatten the coords so that we remove any atom from all structures if any # one of them has a NaN at that position. @@ -586,9 +586,9 @@ def select_backbone(atom_array: AtomArray | AtomArrayStack) -> AtomArray | AtomA mask = np.isin(atom_name, backbone_atoms) if isinstance(atom_array, AtomArrayStack): - return atom_array[:, mask] # pyright: ignore[reportReturnType] + return atom_array[:, mask] else: - return atom_array[mask] # pyright: ignore[reportReturnType] + return atom_array[mask] def make_atom_id(arr: AtomArray | AtomArrayStack) -> np.ndarray: diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index 71e9c57a..ac1eba68 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -62,18 +62,13 @@ def save_trajectory( trajectory, atom_array, output_dir, - reward_param_mask, subdir_name, save_every=10, ): if scaler_type == GuidanceType.PURE_GUIDANCE: - _save_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every - ) + _save_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every) elif scaler_type == GuidanceType.FK_STEERING: - _save_fk_steering_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every - ) + _save_fk_steering_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every) else: # we shouldn't ever get here, since we can't have run guidance w/o this! raise ValueError(f"Invalid scaler type: {scaler_type}") @@ -81,33 +76,24 @@ def save_trajectory( def _write_coords_into_array( array_copy: AtomArrayStack, coords: np.ndarray, - reward_param_mask: np.ndarray, ) -> None: """**Mutates** ``array_copy.coord`` in-place with trajectory coordinates. - When the trajectory spans all atoms in the array (model trajectories during - a mismatch run, where the model's internal atom count differs from the input structure we are - aligning to), coords are assigned directly to ``.coord``. Otherwise the - ``reward_param_mask`` indexes the correct atom subset. + Coordinates must span all atoms in the array. Wrappers are responsible + for producing model atom arrays with valid coordinates for every atom. """ - n_atoms_array = array_copy.coord.shape[-2] # pyright: ignore[reportOptionalMemberAccess] + n_atoms_array = array_copy.coord.shape[-2] n_atoms_coords = coords.shape[-2] - if n_atoms_coords == n_atoms_array: - array_copy.coord = coords - elif n_atoms_coords == int(reward_param_mask.sum()): - array_copy.coord[:, reward_param_mask] = coords # pyright: ignore[reportOptionalSubscript] - else: + if n_atoms_coords != n_atoms_array: raise ValueError( - f"Trajectory coords ({n_atoms_coords} atoms) match neither " - f"the full atom array ({n_atoms_array}) nor the masked subset " - f"({int(reward_param_mask.sum())})" + f"Trajectory coords ({n_atoms_coords} atoms) don't match " + f"atom array ({n_atoms_array} atoms)" ) + array_copy.coord = coords -def _save_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every -): +def _save_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every): output_dir = Path(output_dir / "trajectory" / subdir_name) output_dir.mkdir(parents=True, exist_ok=True) @@ -125,13 +111,11 @@ def _save_trajectory( continue array_copy = atom_array.copy() array_copy = stack([array_copy] * ensemble_size) - _write_coords_into_array(array_copy, coords.detach().numpy(), reward_param_mask) + _write_coords_into_array(array_copy, coords.detach().numpy()) save_structure(str(output_dir / f"trajectory_{i}.cif"), array_copy) -def _save_fk_steering_trajectory( - trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every -): +def _save_fk_steering_trajectory(trajectory, atom_array, output_dir, subdir_name, save_every): output_dir = Path(output_dir / "trajectory" / subdir_name) output_dir.mkdir(parents=True, exist_ok=True) @@ -151,7 +135,7 @@ def _save_fk_steering_trajectory( array_copy = stack([array_copy] * ensemble_size) # we save only the first ensemble out of n_particles, since saving # each particle at every step would clog trajectory saving - _write_coords_into_array(array_copy, coords[0].detach().numpy(), reward_param_mask) + _write_coords_into_array(array_copy, coords[0].detach().numpy()) save_structure(str(output_dir / f"trajectory_{i}.cif"), array_copy) @@ -312,25 +296,16 @@ def save_everything( base_atom_array = ensure_atom_array_stack(refined_structure["asym_unit"])[0] # Use model's internal atom accounting template for mismatch runs when available - atom_array_for_masking: AtomArray = ( + # Wrappers must guarantee model atom arrays have valid coords and occupancy. + atom_array_for_saving: AtomArray = ( model_atom_array if model_atom_array is not None else base_atom_array ) - # Build occupancy mask properly as model atom arrays may lack occupancy annotation - if ( - hasattr(atom_array_for_masking, "occupancy") - and atom_array_for_masking.occupancy is not None - ): - occupancy_mask = atom_array_for_masking.occupancy > 0 - occupancy_mask &= ~np.any(np.isnan(atom_array_for_masking.coord), axis=-1) - else: - occupancy_mask = np.ones(len(atom_array_for_masking), dtype=bool) - if final_state is not None: ensemble_size = final_state.shape[0] - ensemble_array = stack([atom_array_for_masking.copy() for _ in range(ensemble_size)]) - _write_coords_into_array(ensemble_array, final_state.detach().cpu().numpy(), occupancy_mask) + ensemble_array = stack([atom_array_for_saving.copy() for _ in range(ensemble_size)]) + _write_coords_into_array(ensemble_array, final_state.detach().cpu().numpy()) atom_array = ensemble_array else: atom_array = base_atom_array @@ -343,18 +318,16 @@ def save_everything( save_trajectory( scaler_type, traj_denoised, # <--- the difference is here! - atom_array_for_masking, + atom_array_for_saving, output_dir, - occupancy_mask, "denoised", save_every=10, ) save_trajectory( scaler_type, traj_next_step, # <--- and here! - atom_array_for_masking, + atom_array_for_saving, output_dir, - occupancy_mask, "next_step", save_every=10, ) diff --git a/tests/integration/test_mismatch_integration.py b/tests/integration/test_mismatch_integration.py index dcf69b03..96ad205c 100644 --- a/tests/integration/test_mismatch_integration.py +++ b/tests/integration/test_mismatch_integration.py @@ -49,34 +49,6 @@ class StructurePreprocessExpectation: expected_hydrogen_atoms: int -@dataclass(frozen=True) -class RealPairExpectation: - """Reconciliation expectations for one real PDB/CIF pair. - - Parameters - ---------- - id - Case identifier used as pytest ID. - pdb_fixture - Fixture name for the full-structure PDB input. - cif_fixture - Fixture name for the density-input CIF model representation. - expected_n_model - Expected model atom count after filtering. - expected_n_struct - Expected structure atom count after filtering. - expected_n_common - Expected common atom count from :class:`sampleworks.utils.atom_reconciler.AtomReconciler`. - """ - - id: str - pdb_fixture: str - cif_fixture: str - expected_n_model: int - expected_n_struct: int - expected_n_common: int - - STRUCTURE_PREPROCESS_EXPECTATIONS: tuple[StructurePreprocessExpectation, ...] = ( StructurePreprocessExpectation( id="1vme_cif", @@ -141,42 +113,6 @@ class RealPairExpectation: ) -REAL_PAIR_EXPECTATIONS: tuple[RealPairExpectation, ...] = ( - RealPairExpectation( - id="2yl0", - pdb_fixture="structure_2yl0", - cif_fixture="structure_2yl0_density", - expected_n_model=955, - expected_n_struct=1727, - expected_n_common=677, - ), - RealPairExpectation( - id="5sop", - pdb_fixture="structure_5sop", - cif_fixture="structure_5sop_density", - expected_n_model=1264, - expected_n_struct=4873, - expected_n_common=1264, - ), - RealPairExpectation( - id="6ni6", - pdb_fixture="structure_6ni6", - cif_fixture="structure_6ni6_density", - expected_n_model=1678, - expected_n_struct=6812, - expected_n_common=1678, - ), - RealPairExpectation( - id="9bn8", - pdb_fixture="structure_9bn8", - cif_fixture="structure_9bn8_density", - expected_n_model=3260, - expected_n_struct=3330, - expected_n_common=3260, - ), -) - - ALL_MISMATCH_CASE_IDS: tuple[str, ...] = ( "identity_no_mismatch", "from_2yl0", @@ -275,8 +211,6 @@ def _model_space_reward_inputs(n_model: int, batch: int = 1) -> RewardInputs: b_factors=torch.ones(batch, n_model) * 20.0, occupancies=torch.ones(batch, n_model), input_coords=torch.randn(batch, n_model, 3), - reward_param_mask=np.ones(n_model, dtype=bool), - mask_like=torch.ones(batch, n_model), ) @@ -457,44 +391,6 @@ def structure_preprocess_case(request: pytest.FixtureRequest) -> StructurePrepro return request.param -@pytest.fixture(params=REAL_PAIR_EXPECTATIONS, ids=lambda exp: exp.id) -def real_pair_case(request: pytest.FixtureRequest) -> RealPairExpectation: - """Parametrized real pair expectation fixture.""" - return request.param - - -def _get_real_pair_arrays( - request: pytest.FixtureRequest, - pair_case: RealPairExpectation, -) -> tuple[dict[str, Any], AtomArray, AtomArray]: - """Load and filter atom arrays for a real PDB/CIF pair. - - Parameters - ---------- - request - Pytest request object used to resolve fixtures by name. - pair_case - Real pair expectation. - - Returns - ------- - tuple[dict[str, Any], AtomArray, AtomArray] - ``(pdb_structure_dict, pdb_filtered, cif_filtered)``. - """ - pdb_structure = request.getfixturevalue(pair_case.pdb_fixture) - cif_structure = request.getfixturevalue(pair_case.cif_fixture) - - pdb_raw = cast(AtomArray, ensure_atom_array_stack(pdb_structure["asym_unit"])[0]) - pdb_mask = (pdb_raw.occupancy > 0) & ~np.any(np.isnan(pdb_raw.coord), axis=-1) - pdb_filtered = cast(AtomArray, pdb_raw[pdb_mask]) - - cif_raw = cast(AtomArray, ensure_atom_array_stack(cif_structure["asym_unit"])[0]) - cif_mask = (cif_raw.occupancy > 0) & ~np.any(np.isnan(cif_raw.coord), axis=-1) - cif_filtered = cast(AtomArray, cif_raw[cif_mask]) - - return pdb_structure, pdb_filtered, cif_filtered - - class TestRealStructurePreprocessing: """Real-data preprocessing and reconciliation tests.""" @@ -545,204 +441,6 @@ def test_hydrogen_handling( assert hydrogen_count == structure_preprocess_case.expected_hydrogen_atoms - def test_reconciler_from_real_pair( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Reconciler counts on real PDB/CIF pairs match ground truth.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - assert reconciler.n_model == real_pair_case.expected_n_model - assert reconciler.n_struct == real_pair_case.expected_n_struct - assert reconciler.n_common == real_pair_case.expected_n_common - - def test_common_atom_identity_matches( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Common indices map atoms with matching normalized identity.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - model_ids = make_normalized_atom_id(cif_filtered) - struct_ids = make_normalized_atom_id(pdb_filtered) - - np.testing.assert_array_equal( - model_ids[reconciler.model_indices.detach().cpu().numpy()], - struct_ids[reconciler.struct_indices.detach().cpu().numpy()], - ) - - def test_process_structure_to_trajectory_input_with_real_pdb( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Full preprocessing with real PDB input returns consistent model space outputs.""" - pdb_structure, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - - case = MismatchCase( - id=f"real_{real_pair_case.id}", - description=f"Real preprocessing case for {real_pair_case.id}.", - model_atom_array=cif_filtered, - struct_atom_array=pdb_filtered, - expected_n_common=real_pair_case.expected_n_common, - expected_has_mismatch=True, - ) - wrapper = MismatchCaseWrapper(case) - - processed, _ = _preprocess(wrapper, _copy_structure_dict(pdb_structure)) - reconciler = processed.reconciler - - atom_coords = cast(np.ndarray, processed.atom_array.coord) - atom_occupancy = cast(np.ndarray, processed.atom_array.occupancy) - - assert processed.input_coords.shape == (1, real_pair_case.expected_n_model, 3) - assert reconciler.has_mismatch - assert np.isfinite(atom_coords).all() - assert np.all(atom_occupancy > 0) - assert processed.model_atom_array is not None - assert len(processed.model_atom_array) == real_pair_case.expected_n_model - - expected_common_coords = torch.as_tensor( - atom_coords[reconciler.struct_indices.detach().cpu().numpy()], - dtype=processed.input_coords.dtype, - ) - torch.testing.assert_close( - processed.input_coords[0, reconciler.model_indices], - expected_common_coords, - ) - - def test_reward_inputs_from_real_structure( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Reward inputs from real preprocessing have model space counts and B-factor overrides.""" - pdb_structure, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - - case = MismatchCase( - id=f"real_{real_pair_case.id}", - description=f"Real reward input case for {real_pair_case.id}.", - model_atom_array=cif_filtered, - struct_atom_array=pdb_filtered, - expected_n_common=real_pair_case.expected_n_common, - expected_has_mismatch=True, - ) - wrapper = MismatchCaseWrapper(case) - - processed, _ = _preprocess(wrapper, _copy_structure_dict(pdb_structure)) - reconciler = processed.reconciler - reward_inputs = processed.to_reward_inputs(device="cpu") - - assert reward_inputs.elements.shape[-1] == real_pair_case.expected_n_model - assert reward_inputs.input_coords.shape[-2] == real_pair_case.expected_n_model - assert reward_inputs.mask_like.shape == reward_inputs.input_coords.shape[:-1] - - atom_b_factors = cast(np.ndarray, processed.atom_array.b_factor) - expected_common_b_factors = torch.as_tensor( - atom_b_factors[reconciler.struct_indices.detach().cpu().numpy()], - dtype=reward_inputs.b_factors.dtype, - ) - torch.testing.assert_close( - reward_inputs.b_factors[0, reconciler.model_indices], - expected_common_b_factors, - ) - - def test_process_structure_roundtrip_coordinate_integrity( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Every common struct coordinate appears in the correct model index after preprocessing.""" - pdb_structure, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - - case = MismatchCase( - id=f"real_{real_pair_case.id}", - description=f"Real coordinate integrity case for {real_pair_case.id}.", - model_atom_array=cif_filtered, - struct_atom_array=pdb_filtered, - expected_n_common=real_pair_case.expected_n_common, - expected_has_mismatch=True, - ) - wrapper = MismatchCaseWrapper(case) - - processed, _ = _preprocess(wrapper, _copy_structure_dict(pdb_structure)) - reconciler = processed.reconciler - - atom_coords = cast(np.ndarray, processed.atom_array.coord) - for model_idx, struct_idx in zip( - reconciler.model_indices.tolist(), reconciler.struct_indices.tolist() - ): - torch.testing.assert_close( - processed.input_coords[0, model_idx], - torch.as_tensor( - atom_coords[struct_idx], - dtype=processed.input_coords.dtype, - ), - ) - - -class TestRealStructureAlignmentQuality: - """Alignment tests on real geometry from filtered PDB/CIF pairs.""" - - def test_alignment_on_real_geometry( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Known rigid transforms are recovered on realistic atom geometry.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - struct_coords = torch.as_tensor(pdb_filtered.coord, dtype=torch.float64).unsqueeze(0) - model_template = torch.as_tensor(cif_filtered.coord, dtype=torch.float64).unsqueeze(0) - model_reference = reconciler.struct_to_model(struct_coords, model_template) - - theta = np.deg2rad(45.0) - rotation = torch.tensor( - [ - [np.cos(theta), -np.sin(theta), 0.0], - [np.sin(theta), np.cos(theta), 0.0], - [0.0, 0.0, 1.0], - ], - dtype=torch.float64, - ) - translation = torch.tensor([10.0, -3.0, 5.0], dtype=torch.float64) - - transformed = model_reference @ rotation.T + translation - - aligned, _ = reconciler.align(transformed, model_reference) - - common = reconciler.model_indices - rmsd = _rmsd(aligned[..., common, :], model_reference[..., common, :]) - assert rmsd.item() < 1e-4 - - def test_alignment_with_realistic_perturbation( - self, - request: pytest.FixtureRequest, - real_pair_case: RealPairExpectation, - ): - """Alignment lowers RMSD when realistic Gaussian perturbations are present.""" - _, pdb_filtered, cif_filtered = _get_real_pair_arrays(request, real_pair_case) - reconciler = AtomReconciler.from_arrays(cif_filtered, pdb_filtered) - - struct_coords = torch.as_tensor(pdb_filtered.coord, dtype=torch.float32).unsqueeze(0) - model_template = torch.as_tensor(cif_filtered.coord, dtype=torch.float32).unsqueeze(0) - model_reference = reconciler.struct_to_model(struct_coords, model_template) - - torch.manual_seed(0) - perturbed = model_reference + torch.randn_like(model_reference) * 0.5 - - common = reconciler.model_indices - rmsd_before = _rmsd(perturbed[..., common, :], model_reference[..., common, :]) - aligned, _ = reconciler.align(perturbed, model_reference) - rmsd_after = _rmsd(aligned[..., common, :], model_reference[..., common, :]) - - assert rmsd_after.item() < rmsd_before.item() - class TestReconcilerConstruction: """Coordinate identity invariants for reconciler construction.""" @@ -980,7 +678,7 @@ def test_reward_inputs_atom_count(self, mismatch_case: MismatchCase): assert reward_inputs.elements.shape[-1] == n_model assert reward_inputs.b_factors.shape[-1] == n_model assert reward_inputs.input_coords.shape[-2] == n_model - assert reward_inputs.mask_like.shape[-1] == n_model + assert reward_inputs.input_coords.shape[-2] == n_model def test_b_factor_override(self, mismatch_case: MismatchCase): """Structure B-factors override model template values on common atoms.""" diff --git a/tests/integration/test_pipeline_integration.py b/tests/integration/test_pipeline_integration.py index 114987b9..57d4073c 100644 --- a/tests/integration/test_pipeline_integration.py +++ b/tests/integration/test_pipeline_integration.py @@ -13,7 +13,6 @@ Tests are organized from fast (mock-based) to slow (real wrappers). """ -import numpy as np import pytest import torch from sampleworks.core.rewards.protocol import RewardInputs @@ -25,6 +24,7 @@ NoiseSpaceDPSScaler, NoScalingScaler, ) +from sampleworks.eval.structure_utils import process_structure_to_trajectory_input from sampleworks.utils.guidance_constants import ( StepScalers, StructurePredictor, @@ -94,8 +94,6 @@ def create_step_context_with_reward( b_factors=torch.ones(batch_size, num_atoms, device=device) * 20.0, occupancies=torch.ones(batch_size, num_atoms, device=device), input_coords=state.clone(), - reward_param_mask=np.ones(num_atoms, dtype=bool), - mask_like=torch.ones(batch_size, num_atoms, device=device), ) metadata = None @@ -324,8 +322,6 @@ def test_noise_space_dps_gradient_chain_rule(self, device: torch.device): b_factors=torch.ones(batch_size, num_atoms, device=device) * 20.0, occupancies=torch.ones(batch_size, num_atoms, device=device), input_coords=state.detach().clone(), - reward_param_mask=np.ones(num_atoms, dtype=bool), - mask_like=torch.ones(batch_size, num_atoms, device=device), ) context = StepParams( step_index=0, @@ -368,8 +364,6 @@ def test_scaler_scale_returns_correct_shapes( b_factors=torch.ones(batch_size, num_atoms, device=device) * 20.0, occupancies=torch.ones(batch_size, num_atoms, device=device), input_coords=state.detach().clone(), - reward_param_mask=np.ones(num_atoms, dtype=bool), - mask_like=torch.ones(batch_size, num_atoms, device=device), ) context = StepParams( step_index=0, @@ -1075,6 +1069,65 @@ def test_trajectory_scaler_returns_guidance_output( assert len(result.trajectory) == 3 +@pytest.mark.slow +@pytest.mark.parametrize("wrapper_type", get_slow_wrappers(), ids=lambda w: w.value) +@pytest.mark.parametrize("structure_fixture", STRUCTURES, ids=lambda s: s.replace("structure_", "")) +class TestRealWrapperPreprocessing: + """process_structure_to_trajectory_input feeds properly into to_reward_inputs with real wrappers + + Verifies that featurization and preprocessing produce reward_inputs whose + atom dimension matches the state dimension for every wrapper by structure + combination. This catches model-specific NaN coordinate or occupancy + issues (e.g., RF3 on 5I09) that only surface with real featurization. + """ + + def test_reward_inputs_are_valid( + self, + wrapper_type: StructurePredictor, + structure_fixture: str, + temp_output_dir, + request, + ): + """reward_inputs atom count equals the model state atom count. + + Wrappers must ensure model_atom_array has valid coordinates and + occupancy for all atoms. ``RewardInputs.from_atom_array`` uses + all atoms without masking, so any NaN or zero-occupancy atoms + would produce invalid reward tensors and a dimension mismatch + in the step scalers. + """ + wrapper = request.getfixturevalue(get_fixture_name_for_wrapper_type(wrapper_type)) + structure = request.getfixturevalue(structure_fixture) + device = wrapper.device if hasattr(wrapper, "device") else torch.device("cpu") + + annotated = annotate_structure_for_wrapper_type(wrapper_type, structure, temp_output_dir) + features = wrapper.featurize(annotated) + state = wrapper.initialize_from_prior(batch_size=1, features=features) + + processed = process_structure_to_trajectory_input( + structure=annotated, + coords_from_prior=state, + features=features, + ensemble_size=1, + ) + reward_inputs = processed.to_reward_inputs(device=device) + + n_state = state.shape[-2] + n_reward = reward_inputs.elements.shape[-1] + assert n_state == n_reward, ( + f"State atom count ({n_state}) != reward atom count ({n_reward}). " + f"wrapper={wrapper_type.value}, structure={structure_fixture}" + ) + assert torch.isfinite(reward_inputs.input_coords).all(), ( + f"Non-finite reward coordinates for " + f"wrapper={wrapper_type.value}, structure={structure_fixture}" + ) + assert torch.all(reward_inputs.occupancies > 0), ( + f"Non-positive reward occupancies for " + f"wrapper={wrapper_type.value}, structure={structure_fixture}" + ) + + @pytest.mark.slow @pytest.mark.parametrize("wrapper_type", get_slow_wrappers(), ids=lambda w: w.value) class TestRealWrapperNumericalStability: diff --git a/tests/utils/test_density_utils.py b/tests/utils/test_density_utils.py index 0758b8f4..1a274e8a 100644 --- a/tests/utils/test_density_utils.py +++ b/tests/utils/test_density_utils.py @@ -371,7 +371,7 @@ def test_stack_coords_match_per_model( """Coordinates in each batch entry should match the corresponding model.""" coords, _, _, _ = extract_density_inputs_from_atomarray(simple_atom_array_stack, device) for i in range(simple_atom_array_stack.stack_depth()): - expected = torch.from_numpy(simple_atom_array_stack.coord[i].copy()).to( # pyright: ignore[reportOptionalSubscript] + expected = torch.from_numpy(simple_atom_array_stack.coord[i].copy()).to( device, dtype=torch.float32 ) torch.testing.assert_close(coords[i], expected)