Skip to content

Commit

Permalink
support axis argument in nn.glu (#2879)
Browse files Browse the repository at this point in the history
* support axis argument in nn.glu

* also add basic correctness test

* Update nn_test.py
  • Loading branch information
jekbradbury committed May 3, 2020
1 parent 9f7115e commit 1cc6b7d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def glu(x, axis=-1):
"""Gated linear unit activation function."""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
return x[..., :size // 2] * sigmoid(x[..., size // 2:])
x1, x2 = np.split(x, 2, axis)
return x1 * sigmoid(x2)

# other functions

Expand Down
4 changes: 4 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def testEluGrad(self):
def testEluValue(self):
val = nn.elu(1e4)
self.assertAllClose(val, 1e4, check_dtypes=False)

def testGluValue(self):
val = nn.glu(np.array([1.0, 0.0]))
self.assertAllClose(val, np.array([0.5]), check_dtypes=True)

@parameterized.parameters(*itertools.product(
(np.float32, np.bfloat16, np.float16),
Expand Down

0 comments on commit 1cc6b7d

Please sign in to comment.