Skip to content

Commit

Permalink
revert changes to the interface of the EnerFitting
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz committed Jan 14, 2022
1 parent 8f2dc44 commit 4f95d01
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
13 changes: 7 additions & 6 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def _build_lower(
def build (self,
inputs : tf.Tensor,
natoms : tf.Tensor,
nframes : tf.Tensor,
input_dict : dict = {},
reuse : bool = None,
suffix : str = '',
Expand All @@ -353,8 +352,6 @@ def build (self,
natoms[0]: number of local atoms
natoms[1]: total number of atoms held by this processor
natoms[i]: 2 <= i < Ntypes+2, number of type i atoms
nframes : tf.Tensor
The number of frames
reuse
The weights in the networks should be reused when get the variable.
suffix
Expand Down Expand Up @@ -401,11 +398,15 @@ def build (self,
trainable = False,
initializer = tf.constant_initializer(self.aparam_inv_std))

inputs = tf.reshape(inputs, [nframes, self.dim_descrpt * natoms[0]])
inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]])
if len(self.atom_ener):
# only for atom_ener
# like inputs, but we don't want to add a dependency on inputs
inputs_zero = tf.zeros((nframes, self.dim_descrpt * natoms[0]), dtype=self.fitting_precision)
nframes = input_dict.get('nframes')
if nframes is not None:
# like inputs, but we don't want to add a dependency on inputs
inputs_zero = tf.zeros((nframes, self.dim_descrpt * natoms[0]), dtype=self.fitting_precision)
else:
inputs_zero = tf.zeros((nframes, self.dim_descrpt * natoms[0]), dtype=self.fitting_precision)


if bias_atom_e is not None :
Expand Down
3 changes: 1 addition & 2 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def build (self,

coord = tf.reshape (coord_, [-1, natoms[1] * 3])
atype = tf.reshape (atype_, [-1, natoms[1]])
nframes = tf.shape(coord)[0]
input_dict['nframes'] = tf.shape(coord)[0]

# type embedding if any
if self.typeebd is not None:
Expand Down Expand Up @@ -188,7 +188,6 @@ def build (self,

atom_ener = self.fitting.build (dout,
natoms,
nframes,
input_dict,
reuse = reuse,
suffix = suffix)
Expand Down

0 comments on commit 4f95d01

Please sign in to comment.