diff --git a/tests/jaxpr_util_test.py b/tests/jaxpr_util_test.py index 557e7a4b9247..351524e443f6 100644 --- a/tests/jaxpr_util_test.py +++ b/tests/jaxpr_util_test.py @@ -57,13 +57,14 @@ def sub(x, y): hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr) + t = '64' if FLAGS.jax_enable_x64 else '32' shapes = [ - 'add :: float32[]', - 'sin :: float32[]', - 'cos :: float32[]', - 'reduce_sum :: float32[]', - 'concatenate :: float32[2]', - 'xla_call :: float32[] *', + f'add :: float{t}[]', + f'sin :: float{t}[]', + f'cos :: float{t}[]', + f'reduce_sum :: float{t}[]', + f'concatenate :: float{t}[2]', + f'xla_call :: float{t}[] *', ] for k in shapes: self.assertEqual(hist[k], 1)