Skip to content

Commit

Permalink
Fix type mismatch in jet rule for abs (#3807)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 21, 2020
1 parent a6e2d20 commit 71f80a5
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,9 @@ def _reduce_chooser_taylor_rule(g):

def _abs_taylor_rule(x, series_in, **params):
x, = x
zero = lax.full_like(x, 0, shape=())
primal_out = lax.abs_p.bind(x, **params)
negs = lax.select(lax.lt(x, 0.0), lax.full_like(x, -1), lax.full_like(x, 1.0))
negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
fix_sign = lambda y: negs * y
series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
return primal_out, series_out
Expand Down

0 comments on commit 71f80a5

Please sign in to comment.