Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: support dpa2 model parallel inference #3657

Merged
merged 104 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
ae0f799
init
CaRoLZhangxy Apr 7, 2024
96c9309
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 7, 2024
bd1927f
init
CaRoLZhangxy Apr 8, 2024
8350372
fix
CaRoLZhangxy Apr 8, 2024
28ae599
finish
CaRoLZhangxy Apr 8, 2024
1afd8fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
7f6632a
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 15, 2024
29d1bec
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 17, 2024
2a7db1e
use google cuda define
CaRoLZhangxy Apr 17, 2024
6af0d63
update forward api
CaRoLZhangxy Apr 17, 2024
3020781
remove frozen model
CaRoLZhangxy Apr 17, 2024
420868f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
c779828
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 17, 2024
7591dd3
be able to compile without mpi
CaRoLZhangxy Apr 17, 2024
3d0f14d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
14a5fed
type to fix mpich
CaRoLZhangxy Apr 17, 2024
313a4b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
a39aff3
remove unused code
CaRoLZhangxy Apr 17, 2024
31a4f0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
cbb2916
update model
CaRoLZhangxy Apr 18, 2024
05686e8
upload smaller model
CaRoLZhangxy Apr 18, 2024
5dcf5f0
hack to resolve border_op problem
njzjz Apr 18, 2024
fd9177a
update dpa model
CaRoLZhangxy Apr 19, 2024
37989c7
use gpu memcpy
CaRoLZhangxy Apr 19, 2024
a13934b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
06208d2
update ut data
CaRoLZhangxy Apr 19, 2024
48b9833
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 19, 2024
761c1c8
update dpa model
CaRoLZhangxy Apr 19, 2024
0ed6116
update ut data
CaRoLZhangxy Apr 19, 2024
6df987c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
305245c
update ut data
CaRoLZhangxy Apr 19, 2024
8e5e41c
rollback ut and only apply new api to dpa2 model
CaRoLZhangxy Apr 19, 2024
44e0e6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
7bc66be
update ut data
CaRoLZhangxy Apr 19, 2024
3da55f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
50c0f46
add comments
CaRoLZhangxy Apr 19, 2024
38bcdd6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
b49d91d
add ut file
CaRoLZhangxy Apr 19, 2024
46911c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
66303a0
fix bug
CaRoLZhangxy Apr 19, 2024
c9bc208
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 19, 2024
dca1202
fix type bug
CaRoLZhangxy Apr 19, 2024
60605a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
048c2af
try to fix mpich compile error
CaRoLZhangxy Apr 19, 2024
1ff60b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2024
b7a19cf
fix ut data
CaRoLZhangxy Apr 19, 2024
6eb03f4
low requirement at float
CaRoLZhangxy Apr 19, 2024
3dcb4ba
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 19, 2024
0b485f9
skip no balance test
CaRoLZhangxy Apr 19, 2024
7dc5815
update ut data
CaRoLZhangxy Apr 21, 2024
303644f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2024
7867e13
update lmp test data
CaRoLZhangxy Apr 21, 2024
e0a08f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2024
993f05a
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
535ade4
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
f4b4481
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
ffbc4db
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
acf841d
Update source/op/pt/comm.cc
CaRoLZhangxy Apr 22, 2024
9473606
throw error when compiled with mpi without cuda support
CaRoLZhangxy Apr 22, 2024
bc02345
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
1952784
support mpich
CaRoLZhangxy Apr 22, 2024
fc2d61b
include errors.h
CaRoLZhangxy Apr 22, 2024
67b68aa
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 22, 2024
bc5f092
apply memcpy when cuda-aware = 0
CaRoLZhangxy Apr 23, 2024
e534e99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
05cdd92
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 23, 2024
ff17514
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 23, 2024
8b23ebb
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 24, 2024
14b43aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
ffd89cd
fix no cuda error
CaRoLZhangxy Apr 24, 2024
09cc940
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
e406e6a
fix compile error
CaRoLZhangxy Apr 24, 2024
3959f19
print log.lammps to screen in test_cuda if failed
njzjz Apr 25, 2024
521aa28
skip dpa test on cuda ,add todo and fix codeql
CaRoLZhangxy Apr 25, 2024
1704230
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 25, 2024
a06a49c
make pre-commit.ci pass
njzjz Apr 25, 2024
63123de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
f356ca5
add doc
CaRoLZhangxy Apr 25, 2024
5baefcc
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
CaRoLZhangxy Apr 26, 2024
51125b1
add doc
CaRoLZhangxy Apr 26, 2024
b09e857
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
32ba778
try runs-on t4
CaRoLZhangxy Apr 26, 2024
f3b55b4
Update deepmd/pt/model/descriptor/repformers.py
CaRoLZhangxy Apr 26, 2024
92ffb35
rename
CaRoLZhangxy Apr 26, 2024
8ed31c8
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 26, 2024
0269b81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
124c2d6
fix
CaRoLZhangxy Apr 26, 2024
189fc6b
run c++ test only
CaRoLZhangxy Apr 26, 2024
0adc34c
deal with mpi not init
CaRoLZhangxy Apr 26, 2024
d368ccc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
26b63e0
fix doc format
CaRoLZhangxy Apr 26, 2024
0c76246
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
86757bf
try to fix
CaRoLZhangxy Apr 26, 2024
ecdef3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
47242af
fix
CaRoLZhangxy Apr 26, 2024
d051e4b
Merge branch 'dis' of https://github.com/CaRoLZhangxy/deepmd-kit into…
CaRoLZhangxy Apr 26, 2024
8de1785
init mpi_init = 0
CaRoLZhangxy Apr 26, 2024
17b35e0
add world_size
CaRoLZhangxy Apr 26, 2024
74b21e4
add low version support
CaRoLZhangxy Apr 26, 2024
f784505
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
5fe1fd1
fix error
CaRoLZhangxy Apr 26, 2024
33c2798
add &
CaRoLZhangxy Apr 26, 2024
273a446
reset test.yml
CaRoLZhangxy Apr 26, 2024
beba142
add doc str in python
CaRoLZhangxy Apr 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
- run: |
export LD_LIBRARY_PATH=$GITHUB_WORKSPACE/dp_test/lib:$GITHUB_WORKSPACE/libtorch/lib:$CUDA_PATH/lib64:$LD_LIBRARY_PATH
export PATH=$GITHUB_WORKSPACE/dp_test/bin:$PATH
python -m pytest source/lmp/tests
python -m pytest -s source/lmp/tests || (cat log.lammps && exit 1)
python -m pytest source/ipi/tests
env:
OMP_NUM_THREADS: 1
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,14 @@

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.model_def_script:
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
extra_files = {"type": "dpa2"}

