Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLIP: Make module instances semantically meaningful by not overriding module.__new__ #208

Closed
avital opened this issue Apr 17, 2020 · 32 comments

Comments

@avital
Copy link
Contributor

avital commented Apr 17, 2020

Introduction

Currently, while Flax modules are defined by subclassing flax.nn.Module, those modules don't behave the same way that normal Python objects behave.

One of the large differences is that Flax Modules override __new__, meaning that module instances aren't a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)) actually does two things:

  1. Construct an object of type nn.Dense (using the non-documented API module.new_instance())
  2. Call the apply method on that instance and return it.

Some upsides of the current approach are:

  1. Modules are defined as a single function, as opposed to, e.g. the style of other libraries, such as Haiku, where you need to scroll up and down between __init__ and __call__ to fully understand what a module does.
  2. Calls to submodules are very concise, e.g. nn.Dense(x, features=10).

Some downsides of the current approach are:

  1. In order to reuse a module, you must use the module.shared() abstraction which has a confusing mental model -- what does module.shared() return? A module class? A module instance? Moreover, which arguments must be passed into module.shared() in order for the shared module to be usable? (Behind the scenes shared is implemented on top of partial)
  2. You can't instantiate a module directly, outside of another module. This leads to surprising things like new nn.Model(nn.Dense.partial(features=10), params) -- why do we need to use partial to instantiate a Model? What type does the first argument to nn.Model have? Is it a module class? Module instance?
  3. In a few spots in flax/nn/base.py there is code that does "kwarg mangling". Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling.
  4. In order to support multiple methods on a module, the module_method decorator turns methods that aren't apply into new Modules. This is surprising, for example how would I do the equivalent of module.call(params, *args) but to call a method foo that's not apply? That would be module.foo.call(params, *args). That's a pretty surprising mental model.
  5. Wanting to define shared parameters or submodules that work across multiple methods requires either using non-traditional patterns and/or with more complexity in Flax core (see discussion on FLIP: Support __init__ in Modules #161)
  6. apply was a special-cased method on modules.

Proposal

  1. No longer override __new__ in Modules
  2. Eliminate .partial()
  3. Potentially eliminate .shared() (though we may choose to keep it as a safeguard -- see below)
  4. Split up current module's apply methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a "vanilla Python" __call__ method (or actually, any name you want) that only takes in the module input(s)
  5. This may even allow for module instance to directly refer to a read-only version of their parameters, e.g.:
class Foo(Module):
  def __init__(x):
    dense = nn.Dense(features=16)
    x = dense(x)
    # `dense.params` is defined here; maybe also `dense.params.kernel` and `dense.params.bias`

For example, a simple Dense layer may look like this:

@dataclass
class Dense(Module):
  features: int
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros

  def __call__(self, x):
    """Applies a linear transformation to the inputs along the last dimension."""
    kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
    bias = self.param('bias', (self.features,), self.bias_init)
    return jnp.dot(x, kernel) + bias

Then, an MLP would look like this:

class MLP(Module):
  def __call__(self, x):
    x = nn.Dense(features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(features=16)(x)

I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)

For example, in this case re-using a module is trivial -- keep a reference to nn.Dense(features=16) and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when __call__ is invoked a second time, unless .shared() was called on the module instance first)

With this proposal, there's also no need for module.partial -- you can just use functools.partial(module.__call__) or functools.partial(module). (Though this is a bit different than in current Flax because the return value of functools.partial in itself isn't a module, rather it's a function. But maybe it was always confusing to understand module.partial -- does it override kwargs for all module methods? Just apply?)

Possible transition plan

Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.

I propose adding, alongside every new module in flax.nn, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongside Dense as shown above we would also have

def dense(x, features, kernel_init, bias_init):
  """DEPRECATED. Use the new Module API: http://link/to/upgrade/guide."""
  return Dense(features, kernel_init, bias_init)(x)

Then the first part of the upgrade process is mainly search and replace "Dense" -> "dense", etc.. In addition, some more manual changes will possible be necessary for uses of .partial and .shared. Later, users can transition into the new API at a time they see fit.

