Skip to content

Commit

Permalink
Merge pull request #21070 from shuhand0:rel0.0.7
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631218770
  • Loading branch information
jax authors committed May 6, 2024
2 parents f6d8852 + aac3679 commit eee2783
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 21 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/metal_plugin_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["plugin_latest"]
jaxlib-version: ["pypi_latest", "nightly"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]

Expand All @@ -32,13 +32,14 @@ jobs:
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install jax-metal absl-py pytest
pip install absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
fi;
cd jax
pip install .
pip install jax-metal
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
Expand Down
27 changes: 8 additions & 19 deletions tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,11 @@ def testCountNonzero(self, shape, dtype, axis):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
Expand Down Expand Up @@ -370,20 +368,16 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
Expand Down Expand Up @@ -2055,7 +2049,6 @@ def attempt_sideeffect(x):
self.assertAllClose(np_input, expected_np_input_after_call)
self.assertAllClose(jnp_input, expected_jnp_input_after_call)

@unittest.skip("Jax-metal fail to convert 1D convolution op.")
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
Expand All @@ -2077,7 +2070,6 @@ def np_fun(x, y):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skip("Jax-metal fail to convert 1D convolution op.")
@jtu.sample_product(
mode=['full', 'same', 'valid'],
op=['convolve', 'correlate'],
Expand Down Expand Up @@ -4431,15 +4423,12 @@ def args_maker(): return []
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=all_shapes, dtype=all_dtypes,
shape=nonzerodim_shapes, dtype=all_dtypes,
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
Expand Down Expand Up @@ -5724,7 +5713,7 @@ def test_gather_ir(self):
#loc = loc(unknown)
module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> {
%0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = dense<[1, 2, 1]> : tensor<3xi64>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2)
%0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 2, 1>} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2)
return %0 : tensor<3x2xf32> loc(#loc)
} loc(#loc)
} loc(#loc)
Expand Down

0 comments on commit eee2783

Please sign in to comment.