Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ jobs:
- name: Run tests
shell: bash -l {0}
run: |
pytest -s tests
pytest -v tests --disable-warnings
34 changes: 17 additions & 17 deletions src/mattersim/forcefield/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def predict_properties(
raise NotImplementedError
else:
graph_batch.to(self.device)
input = batch_to_dict(graph_batch)
input = batch_to_dict(graph_batch, device=self.device)
result = self.forward(
input,
include_forces=include_forces,
Expand Down Expand Up @@ -553,7 +553,7 @@ def train_one_epoch(
raise NotImplementedError
else:
graph_batch.to(self.device)
input = batch_to_dict(graph_batch)
input = batch_to_dict(graph_batch, device=self.device)
if mode == "train":
result = self.forward(
input,
Expand Down Expand Up @@ -707,7 +707,7 @@ def get_properties(
raise NotImplementedError
else:
graph_batch.to(self.device)
input = batch_to_dict(graph_batch)
input = batch_to_dict(graph_batch, device=self.device)
result = self.forward(
input,
include_forces=include_forces,
Expand Down Expand Up @@ -1076,19 +1076,19 @@ def get_description(self):
def batch_to_dict(graph_batch, model_type="m3gnet", device="cuda"):
if model_type == "m3gnet":
# TODO: key_list
atom_pos = graph_batch.atom_pos
cell = graph_batch.cell
pbc_offsets = graph_batch.pbc_offsets
atom_attr = graph_batch.atom_attr
edge_index = graph_batch.edge_index
three_body_indices = graph_batch.three_body_indices
num_three_body = graph_batch.num_three_body
num_bonds = graph_batch.num_bonds
num_triple_ij = graph_batch.num_triple_ij
num_atoms = graph_batch.num_atoms
atom_pos = graph_batch.atom_pos.to(device)
cell = graph_batch.cell.to(device)
pbc_offsets = graph_batch.pbc_offsets.to(device)
atom_attr = graph_batch.atom_attr.to(device)
edge_index = graph_batch.edge_index.to(device)
three_body_indices = graph_batch.three_body_indices.to(device)
num_three_body = graph_batch.num_three_body.to(device)
num_bonds = graph_batch.num_bonds.to(device)
num_triple_ij = graph_batch.num_triple_ij.to(device)
num_atoms = graph_batch.num_atoms.to(device)
num_graphs = graph_batch.num_graphs
num_graphs = torch.tensor(num_graphs)
batch = graph_batch.batch
num_graphs = torch.tensor(num_graphs, device=device)
batch = graph_batch.batch.to(device)

# Resemble input dictionary
input = {}
Expand Down Expand Up @@ -1212,7 +1212,7 @@ def calculate(
raise NotImplementedError
else:
graph_batch = graph_batch.to(self.device)
input = batch_to_dict(graph_batch)
input = batch_to_dict(graph_batch, device=self.device)

result = self.potential.forward(
input, include_forces=True, include_stresses=self.compute_stress
Expand Down Expand Up @@ -1370,7 +1370,7 @@ def calculate(
raise NotImplementedError
else:
graph_batch = graph_batch.to(self.device)
input = batch_to_dict(graph_batch)
input = batch_to_dict(graph_batch, device=self.device)

result = self.potential.forward(
input, include_forces=True, include_stresses=self.compute_stress
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Shared pytest fixtures for mattersim tests."""

import pytest
import torch


def _available_devices():
"""Return all available torch devices on this machine."""
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
if torch.backends.mps.is_available():
devices.append("mps")
return devices


@pytest.fixture(
params=_available_devices(),
ids=lambda d: f"device={d}",
)
def device(request):
"""Yields each available device (cpu, cuda, mps)."""
return request.param
Empty file added tests/forcefield/__init__.py
Empty file.
101 changes: 101 additions & 0 deletions tests/forcefield/test_batch_to_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for batch_to_dict — verifies that tensor device placement works
correctly (related to GitHub issue #113).
"""

import pytest
import torch
from types import SimpleNamespace

from mattersim.forcefield.potential import batch_to_dict

TENSOR_KEYS = [
"atom_pos",
"cell",
"pbc_offsets",
"atom_attr",
"edge_index",
"three_body_indices",
"num_three_body",
"num_bonds",
"num_triple_ij",
"num_atoms",
"num_graphs",
"batch",
]


def _make_graph_batch(device="cpu"):
"""Create a minimal mock graph_batch with all required tensor fields."""
return SimpleNamespace(
atom_pos=torch.randn(4, 3, device=device),
cell=torch.randn(1, 3, 3, device=device),
pbc_offsets=torch.zeros(6, 3, device=device),
atom_attr=torch.randn(4, 16, device=device),
edge_index=torch.randint(0, 4, (2, 6), device=device),
three_body_indices=torch.randint(0, 6, (3, 8), device=device),
num_three_body=torch.tensor([8], device=device),
num_bonds=torch.tensor([6], device=device),
num_triple_ij=torch.tensor([8], device=device),
num_atoms=torch.tensor([4], device=device),
num_graphs=1, # scalar, not a tensor on the batch object
batch=torch.zeros(4, dtype=torch.long, device=device),
)


class TestBatchToDict:
"""Tests for the batch_to_dict helper function."""

def test_all_tensors_on_target_device(self, device):
"""Every tensor in the returned dict must be on the requested device."""
batch = _make_graph_batch("cpu")
result = batch_to_dict(batch, device=device)

for key in TENSOR_KEYS:
assert key in result, f"Missing key: {key}"
assert result[key].device.type == device, (
f"'{key}' on {result[key].device}, expected {device}"
)

def test_cross_device_move(self, device):
"""Tensors created on CPU must end up on the target device —
the device parameter must not be silently ignored."""
batch = _make_graph_batch("cpu")
result = batch_to_dict(batch, device=device)

for key in TENSOR_KEYS:
assert result[key].device.type == device, (
f"'{key}' still on {result[key].device} instead of {device}"
)

def test_num_graphs_is_tensor_on_correct_device(self, device):
"""num_graphs (a plain int on the batch) must become a tensor
on the target device."""
batch = _make_graph_batch("cpu")
result = batch_to_dict(batch, device=device)

assert isinstance(result["num_graphs"], torch.Tensor)
assert result["num_graphs"].item() == 1
assert result["num_graphs"].device.type == device

def test_returns_all_expected_keys(self):
"""The returned dict must contain exactly the expected keys."""
batch = _make_graph_batch("cpu")
result = batch_to_dict(batch, device="cpu")

assert set(result.keys()) == set(TENSOR_KEYS)

def test_tensor_values_preserved(self):
"""Moving to the same device must not alter tensor values."""
batch = _make_graph_batch("cpu")
result = batch_to_dict(batch, device="cpu")

torch.testing.assert_close(result["atom_pos"], batch.atom_pos)
torch.testing.assert_close(result["edge_index"], batch.edge_index)

def test_unsupported_model_type_raises(self):
"""Non-m3gnet model types should raise NotImplementedError."""
batch = _make_graph_batch("cpu")
with pytest.raises(NotImplementedError):
batch_to_dict(batch, model_type="graphormer", device="cpu")
with pytest.raises(NotImplementedError):
batch_to_dict(batch, model_type="unknown", device="cpu")
Loading