Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FlatState a Mapping instead of a dict #3880

Merged
merged 1 commit into from
May 2, 2024

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented Apr 23, 2024

Other changes:

  • Since nnx.State calls flatten_dict, make the traverse_util.flatten_dict accept mappings.
  • Add a type annotation to flatten_dict so that any future changes to the function (e.g., changing it back to work with dicts only) triggers a type error.
  • Add the appropriate type overloads to the signature to satisfy MyPy errors.
  • Minor tweaks to imports:
    • import from collections.abc instead of typing since the latter imports are deprecated.
    • Import from flax.typing instead of flax.core.scope since the latter has been moved.
  • Add from __future__ import annotations so that the annotations work on Python 3.9.
  • Add pytype: skip-file to work around Unsuported operands reported for type union operator despite future import pytype#1619.
  • Annotate some private functions to make them easier to understand.
  • When printing a type, print its __qualname__ since that's a bit easier to read (str instead of <class 'str'>).

Fixes #3879

@NeilGirdhar
Copy link
Contributor Author

Note that this may change the behavior of calling flatten_dict on a structure with mappings that are neither dict nor FrozenDict. In practice, that's probably extremely rare though?

@NeilGirdhar NeilGirdhar force-pushed the fix_mappings branch 7 times, most recently from 7361b33 to db3833f Compare April 23, 2024 07:45
@NeilGirdhar
Copy link
Contributor Author

The PyType error is just a PyType bug: google/pytype#1619

is_leaf: None | Callable[[tuple[Any, ...],
Mapping[Any, Any]], bool] = None,
*,
sep: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this overload needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that there's a clean way to make the overloads work. Instead, I've changed it to make keep_empty_nodes, is_leaf, and sep keyword-only parameters since it appears that they are always called this way. Is it better this way?

@cgarciae cgarciae self-assigned this Apr 25, 2024
@cgarciae
Copy link
Collaborator

Thanks @NeilGirdhar ! I've been wanting to do this for a while.
I've approved but left a comment.

@codecov-commenter
Copy link

codecov-commenter commented Apr 25, 2024

Codecov Report

Attention: Patch coverage is 88.23529% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 60.44%. Comparing base (2c7d7cd) to head (d1935c7).
Report is 2 commits behind head on main.

Files Patch % Lines
flax/traverse_util.py 86.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3880      +/-   ##
==========================================
+ Coverage   60.43%   60.44%   +0.01%     
==========================================
  Files         105      105              
  Lines       13263    13272       +9     
==========================================
+ Hits         8015     8022       +7     
- Misses       5248     5250       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Also, make traverse_util.flatten_dict accept mappings.

Fixes google#3879
@copybara-service copybara-service bot merged commit d0e080d into google:main May 2, 2024
19 checks passed
@NeilGirdhar NeilGirdhar deleted the fix_mappings branch May 2, 2024 11:07
@NeilGirdhar
Copy link
Contributor Author

@tkoeppe Did this pull request break something? Was there a reason why this was backed out? Without it, there are problems in nnx (see top of the issue).

@tkoeppe
Copy link
Contributor

tkoeppe commented May 13, 2024

@NeilGirdhar Yes, I think this broke something, but I don't have details. Please ask the owners about that. We just saw a bunch of breakages and reverted the change before as a matter of course, and that got propagated to GitHub automatically.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 13, 2024

@tkoeppe That's okay, but in the future, would you mind adding a message as to what broke? Making this change to quite a bit of time. Also, if there is behavior that you are counting on internally, a good idea is to add a tesst for that behavior so that future changes don't break it.

@cgarciae Can you look into this for me? I imagine that the problem was the narrowing of the parameter list to positional-only/keyword-only. I can revert that and we'll have to live with the more clumsy annotation.

@tkoeppe
Copy link
Contributor

tkoeppe commented May 13, 2024

@NeilGirdhar: I'm sorry, but I have nothing to do with this GitHub interaction. I'm maintaining unrelated, internal code and reverted a breaking change internally. The GitHub interaction is managed by those code's owners. I understand that the GitHub push happened automatically, but please take it up with the owners if you think they are making inappropriate or inadequate commits. If this is a problem, perhaps we need to have stronger review requirements or not automate these changes, but I'll leave that for them to decide.

@cgarciae
Copy link
Collaborator

cgarciae commented May 13, 2024

@NeilGirdhar I think you change is good but since this is used at many other code bases at Google it was very hard to try to fix it internally (I tried for a few hours). The main problem is that some users where passing some structures that behaved like Mappings that where now being traversed when before they where leaves.

I have a simpler suggestion, what if we create nnx/traversals.py and copy these over with the desired signature?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented May 13, 2024

@tkoeppe I understand, and I don't want to create more work for you. It's just that I was waiting for this change in the next release to fix my code, and I was waiting for nothing. Just a note on the external pull request would help a lot 😄 and a copy of the error would help even more.

I have a simpler suggestion, what if we create nnx/traversals.py and copy these over with the desired signature?

Sounds great! Would you like me to do that, or do you want to do it?

The main problem is that some users where passing some structures that behaved like Mappings that where now being traversed when before they where leaves.

That should be fixable with:

flatten_dict(..., is_leaf=lambda xs: not isinstance(xs, (flax.core.FrozenDict, dict)))

to recover the old behavior. But I understand if these were buried in odd places since the error would be in a very different place than the problem.

@cgarciae
Copy link
Collaborator

Sounds great! Would you like me to do that, or do you want to do it?

If you wanna do it that would be great! Happy to review it.

That should be fixable with:

Maybe, sadly that was only some of the issues, the bigger issue is that this broke pytype in many places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

No way to call nnx.State.from_flat_path
4 participants