Skip to content

Commit

Permalink
Merge pull request #400 from hawkinsp/master
Browse files Browse the repository at this point in the history
Support tuples in translation rule for zeros_like_p.
  • Loading branch information
hawkinsp committed Feb 18, 2019
2 parents 99abdf9 + 16123f1 commit 27bedc2
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ def translation_rule(p):
subc_a1[1] + a2)
translations[core.identity_p] = lambda c, x: x

# TODO(mattjj): zeros_like and add_jaxvals should handle any jaxval
# TODO(mattjj): add_jaxvals should handle any jaxval
def zeros_like_translation_rule(c, x):
x_shape = c.GetShape(x)
return c.Broadcast(c.Constant(onp.array(0, x_shape.element_type())),
x_shape.dimensions())
def _zeros_like(shape):
if shape.is_tuple():
return c.Tuple(*(_zeros_like(x) for x in shape.tuple_shapes()))
else:
return c.Broadcast(c.Constant(onp.array(0, shape.element_type())),
shape.dimensions())
return _zeros_like(c.GetShape(x))

translations[ad_util.zeros_like_p] = zeros_like_translation_rule
translations[ad_util.add_jaxvals_p] = lambda c, x, y: c.Add(x, y)

Expand Down

0 comments on commit 27bedc2

Please sign in to comment.