From 82c787de0ea98e3c032368a5fd4ee46a1fac6283 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 10 Feb 2022 20:47:24 -0500 Subject: [PATCH] do some small optimization to ops (#943) * do some small optimization to ops 1. avoid concat or add in loops. Instead, append tensors to a list, and concat or accumulate_n after loops 2. remove a duplicated reshape * revert unnecessary changes * revert wfc.py as it has been decrepated --- deepmd/descriptor/se_a.py | 9 +++++---- deepmd/fit/dipole.py | 7 +++---- deepmd/fit/ener.py | 10 +++++----- deepmd/fit/polar.py | 14 ++++++-------- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 9883909082..cf218309bd 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -783,6 +783,7 @@ def _filter( type_i = 0 # natom x 4 x outputs_size if type_embedding is None: + rets = [] for type_i in range(self.ntypes): ret = self._filter_lower( type_i, type_input, @@ -797,12 +798,12 @@ def _filter( bavg = bavg, trainable = trainable, suffix = "_"+str(type_i)) - if type_i == 0: - xyz_scatter_1 = ret - elif (type_input, type_i) not in self.exclude_types: + if (type_input, type_i) not in self.exclude_types: # add zero is meaningless; skip - xyz_scatter_1+= ret + rets.append(ret) start_index += self.sel_a[type_i] + # faster to use accumulate_n than multiple add + xyz_scatter_1 = tf.accumulate_n(rets) else : xyz_scatter_1 = self._filter_lower( type_i, type_input, diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index 44e84dac1f..bed93c181a 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -127,6 +127,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -158,11 +159,9 @@ def build (self, final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3]) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.reshape(outs, [-1]) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 71657e1b98..63be3a8e3f 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -435,6 +435,7 @@ def build (self, if atype_embed is None: start_index = 0 + outs_list = [] for type_i in range(self.ntypes): if bias_atom_e is None : type_bias_ae = 0.0 @@ -454,12 +455,11 @@ def build (self, ) final_layer += self.atom_ener[type_i] - zero_layer final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]]) - # concat the results - if type_i == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) start_index += natoms[2+type_i] + # concat the results + # concat once may be faster than multiple concat + outs = tf.concat(outs_list, axis = 1) # with type embedding else: if len(self.atom_ener) > 0: diff --git a/deepmd/fit/polar.py b/deepmd/fit/polar.py index e5632fadbb..f139fcbd99 100644 --- a/deepmd/fit/polar.py +++ b/deepmd/fit/polar.py @@ -60,6 +60,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, 9 * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -93,11 +94,9 @@ def build (self, final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3, 3]) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) @@ -311,6 +310,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -367,11 +367,9 @@ def build (self, final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye(3, batch_shape=[tf.shape(inputs)[0], natoms[2+type_i]], dtype = GLOBAL_TF_FLOAT_PRECISION) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.reshape(outs, [-1])