Skip to content

Commit

Permalink
Replaced jax.lax.select with jnp.where
Browse files Browse the repository at this point in the history
  • Loading branch information
nlsfnr committed Aug 30, 2022
1 parent 260e5ba commit c9b341f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jmp/_src/loss_scale.py
Expand Up @@ -143,12 +143,12 @@ def adjust(self, grads_finite: jnp.ndarray) -> "DynamicLossScale":
"""Returns the next state dependent on whether grads are finite."""
assert grads_finite.ndim == 0, "Expected boolean scalar"

first_finite = lambda a, b: jax.lax.select(jnp.isfinite(a).all(), a, b)
loss_scale = jax.lax.select(
first_finite = lambda a, b: jnp.where(jnp.isfinite(a).all(), a, b)
loss_scale = jnp.where(
grads_finite,

# When grads are finite increase loss scale periodically.
jax.lax.select(
jnp.where(
self.counter == (self.period - 1),
first_finite(self.loss_scale * self.factor,
self.loss_scale),
Expand Down Expand Up @@ -188,4 +188,4 @@ def all_finite(tree) -> jnp.ndarray:
def select_tree(pred: jnp.ndarray, a: T, b: T) -> T:
"""Selects a pytree based on the given predicate."""
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
return jax.tree_map(functools.partial(jax.lax.select, pred), a, b)
return jax.tree_map(functools.partial(jnp.where, pred), a, b)

0 comments on commit c9b341f

Please sign in to comment.