Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634531740
  • Loading branch information
Jake VanderPlas authored and The bayeux Authors committed May 16, 2024
1 parent 40d1672 commit fb1861b
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bayeux/_src/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def map_fn(chain_method, fn):
elif chain_method == "vectorized":
return jax.vmap(fn)
elif chain_method == "sequential":
return functools.partial(jax.tree_map, fn)
return functools.partial(jax.tree.map, fn)
raise ValueError(f"Chain method {chain_method} not supported.")


Expand Down
14 changes: 7 additions & 7 deletions bayeux/_src/vi/tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Custom(tfb.Bijector):

def __init__(self, bx_model):
super().__init__(
forward_min_event_ndims=jax.tree_map(jnp.ndim, bx_model.test_point))
forward_min_event_ndims=jax.tree.map(jnp.ndim, bx_model.test_point))
self.bx_model = bx_model

def _forward(self, x):
Expand All @@ -46,12 +46,12 @@ def _forward_log_det_jacobian(self, x):
return -self.inverse_log_det_jacobian(self.forward(x))

def _forward_event_shape_tensor(self, input_shape):
return jax.tree_map(jnp.shape,
self._forward(jax.tree_map(jnp.ones, input_shape)))
return jax.tree.map(jnp.shape,
self._forward(jax.tree.map(jnp.ones, input_shape)))

def _inverse_event_shape_tensor(self, output_shape):
return jax.tree_map(jnp.shape,
self._inverse(jax.tree_map(jnp.ones, output_shape)))
return jax.tree.map(jnp.shape,
self._inverse(jax.tree.map(jnp.ones, output_shape)))


def get_fit_kwargs(log_density, kwargs):
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_kwargs(self, **kwargs):
return {
tfp.experimental.vi.build_factored_surrogate_posterior_stateless: (
get_build_kwargs(
jax.tree_map(jnp.shape, self.test_point),
jax.tree.map(jnp.shape, self.test_point),
self.constraining_bijector(),
kwargs)),
tfp.vi.fit_surrogate_posterior_stateless: get_fit_kwargs(
Expand Down Expand Up @@ -140,7 +140,7 @@ def __call__(self, seed, **kwargs):
elif chain_method == "parallel":
mapped_fit = jax.pmap(fit_fn)
elif chain_method == "sequential":
mapped_fit = functools.partial(jax.tree_map, fit_fn)
mapped_fit = functools.partial(jax.tree.map, fit_fn)
else:
raise ValueError(f"Chain method {chain_method} not supported.")

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/numpyro_and_bayeux.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@
"ax.plot(losses.T)\n",
"\n",
"draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))\n",
"jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)"
"jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/pymc_and_bayeux.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@
"ax.plot(losses.T)\n",
"\n",
"draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))\n",
"jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)"
"jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tfp_and_bayeux.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@
"ax.plot(losses.T)\n",
"\n",
"draws = surrogate_posterior.sample(100, seed=draw_key)\n",
"jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)"
"jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)"
]
}
],
Expand Down

0 comments on commit fb1861b

Please sign in to comment.