Skip to content

Commit

Permalink
instantiate zeros (#2924)
Browse files Browse the repository at this point in the history
fix dtype

remove TODO
  • Loading branch information
jacobjinkelly committed May 2, 2020
1 parent f8fa589 commit a821e67
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
4 changes: 3 additions & 1 deletion jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.util import unzip2
from jax import ad_util
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
treedef_is_leaf, tree_flatten, tree_unflatten, tree_map)
import jax.linear_util as lu
from jax.interpreters import xla
from jax.lax import lax
Expand Down Expand Up @@ -59,6 +59,8 @@ def jet_fun(primals, series):
with core.new_master(JetTrace) as master:
out_primals, out_terms = yield (master, primals, series), {}
del master
out_terms = [tree_map(lambda x: onp.zeros_like(x, dtype=onp.result_type(out_primals[0])), series[0])
if s is zero_series else s for s in out_terms]
yield out_primals, out_terms

@lu.transformation
Expand Down
23 changes: 15 additions & 8 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)

# TODO(duvenaud): Lower zero_series to actual zeros automatically.
if terms == zero_series:
terms = tree_map(np.zeros_like, expected_terms)

self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)

Expand All @@ -86,10 +82,6 @@ def _convert(x):
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)

# TODO(duvenaud): Lower zero_series to actual zeros automatically.
if terms == zero_series:
terms = tree_map(np.zeros_like, expected_terms)

self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)

Expand Down Expand Up @@ -291,6 +283,21 @@ def test_select(self):
series_in = (terms_b, terms_x, terms_y)
self.check_jet(np.where, primals, series_in)

def test_inst_zero(self):
def f(x):
return 2.
def g(x):
return 2. + 0 * x
x = np.ones(1)
order = 3
f_out_primals, f_out_series = jet(f, (x, ), ([np.ones_like(x) for _ in range(order)], ))
assert f_out_series is not zero_series

g_out_primals, g_out_series = jet(g, (x, ), ([np.ones_like(x) for _ in range(order)], ))

assert g_out_primals == f_out_primals
assert g_out_series == f_out_series


if __name__ == '__main__':
absltest.main()

0 comments on commit a821e67

Please sign in to comment.