Skip to content

Commit

Permalink
add default values to config context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 12, 2022
1 parent c06eff8 commit 2a46c5e
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def get_state(self):
setattr(Config, name, property(get_state))

return _StateContextManager(name, help, update_thread_local_hook,
extra_description=extra_description)
extra_description=extra_description,
default_value=True)

def define_enum_state(
self, name: str, enum_values: List[str], default: Optional[str],
Expand Down Expand Up @@ -352,18 +353,32 @@ def _trace_context(self):
return (self.x64_enabled, self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision, self.jax_dynamic_shapes)

class NoDefault: pass
no_default = NoDefault()

class _StateContextManager:
def __init__(self, name, help, update_thread_local_hook,
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
extra_description: str = ""):
extra_description: str = "", default_value: Any = no_default):
self._name = name
self.__name__ = name[4:] if name.startswith('jax_') else name
self.__doc__ = f"Context manager for `{name}` config option{extra_description}.\n\n{help}"
self.__doc__ = (f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}")
self._update_thread_local_hook = update_thread_local_hook
self._validate_new_val_hook = validate_new_val_hook
self._default_value = default_value

@contextlib.contextmanager
def __call__(self, new_val):
def __call__(self, new_val=no_default):
if new_val is no_default:
if self._default_value is not no_default:
new_val = self._default_value # default_value provided to constructor
else:
# no default_value provided to constructor and no value provided as an
# argument, so we raise an error
raise TypeError(f"Context manager for {self.__name__} config option "
"requires an argument representing the new value for "
"the config option.")
if self._validate_new_val_hook:
self._validate_new_val_hook(new_val)
prev_val = getattr(_thread_local_state, self._name, unset)
Expand Down

0 comments on commit 2a46c5e

Please sign in to comment.