Skip to content

Commit

Permalink
Change --jax_xla_profile_version definition to config.
Browse files Browse the repository at this point in the history
Changing the flag to a config permits more contained testing.
This is in preparation for an upcoming change to incorporate
AutoFDO profile versions in the cache key.

Testing: test workload.
PiperOrigin-RevId: 554942573
  • Loading branch information
jax authors committed Aug 8, 2023
1 parent 3e50fea commit d01695c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
8 changes: 8 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,14 @@ def _update_disable_jit_thread_local(val):
'work under pmap/pjit.')
)

jax_xla_profile_version = config.define_int_state(
name='jax_xla_profile_version',
default=0,
help=('Optional profile version for XLA compilation. This is meaningful '
'only when XLA is configured to support the remote compilation '
'profile feature.')
)

@contextlib.contextmanager
def explicit_device_put_scope() -> Iterator[None]:
"""Indicates that the current context is an explicit device_put*() call."""
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from jax._src import lib
from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import bool_env, config, int_env
from jax._src.config import bool_env, config
from jax._src.lib import xla_client
from jax._src import traceback_util
from jax._src import util
Expand Down Expand Up @@ -78,11 +78,6 @@
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
'Optional profile version for XLA compilation. '
'This is meaningful only when XLA is configured to '
'support the remote compilation profile feature.')
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
Expand Down Expand Up @@ -175,7 +170,7 @@ def get_compile_options(
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False

compile_options.profile_version = _XLA_PROFILE_VERSION.value
compile_options.profile_version = config.jax_xla_profile_version
return compile_options


Expand Down

0 comments on commit d01695c

Please sign in to comment.