Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLIP: Support __init__ in Modules #161

Closed
jheek opened this issue Apr 1, 2020 · 7 comments
Closed

FLIP: Support __init__ in Modules #161

jheek opened this issue Apr 1, 2020 · 7 comments

Comments

@jheek
Copy link
Member

jheek commented Apr 1, 2020

Introduction

By default Modules are defined using only an apply function and unshared parameters and submodules. This makes it easy to write modules (with control flow) in a concise and readable way.

However, some modules don't fit well in this abstraction. For example consider an autoencoder. During model training we would like to take an observed example and encode and decode it. However, we would also like to be able to call the encode and decode procedures as separate methods for other use cases besides training.

With the current Flax api a simple AutoEncoder can be written as follows:

from flax import nn
class AutoEncoder(nn.Module):
  
  def _create_modules(self, encoder_features, decoder_features):
    encoder = nn.Dense(features=encoder_features, name='encoder')
    decoder = nn.Dense(features=decoder_features, name='decoder')
    return encoder, decoder

  def apply(self, x, **hparams):
    encoder, decoder = self._create_modules(**hparams)
    z = encoder(x)
    return decoder(z)

  @nn.module_method
  def encode(self, x, **hparams):
    encoder, _ = self._create_modules(**hparams)
    return encoder(x) 

  @nn.module_method
  def decode(self, z, **hparams):
    _, decoder = self._create_modules(**hparams)
    return decoder(z) 

A number of issues can be noticed in this examples:

  1. hyper parameters need to be passed around manually from all module methods
  2. _create_modules behaves a lot like a constructor but also needs to be called manually
  3. we cannot directly call the module methods encode and decode from apply leading to even more code duplication

Proposal

We would like to improve the syntax of modules that require multi methods and reuse of parameters, sub modules, and hyperparameters across various methods.

The proposed syntax allows us to rewrite the AutoEncoder module as follows

class AutoEncoder(nn.Module):

  def setup(self, encoder_features, decoder_features, **kwargs):
    self.encoder = nn.Dense.shared(features=encoder_features, name='encoder')
    self.decoder = nn.Dense.shared(features=decoder_features, name='decoder')
    return kwargs

  def apply(self, x):
    z = self.encode(x)
    return self.decode(z)

  @nn.module_method
  def encode(self, x):
    return self.encoder(x)a 

  @nn.module_method
  def decode(self, x):
    return self.decoder(x)

model_def = AutoEncoder.partial(encoder_features=4, decoder_features=8)
_, params = model_def.init(rng, x)
model = nn.Model(model_def, params)
# use apply function for training
x_recon = model(x)
# two step encode+decode
z = model.encode(x)
x_recon = model.decode(z)

A few differences w.r.t. to the introduction example:

  1. a constructor (setup) defines shared modules and assigns them to fields.
  2. the constructor defines the hyperparameters and they are no longer passed around by other methods.
  3. apply reuses the module methods avoid code duplication.

A few changes are required to make the new syntax work

  1. When a Module is constructed we must first call the setup function. The setup function receives all kwargs and returns the remaining keyword arguments that should be passed to the module method.

  2. when calling a module_method using self.some_module_method(...) it behaves as an ordinary python method.

An implementation of this proposal lives in draft PR #104

Alternatives

The main issue in this proposal is determining which arguments are passed to setup. There are a few variations that can be considered:

  1. Introspection is used to determine which keyword arguments belong to setup.

  2. Require users to provide a list of construction arguments

  3. Pass all keyword argument to setup. This woud make the implementation easier but would require most module methods to include something like **unused_kwargs to work correctly.

  4. [CURRENT PROPOSAL] setup receives all keyword arguments and returns a dictionary of keyword arguments that should be passed to the apply method and other module methods

@jheek jheek added the FLIP label Apr 1, 2020
@avital avital changed the title Support __init__ in Modules FLIP: Support __init__ in Modules Apr 2, 2020
@jekbradbury
Copy link
Contributor

In the proposed syntax, why can’t the user write z = self.encoder(x) in apply? What happens if they do? And why do the encoder and decoder modules have to be shared? What happens if they’re not?

