Skip to content

Commit

Permalink
replace accidental use of jax.numpy.min w/ builtin
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 1, 2020
1 parent 1cdd8f1 commit 2263899
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ def gradient_along_axis(a, h, axis):
return []
axis = [_canonicalize_axis(i, a.ndim) for i in axis]

if min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
Expand Down

0 comments on commit 2263899

Please sign in to comment.