Skip to content

Commit

Permalink
[jax2tf] Refactor the backward compatibility tests.
Browse files Browse the repository at this point in the history
We prepare for having multiple test file, and we want to share
the helper functions.
  • Loading branch information
gnecula committed Jun 19, 2023
1 parent 5c31d3b commit 185abe0
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 194 deletions.
378 changes: 203 additions & 175 deletions jax/experimental/jax2tf/tests/back_compat_test.py

Large diffs are not rendered by default.

Expand Up @@ -27,7 +27,7 @@
expected_outputs=(array([[ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
[22.+0.j, -2.+2.j, -2.+0.j, -2.-2.j],
[38.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]], dtype=complex64),),
mlir_module_text="""
mlir_module_text=r"""
module @jit_func {
func.func public @main(%arg0: tensor<3x4xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<3x4xcomplex<f32>> {jax.result_info = ""}) {
%0 = call @fft(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xcomplex<f32>>
Expand Down
Expand Up @@ -31,7 +31,7 @@
[-0.8944271 , -0.18257445, 0.40824816]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914852e+00],
[ 0.0000000e+00, 1.0954441e+00, 2.1908894e+00],
[ 0.0000000e+00, 0.0000000e+00, 7.1525574e-07]], dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xf32>
Expand Down Expand Up @@ -113,7 +113,7 @@
2.1908902300206665e+00],
[ 0.0000000000000000e+00, 0.0000000000000000e+00,
-8.8817841970012523e-16]])),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]"}, tensor<3x3xf64> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xf64>
Expand Down Expand Up @@ -194,7 +194,7 @@
[ 0.0000000e+00+0.j, 1.0954441e+00+0.j, 2.1908894e+00+0.j],
[ 0.0000000e+00+0.j, 0.0000000e+00+0.j, 7.1525574e-07+0.j]],
dtype=complex64)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xcomplex<f32>> {jax.result_info = "[0]"}, tensor<3x3xcomplex<f32>> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xcomplex<f32>>
Expand Down Expand Up @@ -280,7 +280,7 @@
2.1908902300206665e+00+0.j],
[ 0.0000000000000000e+00+0.j, 0.0000000000000000e+00+0.j,
-8.8817841970012523e-16+0.j]])),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xcomplex<f64>> {jax.result_info = "[0]"}, tensor<3x3xcomplex<f64>> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xcomplex<f64>>
Expand Down
Expand Up @@ -44,7 +44,7 @@
dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05,
-1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02],
dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<64xf32>
Expand Down Expand Up @@ -139,7 +139,7 @@
-1.9932120610662194e-14, -5.7323356091157378e-15,
-4.5459724251334835e-16, 4.0479851042511616e-14,
9.2325194924982089e-14, 2.7659880477613365e+02])),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<64xf64>
Expand Down Expand Up @@ -225,7 +225,7 @@
-0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05,
-1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02],
dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<8x8xcomplex<f32>> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<64xcomplex<f32>>
Expand Down Expand Up @@ -325,7 +325,7 @@
-1.9932120610662194e-14, -5.7323356091157378e-15,
-4.5459724251334835e-16, 4.0479851042511616e-14,
9.2325194924982089e-14, 2.7659880477613365e+02])),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<8x8xcomplex<f64>> {jax.result_info = "[0]"}, tensor<8xf64> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<64xcomplex<f64>>
Expand Down
Expand Up @@ -31,7 +31,7 @@
[-0.8944271 , -0.18257457, 0.40824813]], dtype=float32), array([[-6.7082043e+00, -8.0498438e+00, -9.3914843e+00],
[ 0.0000000e+00, 1.0954436e+00, 2.1908882e+00],
[ 0.0000000e+00, 0.0000000e+00, 5.6703755e-08]], dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xf32>
Expand Down Expand Up @@ -110,7 +110,7 @@
[[-2.1213203e+01, -2.2910259e+01, -2.4607319e+01],
[ 0.0000000e+00, 3.4641042e-01, 6.9282258e-01],
[ 0.0000000e+00, 0.0000000e+00, 1.4548683e-06]]], dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<2x3x3xf32> {jax.result_info = "[0]"}, tensor<2x3x3xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<18xf32>
Expand Down
Expand Up @@ -548,7 +548,7 @@
1.49946762e-04, 1.86386926e-04, 1.89535742e-04, 2.40968098e-03,
2.56012683e-03, 2.69382820e-03, 3.27441283e-03, 2.52088105e+04],
dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<36x36xf32> {jax.result_info = "[0]"}, tensor<36xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<1296xf32>
Expand Down Expand Up @@ -1360,7 +1360,7 @@
3.0421425292401405e-13, 3.1193691330212636e-13,
3.1270969371399125e-13, 4.3446674157388007e-13,
1.6764394233642590e-12, 2.5208822708003838e+04])),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<36x36xf64> {jax.result_info = "[0]"}, tensor<36xf64> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<1296xf64>
Expand Down
Expand Up @@ -24,7 +24,7 @@
inputs=(array([42, 43], dtype=uint32),),
expected_outputs=(array([[0.42591238, 0.0769949 , 0.44370103, 0.72904015],
[0.17879379, 0.81439507, 0.00191903, 0.68608475]], dtype=float32),),
mlir_module_text="""
mlir_module_text=r"""
module @jit_func {
func.func public @main(%arg0: tensor<2xui32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<2x4xf32> {jax.result_info = ""}) {
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
Expand Down
Expand Up @@ -25,7 +25,7 @@
serialized_date=datetime.date(2023, 4, 17),
inputs=(),
expected_outputs=(array([7., 6., 5.], dtype=float32), array([6, 5, 4], dtype=int32)),
mlir_module_text="""
mlir_module_text=r"""
#loc = loc(unknown)
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func @top_k_gt_comparator(%arg0: tensor<f32> loc(unknown), %arg1: tensor<f32> loc(unknown), %arg2: tensor<i32> loc(unknown), %arg3: tensor<i32> loc(unknown)) -> tensor<i1> {
Expand Down
Expand Up @@ -42,7 +42,7 @@
0.12771341, -0.6378056 , 0.4931458 ]], dtype=float32), array([-2.4598616e+01, -1.1325381e-03, -1.2342700e-04, 2.9237286e-05,
5.4759425e-05, 3.0579782e-04, 5.1378174e-04, 2.7659894e+02],
dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]"}, tensor<8xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<64xf32>
Expand Down
Expand Up @@ -27,7 +27,7 @@
expected_outputs=(array([[6. , 7. , 8. ],
[0. , 1. , 2. ],
[0.5, 0.5, 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<3xi32> {jax.result_info = "[2]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xf32>
Expand Down
Expand Up @@ -29,7 +29,7 @@
[-0.8944271 , -0.18257444, 0.4082482 ]], dtype=float32), array([[-6.7082043, -8.049844 , -9.391484 ],
[ 0. , 1.0954441, 2.1908882],
[ 0. , 0. , 0. ]], dtype=float32)),
mlir_module_text="""
mlir_module_text=r"""
module @jit__lambda_ {
func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]"}, tensor<3x3xf32> {jax.result_info = "[1]"}) {
%0 = stablehlo.iota dim = 0 : tensor<9xf32>
Expand Down
Expand Up @@ -25,7 +25,7 @@
[4., 5., 6., 7.]], dtype=float32),),
expected_outputs=(array([[4., 5., 6., 7.],
[0., 1., 2., 3.]], dtype=float32),),
mlir_module_text="""
mlir_module_text=r"""
module @jit_wrapped {
func.func public @main(%arg0: tensor<2x4xf32> {jax.arg_info = "args[0]", mhlo.sharding = "{replicated}"}) -> (tensor<2x4xf32> {jax.result_info = ""}) {
%0 = call @wrapped(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32>
Expand Down

0 comments on commit 185abe0

Please sign in to comment.