From 0c517f2d8372a351f139146773150528ffac7ee7 Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Thu, 2 May 2024 17:18:55 -0700 Subject: [PATCH 1/4] update along with lax_numpy_test --- tests/lax_metal_test.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 87ca6eaee322..5c82a6d300a9 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -302,13 +302,11 @@ def testCountNonzero(self, shape, dtype, axis): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) @@ -370,20 +368,16 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testArgWhere(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of this # behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) @@ -4431,15 +4425,12 @@ def args_maker(): return [] self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=all_shapes, dtype=all_dtypes, + shape=nonzerodim_shapes, dtype=all_dtypes, ) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of # this behavior is in testNonzeroSize(). From 99e5b8e9992a610f76ce88f6d255c717f12563ba Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Thu, 2 May 2024 18:20:39 -0700 Subject: [PATCH 2/4] enable conv1d test --- tests/lax_metal_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 5c82a6d300a9..48e036966f21 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -2049,7 +2049,6 @@ def attempt_sideeffect(x): self.assertAllClose(np_input, expected_np_input_after_call) self.assertAllClose(jnp_input, expected_jnp_input_after_call) - @unittest.skip("Jax-metal fail to convert 1D convolution op.") @jtu.sample_product( mode=['full', 'same', 'valid'], op=['convolve', 'correlate'], @@ -2071,7 +2070,6 @@ def np_fun(x, y): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) self._CompileAndCheck(jnp_fun, args_maker) - @unittest.skip("Jax-metal fail to convert 1D convolution op.") @jtu.sample_product( mode=['full', 'same', 'valid'], op=['convolve', 'correlate'], From 28fe45d8727a5061117a038285acc28a2b5559b6 Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Thu, 2 May 2024 18:40:01 -0700 Subject: [PATCH 3/4] metal_plugin ci with jaxlib nightly --- .github/workflows/metal_plugin_ci.yml | 5 +++-- tests/lax_metal_test.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index a62b0a1b6eb2..75ae8d7e80a9 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - jaxlib-version: ["plugin_latest"] + jaxlib-version: ["plugin_latest", "nightly"] name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})" runs-on: [self-hosted, macOS, ARM64] @@ -32,13 +32,14 @@ jobs: python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate pip install -U pip numpy wheel - pip install jax-metal absl-py pytest + pip install absl-py pytest if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then pip install --pre jaxlib \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html fi; cd jax pip install . + pip install jax-metal - name: Run test run: | source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 48e036966f21..5069187d2334 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -5713,7 +5713,7 @@ def test_gather_ir(self): #loc = loc(unknown) module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> { - %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 2, 1]> : tensor<3xi64>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) + %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) return %0 : tensor<3x2xf32> loc(#loc) } loc(#loc) } loc(#loc) From aac36799fde8379847f89b68b6eccb13449cce27 Mon Sep 17 00:00:00 2001 From: Shuhan Ding Date: Mon, 6 May 2024 13:51:22 -0700 Subject: [PATCH 4/4] fix jaxlib config name --- .github/workflows/metal_plugin_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 75ae8d7e80a9..6e67841460e9 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false # don't cancel all jobs on failure matrix: - jaxlib-version: ["plugin_latest", "nightly"] + jaxlib-version: ["pypi_latest", "nightly"] name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})" runs-on: [self-hosted, macOS, ARM64]