New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unwrap variable values in all stateless calls. #19287
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19287 +/- ##
===========================================
- Coverage 80.14% 54.23% -25.91%
===========================================
Files 341 365 +24
Lines 36163 39804 +3641
Branches 7116 7719 +603
===========================================
- Hits 28982 21587 -7395
- Misses 5578 16645 +11067
+ Partials 1603 1572 -31
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/layers/layer.py
Outdated
non_trainable_mapping = zip( | ||
self.non_trainable_variables, non_trainable_variables | ||
all_variables = map( | ||
lambda v: v.value if isinstance(v, KerasVariable) else v, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For greater generality we should do this in StatelessScope
I think?
What I don't get is that we already do v = backend.convert_to_tensor(v, dtype=k.dtype)
in StatelessScope
. This should grab the value, if I'm not mistaken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For greater generality we should do this in
StatelessScope
I think?
Yes, I don't like the current duplication. And it's totally trivial to add.
What I don't get is that we already do
v = backend.convert_to_tensor(v, dtype=k.dtype)
inStatelessScope
. This should grab the value, if I'm not mistaken?
Oh, now I fully understand why my change to tensorflow.core.convert_to_tensor
triggered this.
The issue is that in Tensorflow, we have:
class Variable(
KerasVariable,
tf.__internal__.types.Tensor,
tf.__internal__.tracking.Trackable,
):
Because of the tf.__internal__.types.Tensor
, tf.is_tensor
returns True and bypasses the conversion. However, we were lucky enough that tf.cast
doesn't have a shortcut when the dtypes are identical, so it would cause the cast to happen always, at which point the variable was turned to an actual tf.Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't the issue with JAX specifically? I'm confused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is only with Tensorflow... So it's not a real use case, but the tests do fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, perhaps not fully related, but I did run into exactly this issue with JAX a bit earlier -- you cannot call model.stateless_call
with Variables, you need to unwrap them. The nature of the issue is an infinite recursion, which seems to happen at the level of JAX.
Can you just add explicit if isinstance
logic in StatelessScope
? That should fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, doing that, now it's a 1-line change basically. Thanks!
When variables are passed to stateless calls (`stateless_call`, `stateless_update`, `stateless_apply`), unwrap them to extract their value so that the mapping from variable to value in `StatelessScope` points to a value. This is to prevent an infinite recursion when performing operations (e.g. `__add__`) in a stateless scope. Also fix issue where casting a `tf.SparseTensor` would lose the shape. The optimization in `ops.cast` is what revealed the stateless calls bug.
c512562
to
a27cb72
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
When variables are passed to stateless calls (
stateless_call
,stateless_update
,stateless_apply
), unwrap them to extract their value so that the mapping from variable to value inStatelessScope
points to a value. This is to prevent an infinite recursion when performing operations (e.g.__add__
) in a stateless scope.Also fix issue where casting a
tf.SparseTensor
would lose the shape. The optimization inops.cast
is what revealed the stateless calls bug.