Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows

`Unreleased <https://github.com/lab-cosmo/torch-pme/>`_
-------------------------------------------------------
Added
#####

* Add support for batched calculations


`Version 0.3.2 <https://github.com/lab-cosmo/torch-pme/releases/tag/v0.3.2>`_ - 2025-10-07
------------------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions examples/07-lode-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ def forward(
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_distances: Optional[torch.Tensor] = None,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Update meshes
assert self.potential.smearing is not None # otherwise mypy complains
Expand Down
204 changes: 204 additions & 0 deletions examples/12-padding-example.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also show a speed comparison for a batched one and a looped one...

Might be nice :-)

Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
Batched Ewald Computation with Padding
======================================

This example demonstrates how to compute Ewald potentials for a batch of systems with
different numbers of atoms using padding. The idea is to pad atomic positions, charges,
and neighbor lists to the same length and use masks to ignore padded entries during
computation. Note that batching systems of varying sizes in this way can increase the
computational cost during model training, since padded atoms are included in the batched
operations even though they don't contribute physically.
"""

# %%
import time

import torch
import vesin
from torch.nn.utils.rnn import pad_sequence

import torchpme

dtype = torch.float64
cutoff = 4.4

# %%
# Example: two systems with 5 different systems
systems = [
{
"symbols": ("Cs", "Cl"),
"positions": torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype),
"charges": torch.tensor([[1.0], [-1.0]], dtype=dtype),
"cell": torch.eye(3, dtype=dtype) * 3.0,
"pbc": torch.tensor([True, True, True]),
},
{
"symbols": ("Na", "Cl", "Cl"),
"positions": torch.tensor(
[(0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)], dtype=dtype
),
"charges": torch.tensor([[1.0], [-1.0], [-1.0]], dtype=dtype),
"cell": torch.eye(3, dtype=dtype) * 4.0,
"pbc": torch.tensor([True, True, True]),
},
{
"symbols": ("K", "Br", "Br", "K"),
"positions": torch.tensor(
[(0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25), (0.75, 0.75, 0.75)],
dtype=dtype,
),
"charges": torch.tensor([[1.0], [-1.0], [-1.0], [1.0]], dtype=dtype),
"cell": torch.eye(3, dtype=dtype) * 5.0,
"pbc": torch.tensor([True, True, True]),
},
{
"symbols": ("Mg", "O", "O", "Mg", "O"),
"positions": torch.tensor(
[
(0, 0, 0),
(0.5, 0.5, 0.5),
(0.25, 0.25, 0.25),
(0.75, 0.75, 0.75),
(0.1, 0.1, 0.1),
],
dtype=dtype,
),
"charges": torch.tensor([[2.0], [-2.0], [-2.0], [2.0], [-2.0]], dtype=dtype),
"cell": torch.eye(3, dtype=dtype) * 6.0,
"pbc": torch.tensor([True, True, True]),
},
{
"symbols": ("Al", "O", "O", "Al", "O", "O"),
"positions": torch.tensor(
[
(0, 0, 0),
(0.5, 0.5, 0.5),
(0.25, 0.25, 0.25),
(0.75, 0.75, 0.75),
(0.1, 0.1, 0.1),
(0.9, 0.9, 0.9),
],
dtype=dtype,
),
"charges": torch.tensor(
[[3.0], [-2.0], [-2.0], [3.0], [-2.0], [-2.0]], dtype=dtype
),
"cell": torch.eye(3, dtype=dtype) * 7.0,
"pbc": torch.tensor([True, True, True]),
},
]

# %%
# Compute neighbor lists for each system
i_list, j_list, d_list, pos_list, charges_list, cell_list, periodic_list = (
[],
[],
[],
[],
[],
[],
[],
)

nl = vesin.NeighborList(cutoff=cutoff, full_list=False)

for sys in systems:
neighbor_indices, neighbor_distances = nl.compute(
points=sys["positions"],
box=sys["cell"],
periodic=sys["pbc"][0],
quantities="Pd",
)
i_list.append(torch.tensor(neighbor_indices[:, 0], dtype=torch.int64))
j_list.append(torch.tensor(neighbor_indices[:, 1], dtype=torch.int64))
d_list.append(torch.tensor(neighbor_distances, dtype=dtype))
pos_list.append(sys["positions"])
charges_list.append(sys["charges"])
cell_list.append(sys["cell"])
periodic_list.append(sys["pbc"])

# %%
# Pad positions, charges, and neighbor lists
max_atoms = max(pos.shape[0] for pos in pos_list)
pos_batch = pad_sequence(pos_list, batch_first=True)
charges_batch = pad_sequence(charges_list, batch_first=True)
cell_batch = torch.stack(cell_list)
periodic_batch = torch.stack(periodic_list)
i_batch = pad_sequence(i_list, batch_first=True, padding_value=0)
j_batch = pad_sequence(j_list, batch_first=True, padding_value=0)
d_batch = pad_sequence(d_list, batch_first=True, padding_value=0.0)

# Masks for ignoring padded atoms and neighbor entries
node_mask = (
torch.arange(max_atoms)[None, :]
< torch.tensor([p.shape[0] for p in pos_list])[:, None]
)
pair_mask = (
torch.arange(i_batch.shape[1])[None, :]
< torch.tensor([len(i) for i in i_list])[:, None]
)
# %%
# Initialize Ewald calculator
calculator = torchpme.EwaldCalculator(
torchpme.CoulombPotential(smearing=0.5),
lr_wavelength=4.0,
)
calculator.to(dtype=dtype)

# %%
# Compute potentials in a batched manner using vmap
kvectors = torchpme.lib.compute_batched_kvectors(
lr_wavelength=calculator.lr_wavelength, cells=cell_batch
)

potentials_batch = torch.vmap(calculator.forward)(
charges_batch,
cell_batch,
pos_batch,
torch.stack((i_batch, j_batch), dim=-1),
d_batch,
periodic_batch,
node_mask,
pair_mask,
kvectors,
)

# %%
print("Batched potentials shape:", potentials_batch.shape)
print(potentials_batch)
# %%
# Compare performance of batched vs. looped computation
n_iter = 100

t0 = time.perf_counter()
for _ in range(n_iter):
_ = torch.vmap(calculator.forward)(
charges_batch,
cell_batch,
pos_batch,
torch.stack((i_batch, j_batch), dim=-1),
d_batch,
periodic_batch,
node_mask,
pair_mask,
kvectors,
)
t_batch = (time.perf_counter() - t0) / n_iter

t0 = time.perf_counter()
for _ in range(n_iter):
for k in range(len(pos_list)):
_ = calculator.forward(
charges_list[k],
cell_list[k],
pos_list[k],
torch.stack((i_list[k], j_list[k]), dim=-1),
d_list[k],
periodic_list[k],
)
t_loop = (time.perf_counter() - t0) / n_iter

print(f"Average time per batched call: {t_batch:.6f} s")
print(f"Average time per loop call: {t_loop:.6f} s")
print("Batched is faster" if t_batch < t_loop else "Loop is faster")
# %%
76 changes: 63 additions & 13 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Optional

import torch

Expand All @@ -9,15 +9,17 @@ def _validate_parameters(
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
periodic: Union[torch.Tensor, None] = None,
periodic: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> None:
dtype = positions.dtype
device = positions.device

# check shape, dtype and device of positions
num_atoms = len(positions)
if list(positions.shape) != [len(positions), 3]:
num_atoms = positions.shape[-2]
if list(positions.shape) != [num_atoms, 3]:
raise ValueError(
"`positions` must be a tensor with shape [n_atoms, 3], got tensor "
f"with shape {list(positions.shape)}"
Expand All @@ -40,14 +42,6 @@ def _validate_parameters(
f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)

if smearing is not None and torch.equal(
cell.det(), torch.tensor(0.0, dtype=cell.dtype, device=cell.device)
):
raise ValueError(
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)

# check shape, dtype & device of `charges`
if charges.dim() != 2:
raise ValueError(
Expand Down Expand Up @@ -120,3 +114,59 @@ def _validate_parameters(
f"device of `periodic` ({periodic.device}) must be same as that of "
f"the `positions` class ({device})"
)

if pair_mask is not None:
if pair_mask.shape != neighbor_indices[:, 0].shape:
raise ValueError(
"`pair_mask` must have the same shape as the number of neighbors, "
f"got tensor with shape {list(pair_mask.shape)} while the number of "
f"neighbors is {neighbor_indices.shape[0]}"
)

if pair_mask.device != device:
raise ValueError(
f"device of `pair_mask` ({pair_mask.device}) must be same as that "
f"of the `positions` class ({device})"
)

if pair_mask.dtype != torch.bool:
raise TypeError(
f"type of `pair_mask` ({pair_mask.dtype}) must be torch.bool"
)

if node_mask is not None:
if node_mask.shape != (num_atoms,):
raise ValueError(
"`node_mask` must have shape [n_atoms], got tensor with shape "
f"{list(node_mask.shape)} where n_atoms is {num_atoms}"
)

if node_mask.device != device:
raise ValueError(
f"device of `node_mask` ({node_mask.device}) must be same as that "
f"of the `positions` class ({device})"
)

if node_mask.dtype != torch.bool:
raise TypeError(
f"type of `node_mask` ({node_mask.dtype}) must be torch.bool"
)

if kvectors is not None:
if kvectors.shape[1] != 3:
raise ValueError(
"`kvectors` must be a tensor of shape [n_kvecs, 3], got "
f"tensor with shape {list(kvectors.shape)}"
)

if kvectors.device != device:
raise ValueError(
f"device of `kvectors` ({kvectors.device}) must be same as that of "
f"the `positions` class ({device})"
)

if kvectors.dtype != dtype:
raise TypeError(
f"type of `kvectors` ({kvectors.dtype}) must be same as that of the "
f"`positions` class ({dtype})"
)
Loading