Skip to content

Commit

Permalink
pt: support dpa2 model parallel inference (#3657)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Lysithea <52808607+CaRoLZhangxy@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
  • Loading branch information
4 people committed Apr 30, 2024
1 parent ee47e75 commit d0fe13c
Show file tree
Hide file tree
Showing 26 changed files with 1,553 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
- run: |
export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH
export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
python -m pytest source/lmp/tests
python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
python -m pytest source/ipi/tests
env:
OMP_NUM_THREADS: 1
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,14 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.model_def_script:
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
torch.jit.save(
model,
FLAGS.output,
{},
extra_files,
)


Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def forward_common_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.
Expand All @@ -207,6 +208,8 @@ def forward_common_atomic(
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
comm_dict
The data needed for communication for parallel inference.
Returns
-------
Expand Down Expand Up @@ -234,6 +237,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.
Expand Down Expand Up @@ -163,6 +164,7 @@ def forward_atomic(
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
)
assert descriptor is not None
# energy, force
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def forward_atomic(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -453,6 +454,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.
Expand All @@ -466,6 +468,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
Expand Down
19 changes: 13 additions & 6 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -395,6 +396,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.
Expand All @@ -408,6 +410,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, mapps extended region index to local region.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
Expand Down Expand Up @@ -450,11 +454,13 @@ def forward(
# linear to change shape
g1 = self.g1_shape_tranform(g1)
# mapping g1
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
if comm_dict is None:
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
g1 = g1_ext
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -464,8 +470,9 @@ def forward(
],
extended_coord,
extended_atype,
g1_ext,
g1,
mapping,
comm_dict,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.
Expand All @@ -181,6 +182,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
Expand Down Expand Up @@ -443,6 +446,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Calculate decoded embedding for each atom.
Expand Down
70 changes: 63 additions & 7 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def torch_linear(*args, **kwargs):
mylinear = simple_linear


if not hasattr(torch.ops.deepmd, "border_op"):

def border_op(
argument0,
argument1,
argument2,
argument3,
argument4,
argument5,
argument6,
argument7,
argument8,
) -> torch.Tensor:
raise NotImplementedError(
"border_op is not available since customized PyTorch OP library is not built when freezing the model."
)

# Note: this hack cannot actually save a model that can be runned using LAMMPS.
torch.ops.deepmd.border_op = border_op


@DescriptorBlock.register("se_repformer")
@DescriptorBlock.register("se_uni")
class DescrptBlockRepformers(DescriptorBlock):
Expand Down Expand Up @@ -234,9 +255,11 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
assert mapping is not None
assert extended_atype_embd is not None
if comm_dict is None:
assert mapping is not None
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -257,9 +280,13 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# [nframes, nloc, tebd_dim]
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]

if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
Expand All @@ -275,11 +302,40 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x ng1
mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
if comm_dict is None:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
g1_ext = torch.gather(g1, 1, mapping)
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
assert "send_num" in comm_dict
assert "recv_num" in comm_dict
assert "communicator" in comm_dict
ret = torch.ops.deepmd.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
comm_dict["recv_proc"],
comm_dict["send_num"],
comm_dict["recv_num"],
g1,
comm_dict["communicator"],
torch.tensor(nloc),
torch.tensor(nall - nloc),
)
g1_ext = ret[0].unsqueeze(0)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.
Expand All @@ -204,6 +205,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -92,6 +93,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand All @@ -233,6 +234,8 @@ def forward_common_lower(
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial.
comm_dict
The data needed for communication for parallel inference.
Returns
-------
Expand All @@ -254,6 +257,7 @@ def forward_common_lower(
mapping=mapping,
fparam=fp,
aparam=ap,
comm_dict=comm_dict,
)
model_predict = fit_output_to_model_output(
atomic_ret,
Expand Down
39 changes: 38 additions & 1 deletion source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,49 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
* environments.
* @param[in] inum_ Number of core region atoms.
* @param[in] ilist_ Array storing the core region atom's index.
* @param[in] numneigh_ Array storing the core region atom's neighbor atom
* number.
* @param[in] firstneigh_ Array storing the core region atom's neighbor index.
* @param[in] nswap Number of swaps to be performed in communication.
* @param[in] sendnum Array storing the number of atoms to send for each swap.
* @param[in] recvnum Array storing the number of atoms to receive for each
* swap.
* @param[in] firstrecv Index of the first receive operation for each swap.
* @param[in] sendlist List of atoms to be sent for each swap.
* @param[in] sendproc Array of processor IDs to send atoms to for each swap.
* @param[in] recvproc Array of processor IDs from which atoms are received for
* each swap.
* @param[in] world Pointer to the MPI communicator or similar communication
* world used for the operation.
* @returns A pointer to the initialized neighbor list with communication
* capabilities.
*/
extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_,
int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
void* world);

/**
* @brief Delete a neighbor list.
*
* @param nl Neighbor list to delete.
*/
*
**/
extern void DP_DeleteNlist(DP_Nlist* nl);

/**
Expand Down

0 comments on commit d0fe13c

Please sign in to comment.