Skip to content

Commit

Permalink
Merge pull request #443 from google/issue442
Browse files Browse the repository at this point in the history
fix batchnorm square vs sqrt error (fixes #442)
  • Loading branch information
mattjj committed Feb 25, 2019
2 parents dd5b2a6 + 7e93bff commit d703470
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/experimental/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def init_fun(input_shape):
def apply_fun(params, x, rng=None):
beta, gamma = params
mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True)
z = (x - mean) / (var + epsilon)**2
z = (x - mean) / np.sqrt(var + epsilon)
if center and scale: return gamma * z + beta
if center: return z + beta
if scale: return gamma * z
Expand Down

0 comments on commit d703470

Please sign in to comment.