diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4fb8fa6afab2..7b3060316a72 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2537,6 +2537,9 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, else: compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + fdo_profile = (None if compiler_options is None else + compiler_options.pop("fdo_profile", None)) + compile_options = xb.get_compile_options( num_replicas=num_replicas, num_partitions=num_partitions, @@ -2544,6 +2547,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, use_spmd_partitioning=spmd_lowering, use_auto_spmd_partitioning=auto_spmd_lowering, env_options_overrides=compiler_options, + fdo_profile=fdo_profile, ) opts = compile_options.executable_build_options diff --git a/tests/BUILD b/tests/BUILD index 7e216d31e6af..857e0255ffd3 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -195,6 +195,23 @@ jax_test( ], ) +jax_test( + name = "pgle_test", + srcs = ["pgle_test.py"], + disable_backends = [ + "cpu", + "tpu", + ], + env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"}, + tags = [ + "config-cuda-only", + "multiaccelerator", + ], + deps = [ + "//jax:experimental", + ], +) + jax_test( name = "array_test", srcs = ["array_test.py"], diff --git a/tests/pgle_test.py b/tests/pgle_test.py new file mode 100644 index 000000000000..d7562e3d7ef0 --- /dev/null +++ b/tests/pgle_test.py @@ -0,0 +1,74 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import glob +import logging +import math +import os +import tempfile + +from absl.testing import absltest +import jax +from jax import config +from jax._src import test_util as jtu +from jax.sharding import NamedSharding +from jax.experimental import profiler as exp_profiler +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import numpy as np + +config.parse_flags_with_absl() + + +@jtu.pytest_mark_if_available('multiaccelerator') +class PgleTest(jtu.JaxTestCase): + + def testPassingFDOProfile(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + @partial( + jax.jit, + in_shardings=NamedSharding(mesh, P('x',)), + out_shardings=NamedSharding(mesh, P('x',)), + ) + def f(x, y): + z = x @ y + return z @ y + + shape = (8, 8) + x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) + y = x + 1 + f_lowered = f.lower(x, y) + compiled = f_lowered.compile() + + with tempfile.TemporaryDirectory() as tmpdir: + jax.profiler.start_trace(tmpdir) + compiled(x, y) + jax.profiler.stop_trace() + directories = glob.glob(os.path.join(tmpdir, 'plugins/profile/**/')) + directories = [d for d in directories if os.path.isdir(d)] + rundir = directories[-1] + logging.info('rundir: %s', rundir) + fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir) + + if jtu.device_under_test() == 'gpu' and jtu.is_device_cuda(): + self.assertIn(b'custom', fdo_profile) + + logging.info('fdo_profile: %s', fdo_profile) + # Test pass fdo_profile as compiler_options API works. + f_lowered.compile(compiler_options={'fdo_profile': fdo_profile}) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader())