New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FLIP: Make module instances semantically meaningful by not overriding module.__new__
#208
Comments
I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive. Besides, I actually see a separation between |
Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as
Yes. The issue is that by simply letting people use |
💯 to this change. It aligns the mental model with TF2's
Please don't. Re-using an instance is the common, intuitive, friction-less way of sharing weights; this would just add annoying overhead for the sake of avoiding a mistake which, frankly, I have never encountered. An explicit Regarding the Finally, I think you can have an even more convincing example for modules which have more than just the obvious |
Thanks for this proposal! I agree with the other comments:
Do you mean passing modules directly into jit? One of the things I tried to do away with during my weekend excursion was |
Nice, I like this change. It is a good start. However, if you are making such a breaking change, this feels too conservative. Core Issues:
def __call__(self, x):
x = nn.Dense(features=16)(x)
x = nn.relu(x)
x = nn.Dense(features=16)(x) => Does this mean? def __call__(self, x):
x = nn.Dense(self, features=16)(x)
x = nn.relu(x)
x = nn.Dense(self, features=16)(x) (Or alternatively pytorch / sonnet 2 syntax which both do this better)
bias = self.param('bias', (self.features,), self.bias_init) => ? bias = nn.Param(self, (self.features,), self.bias_init) |
In general I really like the look of this! I think it would be a significant improvement/simplification of Flax's mental model. 👍 for eliminating the use of 👍 for eliminating 👍 for eliminating 👍 for encouraging the use of 👎 for requiring dataclasses, and not allowing 👍 for the proposed transition plan, which looks quite practical. |
One question arises: how does this change effect (if at all) with the way we initialize Flax models? Do we still stick with |
👍 to everything said by @lucasb-eyer, @srush, and @shoyer. I think having separate FWIW, I don't see the hydra thing as much of a disadvantage. In many cases it requires people to be more explicit, and you can see what's going on in the submodule itself, instead of hiding things in implicit behind the scenes work. It also then makes it easier to access model attributes from outside the module if you want to hack things later, say in a colab notebook. Also, i think it's great to allow access to |
This should also help with my confusion #16 (comment) where not calling partial before create_by_shape results in the model being created with different parameters to what it was trained with. |
Can you add an example of how this would work with an equivalent to |
@srush Could you kindly clarify what you mean by this? Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g., This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in |
@shoyer To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters. The mental overhead of having a little boiler plate is a small price to pay for such explicit clarity and python native interaction paradigms (using python's built-in object attributes, vs. some behind the scenes implicit naming schemes) |
Absolutely, these are all real advantages. On the other hand, I've also had cases where separating initialization/use of layers made my code harder to read and modify because two different parts of the code need to be kept in sync. You also can't use input shapes to determine the shapes of variables. It is not clear to me (personally) which is better/worse in general. It may depend on the context. Keras lets you write things both ways, which is convenient for users, but of course imposes an even higher cost in terms of complexity. For JAX, there is one additional consideration, which is whether the module abstraction is amenable to functional transformations -- one of the core strengths of JAX. My understanding is that this is hard to do with Python's mutable object model. |
I consider inline initialization a Keras design flaw. It mixes functional and structural concerns and makes it very hard to reason about, document, and analyze modules. However, whether or not you agree with this, the fact that it is causing the library to have ill-defined semantics, with very minimal benefits ("less scrolling up?"), should be a red flag that it is maybe a problem. |
I think we should not use worlds like "normal" or "pythonic". They are really vague statements that essentially refer to similarity with existing programing paradigms that are common in the Python world. We shouldn't strive to please the status quo. I think the points raised by @srush are important. Although sharing becomes clearer with explicit construction it still isn't quite like an object that owns it's parameters. Consider the following example:
Clearly in Of course we can add annotation trickery to distinguish between module methods that have a scope and "inline methods"? But the mental model is still significantly more complex than plain old Python objects. |
Perhaps I am missing something, but I don't really understand the example above. The implied semantics feel really complicated to me as state seems to bind to functions in a way I cannot trace. Btw, I don't know if it is helpful, but here is a proof-of-concept of the sort of pure world I like (not saying flax needs to go this way). https://github.com/srush/parallax # Everything is immutable @module = dataclass(frozen=True, repr=False)
@module
class Dense(Module):
# All parameter-holders are explicitly declared.
weight : Parameter
bias : Parameter
# Setup replace __init__ and creates shapes and binds lazy initializers.
@staticmethod
def setup(in_size, out_size):
return Dense.init(
weight = Parameter.setup((out_size, in_size), init.xavier_normal_),
bias = Parameter.setup((out_size,), init.normal_))
# Forward is just like standard pytorch.
def forward(self, input):
return self.weight @ input + self.bias "Sharing" would requires a manual split of the parameter into two parts like this. @module
class BinaryNetwork(Module):
# No difference between modules and parameters
dense1 : Dense
dense2 : Dense
dense3 : Dense
dropout : Dropout
@staticmethod
def setup(input_size, hidden_size):
return BinaryNetwork.init(
dense1 = Dense.setup(input_size, hidden_size),
dense2 = Dense.setup(hidden_size, hidden_size),
dense3 = Dense.setup(hidden_size, 1),
dropout = Dropout.setup(rate=0.2)
)
def forward(self, input):
# Standard usage works out of the box.
x = torch.tanh(self.dense1(input))
# Stochastic modules (have random seed already)
x = self.dropout(x)
# Shared params / recurrence requires split (like RNG)
dense2_a, dense2_b = self.dense2.split(2)
x = torch.tanh(dense2_a(x))
x = torch.tanh(dense2_b(x))
return torch.sigmoid(self.dense3(torch.tanh(x))) |
My expectation from reading this code is that all |
yep, was about to say the same as @shoyer the example is convoluted, but we are creating a new Dense object each time, so would definitely not expect weight sharing. Any sharing happening in that code would be weird magic happening under the hood that is very confusing. |
@srush I fail to see how your example semantically differs from plain PyTorch/nn code? It's "create object at init, use object to apply at forward" semantics, the remaining differences from plain PyTorch/nn look like mostly syntax to me? (edit: not saying this is bad, I like PyTorch/nn) |
@lucasb-eyer Sorry, I should have explained better. The fact that it looks like pytorch syntax is a red-herring, unlike pytorch the implementation is pure / immutable. It's "create declarative skeleton at init, (engine fills in tensors), (engine distributes RNG to module), use objects statelessly to apply at forward" layer = BinaryNetwork.setup(5, 10)
# Initialize parameters -> stateful, hidden
rng = rng_state()
layer = layer.initialize(rng)
for i in range(10):
rng = rng_state()
layer = layer.init_state(rng, mode="train")
grad = grad(layer.forward)(x)
layer = layer.update(lambda a, b: a + b, grad) |
I see, yeah I was missing the "use it" code, should've checked your repo. My personal opinion is that classes are the wrong concept to build something pure/immutable/functional. A few colleagues and I have an internal codebase built on jax, which uses flax in a completely pure/functional way, and flax was open to some design changes to make using flax in that way possible and nice. I think it is very close to your example code actually. We made a simplified version of it public just now, see here: https://github.com/google-research/big_transfer/tree/master/bit_jax However, all of this pure, pretty, neat, readable stuff goes to 💣 💩 ⚡ the moment you want to add BatchNorm :) |
Nice I will check it out. Maybe what needs to happen is for the jax community to just have nn.functional module like pytorch so different module systems can use the same layers. @lucasb-eyer I am still just stuck on one point that is keeping me bother by all these solutions: When you read this code below what is the internal/informal semantics that is going on in your head. Particularly: Where do you imagine that name is stored? do you believe this code knows it is in an object? Do you have a type in your head of x? How do you reason about whether this line of code knows if it is the first or last time it is called? Could this code be tested independently of its system? x = nn.Dense(x, num_classes, name="conv_head", kernel_init=nn.initializers.zeros) Until I can answer these questions, I just can't imagine this will be the final state of a reliable module system. |
but Regarding your second paragraph, I agree that the line has too much magic (also, where are the dense's w/b tracked? a global collection maybe? 😨) And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics. |
That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy : https://pytorch.org/docs/stable/nn.functional.html
I agree. That's what I'm interested in.
It gets halfway there, I'm arguing it needs to be really solved. |
@lucasb-eyer Very neat paper though! |
Not true either. I was about to concede it's missing an RNN, but there is actually none in
I went back to read it, and I actually agree with the points in your first comment in this thread. Thanks :) |
Oh well, now I feel silly. It does seem like the lax functions just are much more general than the pytorch implementations. I honestly never found reduce_window on my own (the doc of "Wraps XLA’s ReduceWindow operator" doesn't really help). The Stax implementation does make it clear though. |
No worries. |
FYI (you may already know this) most of the ops in lax (things like reduce_window) are documented in more detail here. I guess we ought to copy more of those docs over to JAX. |
Speaking as a teacher, the XLA docs scare me. They are very jargon heavy. It would be like asking numpy students to read the blas docs. |
Inspired by google/flax#208 I wondered if we could similarly offer support for @DataClass to define modules. It seems we can :) One thing I wanted to maintain was support for custom names. To do this in the examples below I add an optional name property to the dataclasses (you can omit this if you prefer to take the default name) and then we override it with the uniquified name as part of `__post_init__`. ```python @DataClass class Linear(hk.Module): output_size: int name: Optional[str] = None def __call__(self, x): j, k = x.shape[-1], self.output_size w = hk.get_parameter("w", [j, k], init=jnp.ones) b = hk.get_parameter("b", [k], init=jnp.zeros) return x @ w + b @DataClass class MLP(hk.Module): output_sizes: Sequence[int] activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu name: Optional[str] = None def __call__(self, x): for i, output_size in enumerate(self.output_sizes): if i > 0: x = self.activation(x) x = Linear(output_size)(x) return x ``` PiperOrigin-RevId: 313392449 Change-Id: I4c30d34474bd78c567c4101958fc12c8c82f3efa
It's been a while, and sorry for not posting more in this thread. We've gone through a major API redesign aligned with the goals originally described in this thread. Our new Linen API came out of many user group discussions, trying to find a solution that empowers our users, while staying relatively simple and exposes the full power of JAX. All of our examples have been ported, and multiple large projects have transitioned using our upgrade guide, so now we're making it the official API. Please check it out! Please ask any questions or suggestions for improvements on our discussion board. The old |
Thanks @avital ! I really like the new api, thanks for putting the work into it and being direct about the tradeoffs. I will definitely be using it for my next project. (probably without @nn.compact , but that is totally okay if they are compatible). I found this helpful: https://colab.research.google.com/github/google/flax/blob/master/docs/notebooks/linen_intro.ipynb |
Introduction
Currently, while Flax modules are defined by subclassing
flax.nn.Module
, those modules don't behave the same way that normal Python objects behave.One of the large differences is that Flax Modules override
__new__
, meaning that module instances aren't a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)
) actually does two things:nn.Dense
(using the non-documented APImodule.new_instance()
)apply
method on that instance and return it.Some upsides of the current approach are:
__init__
and__call__
to fully understand what a module does.nn.Dense(x, features=10)
.Some downsides of the current approach are:
module.shared()
abstraction which has a confusing mental model -- what doesmodule.shared()
return? A module class? A module instance? Moreover, which arguments must be passed intomodule.shared()
in order for the shared module to be usable? (Behind the scenesshared
is implemented on top ofpartial
)new nn.Model(nn.Dense.partial(features=10), params)
-- why do we need to usepartial
to instantiate a Model? What type does the first argument tonn.Model
have? Is it a module class? Module instance?flax/nn/base.py
there is code that does "kwarg mangling". Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling.module_method
decorator turns methods that aren'tapply
into new Modules. This is surprising, for example how would I do the equivalent ofmodule.call(params, *args)
but to call a methodfoo
that's notapply
? That would bemodule.foo.call(params, *args)
. That's a pretty surprising mental model.apply
was a special-cased method on modules.Proposal
__new__
in Modules.partial()
.shared()
(though we may choose to keep it as a safeguard -- see below)apply
methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a "vanilla Python"__call__
method (or actually, any name you want) that only takes in the module input(s)For example, a simple Dense layer may look like this:
Then, an MLP would look like this:
I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)
For example, in this case re-using a module is trivial -- keep a reference to
nn.Dense(features=16)
and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of.shared()
that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when__call__
is invoked a second time, unless.shared()
was called on the module instance first)With this proposal, there's also no need for
module.partial
-- you can just usefunctools.partial(module.__call__)
orfunctools.partial(module)
. (Though this is a bit different than in current Flax because the return value offunctools.partial
in itself isn't a module, rather it's a function. But maybe it was always confusing to understandmodule.partial
-- does it override kwargs for all module methods? Just apply?)Possible transition plan
Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.
I propose adding, alongside every new module in
flax.nn
, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongsideDense
as shown above we would also haveThen the first part of the upgrade process is mainly search and replace "Dense" -> "dense", etc.. In addition, some more manual changes will possible be necessary for uses of
.partial
and.shared
. Later, users can transition into the new API at a time they see fit.Current State
@avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.
The text was updated successfully, but these errors were encountered: