diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 04677e6cf..f5dbfd35a 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -37,7 +37,7 @@ def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. - return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) + return tuple(sorted([ax if ax >= 0 else ndim + ax for ax in axes])) def _canonicalize_tuple(x): @@ -107,6 +107,10 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): return jnp.reshape(kernel, shape) batch_shape = tuple([inputs.shape[ax] for ax in batch_dims]) + # batch and non-contracting dims of input with 1s for batch dims. + expanded_batch_shape = tuple( + inputs.shape[ax] if ax in batch_dims else 1 + for ax in range(inputs.ndim) if ax not in axis) kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape) kernel = jnp.asarray(kernel, self.dtype) @@ -117,6 +121,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32): kernel, ((axis, contract_ind), (batch_dims, batch_ind)), precision=self.precision) + # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: def bias_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) @@ -126,12 +131,8 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): return jnp.reshape(bias, shape) bias = self.param('bias', bias_init_wrap, batch_shape + features) - - # Reshape bias for broadcast. - expand_dims = sorted( - set(range(inputs.ndim)) - set(axis) - set(batch_dims)) - for ax in expand_dims: - bias = jnp.expand_dims(bias, ax) + # expand bias shape to broadcast bias over batch dims. + bias = jnp.reshape(bias, expanded_batch_shape + features) bias = jnp.asarray(bias, self.dtype) out = out + bias return out diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index 6fd504d6c..790abed04 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -144,7 +144,7 @@ def _counter_init(rng, shape, dtype, state): np.testing.assert_allclose(y, target) @parameterized.parameters([((-2, 3), (), 'bijk,jklm->bilm'), - ((3, -2), (), 'bijk,kjlm->bilm'), + ((3, -2), (), 'bijk,jklm->bilm'), ((-2, 3), (0,), 'bijk,bjklm->bilm')]) def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): rng = dict(params=random.PRNGKey(0)) @@ -271,6 +271,32 @@ def test_embed(self): np.testing.assert_allclose(y, dummy_embedding[None]) z = embed_module.apply(initial_params, jnp.ones((3,)), method=embed_module.attend) np.testing.assert_allclose(z, 3. * jnp.arange(4)) + + def test_non_final_axis(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return nn.DenseGeneral(features=6, axis=1, name='dense')(x) + + x = jnp.ones((2, 4, 8)) + y, variables = Foo().init_with_output(random.PRNGKey(0), x) + self.assertEqual(jax.tree_map(jnp.shape, variables['params']), { + 'dense': {'kernel': (4, 6), 'bias': (6,)} + }) + self.assertEqual(y.shape, (2, 8, 6)) + + def test_non_final_axes(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + return nn.DenseGeneral(features=6, axis=(0, 1), name='dense')(x) + + x = jnp.ones((2, 4, 8)) + y, variables = Foo().init_with_output(random.PRNGKey(0), x) + self.assertEqual(jax.tree_map(jnp.shape, variables['params']), { + 'dense': {'kernel': (2, 4, 6), 'bias': (6,)} + }) + self.assertEqual(y.shape, (8, 6)) if __name__ == '__main__':