Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

def _normalize_axes(axes, ndim):
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
return tuple(sorted([ax if ax >= 0 else ndim + ax for ax in axes]))


def _canonicalize_tuple(x):
Expand Down Expand Up @@ -107,6 +107,10 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
return jnp.reshape(kernel, shape)

batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
# batch and non-contracting dims of input with 1s for batch dims.
expanded_batch_shape = tuple(
inputs.shape[ax] if ax in batch_dims else 1
for ax in range(inputs.ndim) if ax not in axis)
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape)
kernel = jnp.asarray(kernel, self.dtype)
Expand All @@ -117,6 +121,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
kernel,
((axis, contract_ind), (batch_dims, batch_ind)),
precision=self.precision)
# dot_general output has shape [batch_dims/group_dims] + [feature_dims]
if self.use_bias:
def bias_init_wrap(rng, shape, dtype=jnp.float32):
size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
Expand All @@ -126,12 +131,8 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
return jnp.reshape(bias, shape)

bias = self.param('bias', bias_init_wrap, batch_shape + features)

# Reshape bias for broadcast.
expand_dims = sorted(
set(range(inputs.ndim)) - set(axis) - set(batch_dims))
for ax in expand_dims:
bias = jnp.expand_dims(bias, ax)
# expand bias shape to broadcast bias over batch dims.
bias = jnp.reshape(bias, expanded_batch_shape + features)
bias = jnp.asarray(bias, self.dtype)
out = out + bias
return out
Expand Down
28 changes: 27 additions & 1 deletion tests/linen/linen_linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _counter_init(rng, shape, dtype, state):
np.testing.assert_allclose(y, target)

@parameterized.parameters([((-2, 3), (), 'bijk,jklm->bilm'),
((3, -2), (), 'bijk,kjlm->bilm'),
((3, -2), (), 'bijk,jklm->bilm'),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So our testcase was actually testing wrong behavior? That's odd 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this was correct. We are free to pick the order of the kernel dimensions. But I made it invariant wrt to the order in which you pass the axes. I think this is more consistent and it makes the implementation easier.

((-2, 3), (0,), 'bijk,bjklm->bilm')])
def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr):
rng = dict(params=random.PRNGKey(0))
Expand Down Expand Up @@ -271,6 +271,32 @@ def test_embed(self):
np.testing.assert_allclose(y, dummy_embedding[None])
z = embed_module.apply(initial_params, jnp.ones((3,)), method=embed_module.attend)
np.testing.assert_allclose(z, 3. * jnp.arange(4))

def test_non_final_axis(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return nn.DenseGeneral(features=6, axis=1, name='dense')(x)

x = jnp.ones((2, 4, 8))
y, variables = Foo().init_with_output(random.PRNGKey(0), x)
self.assertEqual(jax.tree_map(jnp.shape, variables['params']), {
'dense': {'kernel': (4, 6), 'bias': (6,)}
})
self.assertEqual(y.shape, (2, 8, 6))

def test_non_final_axes(self):
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return nn.DenseGeneral(features=6, axis=(0, 1), name='dense')(x)

x = jnp.ones((2, 4, 8))
y, variables = Foo().init_with_output(random.PRNGKey(0), x)
self.assertEqual(jax.tree_map(jnp.shape, variables['params']), {
'dense': {'kernel': (2, 4, 6), 'bias': (6,)}
})
self.assertEqual(y.shape, (8, 6))


if __name__ == '__main__':
Expand Down