-
Notifications
You must be signed in to change notification settings - Fork 4
fix(rf3): squashed RF3 bug that was blocking running on models with unresolved residues #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4dcba73
a6213e3
eaa92c1
7c854b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,14 +23,17 @@ class RewardInputs: | |
| Contains all the information needed to call a RewardFunctionProtocol, | ||
| extracted from an atom array. This allows the caller to extract inputs | ||
| once and pass them to scale() methods without redundant extraction. | ||
|
|
||
| The atom array passed to :meth:`from_atom_array` must already be clean: | ||
| all coordinates finite and all occupancies positive. Wrappers are | ||
| responsible for ensuring this (e.g. replacing NaN coordinates with | ||
| noise and setting occupancy to 1.0 for model-operated atoms). | ||
| """ | ||
|
|
||
| elements: Float[torch.Tensor, "*batch n_atoms"] | ||
| b_factors: Float[torch.Tensor, "*batch n_atoms"] | ||
| occupancies: Float[torch.Tensor, "*batch n_atoms"] | ||
| input_coords: Float[torch.Tensor, "*batch n_atoms 3"] | ||
| reward_param_mask: np.ndarray | ||
| mask_like: Float[torch.Tensor, "*batch n_atoms"] | ||
|
|
||
| @classmethod | ||
| def from_atom_array( | ||
|
|
@@ -42,10 +45,15 @@ def from_atom_array( | |
| ) -> RewardInputs: | ||
| """Construct RewardInputs from a Biotite AtomArray. | ||
|
|
||
| The atom array must contain only valid atoms (finite coordinates, | ||
| positive occupancy). Callers are responsible for filtering | ||
| beforehand; no masking is applied here. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| atom_array | ||
| Biotite AtomArray or AtomArrayStack containing structure data. | ||
| Must have not NaN coordinates and positive occupancy. | ||
| ensemble_size | ||
| Number of ensemble members (batch dimension). | ||
| num_particles | ||
|
|
@@ -58,13 +66,17 @@ def from_atom_array( | |
| RewardInputs | ||
| Dataclass containing all inputs needed for reward function computation. | ||
| """ | ||
| occupancy_mask = atom_array.occupancy > 0 | ||
| nan_mask = ~np.any(np.isnan(atom_array.coord), axis=-1) | ||
| reward_param_mask = occupancy_mask & nan_mask | ||
|
|
||
| elements_list = [ | ||
| ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element[reward_param_mask] | ||
| ] | ||
| # input validation: ensure atom_array has required annotations and valid values | ||
| if not hasattr(atom_array, "element"): | ||
| raise ValueError("Atom array must have 'element' annotation.") | ||
| if not hasattr(atom_array, "b_factor"): | ||
| raise ValueError("Atom array must have 'b_factor' annotation.") | ||
| if np.any(np.isnan(atom_array.coord)): | ||
| raise ValueError("Atom array contains NaN coordinates.") | ||
| if np.any((atom_array.occupancy <= 0) | (atom_array.occupancy > 1)): | ||
| raise ValueError("Atom array contains invalid occupancy values.") | ||
|
Comment on lines
+69
to
+77
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
printf '\n=== to_reward_inputs implementation ===\n'
sed -n '1,220p' src/sampleworks/eval/structure_utils.py
printf '\n=== wrapper/model-atom-array sanitation sites ===\n'
rg -n -C3 'model_atom_array|occupanc|isnan|isfinite|filter_zero_occupancy|coord' \
src/sampleworks/models src/sampleworks/eval/structure_utils.pyRepository: diff-use/sampleworks Length of output: 50376 🏁 Script executed: #!/bin/bash
# Check Protenix wrapper for coordinate sanitization
rg -n -A10 'model_atom_array|model_aa\s*=|coord.*protenix' \
src/sampleworks/models/protenix/wrapper.py | head -100
# Also check if Protenix structure_processing has any sanitization
fd -e py protenix | xargs rg -l 'nan|isfinite|coord' | head -5Repository: diff-use/sampleworks Length of output: 2209 Move the NaN/occupancy validation to the final coordinate path, or require all wrappers to sanitize The validation at lines 69-77 checks 🧰 Tools🪛 Ruff (0.15.4)[warning] 71-71: Avoid specifying long messages outside the exception class (TRY003) [warning] 73-73: Avoid specifying long messages outside the exception class (TRY003) [warning] 75-75: Avoid specifying long messages outside the exception class (TRY003) [warning] 77-77: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| elements_list = [ELEMENT_TO_ATOMIC_NUM[e.title()] for e in atom_array.element] | ||
|
|
||
| total_batch_size = num_particles * ensemble_size if num_particles > 1 else ensemble_size | ||
|
|
||
|
|
@@ -80,7 +92,7 @@ def from_atom_array( | |
| ) | ||
| b_factors = einx.rearrange( | ||
| "n -> p e n", | ||
| torch.Tensor(atom_array.b_factor[reward_param_mask]), | ||
| torch.Tensor(atom_array.b_factor), | ||
| p=num_particles, | ||
| e=ensemble_size, | ||
| ) | ||
|
|
@@ -89,22 +101,20 @@ def from_atom_array( | |
| "... -> b ...", | ||
| coords_t, | ||
| b=total_batch_size, | ||
| )[..., reward_param_mask, :] | ||
| ) | ||
| else: | ||
| elements = einx.rearrange("n -> b n", torch.Tensor(elements_list), b=ensemble_size) | ||
| b_factors = einx.rearrange( | ||
| "n -> b n", | ||
| torch.Tensor(atom_array.b_factor[reward_param_mask]), | ||
| torch.Tensor(atom_array.b_factor), | ||
| b=ensemble_size, | ||
| ) | ||
| occupancies = torch.ones_like(b_factors) / ensemble_size | ||
| input_coords = einx.rearrange( | ||
| "... -> e ...", | ||
| coords_t, | ||
| e=ensemble_size, | ||
| )[..., reward_param_mask, :] | ||
|
|
||
| mask_like = torch.ones_like(input_coords[..., 0]) | ||
| ) | ||
|
|
||
| if isinstance(device, str): | ||
| device = torch.device(device) | ||
|
|
@@ -114,8 +124,6 @@ def from_atom_array( | |
| b_factors=b_factors.to(device), | ||
| occupancies=occupancies.to(device), | ||
| input_coords=input_coords.to(device), | ||
| reward_param_mask=reward_param_mask, | ||
| mask_like=mask_like.to(device), | ||
| ) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.