-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] Add backwards compatibility test for Pallas call on GPUs.
Note that this adds the minimum of safety net to protect against non-backwards-compatible changes. We really should have more tests that cover more of the Triton MLIR. Also enable serialization of such calls. PiperOrigin-RevId: 630033989
- Loading branch information
Showing
7 changed files
with
164 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import datetime | ||
from numpy import array, float32 | ||
|
||
|
||
# Pasted from the test output (see export_back_compat_test_util.py module docstring) | ||
data_2024_05_02 = dict( | ||
testdata_version=1, | ||
platform='cuda', | ||
custom_call_targets=['__gpu$xla.gpu.triton'], | ||
serialized_date=datetime.date(2024, 5, 2), | ||
inputs=(array([0., 1., 2., 3., 4., 5., 6., 7.], dtype=float32),), | ||
expected_outputs=(array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32),), | ||
mlir_module_text=r""" | ||
#loc1 = loc("x") | ||
#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":43:13) | ||
#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2)) | ||
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { | ||
func.func public @main(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<8xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { | ||
%0 = call @wrapped(%arg0) : (tensor<8xf32>) -> tensor<8xf32> loc(#loc3) | ||
return %0 : tensor<8xf32> loc(#loc) | ||
} loc(#loc) | ||
func.func private @wrapped(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2))) -> (tensor<8xf32> {mhlo.layout_mode = "default"}) { | ||
%0 = stablehlo.custom_call @__gpu$xla.gpu.triton(%arg0) {mhlo.backend_config = {debug = false, grid_x = 8 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIRgooglex-trunk\00\01-\07\01\05\09\17\01\03\0F\03\0D\13\17\1B\1F#'\05\09+/37\03O1\0B\01-\07\0F\0F\0F\0F\13\13\13\0B\0F\0F\13\0B\0F\0B\0B\0B\0B\1F\0B\0B\13\05\05YY\01\09\0F\07\17\0B\03\035\02\16\02\1F\11\01\05\1D)+\1D#\0F\11\01\01#\01\01\01\03\03\19\1B\17\11U'\05\1D\11\07\00\1D'\0F\01\05\0D\0D\05\1F\11\01\81\0D\05\05!\05#\05%\13\03\10\00\00\E0\0F\05'\05)\17\11U\11#arith.overflow<none>\00#arith.fastmath<none>\00\01\02\02\0B\05\05\09\09\01\01\09!tt.ptr<f32>\00\04\D2\02\05\01P\01\01\07\04\AE\02\03\01\05\07P\01\03\07\04\82\02\03+W\05\11\11\00\09B\01\05\03\01\0FB\01\07\03\01\11F\01\09\03\01\05\05\07\0FB\01\07\03\01\11F\01\09\03\01\05\05\0B\0FB\07\05\03\01\13F\07\09\03\01\05\0F\09\0FB\07\07\03\01\11F\07\09\03\01\05\11\13\03\06\07\03\09\05\01\15\05F\07\0B\03\03\03\17\0FB\15\0D\03\03\15F\15\0F\03\03\05\19\1B\0FB\05\05\03\01\13F\05\09\03\01\05\1F\0D\0FB\05\07\03\01\11F\05\09\03\01\05!#\03\06\05\03\09\05\03%\05F\05\0B\03\03\03'\0BD\05\11\05'\1D\0D\00\01\06\03\01\05\01\00\F2\05+\A5\0B\A3\0F\11!\85\0B\0B\0B\13\0F\0D\1F\0B\0B\0F\0F\0D\07\11builtin\00tt\00arith\00module\00addptr\00load\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00addf\00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\00tt.divisibility\00add_one\00public\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00/add\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00\08C\13\05\01\01\0B/\1D\01\1FC\03\09\03\03\03[\11\17\07\07'\01\07\01\03\03%\03_\07\17\07\07", name = "add_one", num_stages = 3 : i32, num_warps = 4 : i32}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xf32>) -> tensor<8xf32> loc(#loc4) | ||
return %0 : tensor<8xf32> loc(#loc3) | ||
} loc(#loc3) | ||
} loc(#loc) | ||
#loc = loc(unknown) | ||
#loc4 = loc("jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]"(#loc2)) | ||
""", | ||
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xaf\x8b\x11\x01G\x07\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03E\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x13\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0bK\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f/\x01\x05\x0b\x0f\x03\r\x13\x07\x17\x07\x13\x07\x02\xee\x03\x1f\x1d#\x11\x05\x0f\x11\x03\x05\x05\x11\x05\x13\x05\x15\x05\x17\x17%W\x1b\x03\t\x15\x17\x19\x07\x1b\x07\x05\x1d\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\tG\x0bM\r]\x05c\x0fe\x03\x0b\tG\x0bM\rG\x05Q\x0fg\x05!\x05#\x03\x13)i+O-k/S1U3m5Y7S9Y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x055\x1d=\x11\x057\x1dA\x01\x059\x03\x03EQ\x05;\x03\x03[\x1d=\x1d?#\t\x1dA\x1dC\x03\x01\x05\x01\x13\x07\x05\x03\x03\x89\r\x03IK\x03\x03_\r\x05aOIK\x1dE\x1dG\x1dI\x1dK\x0b\x03\x1dM\r\x11oUqsuWwWy{}\x7f\x81\x83\x85\x87\x1dO\x1dQ\x13\x07!\x1dS\x1dU\x1dW\x1dY\x1d[\x1d]\x1d_\x13\x07\r\x1da\x13\x07\x11\x1f\r\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03!\x0b\x1b\x11\x03\x05\x03\x05\t)\x03\x05\x0f\x13\x04s\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x1f\x07\x03\x05\x0b\x03\x05?\t\x07\x03C\x03\x05\x03\x01\x05\x04\x01\x03\x03\x03\x11\x03!\x07\x03\x05\x0b\x03\x05\x03\x07\x07;'\x03\x05\x03\x01\x05\x04\x03\x03\x03\x06\x03\x01\x05\x01\x00\x12%c\x15\x17\x11\x0b\xfe\x0c\x07\x0f\x0f\x0f\r+\x11\x0f\x0b!\x11\x03\x11#\x0f\x05\xd2\n\x1f/!)!)#\x1f\x19\x85j\x03\x13%)9\x1f\x15\x1d\x15\x13\x11\x1f\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]\x00x\x00callee\x00mhlo.layout_mode\x00default\x00\x00wrapped\x00jax.result_info\x00main\x00public\x00private\x00__gpu$xla.gpu.triton\x00debug\x00grid_x\x00grid_y\x00grid_z\x00ir\x00ML\xefR\rMLIRgooglex-trunk\x00\x01-\x07\x01\x05\t\x17\x01\x03\x0f\x03\r\x13\x17\x1b\x1f#'\x05\t+/37\x03O1\x0b\x01-\x07\x0f\x0f\x0f\x0f\x13\x13\x13\x0b\x0f\x0f\x13\x0b\x0f\x0b\x0b\x0b\x0b\x1f\x0b\x0b\x13\x05\x05YY\x01\t\x0f\x07\x17\x0b\x03\x035\x02\x16\x02\x1f\x11\x01\x05\x1d)+\x1d#\x0f\x11\x01\x01#\x01\x01\x01\x03\x03\x19\x1b\x17\x11U'\x05\x1d\x11\x07\x00\x1d'\x0f\x01\x05\r\r\x05\x1f\x11\x01\x81\r\x05\x05!\x05#\x05%\x13\x03\x10\x00\x00\xe0\x0f\x05'\x05)\x17\x11U\x11#arith.overflow<none>\x00#arith.fastmath<none>\x00\x01\x02\x02\x0b\x05\x05\t\t\x01\x01\t!tt.ptr<f32>\x00\x04\xd2\x02\x05\x01P\x01\x01\x07\x04\xae\x02\x03\x01\x05\x07P\x01\x03\x07\x04\x82\x02\x03+W\x05\x11\x11\x00\tB\x01\x05\x03\x01\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x07\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x0b\x0fB\x07\x05\x03\x01\x13F\x07\t\x03\x01\x05\x0f\t\x0fB\x07\x07\x03\x01\x11F\x07\t\x03\x01\x05\x11\x13\x03\x06\x07\x03\t\x05\x01\x15\x05F\x07\x0b\x03\x03\x03\x17\x0fB\x15\r\x03\x03\x15F\x15\x0f\x03\x03\x05\x19\x1b\x0fB\x05\x05\x03\x01\x13F\x05\t\x03\x01\x05\x1f\r\x0fB\x05\x07\x03\x01\x11F\x05\t\x03\x01\x05!#\x03\x06\x05\x03\t\x05\x03%\x05F\x05\x0b\x03\x03\x03'\x0bD\x05\x11\x05'\x1d\r\x00\x01\x06\x03\x01\x05\x01\x00\xf2\x05+\xa5\x0b\xa3\x0f\x11!\x85\x0b\x0b\x0b\x13\x0f\r\x1f\x0b\x0b\x0f\x0f\r\x07\x11builtin\x00tt\x00arith\x00module\x00addptr\x00load\x00func\x00get_program_id\x00store\x00return\x00constant\x00muli\x00addi\x00addf\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00tt.divisibility\x00add_one\x00public\x00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00/add\x00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00\x08C\x13\x05\x01\x01\x0b/\x1d\x01\x1fC\x03\t\x03\x03\x03[\x11\x17\x07\x07'\x01\x07\x01\x03\x03%\x03_\x07\x17\x07\x07\x00name\x00add_one\x00num_stages\x00num_warps\x00", | ||
xla_call_module_version=9, | ||
nr_devices=1, | ||
) # End paste |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for backwards compatibility of exporting code with Pallas custom calls. | ||
See the export_back_compat_test_util module docstring for how to setup and | ||
update these tests. | ||
""" | ||
|
||
from absl.testing import absltest | ||
|
||
import jax | ||
from jax._src import config | ||
from jax._src import test_util as jtu | ||
from jax._src.internal_test_util import export_back_compat_test_util as bctu | ||
|
||
from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one | ||
|
||
from jax.experimental import pallas as pl | ||
try: | ||
from jax.experimental.pallas import gpu as plgpu | ||
except ImportError: | ||
plgpu = None | ||
import jax.numpy as jnp | ||
|
||
|
||
config.parse_flags_with_absl() | ||
|
||
|
||
@jtu.with_config(jax_include_full_tracebacks_in_locations=False) | ||
class CompatTest(bctu.CompatTestBase): | ||
|
||
def setUp(self): | ||
if jax.config.x64_enabled: | ||
self.skipTest("Only works in 32-bit") | ||
if not jtu.test_device_matches(["gpu"]): | ||
self.skipTest("Only works on GPU") | ||
if (jtu.test_device_matches(["cuda"]) and | ||
(plgpu is None or plgpu.get_compute_capability(0) < 80)): | ||
self.skipTest("Only works on GPUs with capability >= sm80") | ||
super().setUp() | ||
|
||
def test_cuda_add_one(self): | ||
def func(x): | ||
def add_one(x_ref, o_ref): | ||
o_ref[0] = x_ref[0] + 1 | ||
return pl.pallas_call(add_one, | ||
out_shape=jax.ShapeDtypeStruct((8,), jnp.float32), | ||
in_specs=[pl.BlockSpec(lambda i: i, (1,))], | ||
out_specs=pl.BlockSpec(lambda i: i, (1,)), | ||
grid=8)(x) | ||
data = self.load_testdata(cuda_add_one.data_2024_05_02) | ||
|
||
self.run_one_test(func, data) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main(testLoader=jtu.JaxTestLoader()) |