Skip to content

Commit

Permalink
Fix a potential slice bug in se_t descriptor (#1087)
Browse files Browse the repository at this point in the history
* fix a potential slice bug in se_t

* fix UT error

* address comments
  • Loading branch information
denghuilu committed Sep 3, 2021
1 parent c824ff6 commit aab124f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion deepmd/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ def _filter(self,
inputs_i = tf.slice (inputs,
[ 0, start_index_i *4],
[-1, self.sel_a[type_i] *4] )
start_index_j = start_index_i
start_index_i += self.sel_a[type_i]
nei_type_i = self.sel_a[type_i]
shape_i = inputs_i.get_shape().as_list()
Expand All @@ -477,7 +478,6 @@ def _filter(self,
env_i = tf.reshape(inputs_i, [-1, nei_type_i, 4])
# with natom x nei_type_i x 3
env_i = tf.slice(env_i, [0, 0, 1], [-1, -1, -1])
start_index_j = 0
for type_j in range(type_i, self.ntypes):
# with natom x (nei_type_j x 4)
inputs_j = tf.slice (inputs,
Expand Down
6 changes: 3 additions & 3 deletions source/tests/test_model_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def test_model(self):
np.savetxt('e.out', e.reshape([1, -1]))
np.savetxt('f.out', f.reshape([1, -1]), delimiter = ',')
np.savetxt('v.out', v.reshape([1, -1]), delimiter = ',')
refe = [4.826771866004193612e+01]
reff = [5.355088169393570574e+00,5.606772412401632266e+00,2.703270748296462966e-01,5.381408138049708967e+00,5.261355614357515975e+00,-4.079549918988090162e-01,-5.182324474551911919e+00,3.695481388907447262e-01,-5.238474288082559799e-02,1.665564584447352670e-01,-5.955401876564963892e+00,-2.217626865156164251e-01,-5.967343479332643419e+00,9.073821102416884665e-02,3.703103995504785639e-01,2.466151879965444438e-01,-5.373012500109097367e+00,4.146494691512622732e-02]
refv = [-1.336768232407933077e+01,4.818050125305787801e-01,3.589284283410607568e-01,4.818050125305786691e-01,-1.225345559839458964e+01,-1.701405121682751653e-01,3.589284283410607568e-01,-1.701405121682752486e-01,-3.428455515842296353e-02]
refe = [4.8436558582194039e+01]
reff = [5.2896335066946598e+00,5.5778402259211131e+00,2.6839994229557251e-01,5.3528786387686784e+00,5.2477755362164968e+00,-4.0486366542657343e-01,-5.1297084055340498e+00,3.4607112287117253e-01,-5.1800783428369482e-02,1.5557068351407846e-01,-5.9071343228741506e+00,-2.2012359669589748e-01,-5.9156735320857488e+00,8.8397615509389127e-02,3.6701215949753935e-01,2.4729910864238122e-01,-5.3529501776440211e+00,4.1375943757728552e-02]
refv = [-1.3159448660141607e+01,4.6952048725161544e-01,3.5482003698976106e-01,4.6952048725161577e-01,-1.2178990983673918e+01,-1.6867277410496895e-01,3.5482003698976106e-01,-1.6867277410496900e-01,-3.3986741457321945e-02]
refe = np.reshape(refe, [-1])
reff = np.reshape(reff, [-1])
refv = np.reshape(refv, [-1])
Expand Down

0 comments on commit aab124f

Please sign in to comment.