Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 136 additions & 134 deletions pixi.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions src/sampleworks/core/samplers/edm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class EDMSamplerConfig:
r"""Config for an EDM sampler.

Default values match the AF3 parameterization of the EDM framework
(Karras et al., 2022).
(Karras et al., 2022) except for ``gamma_min``, which for now is set to 0.2
instead of 1.0.

Parameters
----------
Expand Down Expand Up @@ -113,7 +114,7 @@ class EDMSamplerConfig:
s_max: float = 160.0
s_min: float = 4e-4
p: float = 7.0
gamma_min: float = 0.2
gamma_min: float = 0.2 # this is not the default value from AF3! AF3 uses 1.0
Comment thread
k-chrispens marked this conversation as resolved.
gamma_0: float = 0.8
noise_scale: float = 1.003
step_scale: float = 1.5
Expand All @@ -139,7 +140,7 @@ class AF3EDMSampler:

Initialized with a single :class:`EDMSamplerConfig` object that holds all
schedule hyperparameters and runtime options. Default values in the config
match the AF3 parameterization.
match the AF3 parameterization, except for ``gamma_min`` which is set to 0.2 instead of 1.0.

This sampler implements the EDM (Karras et al.) style sampling
approach as used in AlphaFold3 and related models, which is the Euler
Expand Down Expand Up @@ -521,4 +522,4 @@ def step(
denoised=x_hat_0_working_frame_t,
loss=loss,
log_proposal_correction=log_proposal_correction,
)
)
24 changes: 16 additions & 8 deletions src/sampleworks/models/rf3/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,22 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:

num_atoms = len(pairformer_out["features"]["atom_to_token_map"])

# Build model atom array from non-hydrogen InferenceInput atoms
model_aa = cast(
AtomArray, inference_input.atom_array[inference_input.atom_array.element != "H"]
)
# RF3 feature assembly preserves inference_input atom order after hydrogen filtering.
# Any excess atoms are trailing entries not represented in atom_to_token_map.
if len(model_aa) > num_atoms:
model_aa = cast(AtomArray, model_aa[:num_atoms])
# Use the pipeline output array that RF3's native inference
# uses (rf3/inference_engines/rf3.py line 594). The pipeline removes
# atoms (H, OXT, etc.) that the model doesn't operate on automatically,
# so we will use this for the "model_atom_array" that refers to the set of
# atoms that the model operates on during sampling.
if "atom_array" not in pipeline_output:
raise ValueError(
"pipeline_output is missing 'atom_array' key, cannot determine model_atom_array"
Comment thread
marcuscollins marked this conversation as resolved.
)
model_aa = pipeline_output["atom_array"].copy()
if len(model_aa) != num_atoms:
raise ValueError(
f"model_atom_array has {len(model_aa)} atoms but the model's "
f"atom_to_token_map has {num_atoms} entries. These must match exactly "
"for correct coordinate-to-atom mapping."
Comment thread
k-chrispens marked this conversation as resolved.
)

# atomworks's add_missing_atoms adds unresolved atoms with
# occupancy=0.0 and NaN coordinates when we get our atom array with
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@ def structure_9bn8_density(resources_dir: Path) -> dict:
return parse(resources_dir / "9BN8" / "9BN8_single_001_density_input.cif", ccd_mirror_path=None)


@pytest.fixture(scope="session")
def structure_5i09_density(resources_dir: Path) -> dict:
return parse(resources_dir / "5I09" / "5I09_single_001_density_input.cif", ccd_mirror_path=None)


@pytest.fixture(scope="session")
def structure_6b8x_with_altlocs(resources_dir: Path) -> AtomArray | AtomArrayStack:
return load_any(
Expand Down
113 changes: 113 additions & 0 deletions tests/integration/test_no_guidance_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""No guidance geometry validation across all model wrappers and structures.

Verifies that running each model wrapper through the sampleworks sampling
pipeline without guidance produces structures with valid geometry.
"""

import pytest
import torch
from biotite.structure import filter_linear_bond_continuity, filter_peptide_backbone
from sampleworks.core.samplers.edm import AF3EDMSampler, EDMSamplerConfig
from sampleworks.utils.guidance_constants import StructurePredictor