Current State

@avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.

@avital avital added the FLIP label Apr 17, 2020
@akolesnikoff
Copy link
Contributor

I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive.

Besides, I actually see a separation between __init__ and __call__ as a potential win. Conceptually, I imagine that __init__ should admit static parameters, like number of channels, and __call__ should admit actual data that is processed. Currently, these different types of parameters are all mixed together.

@avital
Copy link
Contributor Author

avital commented Apr 17, 2020

I would be in favor of this change, as the proposed way of creating new modules is better aligned with my default mental model: layers are classes and particular instances of these layers (with associated weights and parameters) are objects that process the data by being called. I also fully agree that variable sharing will become more intuitive.

Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as jit and pmap which don't work with mutations.

Besides, I actually see a separation between init and call as a potential win. Conceptually, I imagine that init should admit static parameters, like number of channels, and call should admit actual data that is processed. Currently, these different types of parameters are all mixed together.

Yes. The issue is that by simply letting people use __init__ and __call__ arbitrarily, you many times end up with things like this, where you really have to move up and down many times to be able to fully follow the flow of what the module's forward pass does. Hence the restriction from using dataclasses encourages the __init__ to be as dumb as possible.

@lucasb-eyer
Copy link
Member

💯 to this change. It aligns the mental model with TF2's tf.Module and PyTorch's nn.Module a lot more, and both of these have converged to where they are now after many years of mistakes, so this is a good thing.

(NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when call is invoked a second time, unless .shared() was called on the module instance first)

Please don't. Re-using an instance is the common, intuitive, friction-less way of sharing weights; this would just add annoying overhead for the sake of avoiding a mistake which, frankly, I have never encountered. An explicit :share method was how it was done in Torch7, and it was annoying and painful and does not exist anymore in PyTorch.

Regarding the __init__ vs __call__ separation, I don't think that it makes good code impossible, so if someone creates a monster hydra code because of that, it's probable the author's fault, not the library's. Using dataclass (or attr.s) for this is an interesting idea. However, usually what is done in __init__ is just normalizing of convenience of parameters, for example allowing filter-size to be passed as (3,3) or as 3, and then turning 3 into (3,3) in __init__, such that __call__ is cleaner to read, and really you can skip reading __init__ with that in mind. I think this is a good thing.

Finally, I think you can have an even more convincing example for modules which have more than just the obvious __call__, like the VAE example here which currently is not trivial to understand: I either have to do a lot of guess-work about FLAX internals, or go back and read the whole docs. Whereas after your proposal (and in PyTorch) it can be much more straightforward.

@danielsuo
Copy link
Collaborator

Thanks for this proposal! I agree with the other comments:

  • How __init__ and __call__ might separate responsibility (user-created monster hydras not withstanding @lucasb-eyer)
  • Removing .shared(). I understand the rationale for keeping (one less thing to debug), but in this case, it makes sense to opt for less friction vs. more safety if that's the common user expectation. If we really wanted to be extra, we could provide some flax linting utilities (FL201: Did you mean to reuse a module?) :)

Yes. One caveat is that while layers have have parameters on them, those parameters will be immutable and you'd still need to mutate your parameters at the top-level rather than within your module. This is due to our desire to allow you to use vanilla transformations such as jit and pmap which don't work with mutations.

Do you mean passing modules directly into jit? One of the things I tried to do away with during my weekend excursion was flax.nn.Model, given the constraint that flax.nn.Module must be immutable. The solution was not great: have an instance method that returns a new flax.nn.Module when you update parameters or state.

@srush
Copy link

srush commented Apr 20, 2020

Nice, I like this change. It is a good start.

However, if you are making such a breaking change, this feels too conservative.

