diff --git a/tests/nn_test.py b/tests/nn_test.py index ec4428573e5b..f3b61a165f82 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -227,7 +227,7 @@ def fwd(): a = jnp.array(1., 'float32') def f(hx, _): - hx = jax.nn.relu(hx + a) + hx = sigmoid(hx + a) return hx, None hx = jnp.array(0., 'float32')