Skip to content

Commit

Permalink
jax.nn.glu: fix static argname issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 17, 2022
1 parent c3a4a6e commit c762e07
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def gelu(x: Array, approximate: bool = True) -> Array:
else:
return jnp.array(x * (lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)

@partial(jax.jit, static_argnames=("glu",))
@partial(jax.jit, static_argnames=("axis",))
def glu(x: Array, axis: int = -1) -> Array:
"""Gated linear unit activation function.
Expand Down
2 changes: 1 addition & 1 deletion tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def testEluValue(self):
self.assertAllClose(val, 1e4, check_dtypes=False)

def testGluValue(self):
val = nn.glu(jnp.array([1.0, 0.0]))
val = nn.glu(jnp.array([1.0, 0.0]), axis=0)
self.assertAllClose(val, jnp.array([0.5]))

@parameterized.parameters(False, True)
Expand Down

0 comments on commit c762e07

Please sign in to comment.