Skip to content

Commit

Permalink
BEGIN_PUBLIC
Browse files Browse the repository at this point in the history
Use new lax.cond API (#43)
ENND_PUBLIC

PiperOrigin-RevId: 320217575
  • Loading branch information
romanngg committed Jul 8, 2020
1 parent 661a4fd commit 156a4f7
Show file tree
Hide file tree
Showing 11 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions neural_tangents/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ def predict_fn(
# `fx_test_0`] are evaluated, but also a strictly increasing sequence of
# timesteps, so we always temporarily append an [almost] `0` at the start.
identity = lambda x: x
t0 = lax.cond(t[0] == 0.,
np.full((1,), -1e-24, t.dtype), identity,
np.zeros((1,), t.dtype), identity)
t0 = np.where(t[0] == 0,
np.full((1,), -1e-24, t.dtype),
np.zeros((1,), t.dtype))
t = np.concatenate([t0, t])

# Solve the ODE.
Expand Down
6 changes: 3 additions & 3 deletions neural_tangents/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,10 @@ def get_masked_array(x: ArrayOrList,
if mask_constant is None:
mask = None
else:
id_fn = lambda m: m
mask = lax.cond(np.isnan(mask_constant),
np.isnan(x), id_fn,
x == mask_constant, id_fn)
lambda x: np.isnan(x),
lambda x: x == mask_constant,
x)
else:
raise TypeError(x, type(x))

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 156a4f7

Please sign in to comment.