Core Issues:

  • This function still violates Pythonic conventions. nn.Dense is seemingly making mutable changes to some internal state buffer that is invisible to the user and not transparent in the syntax. (I know this happens in TF, but flax should be better.)
  def __call__(self, x):
    x = nn.Dense(features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(features=16)(x)

=> Does this mean?

  def __call__(self, x):
    x = nn.Dense(self, features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(self, features=16)(x)

(Or alternatively pytorch / sonnet 2 syntax which both do this better)

  • Params are still treated differently than Layers, and use a string-based naming which seems dangerous and tempting for abuse.
bias = self.param('bias', (self.features,), self.bias_init) 

=> ?

bias = nn.Param(self, (self.features,), self.bias_init)

@shoyer
Copy link
Member

shoyer commented Apr 20, 2020

In general I really like the look of this! I think it would be a significant improvement/simplification of Flax's mental model.

👍 for eliminating the use of __new__ in Modules

👍 for eliminating .partial().

👍 for eliminating .shared(). I don't think we need the safeguard -- it is quite common to intentionally reuse models in neural net code

👍 for encouraging the use of dataclasses (in particular, @dataclass(frozen=True) to enforce immutability)

👎 for requiring dataclasses, and not allowing __init__ methods to be written explicitly. Even if this were possible to enforce in a clean way (I have my doubts), sometimes __init__ can be a nice way to write this, as @lucasb-eyer writes in #208 (comment).

👍 for the proposed transition plan, which looks quite practical.

@shoyer
Copy link
Member

shoyer commented Apr 21, 2020

One question arises: how does this change effect (if at all) with the way we initialize Flax models? Do we still stick with Module.init and call methods, except these are now normal methods instead of class methods?

@jesseengel
Copy link

jesseengel commented Apr 22, 2020

👍 to everything said by @lucasb-eyer, @srush, and @shoyer. I think having separate __init__ and __call__ is actually a huge net positive. It allows people to just think in Python instead of "thinking in Flax" like we have to do with TF.

FWIW, I don't see the hydra thing as much of a disadvantage. In many cases it requires people to be more explicit, and you can see what's going on in the submodule itself, instead of hiding things in implicit behind the scenes work. It also then makes it easier to access model attributes from outside the module if you want to hack things later, say in a colab notebook.

Also, i think it's great to allow access to __call__ directly, rather than redirecting to some other function like apply. I'm running into challenges with this in Keras at the moment, as I'm trying to work around some aspects of the forced programming model, but it's inflexible if I only have access to call and not __call__, and requires me digging deep into the Keras base layer code, which is a mess. Let's not make the same mistake for Flax.

@david-waterworth
Copy link

This should also help with my confusion #16 (comment) where not calling partial before create_by_shape results in the model being created with different parameters to what it was trained with.

@cghawthorne
Copy link
Contributor

Can you add an example of how this would work with an equivalent to module_method?

@shoyer
Copy link
Member

shoyer commented May 14, 2020

(Or alternatively pytorch / sonnet 2 syntax which both do this better)

@srush Could you kindly clarify what you mean by this?

Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g., self.dense = nn.Dense(features=16)?

This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in __init__, which is separated from where they are used.

@jesseengel
Copy link

jesseengel commented May 14, 2020

Is this just a reference to how PyTorch / Sonnet 2 use explicit attribute assignment for submodules? e.g., self.dense = nn.Dense(features=16)?

This does make module hierarchies and when mutation is happening very clear. The downside is that layers get specified in __init__, which is separated from where they are used.

@shoyer To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters.

The mental overhead of having a little boiler plate is a small price to pay for such explicit clarity and python native interaction paradigms (using python's built-in object attributes, vs. some behind the scenes implicit naming schemes)

@shoyer
Copy link
Member

shoyer commented May 14, 2020

To be clear, I think a lot of people consider that actually consider that an upside. It separates creation/ownership from usage, so it's much clearer when reuse is happening, and easier to access submodules from outside the class itself for more creative routing of shared parameters.

Absolutely, these are all real advantages. On the other hand, I've also had cases where separating initialization/use of layers made my code harder to read and modify because two different parts of the code need to be kept in sync. You also can't use input shapes to determine the shapes of variables. It is not clear to me (personally) which is better/worse in general. It may depend on the context.

Keras lets you write things both ways, which is convenient for users, but of course imposes an even higher cost in terms of complexity.

For JAX, there is one additional consideration, which is whether the module abstraction is amenable to functional transformations -- one of the core strengths of JAX. My understanding is that this is hard to do with Python's mutable object model.

@srush
Copy link

srush commented May 14, 2020

I consider inline initialization a Keras design flaw. It mixes functional and structural concerns and makes it very hard to reason about, document, and analyze modules.

However, whether or not you agree with this, the fact that it is causing the library to have ill-defined semantics, with very minimal benefits ("less scrolling up?"), should be a red flag that it is maybe a problem.

@jheek
Copy link
Member

jheek commented May 18, 2020

I think we should not use worlds like "normal" or "pythonic". They are really vague statements that essentially refer to similarity with existing programing paradigms that are common in the Python world. We shouldn't strive to please the status quo.

I think the points raised by @srush are important. Although sharing becomes clearer with explicit construction it still isn't quite like an object that owns it's parameters.

Consider the following example:

class MyModule(flax.Module):
  def apply_a(self, x):
    def inner_dense_factory():
      return nn.Dense(123)
    for i in range(3):
      x = inner_dense_factory()(x) 
    return x

  def apply_b(self, x):
    for i in range(3):
      x = self._dense_factory()(x) 
    return x

  def _dense_factory(self):
    return nn.Dense(123, self.my_fancy_init)

Clearly in apply_a we expect the Dense parameters to be shared when we call apply_a multiple times on an instance of MyModule but not between iterations of the loop. But what if we take the dense_factory from apply_a and turn it into a method (_dense_factory)? A seemingly innocent refactor will now cause all the Dense modules in apply_b to be shared.

Of course we can add annotation trickery to distinguish between module methods that have a scope and "inline methods"? But the mental model is still significantly more complex than plain old Python objects.

@srush
Copy link

srush commented May 19, 2020

Perhaps I am missing something, but I don't really understand the example above. The implied semantics feel really complicated to me as state seems to bind to functions in a way I cannot trace.

Btw, I don't know if it is helpful, but here is a proof-of-concept of the sort of pure world I like (not saying flax needs to go this way).

https://github.com/srush/parallax

# Everything is immutable @module =  dataclass(frozen=True, repr=False)
@module
class Dense(Module):

    # All parameter-holders are explicitly declared.
    weight : Parameter
    bias   : Parameter

    # Setup replace __init__ and creates shapes and binds lazy initializers.
    @staticmethod
    def setup(in_size, out_size):
        return Dense.init(
            weight = Parameter.setup((out_size, in_size), init.xavier_normal_),
            bias   = Parameter.setup((out_size,), init.normal_))

    # Forward is just like standard pytorch.
    def forward(self, input):
        return self.weight @ input + self.bias

"Sharing" would requires a manual split of the parameter into two parts like this.

@module
class BinaryNetwork(Module):

    # No difference between modules and parameters
    dense1  : Dense
    dense2  : Dense
    dense3  : Dense
    dropout : Dropout

    @staticmethod
    def setup(input_size, hidden_size):
        return BinaryNetwork.init(
            dense1  = Dense.setup(input_size, hidden_size),
            dense2  = Dense.setup(hidden_size, hidden_size),
            dense3  = Dense.setup(hidden_size, 1),
            dropout = Dropout.setup(rate=0.2)
        )

    def forward(self, input):

        # Standard usage works out of the box.
        x = torch.tanh(self.dense1(input))

        # Stochastic modules (have random seed already)
        x = self.dropout(x)

        # Shared params / recurrence requires split (like RNG)
        dense2_a, dense2_b = self.dense2.split(2)
        x = torch.tanh(dense2_a(x))
        x = torch.tanh(dense2_b(x))

        return torch.sigmoid(self.dense3(torch.tanh(x)))

@shoyer
Copy link
Member

shoyer commented May 19, 2020

Clearly in apply_a we expect the Dense parameters to be shared when we call apply_a multiple times on an instance of MyModule but not between iterations of the loop. But what if we take the dense_factory from apply_a and turn it into a method (_dense_factory)? A seemingly innocent refactor will now cause all the Dense modules in apply_b to be shared.

My expectation from reading this code is that all Dense parameters in both examples would be unshared. If you want to use the same parameters, you need to use the same Dense object.

@lucasb-eyer
Copy link
Member

yep, was about to say the same as @shoyer the example is convoluted, but we are creating a new Dense object each time, so would definitely not expect weight sharing. Any sharing happening in that code would be weird magic happening under the hood that is very confusing.

@lucasb-eyer
Copy link
Member

lucasb-eyer commented May 19, 2020

@srush I fail to see how your example semantically differs from plain PyTorch/nn code? It's "create object at init, use object to apply at forward" semantics, the remaining differences from plain PyTorch/nn look like mostly syntax to me? (edit: not saying this is bad, I like PyTorch/nn)

@srush
Copy link

srush commented May 19, 2020

@lucasb-eyer Sorry, I should have explained better. The fact that it looks like pytorch syntax is a red-herring, unlike pytorch the implementation is pure / immutable.

It's "create declarative skeleton at init, (engine fills in tensors), (engine distributes RNG to module), use objects statelessly to apply at forward"

layer = BinaryNetwork.setup(5, 10)

# Initialize parameters -> stateful, hidden
rng = rng_state()
layer = layer.initialize(rng)

for i in range(10):
    rng = rng_state()
    layer = layer.init_state(rng, mode="train")
    grad = grad(layer.forward)(x)
    layer = layer.update(lambda a, b: a + b, grad)

@lucasb-eyer
Copy link
Member

I see, yeah I was missing the "use it" code, should've checked your repo. My personal opinion is that classes are the wrong concept to build something pure/immutable/functional.

A few colleagues and I have an internal codebase built on jax, which uses flax in a completely pure/functional way, and flax was open to some design changes to make using flax in that way possible and nice. I think it is very close to your example code actually. We made a simplified version of it public just now, see here: https://github.com/google-research/big_transfer/tree/master/bit_jax

However, all of this pure, pretty, neat, readable stuff goes to 💣 💩 ⚡ the moment you want to add BatchNorm :)

@srush
Copy link

srush commented May 21, 2020

Nice I will check it out. Maybe what needs to happen is for the jax community to just have nn.functional module like pytorch so different module systems can use the same layers.

@lucasb-eyer I am still just stuck on one point that is keeping me bother by all these solutions: When you read this code below what is the internal/informal semantics that is going on in your head. Particularly: Where do you imagine that name is stored? do you believe this code knows it is in an object? Do you have a type in your head of x? How do you reason about whether this line of code knows if it is the first or last time it is called? Could this code be tested independently of its system?

    x = nn.Dense(x, num_classes, name="conv_head", kernel_init=nn.initializers.zeros)

Until I can answer these questions, I just can't imagine this will be the final state of a reliable module system.

@lucasb-eyer
Copy link
Member

but jax.lax and jax.numpy pretty much correspond to nn.functional :) The next step is deciding how bookkeeping of variables/parameters happens, and that is where all the frameworks opinions differ (and mine differs again, and so does yours).

Regarding your second paragraph, I agree that the line has too much magic (also, where are the dense's w/b tracked? a global collection maybe? 😨) And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics.

@srush
Copy link

srush commented May 21, 2020

but jax.lax and jax.numpy pretty much correspond to nn.functional :)

That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy : https://pytorch.org/docs/stable/nn.functional.html

The next step is deciding how bookkeeping of variables/parameters happens,

I agree. That's what I'm interested in.

And my understanding is that @avital 's proposal in the OP is exactly about reducing this magic and, effectively, being closer to "plain python" or PyTorch semantics.

It gets halfway there, I'm arguing it needs to be really solved.

@srush
Copy link

srush commented May 21, 2020

@lucasb-eyer Very neat paper though!

@lucasb-eyer
Copy link
Member

lucasb-eyer commented May 21, 2020

That's not true, nn.functional is clean functional nn implementations of conv/dense/rnn/etc that could be used with any module system, none of that is in jax.lax or jax.numpy

Not true either. jax.lax has a pretty powerful implementation of conv (jax.lax.conv_general_dilated), similar for pooling (jax.lax.reduce_window) and, linear (jax.lax.dot_general).

I was about to concede it's missing an RNN, but there is actually none in nn.functional either. The only remaining non-trivial entry of nn.functional that is missing from jax.{nn,lax,numpy} is ctc_loss, and I'm sure jaxers would happily accept a PR for jax.nn.ctc_loss. So I maintain my point that torch.nn.functionaljax.{nn,lax,numpy}.

It gets halfway there, I'm arguing it needs to be really solved.

I went back to read it, and I actually agree with the points in your first comment in this thread.

Thanks :)

