Skip to content

Commit

Permalink
fix the bug of empty fitting net neuron. (#3458)
Browse files Browse the repository at this point in the history
fixes #3448

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang committed Mar 13, 2024
1 parent da68686 commit 571ddec
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def __init__(
resnet_dt=resnet_dt,
precision=precision,
)
i_in, i_ot = neuron[-1], out_dim
i_in = neuron[-1] if len(neuron) > 0 else in_dim
i_ot = out_dim
self.layers.append(
T_NetworkLayer(
i_in,
Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def test_consistency(
)
atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE)

for od, mixed_types, nfp, nap, et in itertools.product(
for od, mixed_types, nfp, nap, et, nn in itertools.product(
[1, 3],
[True, False],
[0, 3],
[0, 4],
[[], [0], [1]],
[[4, 4, 4], []],
):
ft0 = InvarFitting(
"foo",
Expand All @@ -60,6 +61,7 @@ def test_consistency(
numb_aparam=nap,
mixed_types=mixed_types,
exclude_types=et,
neuron=nn,
).to(env.DEVICE)
ft1 = DPInvarFitting.deserialize(ft0.serialize())
ft2 = InvarFitting.deserialize(ft0.serialize())
Expand Down

0 comments on commit 571ddec

Please sign in to comment.