Skip to content

Commit

Permalink
test(hybrid): add ut for descriptor hybrid (#3711)
Browse files Browse the repository at this point in the history
Fix #3705.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Apr 29, 2024
1 parent 981ce44 commit c6b7f17
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def call(
# nf x nloc x (ng x ng1 + tebd_dim)
if self.concat_output_tebd:
grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1)
gr = gr.reshape(nf, nloc, *gr.shape[1:])
return grrg, gr[..., 1:], None, None, sw

def serialize(self) -> dict:
Expand Down
66 changes: 66 additions & 0 deletions source/tests/common/dpmodel/test_descriptor_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.dpmodel.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.dpmodel.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.descriptor.se_r import (
DescrptSeR,
)

from .case_single_frame_with_nlist import (
TestCaseSingleFrameWithNlist,
)


class TestDescrptHybrid(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
unittest.TestCase.setUp(self)
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(
self,
):
rng = np.random.default_rng()
nf, nloc, nnei = self.nlist.shape
davg = rng.normal(size=(self.nt, nnei, 4))
dstd = rng.normal(size=(self.nt, nnei, 4))
dstd = 0.1 + np.abs(dstd)

ddsub0 = DescrptSeA(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
sel=self.sel,
)
ddsub0.davg = davg
ddsub0.dstd = dstd
ddsub1 = DescrptDPA1(
rcut=self.rcut,
rcut_smth=self.rcut_smth,
sel=np.sum(self.sel).item() - 1,
ntypes=len(self.sel),
)
ddsub1.davg = davg[:, :6]
ddsub1.dstd = dstd[:, :6]
ddsub2 = DescrptSeR(
rcut=self.rcut / 2,
rcut_smth=self.rcut_smth,
sel=[3, 1],
)
ddsub2.davg = davg[:, :4, :1]
ddsub2.dstd = dstd[:, :4, :1]
em0 = DescrptHybrid(list=[ddsub0, ddsub1, ddsub2])

em1 = DescrptHybrid.deserialize(em0.serialize())
mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist)
mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist)
for ii in [0, 1]:
np.testing.assert_allclose(mm0[ii], mm1[ii])

0 comments on commit c6b7f17

Please sign in to comment.