@srush
Copy link

srush commented May 21, 2020

Oh well, now I feel silly. It does seem like the lax functions just are much more general than the pytorch implementations. I honestly never found reduce_window on my own (the doc of "Wraps XLA’s ReduceWindow operator" doesn't really help). The Stax implementation does make it clear though.

@lucasb-eyer
Copy link
Member

No worries. jax.lax is extremely powerful, I like its API a lot (reminds me of BLAS, but in times of XLA) and is criminally under-documented!

@j-towns
Copy link
Contributor

j-towns commented May 26, 2020

FYI (you may already know this) most of the ops in lax (things like reduce_window) are documented in more detail here. I guess we ought to copy more of those docs over to JAX.

@srush
Copy link

srush commented May 26, 2020

Speaking as a teacher, the XLA docs scare me. They are very jargon heavy. It would be like asking numpy students to read the blas docs.

copybara-service bot pushed a commit to google-deepmind/dm-haiku that referenced this issue May 27, 2020
Inspired by google/flax#208 I wondered if we could similarly offer support for
@DataClass to define modules. It seems we can :) One thing I wanted to maintain
was support for custom names. To do this in the examples below I add an optional
name property to the dataclasses (you can omit this if you prefer to take the
default name) and then we override it with the uniquified name as part of
`__post_init__`.

