Lazy layer / Shape inference? #15
Comments
Hey @lkhphuc! This question comes at a great time, I was thinking about this topic yesterday. There are 2 related topics:
1. Shape Inference onlyHere we check if the module is not class Linear(Module):
kernel: Optional[jnp.ndarray] = Parameter.node()
bias: Optional[jnp.ndarray] = Parameter.node()
def __init__(self, features: int):
self.features = features
self.kernel = None
self.bias = None
def __call__(self, x):
if not self.initialized:
features_in = x.shape[-1]
self.kernel = jax.random.uniform(tx.next_key(), shape=(features_in, self.features))
self.bias = jnp.zeros(shape=(self.features,))
return jnp.dot(x, self.kernel) + self.bias
class MLP(Module):
def __init__(self):
self.linear1 = Linear(32)
self.linear2 = Linear(10)
def __call__(self, x):
x = self.linear1(x)
x = jax.nn.relu(x)
x = self.linear2(x)
return x
module = MLP().init(42, inputs=X_train[:32]) The trick is that 2. Shape Inference + HooksAdding hooks very possible, in fact Treeo already supports Compact Trees, Treex just needs to leverage this and expand it to cover certain cases. class Linear(Module):
kernel: Optional[jnp.ndarray] = Parameter.node()
bias: Optional[jnp.ndarray] = Parameter.node()
def __init__(self, features: int):
self.features = features
self.kernel = None
self.bias = None
def __call__(self, x):
if not self.initialized:
features_in = x.shape[-1]
self.kernel = jax.random.uniform(tx.next_key(), shape=(features_in, self.features))
self.bias = jnp.zeros(shape=(self.features,))
return jnp.dot(x, self.kernel) + self.bias
class MLP(Module):
@tx.compact
def __call__(self, x):
x = Linear(32)(x)
x = jax.nn.relu(x)
x = Linear(10)(x)
return x
module = MLP().init(42, inputs=X_train[:32]) Notice that the definition of 3. Shape Inference via HooksOne last possibility is not implementing shape inference but leveraging the fact that hooks have access to the data: class MLP(Module):
@tx.compact
def __call__(self, x):
x = tx.Linear(x.shape[-1], 32)(x)
x = tx.Linear(32, 10)(x)
return x
module = MLP().init(42, inputs=X_train[:32]) Nice thing is that you can still support data-independent initialization (current behavior) if you wish, get more shape guarantees since you have to be so explicit, downside is that its more verbose. DiscussionSeems there are 2 independent choices: implement shape inference and implement hooks. I definitely like the idea of hooks since it makes writing composite modules simpler, however I did this poll and people do like Pytorch's explicit API which WDYT? |
Would it be possible to support pytorch's Lazy layer, i.e shape inference based on input?
One possible solution is to provide a sample input to the method
module.init(rng=42, input=jnp.ones_like([64,64,3]))
The text was updated successfully, but these errors were encountered: