Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Lazy layer / Shape inference? #15

Closed
lkhphuc opened this issue Oct 25, 2021 · 1 comment
Closed

Lazy layer / Shape inference? #15

lkhphuc opened this issue Oct 25, 2021 · 1 comment

Comments

@lkhphuc
Copy link
Contributor

lkhphuc commented Oct 25, 2021

Would it be possible to support pytorch's Lazy layer, i.e shape inference based on input?
One possible solution is to provide a sample input to the method module.init(rng=42, input=jnp.ones_like([64,64,3]))

@cgarciae
Copy link
Owner

cgarciae commented Oct 25, 2021

Hey @lkhphuc! This question comes at a great time, I was thinking about this topic yesterday.

There are 2 related topics:

  • Shape Inference
  • Module Hooks (Flax's @compact behavior)

1. Shape Inference only

Here we check if the module is not initialized and instantiate the parameters:

class Linear(Module):
    kernel: Optional[jnp.ndarray] = Parameter.node()
    bias: Optional[jnp.ndarray] = Parameter.node()
    
    def __init__(self, features: int):
        self.features = features
        self.kernel = None
        self.bias = None
        
    def __call__(self, x):
        if not self.initialized:
            features_in = x.shape[-1]
            self.kernel = jax.random.uniform(tx.next_key(), shape=(features_in, self.features))
            self.bias = jnp.zeros(shape=(self.features,))
        
        return jnp.dot(x, self.kernel) + self.bias
        
class MLP(Module):
    def __init__(self):
        self.linear1 = Linear(32)
        self.linear2 = Linear(10)
        
    def __call__(self, x):
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x
        
module = MLP().init(42, inputs=X_train[:32])

The trick is that tx.next_key() is only available during .init to leverage the key.

2. Shape Inference + Hooks

Adding hooks very possible, in fact Treeo already supports Compact Trees, Treex just needs to leverage this and expand it to cover certain cases.

class Linear(Module):
    kernel: Optional[jnp.ndarray] = Parameter.node()
    bias: Optional[jnp.ndarray] = Parameter.node()
    
    def __init__(self, features: int):
        self.features = features
        self.kernel = None
        self.bias = None
        
    def __call__(self, x):
        if not self.initialized:
            features_in = x.shape[-1]
            self.kernel = jax.random.uniform(tx.next_key(), shape=(features_in, self.features))
            self.bias = jnp.zeros(shape=(self.features,))
        
        return jnp.dot(x, self.kernel) + self.bias
        
class MLP(Module):
    @tx.compact
    def __call__(self, x):
        x = Linear(32)(x)
        x = jax.nn.relu(x)
        x = Linear(10)(x)
        return x

module = MLP().init(42, inputs=X_train[:32])

Notice that the definition of Linear doesn't change but MLP get shorter.

3. Shape Inference via Hooks

One last possibility is not implementing shape inference but leveraging the fact that hooks have access to the data:

class MLP(Module):
    @tx.compact
    def __call__(self, x):
        x = tx.Linear(x.shape[-1], 32)(x)
        x = tx.Linear(32, 10)(x)
        return x
        
module = MLP().init(42, inputs=X_train[:32])

Nice thing is that you can still support data-independent initialization (current behavior) if you wish, get more shape guarantees since you have to be so explicit, downside is that its more verbose.

Discussion

Seems there are 2 independent choices: implement shape inference and implement hooks. I definitely like the idea of hooks since it makes writing composite modules simpler, however I did this poll and people do like Pytorch's explicit API which 3 keeps.

WDYT?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants