Skip to content

Commit

Permalink
fix jaxpr util test under enable_x64
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Aug 19, 2020
1 parent b892236 commit 5135fd1
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tests/jaxpr_util_test.py
Expand Up @@ -57,13 +57,14 @@ def sub(x, y):

hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)

t = '64' if FLAGS.jax_enable_x64 else '32'
shapes = [
'add :: float32[]',
'sin :: float32[]',
'cos :: float32[]',
'reduce_sum :: float32[]',
'concatenate :: float32[2]',
'xla_call :: float32[] *',
f'add :: float{t}[]',
f'sin :: float{t}[]',
f'cos :: float{t}[]',
f'reduce_sum :: float{t}[]',
f'concatenate :: float{t}[2]',
f'xla_call :: float{t}[] *',
]
for k in shapes:
self.assertEqual(hist[k], 1)
Expand Down

0 comments on commit 5135fd1

Please sign in to comment.