Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions src/sampleworks/models/rf3/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,39 @@ def add_msa_to_chain_info(chain_info: dict, msa_path: str | Path | dict | None)
return updated_chain_info


def _cuda_index(device: torch.device | str) -> int:
"""Extract the CUDA device index from a ``torch.device`` or string.

Parameters
----------
device: torch.device | str
Device spec, e.g. ``"cuda:3"``, ``torch.device("cuda", 2)``, or ``"cuda"``.

Returns
-------
int
Device index (``0`` when unspecified, matching Torch defaults).

Raises
------
ValueError
If ``device`` is not a CUDA device. Currently, sampleworks only supports
CUDA systems, so non-CUDA devices will fail here.
"""
dev = torch.device(device)
if dev.type != "cuda":
raise ValueError(f"RF3Wrapper requires a CUDA device, got {dev!r}")
return dev.index if dev.index is not None else 0


class RF3Wrapper:
"""Wrapper for RosettaFold 3 (Baker Lab AlphaFold 3 replication)."""

def __init__(
self,
checkpoint_path: str | Path,
msa_manager: MSAManager | None = None,
device: torch.device | str | None = None,
):
"""
Parameters
Expand All @@ -198,17 +224,31 @@ def __init__(
Filesystem path to the checkpoint containing trained weights.
msa_manager: MSAManager | None
MSA manager for retrieving MSAs for input structures.
device: torch.device | str | None
CUDA device to bind the underlying Lightning Fabric to (e.g. ``"cuda:3"``).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure I understood this, I went searching in the foundry repo and found this line: https://github.com/RosettaCommons/foundry/blob/b071919caa19ff334bc04b1b41145cac61eba819/src/foundry/trainers/fabric.py#L92

Would probably be worth referencing this here for posterity

When ``None``, Fabric picks the first available device. Required for
parallel jobs that must target distinct GPUs — passing an ``int``
to Fabric (the default) always resolves to GPU 0, which serialises
otherwise-parallel workers onto a single device.

References: https://lightning.ai/docs/fabric/stable/fundamentals/launch.html
devices argument to fabric run
https://github.com/RosettaCommons/foundry/blob/b071919caa19ff334bc04b1b41145cac61eba819/src/foundry/trainers/fabric.py#L92
Comment thread
k-chrispens marked this conversation as resolved.
"""
logger.info("Loading RF3 Inference Engine")

self.checkpoint_path = Path(checkpoint_path).expanduser().resolve()
self.msa_manager = msa_manager
self.msa_pairing_strategy = "greedy"

self.inference_engine = RF3InferenceEngine(
ckpt_path=str(self.checkpoint_path),
diffusion_batch_size=1,
)
engine_kwargs: dict[str, Any] = {
"ckpt_path": str(self.checkpoint_path),
"diffusion_batch_size": 1,
}
if device is not None:
engine_kwargs["devices_per_node"] = [_cuda_index(device)]

self.inference_engine = RF3InferenceEngine(**engine_kwargs)
self.inference_engine.initialize()

self.inference_engine.trainer = cast(
Expand Down
4 changes: 1 addition & 3 deletions src/sampleworks/utils/guidance_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,11 @@ def get_model_and_device(
model_wrapper = RF3Wrapper(
checkpoint_path=validated_checkpoint_path,
msa_manager=MSAManager(),
device=device,
)
else:
raise ValueError(f"Unknown model type: {model_type}")

# RF3 currently manages its own device; prefer that when available.
device = getattr(model_wrapper, "device", device)

# (pyright doesn't think Boltz1Wrapper etc are "Any")
return device, model_wrapper

Expand Down
56 changes: 56 additions & 0 deletions tests/models/test_rf3_device_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Tests for RF3Wrapper device selection (issue #37).

Fabric's ``devices`` argument treats an ``int`` as "take this many GPUs, starting
from index 0", which pins every worker to ``cuda:0`` and serialises nominally
parallel jobs. The wrapper must accept an explicit device and bind Fabric to
that specific GPU index via ``devices_per_node=[idx]``.
"""

import pytest
import torch
from sampleworks.utils.imports import require_rf3, RF3_AVAILABLE


pytestmark = pytest.mark.skipif(not RF3_AVAILABLE, reason="RF3 dependencies not installed")

if RF3_AVAILABLE:
from sampleworks.models.rf3.wrapper import _cuda_index, RF3Wrapper


class TestCudaIndex:
"""Validate CUDA index extraction used to drive Fabric device selection."""

@pytest.mark.parametrize(
("device", "expected"),
[
("cuda:0", 0),
("cuda:3", 3),
(torch.device("cuda", 5), 5),
("cuda", 0),
],
)
def test_returns_index(self, device, expected):
assert _cuda_index(device) == expected

@pytest.mark.parametrize("device", ["cpu", torch.device("cpu")])
def test_rejects_non_cuda(self, device):
with pytest.raises(ValueError, match="CUDA device"):
_cuda_index(device)


@pytest.mark.gpu
@pytest.mark.slow
class TestRF3WrapperDeviceBinding:
"""Regression for issue #37: device must propagate to Fabric.

Prior behaviour: RF3Wrapper ignored caller-specified device and always
landed on ``cuda:0`` because Fabric defaults to ``devices=1``.
"""

@require_rf3()
def test_wrapper_honors_requested_cuda_index(self, rf3_checkpoint_path):
if torch.cuda.device_count() < 2:
pytest.skip("multi-GPU regression test needs >= 2 CUDA devices")

wrapper = RF3Wrapper(checkpoint_path=rf3_checkpoint_path, device="cuda:1")
assert wrapper.device == torch.device("cuda:1")
Loading