Skip to content

Commit

Permalink
fix(tf): make se_atten_v2 masking smooth when davg is not zero (#3632)
Browse files Browse the repository at this point in the history
Currently, `se_atten_v2` is always masked to zero when `exclude_types`
is given. However, for the no neighbor case, the placeholder for a
virtual neighbor is `davg`. This causes discontinuity when
`set_davg_zero` is not set.

This PR uses `davg` for masking.

In production, we usually use `set_davg_zero` along with
`exclude_types`, so it hasn't caused a real problem.

I notice PT hasn't implemented `se_atten_v2` or `exclude_types`, but we
need attention in the future.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Apr 2, 2024
1 parent cb08410 commit 63601b0
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,17 @@ def _pass_filter(
tf.shape(inputs_i)[0],
self.nei_type_vec, # extra input for atten
)
inputs_i *= mask
if self.smooth:
inputs_i = tf.where(
tf.cast(mask, tf.bool),
inputs_i,
# (nframes * nloc, 1) -> (nframes * nloc, ndescrpt)
tf.tile(
tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]
),
)
else:
inputs_i *= mask
if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor:
inputs_i = descrpt2r4(inputs_i, atype)
layer, qmat = self._filter(
Expand Down

0 comments on commit 63601b0

Please sign in to comment.