from tests.conftest import (
annotate_structure_for_wrapper_type,
get_fixture_name_for_wrapper_type,
get_slow_wrappers,
STRUCTURES,
)


NO_GUIDANCE_STRUCTURES: list[str] = [*STRUCTURES, "structure_5i09_density"]


@pytest.mark.gpu
@pytest.mark.slow
@pytest.mark.parametrize("wrapper_type", get_slow_wrappers(), ids=lambda w: w.value)
@pytest.mark.parametrize(
"structure_fixture",
NO_GUIDANCE_STRUCTURES,
ids=lambda s: s.replace("structure_", ""),
)
class TestNoGuidanceGeometry:
"""Validate that no-guidance sampling produces valid peptide geometry.

Each test runs a short 20 step trajectory with no guidance and
checks that the resulting peptide bond lengths are physically reasonable.
"""

NUM_STEPS: int = 20

def test_no_guidance_produces_valid_peptide_geometry(
self,
wrapper_type: StructurePredictor,
structure_fixture: str,
temp_output_dir,
request,
):
"""
No guidance sampling should produce structures with valid C-N bonds.
"""
wrapper = request.getfixturevalue(get_fixture_name_for_wrapper_type(wrapper_type))
structure = request.getfixturevalue(structure_fixture)
device = wrapper.device if hasattr(wrapper, "device") else torch.device("cpu")

annotated = annotate_structure_for_wrapper_type(
Comment thread
k-chrispens marked this conversation as resolved.
wrapper_type, structure, temp_output_dir, ensemble_size=1
)
features = wrapper.featurize(annotated)

sampler = AF3EDMSampler(
EDMSamplerConfig(device=device, augmentation=False, align_to_input=False)
) # NOTE: augmentation and alignment might be useful to parametrize here?
schedule = sampler.compute_schedule(num_steps=self.NUM_STEPS)

torch.manual_seed(42)
state = wrapper.initialize_from_prior(batch_size=1, features=features)

for i in range(self.NUM_STEPS):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some way to run this without explicitly including this loop here? It would be better to test the actual code path we use during inference.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the inference code path, we do a for loop over the schedule (e.g. src/sampleworks/scalers/pure_guidance.py:108)

context = sampler.get_context_for_step(i, schedule)
step_output = sampler.step(
state=state,
model_wrapper=wrapper,
context=context,
scaler=None,
features=features,
)
state = step_output.state

cond = features.conditioning
model_aa = cond.model_atom_array
assert model_aa is not None, (
f"model_atom_array is None for wrapper={wrapper_type.value}, "
f"structure={structure_fixture}"
)

output_aa = model_aa.copy()
# state shape: (batch=1, atoms, 3)
output_aa.coord = state[0].detach().cpu().numpy()

# Filter_peptide_backbone selects N, CA, C atoms, filter_linear_bond_continuity returns a
# boolean mask where True means the distance to the next atom is within [min_len, max_len]
bb_mask = filter_peptide_backbone(output_aa)
bb = output_aa[bb_mask]
assert len(bb) >= 6, ( # at least 2 residues (N, CA, C each)
f"Too few backbone atoms ({len(bb)}) for "
f"wrapper={wrapper_type.value}, structure={structure_fixture}"
)

# Generous bounds for a 20-step stochastic sample.
# Ideal backbone bonds are 1.33–1.52 Å
# We're a little more generous because of the short trajectory.
con_mask = filter_linear_bond_continuity(bb, min_len=1.1, max_len=1.7)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we be more strict, say by looking at only the peptide bond?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get strict ranges from RDKit, for instance (see the bond length script)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just figure for speed of testing we might not expect all models to pass the most rigorous geometry checks (but 5I09 was failing this before)

Comment thread
k-chrispens marked this conversation as resolved.
# Last element is always True (no next atom to compare), exclude it.
n_valid = int(con_mask[:-1].sum())
n_total = len(con_mask) - 1
fraction_valid = n_valid / n_total

assert fraction_valid >= 0.9, (
f"Only {n_valid}/{n_total} ({fraction_valid:.1%}) backbone bonds "
"are within [1.1, 1.7] Å for "
f"wrapper={wrapper_type.value}, structure={structure_fixture}. "
"This suggests something is wrong (an atom ordering problem?)."
)
105 changes: 105 additions & 0 deletions tests/models/test_rf3_atom_ordering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Regression tests for RF3 atom ordering consistency.

