Skip to content

Commit

Permalink
do some small optimization to ops (#943)
Browse files Browse the repository at this point in the history
* do some small optimization to ops

1. avoid concat or add in loops. Instead, append tensors to a list, and concat or accumulate_n after loops
2. remove a duplicated reshape

* revert unnecessary changes

* revert wfc.py as it has been decrepated
  • Loading branch information
njzjz committed Feb 11, 2022
1 parent 0d8fe0a commit 82c787d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 21 deletions.
9 changes: 5 additions & 4 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def _filter(
type_i = 0
# natom x 4 x outputs_size
if type_embedding is None:
rets = []
for type_i in range(self.ntypes):
ret = self._filter_lower(
type_i, type_input,
Expand All @@ -797,12 +798,12 @@ def _filter(
bavg = bavg,
trainable = trainable,
suffix = "_"+str(type_i))
if type_i == 0:
xyz_scatter_1 = ret
elif (type_input, type_i) not in self.exclude_types:
if (type_input, type_i) not in self.exclude_types:
# add zero is meaningless; skip
xyz_scatter_1+= ret
rets.append(ret)
start_index += self.sel_a[type_i]
# faster to use accumulate_n than multiple add
xyz_scatter_1 = tf.accumulate_n(rets)
else :
xyz_scatter_1 = self._filter_lower(
type_i, type_input,
Expand Down
7 changes: 3 additions & 4 deletions deepmd/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def build (self,
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

count = 0
outs_list = []
for type_i in range(self.ntypes):
# cut-out inputs
inputs_i = tf.slice (inputs,
Expand Down Expand Up @@ -158,11 +159,9 @@ def build (self,
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3])

# concat the results
if count == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis = 1)

tf.summary.histogram('fitting_net_output', outs)
return tf.reshape(outs, [-1])
Expand Down
10 changes: 5 additions & 5 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def build (self,

if atype_embed is None:
start_index = 0
outs_list = []
for type_i in range(self.ntypes):
if bias_atom_e is None :
type_bias_ae = 0.0
Expand All @@ -454,12 +455,11 @@ def build (self,
)
final_layer += self.atom_ener[type_i] - zero_layer
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]])
# concat the results
if type_i == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
start_index += natoms[2+type_i]
# concat the results
# concat once may be faster than multiple concat
outs = tf.concat(outs_list, axis = 1)
# with type embedding
else:
if len(self.atom_ener) > 0:
Expand Down
14 changes: 6 additions & 8 deletions deepmd/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def build (self,
rot_mat = tf.reshape(rot_mat, [-1, 9 * natoms[0]])

count = 0
outs_list = []
for type_i in range(self.ntypes):
# cut-out inputs
inputs_i = tf.slice (inputs,
Expand Down Expand Up @@ -93,11 +94,9 @@ def build (self,
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3, 3])

# concat the results
if count == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis = 1)

tf.summary.histogram('fitting_net_output', outs)
return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION)
Expand Down Expand Up @@ -311,6 +310,7 @@ def build (self,
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

count = 0
outs_list = []
for type_i in range(self.ntypes):
# cut-out inputs
inputs_i = tf.slice (inputs,
Expand Down Expand Up @@ -367,11 +367,9 @@ def build (self,
final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye(3, batch_shape=[tf.shape(inputs)[0], natoms[2+type_i]], dtype = GLOBAL_TF_FLOAT_PRECISION)

# concat the results
if count == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis = 1)

tf.summary.histogram('fitting_net_output', outs)
return tf.reshape(outs, [-1])
Expand Down

0 comments on commit 82c787d

Please sign in to comment.