```python
@DataClass
class Linear(hk.Module):
  output_size: int
  name: Optional[str] = None

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w = hk.get_parameter("w", [j, k], init=jnp.ones)
    b = hk.get_parameter("b", [k], init=jnp.zeros)
    return x @ w + b

@DataClass
class MLP(hk.Module):
  output_sizes: Sequence[int]
  activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu
  name: Optional[str] = None

  def __call__(self, x):
    for i, output_size in enumerate(self.output_sizes):
      if i > 0:
        x = self.activation(x)
      x = Linear(output_size)(x)
    return x
```

PiperOrigin-RevId: 313392449
Change-Id: I4c30d34474bd78c567c4101958fc12c8c82f3efa
@avital
Copy link
Contributor Author

avital commented Nov 20, 2020

It's been a while, and sorry for not posting more in this thread. We've gone through a major API redesign aligned with the goals originally described in this thread.

Our new Linen API came out of many user group discussions, trying to find a solution that empowers our users, while staying relatively simple and exposes the full power of JAX.

All of our examples have been ported, and multiple large projects have transitioned using our upgrade guide, so now we're making it the official API.

Please check it out! Please ask any questions or suggestions for improvements on our discussion board.

The old flax.nn API is being deprecated.

@avital avital closed this as completed Nov 20, 2020
@srush
Copy link

srush commented Nov 20, 2020

Thanks @avital !

I really like the new api, thanks for putting the work into it and being direct about the tradeoffs. I will definitely be using it for my next project. (probably without @nn.compact , but that is totally okay if they are compatible).

I found this helpful: https://colab.research.google.com/github/google/flax/blob/master/docs/notebooks/linen_intro.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests