-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Should Flax return FrozenDicts or regular dicts? #1223
Comments
I think the Python saying "We're all consenting adults here" is pretty fitting. In my view, trading convenience for safety is reasonable here because JAX users should know (or will quickly come to learn) that under the JAX transformations, they should not mutate state. Since Though I would prefer the user-facing API to just use
Explicit state management is one of my favourite aspects of Flax, as it gives me the ability to transparently manipulate modules/parameters without worrying about hidden side effects. I totally agree with @lucasb-eyer's point that it's counterproductive to provide explicit state without allowing the user to fully control it. |
Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). That said, I don't think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn't avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function. I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:
Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together. |
This is a good point, especially as a codebase grows it can sneak past you. Personally, if
I actually quite like the immutable dataclasses, since the |
Gear, very happy about this decision. I'd just like to add that
Is a complete red herring. This is about hidden global state, whereas this discussion is specifically about explicit, non-global state. It's actually more about rng design than anything else, and what we are talking about doing here is already the "better" rng design where the user explicitly is given, and trusted to correctly handle, the state. |
Is there any further development on this? |
Sorry for the delay -- I was on parental leave. @jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses? |
What does merging the dataclasses consist of? Are flax dataclasses going to be inheriting the mapping interface? |
The merging of dataclasses is taking much longer than originally anticipated. I'll bring this up in our next sync meeting because I think we should start to move towards allowing mutability independently of actually merging the implementations witch chex |
Sorry, but why would you do that?
Also, I stil don't understand what this merge will consist of. Flax's dataclasses are well-designed: They are just frozen dataclasses that register as pytrees, have a field function that conveniently supports marking static fields, and add a Chex datacasses are badly designed: they are not frozen, they can't mark static fields, and they unnecessarily expose the whole mapping interface, which means you can access fields as attributes or keys. They also expose a I was hoping to ditch tjax's dataclasses in favor of flax's, but if you're merging in any of chex's behavior, I won't be able to. |
We won't be removing features like frozen, static fields, and replace. As for the mapping interface. This is actually what's blocking a merge. Chex dataclases support tf.nest and dm-tree. Which is an alternative to jax.tree_util that relies on the mapping interface and doesn't support custom types. This is also why chex cannot easily add static fields because tf.nest doesn't support it. We don't want to inherit the mapping interface because it limits functionality and is really mostly a hack to support custom tf.nest types. |
I understand. But the issue with that is that you open users to bugs by allowing impure methods. The reality is that Jax's decorated functions ( For statistics, in my 5500 line Jax project, I call
Instead having a gigantic interface and passing the dataclass
I see. Why not create an
Yes! Thank you! |
Yes, this is the tradeoff we have to think about and this we will discuss this further before making a final decision. |
What is a particular form of this problem? Typically the code in the main training loop isn't pure anyways (and isn't meant to be pure, as it reports metrics, saves checkpoints, etc). I understand the need to ensure frozen data structures within modules (and we're not proposing this changes -- |
By the way, I also think that Flax returning frozen dictionaries is extremely annoying. Changing this behaviour would also address google-deepmind/optax#160 Moreover, our (NetKet) users and students learning Jax/Flax find it often confusing why they keep getting this object that they have to melt to edit. |
Sorry, I'm not actually discussing the topic of the issue. I just noticed a comment about merging
I can't find the example, but I saw one with treex (which doesn't enforce frozen dataclasses) where someone was doing def f(x):
x.some_member = some_value
return x
@jit
def g(...):
...
x = f(x) # if you forget to assign to x, you will get different behavior for the jitted and unjitted function.
Could you point me to an example? It seems that in that case, you can use an ordinary dataclass from the standard library or an ordinary class. |
@NeilGirdhar I just mean things like updating state and params and reporting metrics -- it's totally fine to directly manipulate the variables dict in the main training loop, and people have to jump through (IMHO unnecessary) hoops to achieve this: #1729 (comment). |
@avital Fair enough. I need to learn Flax better before I can really suggest something. A couple other options: An
The Or maybe a context manager that provides the handle and automatically rolls it back in when it ends:
You'd still be jumping through hoops, but it's just one hoop. |
The problem with any hoop isn't it's complexity -- it's that it's something you have to learn suddenly, when you "just wanted to try this one thing". So any hoop should be justified by the benefit it gives you (hopefully a lot). Maybe I'm just misunderstanding this but I never understood the benefit of having |
I guess another way to put it -- if someone really wants immutable data structures, they can always do, e.g. |
And the answer is just plain dict, at least for this user here :) |
+1 for this! I have a lot of code that immediately calls |
Hey @NeilGirdhar! I believe you're looking for this example from Treex's User Guide. |
Since this would be a breaking change, we should bump Flax's version to avoid breaking OS user's using semantic versioning. |
FYI: @chiamp is going to look into this |
Closing after #3193 landed. |
This topic is discussed regularly internally, and I feel we haven't reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.
Arguments in favor of FrozenDict
Arguments in favor of regular dicts
@lucasb-eyer: Flax tells me "here's these precious weights, please hold them for me and give them back to me later on, but DONT TOUCH" it begs the question: why give them to me in the first place, if I'm not supposed to do anything with it?
@avital: I also think it'd be better for Flax to return normal Python dicts, but still use FrozenDict within modules (via the
mutable
argument toapply
).The text was updated successfully, but these errors were encountered: