Skip to content

Commit

Permalink
introduce an upgrade flag for custom_vjp overhaul (still experimental)
Browse files Browse the repository at this point in the history
Along the way, make "boolean upgrade flag" a common option for current
and future re-use.
  • Loading branch information
froystig committed Mar 28, 2022
1 parent 5dc068f commit 1d9a6c1
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions jax/_src/config.py
Expand Up @@ -54,6 +54,14 @@ def int_env(varname: str, default: int) -> int:
return int(os.getenv(varname, str(default)))


UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
"point all uses of the flag will be considered deprecated (following "
"https://jax.readthedocs.io/en/latest/api_compatibility.html).")

UPGRADE_BOOL_EXTRA_DESC = " (transient)"


class Config:
_HAS_DYNAMIC_ATTRIBUTES = True

Expand Down Expand Up @@ -181,6 +189,7 @@ def define_bool_state(
self, name: str, default: bool, help: str, *,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
upgrade: bool = False,
extra_description: str = ""):
"""Set up thread-local state and return a contextmanager for managing it.
Expand All @@ -203,6 +212,9 @@ def define_bool_state(
update_thread_local_hook: a optional callback that is called with the
updated value of the thread-local state when it is altered or set
initially.
upgrade: optional indicator that this flag controls a canonical feature
upgrade, so that it is `True` for the incoming functionality, `False`
for the outgoing functionality to be deprecated.
extra_description: string, optional: extra information to add to the
summary description.
Expand All @@ -228,8 +240,12 @@ def define_bool_state(
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
an error.
"""
name = name.lower()
if upgrade:
help += ' ' + UPGRADE_BOOL_HELP
extra_description += UPGRADE_BOOL_EXTRA_DESC
self.DEFINE_bool(name, bool_env(name.upper(), default), help,
update_hook=update_global_hook)
self._contextmanager_flags.add(name)
Expand Down Expand Up @@ -548,15 +564,12 @@ def update_thread_local_jit_state(**kw):
'computations. Logging is performed with `absl.logging` at WARNING '
'level.'))


enable_custom_prng = config.define_bool_state(
name='jax_enable_custom_prng',
default=False,
help=('Enables an internal upgrade that allows one to define custom '
'pseudo-random number generator implementations. This will '
'be enabled by default in future versions of JAX, at which point '
'disabling it will be considered deprecated. In a version '
'after that the flag will be removed altogether.'),
extra_description=" (transient)")
'pseudo-random number generator implementations.'))

default_prng_impl = config.define_enum_state(
name='jax_default_prng_impl',
Expand All @@ -565,6 +578,12 @@ def update_thread_local_jit_state(**kw):
help=('Select the default PRNG implementation, used when one is not '
'explicitly provided at seeding time.'))

enable_custom_vjp_by_custom_transpose = config.define_bool_state(
name='jax_enable_custom_vjp_by_custom_transpose',
default=False,
help=('Enables an internal upgrade that implements `jax.custom_vjp` by '
'reduction to `jax.custom_jvp` and `jax.custom_transpose`.'))

hlo_source_file_canonicalization_regex = config.define_string_state(
name='jax_hlo_source_file_canonicalization_regex',
default=None,
Expand Down

0 comments on commit 1d9a6c1

Please sign in to comment.