@jheek
Copy link
Member Author

jheek commented Apr 3, 2020

why can’t the user write z = self.encoder(x) in apply?

That will work too. In this case it also doesn't matter because the function is a one-liner. If it is more complicated one might want to use the module method from within apply for code reuse

And why do the encoder and decoder modules have to be shared? What happens if they’re not?

In the example we use the encoder and decoder at most once. So in this case partial would also work.
The idea is to promote the usage of shared in the constructor because using a module multiple times would bite you otherwise.

@mattwescott
Copy link

This would be useful for normalizing flows as well. In the meantime, there seems to be a namespace issue with the _shared_module approach. By modifying this line, to use cls.__name__, my implementation is working again.

def compose_transforms(transforms):

    class TransformSequence(Transform):
        
        def _shared_modules(self):
            return [t.shared() for t in transforms]

        @flax.nn.module_method
        def transform(self, x):

            transforms = self._shared_modules()

            for t in transforms:
                x = t.transform(x)
            return x

        @flax.nn.module_method
        def inverse_and_log_det_jac(self, y):

            transforms = self._shared_modules()

            log_det_jac = 0.0
            for t in reversed(transforms):
                y, term = t.inverse_and_log_det_jac(y)
                log_det_jac += term

            return y, log_det_jac

    return TransformSequence

Is there a better workaround for now?

@shoyer
Copy link
Member

shoyer commented Apr 10, 2020

My main worry about supporting __init__ is that it could lead to incorrect assumptions about how Flax works, due to apparent similarity with normal class syntax. If I saw a call like the proposed AutoEncoder without knowing anything about Flax, I would expect the only valid argument to AutoEncoder(*args, **kwargs) to be those that appear explicitly on __init__, but that isn't how Flax works.

Some ideas:

  • Use a different name from __init__, e.g., setup (we already use init for variables).
  • Consider (conditionally?) switching to an explicit setup/call split like Keras or Haiku: AutoEncoder(**params)(x). This would probably be more pervasive than you want.

The magic separation of arguments between __init__ and apply also worries me a little bit. I don't know if there is a good way to do this, but I do think passing all **kwargs to __init__ (or setup) is a better alternative than using introspection.

@jheek
Copy link
Member Author

jheek commented Apr 14, 2020

This would be useful for normalizing flows as well. In the meantime, there seems to be a namespace issue with the _shared_module approach. By modifying this line, to use cls.__name__, my implementation is working again.

def compose_transforms(transforms):

    class TransformSequence(Transform):
        
        def _shared_modules(self):
            return [t.shared() for t in transforms]

        @flax.nn.module_method
        def transform(self, x):

            transforms = self._shared_modules()

            for t in transforms:
                x = t.transform(x)
            return x

        @flax.nn.module_method
        def inverse_and_log_det_jac(self, y):

            transforms = self._shared_modules()

            log_det_jac = 0.0
            for t in reversed(transforms):
                y, term = t.inverse_and_log_det_jac(y)
                log_det_jac += term

            return y, log_det_jac

    return TransformSequence

Is there a better workaround for now?

@mattwescott I think your issue was introduced in a recent change to the default name policy. This is probably a bug that should be fixed. I'll look into it ASAP

@jheek
Copy link
Member Author

jheek commented Apr 14, 2020

  • Use a different name from __init__, e.g., setup (we already use init for variables).

I agree with this proposal

  • Consider (conditionally?) switching to an explicit setup/call split like Keras or Haiku: AutoEncoder(**params)(x). This would probably be more pervasive than you want.

That's a very big change to make and takes away a key advantage of calling modules as functions.

The magic separation of arguments between __init__ and apply also worries me a little bit. I don't know if there is a good way to do this, but I do think passing all **kwargs to __init__ (or setup) is a better alternative than using introspection.

I do really dislike the introspection as well. Passing all kwargs to setup is probably fine the main downside is that they are also passed to all the module methods which then probably need something like an **unused_kwargs.

@avital
Copy link
Contributor

avital commented Dec 12, 2020

This is no longer relevant since Linen has landed, so I'm closing this for now.

@avital avital closed this as completed Dec 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants