Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unwrap variable values in all stateless calls. #19287

Merged
merged 1 commit into from Mar 12, 2024

Conversation

hertschuh
Copy link
Contributor

When variables are passed to stateless calls (stateless_call, stateless_update, stateless_apply), unwrap them to extract their value so that the mapping from variable to value in StatelessScope points to a value. This is to prevent an infinite recursion when performing operations (e.g. __add__) in a stateless scope.

Also fix issue where casting a tf.SparseTensor would lose the shape. The optimization in ops.cast is what revealed the stateless calls bug.

@codecov-commenter
Copy link

codecov-commenter commented Mar 11, 2024

Codecov Report

Attention: Patch coverage is 44.44444% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 54.23%. Comparing base (c8700f4) to head (a27cb72).
Report is 88 commits behind head on master.

Files Patch % Lines
keras/backend/tensorflow/core.py 16.66% 4 Missing and 1 partial ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #19287       +/-   ##
===========================================
- Coverage   80.14%   54.23%   -25.91%     
===========================================
  Files         341      365       +24     
  Lines       36163    39804     +3641     
  Branches     7116     7719      +603     
===========================================
- Hits        28982    21587     -7395     
- Misses       5578    16645    +11067     
+ Partials     1603     1572       -31     
Flag Coverage Δ
keras 54.23% <44.44%> (-25.76%) ⬇️
keras-jax ?
keras-numpy 54.23% <44.44%> (-2.86%) ⬇️
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

non_trainable_mapping = zip(
self.non_trainable_variables, non_trainable_variables
all_variables = map(
lambda v: v.value if isinstance(v, KerasVariable) else v,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For greater generality we should do this in StatelessScope I think?

What I don't get is that we already do v = backend.convert_to_tensor(v, dtype=k.dtype) in StatelessScope. This should grab the value, if I'm not mistaken?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For greater generality we should do this in StatelessScope I think?

Yes, I don't like the current duplication. And it's totally trivial to add.

What I don't get is that we already do v = backend.convert_to_tensor(v, dtype=k.dtype) in StatelessScope. This should grab the value, if I'm not mistaken?

Oh, now I fully understand why my change to tensorflow.core.convert_to_tensor triggered this.

The issue is that in Tensorflow, we have:

class Variable(
    KerasVariable,
    tf.__internal__.types.Tensor,
    tf.__internal__.tracking.Trackable,
):

Because of the tf.__internal__.types.Tensor, tf.is_tensor returns True and bypasses the conversion. However, we were lucky enough that tf.cast doesn't have a shortcut when the dtypes are identical, so it would cause the cast to happen always, at which point the variable was turned to an actual tf.Tensor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't the issue with JAX specifically? I'm confused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is only with Tensorflow... So it's not a real use case, but the tests do fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, perhaps not fully related, but I did run into exactly this issue with JAX a bit earlier -- you cannot call model.stateless_call with Variables, you need to unwrap them. The nature of the issue is an infinite recursion, which seems to happen at the level of JAX.

Can you just add explicit if isinstance logic in StatelessScope? That should fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, doing that, now it's a 1-line change basically. Thanks!

When variables are passed to stateless calls (`stateless_call`, `stateless_update`, `stateless_apply`), unwrap them to extract their value so that the mapping from variable to value in `StatelessScope` points to a value.
This is to prevent an infinite recursion when performing operations (e.g. `__add__`) in a stateless scope.

Also fix issue where casting a `tf.SparseTensor` would lose the shape. The optimization in `ops.cast` is what revealed the stateless calls bug.
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 12, 2024
@fchollet fchollet merged commit ffa9d52 into keras-team:master Mar 12, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Mar 12, 2024
@hertschuh hertschuh deleted the stateless_call branch March 12, 2024 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants