Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do some small optimization to ops #943

Merged
merged 5 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def _filter(
type_i = 0
# natom x 4 x outputs_size
if type_embedding is None:
rets = []
for type_i in range(self.ntypes):
ret = self._filter_lower(
type_i, type_input,
Expand All @@ -797,12 +798,12 @@ def _filter(
bavg = bavg,
trainable = trainable,
suffix = "_"+str(type_i))
if type_i == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we have a bug here? if type_i == 0 and (type_input, type_i) in self.exclude_types we had ret accumulated.

xyz_scatter_1 = ret
elif (type_input, type_i) not in self.exclude_types:
if (type_input, type_i) not in self.exclude_types:
# add zero is meaningless; skip
xyz_scatter_1+= ret
rets.append(ret)
start_index += self.sel_a[type_i]
# faster to use accumulate_n than multiple add
xyz_scatter_1 = tf.accumulate_n(rets)
else :
xyz_scatter_1 = self._filter_lower(
type_i, type_input,
Expand Down
7 changes: 3 additions & 4 deletions deepmd/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def build (self,
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

count = 0
outs_list = []
for type_i in range(self.ntypes):
# cut-out inputs
inputs_i = tf.slice (inputs,
Expand Down Expand Up @@ -158,11 +159,9 @@ def build (self,
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3])

# concat the results
if count == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis = 1)

tf.summary.histogram('fitting_net_output', outs)
return tf.reshape(outs, [-1])
Expand Down
10 changes: 5 additions & 5 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def build (self,

if atype_embed is None:
start_index = 0
outs_list = []
for type_i in range(self.ntypes):
if bias_atom_e is None :
type_bias_ae = 0.0
Expand All @@ -454,12 +455,11 @@ def build (self,
)
final_layer += self.atom_ener[type_i] - zero_layer
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]])
# concat the results
if type_i == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
start_index += natoms[2+type_i]
# concat the results
# concat once may be faster than multiple concat
outs = tf.concat(outs_list, axis = 1)
# with type embedding
else:
if len(self.atom_ener) > 0:
Expand Down
14 changes: 6 additions & 8 deletions deepmd/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def build (self,
rot_mat = tf.reshape(rot_mat, [-1, 9 * natoms[0]])

count = 0
outs_list = []
for type_i in range(self.ntypes):
# cut-out inputs
inputs_i = tf.slice (inputs,
Expand Down Expand Up @@ -93,11 +94,9 @@ def build (self,
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3, 3])

# concat the results
if count == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis = 1)

tf.summary.histogram('fitting_net_output', outs)
return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION)
Expand Down Expand Up @@ -311,6 +310,7 @@ def build (self,
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

count = 0
outs_list = []
for type_i in range(self.ntypes):
# cut-out inputs
inputs_i = tf.slice (inputs,
Expand Down Expand Up @@ -367,11 +367,9 @@ def build (self,
final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye(3, batch_shape=[tf.shape(inputs)[0], natoms[2+type_i]], dtype = GLOBAL_TF_FLOAT_PRECISION)

# concat the results
if count == 0:
outs = final_layer
else:
outs = tf.concat([outs, final_layer], axis = 1)
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis = 1)

tf.summary.histogram('fitting_net_output', outs)
return tf.reshape(outs, [-1])
Expand Down