Verifies that ``model_atom_array`` built during ``RF3Wrapper.featurize()`` has
the same atom count and ordering as the model's internal feature tensors
(``atom_to_token_map``).
"""

import numpy as np
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This maybe doesn't need to be in here for purposes of this PR, but I feel like these tests are still not stringent enough. (I admit there are basically no tests yet for my CIF patching operations, which pose similar risks). We probably should actually check the alignment and make sure the atoms are correct and match each other.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add an issue for this

import pytest
from sampleworks.utils.imports import require_rf3, RF3_AVAILABLE


if RF3_AVAILABLE:
from sampleworks.models.rf3.wrapper import annotate_structure_for_rf3


@pytest.mark.gpu
@pytest.mark.slow
class TestRF3AtomOrdering:
"""Validate model_atom_array matches model feature dimensions."""

@require_rf3()
@pytest.mark.parametrize(
"structure_fixture",
[
"structure_5i09_density",
"structure_5sop_density",
"structure_9bn8_density",
"structure_6ni6_density",
],
)
def test_model_atom_array_matches_feature_count(self, rf3_wrapper, structure_fixture, request):
"""model_atom_array atom count must match atom_to_token_map length exactly.

Regression test for 5I09 where OXT atoms at chain breaks were retained
in model_atom_array but absent from the model's
internal atom accounting, causing coordinate misalignment.
"""
structure = request.getfixturevalue(structure_fixture)
annotated = annotate_structure_for_rf3(structure, ensemble_size=1)
features = rf3_wrapper.featurize(annotated)
cond = features.conditioning

assert cond.model_atom_array is not None

num_feature_atoms = len(cond.features["atom_to_token_map"])
assert len(cond.model_atom_array) == num_feature_atoms, (
f"model_atom_array has {len(cond.model_atom_array)} atoms but "
f"atom_to_token_map has {num_feature_atoms}. Coordinate mapping "
"will be misaligned."
)

@require_rf3()
@pytest.mark.parametrize(
"structure_fixture",
[
"structure_5i09_density",
"structure_5sop_density",
"structure_9bn8_density",
"structure_6ni6_density",
],
)
def test_no_oxt_atoms_in_model_atom_array(self, rf3_wrapper, structure_fixture, request):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be one OXT per chain? Does RF3's pipeline actually remove the final terminal oxygen?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should, but it is removed by most models (except protenix)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a reference, I presume this is done to standardize all the residues such that a residue token corresponds to a specific set of atoms always. I think protenix does something a bit fancier as their outputs do have OXT, but I don't know the specifics (only that I had to deal with them requiring it)

"""model_atom_array must not contain OXT atoms.

The pipeline's ``RemoveTerminalOxygen`` transform removes these atoms.
"""
structure = request.getfixturevalue(structure_fixture)
annotated = annotate_structure_for_rf3(structure, ensemble_size=1)
features = rf3_wrapper.featurize(annotated)
cond = features.conditioning

assert cond.model_atom_array is not None
oxt_count = int(np.sum(cond.model_atom_array.atom_name == "OXT"))
assert oxt_count == 0, (
f"model_atom_array contains {oxt_count} OXT atoms that the pipeline "
"should have removed."
)

@require_rf3()
@pytest.mark.parametrize(
"structure_fixture",
[
"structure_5i09_density",
"structure_5sop_density",
"structure_9bn8_density",
"structure_6ni6_density",
],
)
def test_no_hydrogen_atoms_in_model_atom_array(self, rf3_wrapper, structure_fixture, request):
"""model_atom_array must not contain hydrogen atoms.

The pipeline's ``RemoveHydrogens`` transform removes these.
"""
structure = request.getfixturevalue(structure_fixture)
annotated = annotate_structure_for_rf3(structure, ensemble_size=1)
features = rf3_wrapper.featurize(annotated)
cond = features.conditioning

assert cond.model_atom_array is not None
h_count = int(np.sum(cond.model_atom_array.element == "H"))
assert h_count == 0, (
f"model_atom_array contains {h_count} hydrogen atoms that the pipeline "
"should have removed."
)
Loading
Loading