Check warning on line 290 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L290

Added line #L290 was not covered by tests
else:
extra_files = {"type": "else"}
torch.jit.save(
model,
FLAGS.output,
{},
extra_files,
)


Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def forward_common_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.

Expand All @@ -207,6 +208,8 @@ def forward_common_atomic(
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand Down Expand Up @@ -234,6 +237,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.

Expand Down Expand Up @@ -163,6 +164,7 @@ def forward_atomic(
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
)
assert descriptor is not None
# energy, force
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def forward_atomic(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -453,6 +454,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand All @@ -466,6 +468,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand Down
19 changes: 13 additions & 6 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -395,6 +396,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand All @@ -408,6 +410,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, mapps extended region index to local region.
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand Down Expand Up @@ -450,11 +454,13 @@ def forward(
# linear to change shape
g1 = self.g1_shape_tranform(g1)
# mapping g1
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
if comm_dict is None:
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
g1 = g1_ext
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -464,8 +470,9 @@ def forward(
],
extended_coord,
extended_atype,
g1_ext,
g1,
mapping,
comm_dict,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand All @@ -181,6 +182,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand Down Expand Up @@ -443,6 +446,7 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Calculate decoded embedding for each atom.

Expand Down
70 changes: 63 additions & 7 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def torch_linear(*args, **kwargs):
mylinear = simple_linear


if not hasattr(torch.ops.deepmd, "border_op"):

def border_op(
argument0,
argument1,
argument2,
argument3,
argument4,
argument5,
argument6,
argument7,
argument8,
) -> torch.Tensor:
raise NotImplementedError(
"border_op is not available since customized PyTorch OP library is not built when freezing the model."
)

# Note: this hack cannot actually save a model that can be runned using LAMMPS.
torch.ops.deepmd.border_op = border_op


@DescriptorBlock.register("se_repformer")
@DescriptorBlock.register("se_uni")
class DescrptBlockRepformers(DescriptorBlock):
Expand Down Expand Up @@ -234,9 +255,11 @@ def forward(
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
assert mapping is not None
assert extended_atype_embd is not None
if comm_dict is None:
assert mapping is not None
assert extended_atype_embd is not None
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -257,9 +280,13 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# [nframes, nloc, tebd_dim]
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]

if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
Expand All @@ -275,11 +302,40 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x ng1
mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
if comm_dict is None:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
g1_ext = torch.gather(g1, 1, mapping)
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
assert "send_num" in comm_dict
assert "recv_num" in comm_dict
assert "communicator" in comm_dict
ret = torch.ops.deepmd.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
comm_dict["recv_proc"],
comm_dict["send_num"],
comm_dict["recv_num"],
g1,
comm_dict["communicator"],
torch.tensor(nloc),
torch.tensor(nall - nloc),
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
)
g1_ext = ret[0].unsqueeze(0)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand All @@ -204,6 +205,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -92,6 +93,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand All @@ -233,6 +234,8 @@ def forward_common_lower(
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial.
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand All @@ -254,6 +257,7 @@ def forward_common_lower(
mapping=mapping,
fparam=fp,
aparam=ap,
comm_dict=comm_dict,
)
model_predict = fit_output_to_model_output(
atomic_ret,
Expand Down
39 changes: 38 additions & 1 deletion source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,49 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
CaRoLZhangxy marked this conversation as resolved.
Show resolved Hide resolved
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
* environments.
* @param[in] inum_ Number of core region atoms.
* @param[in] ilist_ Array storing the core region atom's index.
* @param[in] numneigh_ Array storing the core region atom's neighbor atom
* number.
* @param[in] firstneigh_ Array storing the core region atom's neighbor index.
* @param[in] nswap Number of swaps to be performed in communication.
* @param[in] sendnum Array storing the number of atoms to send for each swap.
* @param[in] recvnum Array storing the number of atoms to receive for each
* swap.
* @param[in] firstrecv Index of the first receive operation for each swap.
* @param[in] sendlist List of atoms to be sent for each swap.
* @param[in] sendproc Array of processor IDs to send atoms to for each swap.
* @param[in] recvproc Array of processor IDs from which atoms are received for
* each swap.
* @param[in] world Pointer to the MPI communicator or similar communication
* world used for the operation.
* @returns A pointer to the initialized neighbor list with communication
* capabilities.
*/
extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_,
int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
void* world);

/**
* @brief Delete a neighbor list.
*
* @param nl Neighbor list to delete.
*/
*
**/
extern void DP_DeleteNlist(DP_Nlist* nl);

/**
Expand Down