-
Notifications
You must be signed in to change notification settings - Fork 4
fix(rf3): fix issues with 5I09 due to chain breaks and add associated tests #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| """No guidance geometry validation across all model wrappers and structures. | ||
|
|
||
| Verifies that running each model wrapper through the sampleworks sampling | ||
| pipeline without guidance produces structures with valid geometry. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
| from biotite.structure import filter_linear_bond_continuity, filter_peptide_backbone | ||
| from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig | ||
| from sampleworks.utils.guidance_constants import StructurePredictor | ||
|
|
||
| from tests.conftest import ( | ||
| annotate_structure_for_wrapper_type, | ||
| get_fixture_name_for_wrapper_type, | ||
| get_slow_wrappers, | ||
| STRUCTURES, | ||
| ) | ||
|
|
||
|
|
||
| NO_GUIDANCE_STRUCTURES: list[str] = [*STRUCTURES, "structure_5i09_density"] | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| @pytest.mark.slow | ||
| @pytest.mark.parametrize("wrapper_type", get_slow_wrappers(), ids=lambda w: w.value) | ||
| @pytest.mark.parametrize( | ||
| "structure_fixture", | ||
| NO_GUIDANCE_STRUCTURES, | ||
| ids=lambda s: s.replace("structure_", ""), | ||
| ) | ||
| class TestNoGuidanceGeometry: | ||
| """Validate that no-guidance sampling produces valid peptide geometry. | ||
|
|
||
| Each test runs a short 20 step trajectory with no guidance and | ||
| checks that the resulting peptide bond lengths are physically reasonable. | ||
| """ | ||
|
|
||
| NUM_STEPS: int = 20 | ||
|
|
||
| def test_no_guidance_produces_valid_peptide_geometry( | ||
| self, | ||
| wrapper_type: StructurePredictor, | ||
| structure_fixture: str, | ||
| temp_output_dir, | ||
| request, | ||
| ): | ||
| """ | ||
| No guidance sampling should produce structures with valid C-N bonds. | ||
| """ | ||
| wrapper = request.getfixturevalue(get_fixture_name_for_wrapper_type(wrapper_type)) | ||
| structure = request.getfixturevalue(structure_fixture) | ||
| device = wrapper.device if hasattr(wrapper, "device") else torch.device("cpu") | ||
|
|
||
| annotated = annotate_structure_for_wrapper_type( | ||
|
k-chrispens marked this conversation as resolved.
|
||
| wrapper_type, structure, temp_output_dir, ensemble_size=1 | ||
| ) | ||
| features = wrapper.featurize(annotated) | ||
|
|
||
| sampler = AF3EDMSampler( | ||
| EDMSamplerConfig(device=device, augmentation=False, align_to_input=False) | ||
| ) # NOTE: augmentation and alignment might be useful to parametrize here? | ||
| schedule = sampler.compute_schedule(num_steps=self.NUM_STEPS) | ||
|
|
||
| torch.manual_seed(42) | ||
| state = wrapper.initialize_from_prior(batch_size=1, features=features) | ||
|
|
||
| for i in range(self.NUM_STEPS): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some way to run this without explicitly including this loop here? It would be better to test the actual code path we use during inference.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the inference code path, we do a for loop over the schedule (e.g. |
||
| context = sampler.get_context_for_step(i, schedule) | ||
| step_output = sampler.step( | ||
| state=state, | ||
| model_wrapper=wrapper, | ||
| context=context, | ||
| scaler=None, | ||
| features=features, | ||
| ) | ||
| state = step_output.state | ||
|
|
||
| cond = features.conditioning | ||
| model_aa = cond.model_atom_array | ||
| assert model_aa is not None, ( | ||
| f"model_atom_array is None for wrapper={wrapper_type.value}, " | ||
| f"structure={structure_fixture}" | ||
| ) | ||
|
|
||
| output_aa = model_aa.copy() | ||
| # state shape: (batch=1, atoms, 3) | ||
| output_aa.coord = state[0].detach().cpu().numpy() | ||
|
|
||
| # Filter_peptide_backbone selects N, CA, C atoms, filter_linear_bond_continuity returns a | ||
| # boolean mask where True means the distance to the next atom is within [min_len, max_len] | ||
| bb_mask = filter_peptide_backbone(output_aa) | ||
| bb = output_aa[bb_mask] | ||
| assert len(bb) >= 6, ( # at least 2 residues (N, CA, C each) | ||
| f"Too few backbone atoms ({len(bb)}) for " | ||
| f"wrapper={wrapper_type.value}, structure={structure_fixture}" | ||
| ) | ||
|
|
||
| # Generous bounds for a 20-step stochastic sample. | ||
| # Ideal backbone bonds are 1.33–1.52 Å | ||
| # We're a little more generous because of the short trajectory. | ||
| con_mask = filter_linear_bond_continuity(bb, min_len=1.1, max_len=1.7) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we be more strict, say by looking at only the peptide bond?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can get strict ranges from RDKit, for instance (see the bond length script)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I just figure for speed of testing we might not expect all models to pass the most rigorous geometry checks (but 5I09 was failing this before)
k-chrispens marked this conversation as resolved.
|
||
| # Last element is always True (no next atom to compare), exclude it. | ||
| n_valid = int(con_mask[:-1].sum()) | ||
| n_total = len(con_mask) - 1 | ||
| fraction_valid = n_valid / n_total | ||
|
|
||
| assert fraction_valid >= 0.9, ( | ||
| f"Only {n_valid}/{n_total} ({fraction_valid:.1%}) backbone bonds " | ||
| "are within [1.1, 1.7] Å for " | ||
| f"wrapper={wrapper_type.value}, structure={structure_fixture}. " | ||
| "This suggests something is wrong (an atom ordering problem?)." | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| """Regression tests for RF3 atom ordering consistency. | ||
|
|
||
| Verifies that ``model_atom_array`` built during ``RF3Wrapper.featurize()`` has | ||
| the same atom count and ordering as the model's internal feature tensors | ||
| (``atom_to_token_map``). | ||
| """ | ||
|
|
||
| import numpy as np | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This maybe doesn't need to be in here for purposes of this PR, but I feel like these tests are still not stringent enough. (I admit there are basically no tests yet for my CIF patching operations, which pose similar risks). We probably should actually check the alignment and make sure the atoms are correct and match each other.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add an issue for this |
||
| import pytest | ||
| from sampleworks.utils.imports import require_rf3, RF3_AVAILABLE | ||
|
|
||
|
|
||
| if RF3_AVAILABLE: | ||
| from sampleworks.models.rf3.wrapper import annotate_structure_for_rf3 | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| @pytest.mark.slow | ||
| class TestRF3AtomOrdering: | ||
| """Validate model_atom_array matches model feature dimensions.""" | ||
|
|
||
| @require_rf3() | ||
| @pytest.mark.parametrize( | ||
| "structure_fixture", | ||
| [ | ||
| "structure_5i09_density", | ||
| "structure_5sop_density", | ||
| "structure_9bn8_density", | ||
| "structure_6ni6_density", | ||
| ], | ||
| ) | ||
| def test_model_atom_array_matches_feature_count(self, rf3_wrapper, structure_fixture, request): | ||
| """model_atom_array atom count must match atom_to_token_map length exactly. | ||
|
|
||
| Regression test for 5I09 where OXT atoms at chain breaks were retained | ||
| in model_atom_array but absent from the model's | ||
| internal atom accounting, causing coordinate misalignment. | ||
| """ | ||
| structure = request.getfixturevalue(structure_fixture) | ||
| annotated = annotate_structure_for_rf3(structure, ensemble_size=1) | ||
| features = rf3_wrapper.featurize(annotated) | ||
| cond = features.conditioning | ||
|
|
||
| assert cond.model_atom_array is not None | ||
|
|
||
| num_feature_atoms = len(cond.features["atom_to_token_map"]) | ||
| assert len(cond.model_atom_array) == num_feature_atoms, ( | ||
| f"model_atom_array has {len(cond.model_atom_array)} atoms but " | ||
| f"atom_to_token_map has {num_feature_atoms}. Coordinate mapping " | ||
| "will be misaligned." | ||
| ) | ||
|
|
||
| @require_rf3() | ||
| @pytest.mark.parametrize( | ||
| "structure_fixture", | ||
| [ | ||
| "structure_5i09_density", | ||
| "structure_5sop_density", | ||
| "structure_9bn8_density", | ||
| "structure_6ni6_density", | ||
| ], | ||
| ) | ||
| def test_no_oxt_atoms_in_model_atom_array(self, rf3_wrapper, structure_fixture, request): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't there be one OXT per chain? Does RF3's pipeline actually remove the final terminal oxygen?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should, but it is removed by most models (except protenix)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a reference, I presume this is done to standardize all the residues such that a residue token corresponds to a specific set of atoms always. I think protenix does something a bit fancier as their outputs do have OXT, but I don't know the specifics (only that I had to deal with them requiring it) |
||
| """model_atom_array must not contain OXT atoms. | ||
|
|
||
| The pipeline's ``RemoveTerminalOxygen`` transform removes these atoms. | ||
| """ | ||
| structure = request.getfixturevalue(structure_fixture) | ||
| annotated = annotate_structure_for_rf3(structure, ensemble_size=1) | ||
| features = rf3_wrapper.featurize(annotated) | ||
| cond = features.conditioning | ||
|
|
||
| assert cond.model_atom_array is not None | ||
| oxt_count = int(np.sum(cond.model_atom_array.atom_name == "OXT")) | ||
| assert oxt_count == 0, ( | ||
| f"model_atom_array contains {oxt_count} OXT atoms that the pipeline " | ||
| "should have removed." | ||
| ) | ||
|
|
||
| @require_rf3() | ||
| @pytest.mark.parametrize( | ||
| "structure_fixture", | ||
| [ | ||
| "structure_5i09_density", | ||
| "structure_5sop_density", | ||
| "structure_9bn8_density", | ||
| "structure_6ni6_density", | ||
| ], | ||
| ) | ||
| def test_no_hydrogen_atoms_in_model_atom_array(self, rf3_wrapper, structure_fixture, request): | ||
| """model_atom_array must not contain hydrogen atoms. | ||
|
|
||
| The pipeline's ``RemoveHydrogens`` transform removes these. | ||
| """ | ||
| structure = request.getfixturevalue(structure_fixture) | ||
| annotated = annotate_structure_for_rf3(structure, ensemble_size=1) | ||
| features = rf3_wrapper.featurize(annotated) | ||
| cond = features.conditioning | ||
|
|
||
| assert cond.model_atom_array is not None | ||
| h_count = int(np.sum(cond.model_atom_array.element == "H")) | ||
| assert h_count == 0, ( | ||
| f"model_atom_array contains {h_count} hydrogen atoms that the pipeline " | ||
| "should have removed." | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.