Skip to content

Commit

Permalink
_StateContextManager now preserves the type of the value it stores.
Browse files Browse the repository at this point in the history
This change is a follow-on to #16866, which added an ABSL-like API
for flags defined with `DEFINE_...`. Here we add a similar typed API for flags
defined with `define_..._state`.

See https://github.com/abseil/abseil-py/blob/37dad4d356ca9e13f1c533ad6309631b397a2b6b/absl/flags/_flagvalues.py#L1333.

PiperOrigin-RevId: 570721827
  • Loading branch information
superbobry authored and jax authors committed Oct 4, 2023
1 parent 2fe00f8 commit 923498f
Showing 1 changed file with 65 additions and 54 deletions.
119 changes: 65 additions & 54 deletions jax/_src/config.py
Expand Up @@ -202,11 +202,16 @@ def parse_flags_with_absl(self):
already_configured_with_absl = True

def define_bool_state(
self, name: str, default: bool, help: str, *,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
upgrade: bool = False,
extra_description: str = ""):
self,
name: str,
default: bool,
help: str,
*,
update_global_hook: Optional[Callable[[bool], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None,
upgrade: bool = False,
extra_description: str = '',
) -> _StateContextManager[bool]:
"""Set up thread-local state and return a contextmanager for managing it.
This function is a convenience wrapper. It defines a flag, environment
Expand Down Expand Up @@ -266,20 +271,21 @@ def define_bool_state(
update_hook=update_global_hook)
self._contextmanager_flags.add(name)

def get_state(self):
val = _thread_local_state.__dict__.get(name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

return _StateContextManager(name, help, update_thread_local_hook,
extra_description=extra_description,
default_value=True)
s = _StateContextManager[bool](
name, help, update_thread_local_hook,
extra_description=extra_description, default_value=True)
setattr(Config, name, property(lambda _: s.value))
return s

def define_enum_state(
self, name: str, enum_values: list[str], default: Optional[str],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):
self,
name: str,
enum_values: list[str],
default: Optional[str],
help: str,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
) -> _StateContextManager[str]:
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
Expand All @@ -303,24 +309,24 @@ def define_enum_state(
update_hook=update_global_hook)
self._contextmanager_flags.add(name)

def get_state(self):
val = _thread_local_state.__dict__.get(name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

def validate(new_val):
if (new_val is not None and
(type(new_val) is not str or new_val not in enum_values)):
raise ValueError(f"new enum value must be None or in {enum_values}, "
f"got {new_val} of type {type(new_val)}.")

return _StateContextManager(name, help, update_thread_local_hook, validate)
s = _StateContextManager[str](name, help, update_thread_local_hook, validate)
setattr(Config, name, property(lambda _: s.value))
return s

def define_int_state(
self, name: str, default: Optional[int],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):
self,
name: str,
default: Optional[int],
help: str,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
) -> _StateContextManager[int]:
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
Expand All @@ -343,23 +349,23 @@ def define_int_state(
self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)

def get_state(self):
val = _thread_local_state.__dict__.get(name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

def validate(new_val):
if new_val is not None and not isinstance(new_val, int):
raise ValueError(f'new int config value must be None or of type int, '
f'got {new_val} of type {type(new_val)}')

return _StateContextManager(name, help, update_thread_local_hook, validate)
s = _StateContextManager[int](name, help, update_thread_local_hook, validate)
setattr(Config, name, property(lambda _: s.value))
return s

def define_float_state(
self, name: str, default: Optional[float],
help: str, update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \
= None):
self,
name: str,
default: Optional[float],
help: str,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
) -> _StateContextManager[float]:
"""Set up thread-local state and return a contextmanager for managing it.
Args:
name: string, converted to lowercase to define the name of the config
Expand All @@ -382,22 +388,23 @@ def define_float_state(
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)

def get_state(self):
val = _thread_local_state.__dict__.get(name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

def validate(new_val):
if new_val is not None and not isinstance(new_val, (float, int)):
raise ValueError(f'new float config value must be None or of type float, '
f'got {new_val} of type {type(new_val)}')

return _StateContextManager(name, help, update_thread_local_hook, validate)
s = _StateContextManager[float](name, help, update_thread_local_hook, validate)
setattr(Config, name, property(lambda _: s.value))
return s

def define_string_state(
self, name: str, default: Optional[str], help: str,
self,
name: str,
default: Optional[str],
help: str,
update_global_hook: Optional[Callable[[str], None]] = None,
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None):
update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None,
) -> _StateContextManager[str]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``define_bool_state``.
Expand Down Expand Up @@ -436,7 +443,8 @@ def define_string_or_object_state(
help: str,
update_global_hook: Optional[Callable[[Any], None]] = None,
update_thread_local_hook: Optional[Callable[[Any], None]] = None,
validate_new_val_hook: Optional[Callable[[Any], None]] = None):
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
) -> _StateContextManager[Any]:
"""Set up thread-local state and return a contextmanager for managing it.
Similar to ``define_string_state``, except the context manager will accept
Expand Down Expand Up @@ -468,13 +476,10 @@ def define_string_or_object_state(
update_hook=update_global_hook)
self._contextmanager_flags.add(name)

def get_state(self):
val = _thread_local_state.__dict__.get(name, unset)
return val if val is not unset else self._read(name)
setattr(Config, name, property(get_state))

return _StateContextManager(name, help, update_thread_local_hook,
validate_new_val_hook)
s = _StateContextManager[Any](
name, help, update_thread_local_hook, validate_new_val_hook)
setattr(Config, name, property(lambda _: s.value))
return s

def _trace_context(self):
"""Returns a tuple of configuration values that affect tracing.
Expand Down Expand Up @@ -505,7 +510,8 @@ def _trace_context(self):
class NoDefault: pass
no_default = NoDefault()

class _StateContextManager:

class _StateContextManager(Generic[_T]):
def __init__(self, name, help, update_thread_local_hook,
validate_new_val_hook: Optional[Callable[[Any], None]] = None,
extra_description: str = "", default_value: Any = no_default):
Expand All @@ -517,6 +523,11 @@ def __init__(self, name, help, update_thread_local_hook,
self._validate_new_val_hook = validate_new_val_hook
self._default_value = default_value

@property
def value(self) -> _T:
val = _thread_local_state.__dict__.get(self._name, unset)
return val if val is not unset else config._read(self._name)

@contextlib.contextmanager
def __call__(self, new_val: Any = no_default):
if new_val is no_default:
Expand Down

0 comments on commit 923498f

Please sign in to comment.