-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix network precision under specific situation #1391
Conversation
Codecov Report
@@ Coverage Diff @@
## devel #1391 +/- ##
==========================================
+ Coverage 75.53% 75.63% +0.09%
==========================================
Files 91 92 +1
Lines 7506 7531 +25
==========================================
+ Hits 5670 5696 +26
+ Misses 1836 1835 -1
Continue to review full report at Codecov.
|
@@ -736,7 +736,7 @@ def _filter_lower( | |||
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift | |||
else: | |||
# we can safely return the final xyz_scatter filled with zero directly | |||
return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) | |||
return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), self.filter_precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cast the result back to global precision before return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any cast
in line 722 or line 747
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there something wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any
cast
in line 722 or line 747
747 should be cast back...
Why do we need to cast at L722?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of _filter_lower
, I think we should cast back at line 839 before _filter
returns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not writing a decorator for doing that? It casts the inputs of _filter
to filter_precision
and casts back to global precision when _filter
returns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea!
"""A decorator that casts and casts back the input | ||
and output tensor of a method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please write more on the logic behind cast_precision
?
- It casts i tensor from
GLOBAL_TF_FLOAT_PRECISION
to precision defined by propertyprecision
. - It casts o tensor from
precision
toGLOBAL_TF_FLOAT_PRECISION
. - It checks the i/o list and only cast when an i or o is tensor and the tensor matches
GLOBAL_TF_FLOAT_PRECISION
orprecision
, respectively.
@@ -392,15 +392,15 @@ def _pass_filter(self, | |||
[ 0, start_index* self.ndescrpt], | |||
[-1, natoms[2+type_i]* self.ndescrpt] ) | |||
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) | |||
layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | |||
layer = self._filter_r(self.filter_precision, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first variable should be "inputs_i" instead of “self.filter_precision”. It couldn't pass UT but it was merged into devel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, this line is not covered by UT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's covered by source/tests/test_model_compression_se_r.py, a UT just added in PR #1361.
* Finish model compression for se_r descriptor! * Improve error type. * Update compress.md * Improve error type. * Improve error type. * Improve exception handling and unittest. * Add gtest for tabulate_fusion se_r. * Fix variable mistake from #1391 Co-authored-by: huangliang <huangla@dp.tech>
No description provided.