Skip to content

Commit

Permalink
Deflake tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 538511611
  • Loading branch information
romanngg authored and jaehlee committed Aug 2, 2023
1 parent 0189579 commit 1e82970
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/experimental/empirical_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _compare_ntks(

x1_jax = np.array(x1)
x2_jax = np.array(x2)
params_jax = jax.tree_map(lambda x: np.array(x), params)
params_jax = jax.tree_map(np.array, params)

jax_ntks = [ntk_fn_i(x1_jax, x2_jax, params_jax)
for ntk_fn_i in jax_ntk_fns]
Expand Down
6 changes: 3 additions & 3 deletions tests/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@
(),
# (0,),
(2,),
(0, 0),
(1, 0),
(0, 1),
# (0, 0),
# (1, 0),
# (0, 1),
# (1, 1),
# (2, 1),
(1, 2),
Expand Down

0 comments on commit 1e82970

Please sign in to comment.