From 923498fb45e82f9b1632725ed63ac49fafb8790b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Oct 2023 09:44:12 -0700 Subject: [PATCH] `_StateContextManager` now preserves the type of the value it stores. This change is a follow-on to google/jax#16866, which added an ABSL-like API for flags defined with `DEFINE_...`. Here we add a similar typed API for flags defined with `define_..._state`. See https://github.com/abseil/abseil-py/blob/37dad4d356ca9e13f1c533ad6309631b397a2b6b/absl/flags/_flagvalues.py#L1333. PiperOrigin-RevId: 570721827 --- jax/_src/config.py | 119 +++++++++++++++++++++++++-------------------- 1 file changed, 65 insertions(+), 54 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 1c9eb5be2243..50b5bda026db 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -202,11 +202,16 @@ def parse_flags_with_absl(self): already_configured_with_absl = True def define_bool_state( - self, name: str, default: bool, help: str, *, - update_global_hook: Optional[Callable[[bool], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None, - upgrade: bool = False, - extra_description: str = ""): + self, + name: str, + default: bool, + help: str, + *, + update_global_hook: Optional[Callable[[bool], None]] = None, + update_thread_local_hook: Optional[Callable[[Optional[bool]], None]] = None, + upgrade: bool = False, + extra_description: str = '', + ) -> _StateContextManager[bool]: """Set up thread-local state and return a contextmanager for managing it. This function is a convenience wrapper. It defines a flag, environment @@ -266,20 +271,21 @@ def define_bool_state( update_hook=update_global_hook) self._contextmanager_flags.add(name) - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - setattr(Config, name, property(get_state)) - - return _StateContextManager(name, help, update_thread_local_hook, - extra_description=extra_description, - default_value=True) + s = _StateContextManager[bool]( + name, help, update_thread_local_hook, + extra_description=extra_description, default_value=True) + setattr(Config, name, property(lambda _: s.value)) + return s def define_enum_state( - self, name: str, enum_values: list[str], default: Optional[str], - help: str, update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \ - = None): + self, + name: str, + enum_values: list[str], + default: Optional[str], + help: str, + update_global_hook: Optional[Callable[[str], None]] = None, + update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + ) -> _StateContextManager[str]: """Set up thread-local state and return a contextmanager for managing it. Args: name: string, converted to lowercase to define the name of the config @@ -303,24 +309,24 @@ def define_enum_state( update_hook=update_global_hook) self._contextmanager_flags.add(name) - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - setattr(Config, name, property(get_state)) - def validate(new_val): if (new_val is not None and (type(new_val) is not str or new_val not in enum_values)): raise ValueError(f"new enum value must be None or in {enum_values}, " f"got {new_val} of type {type(new_val)}.") - return _StateContextManager(name, help, update_thread_local_hook, validate) + s = _StateContextManager[str](name, help, update_thread_local_hook, validate) + setattr(Config, name, property(lambda _: s.value)) + return s def define_int_state( - self, name: str, default: Optional[int], - help: str, update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \ - = None): + self, + name: str, + default: Optional[int], + help: str, + update_global_hook: Optional[Callable[[str], None]] = None, + update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + ) -> _StateContextManager[int]: """Set up thread-local state and return a contextmanager for managing it. Args: name: string, converted to lowercase to define the name of the config @@ -343,23 +349,23 @@ def define_int_state( self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - setattr(Config, name, property(get_state)) - def validate(new_val): if new_val is not None and not isinstance(new_val, int): raise ValueError(f'new int config value must be None or of type int, ' f'got {new_val} of type {type(new_val)}') - return _StateContextManager(name, help, update_thread_local_hook, validate) + s = _StateContextManager[int](name, help, update_thread_local_hook, validate) + setattr(Config, name, property(lambda _: s.value)) + return s def define_float_state( - self, name: str, default: Optional[float], - help: str, update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] \ - = None): + self, + name: str, + default: Optional[float], + help: str, + update_global_hook: Optional[Callable[[str], None]] = None, + update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + ) -> _StateContextManager[float]: """Set up thread-local state and return a contextmanager for managing it. Args: name: string, converted to lowercase to define the name of the config @@ -382,22 +388,23 @@ def define_float_state( self.DEFINE_float(name, default, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - setattr(Config, name, property(get_state)) - def validate(new_val): if new_val is not None and not isinstance(new_val, (float, int)): raise ValueError(f'new float config value must be None or of type float, ' f'got {new_val} of type {type(new_val)}') - return _StateContextManager(name, help, update_thread_local_hook, validate) + s = _StateContextManager[float](name, help, update_thread_local_hook, validate) + setattr(Config, name, property(lambda _: s.value)) + return s def define_string_state( - self, name: str, default: Optional[str], help: str, + self, + name: str, + default: Optional[str], + help: str, update_global_hook: Optional[Callable[[str], None]] = None, - update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None): + update_thread_local_hook: Optional[Callable[[Optional[str]], None]] = None, + ) -> _StateContextManager[str]: """Set up thread-local state and return a contextmanager for managing it. See docstring for ``define_bool_state``. @@ -436,7 +443,8 @@ def define_string_or_object_state( help: str, update_global_hook: Optional[Callable[[Any], None]] = None, update_thread_local_hook: Optional[Callable[[Any], None]] = None, - validate_new_val_hook: Optional[Callable[[Any], None]] = None): + validate_new_val_hook: Optional[Callable[[Any], None]] = None, + ) -> _StateContextManager[Any]: """Set up thread-local state and return a contextmanager for managing it. Similar to ``define_string_state``, except the context manager will accept @@ -468,13 +476,10 @@ def define_string_or_object_state( update_hook=update_global_hook) self._contextmanager_flags.add(name) - def get_state(self): - val = _thread_local_state.__dict__.get(name, unset) - return val if val is not unset else self._read(name) - setattr(Config, name, property(get_state)) - - return _StateContextManager(name, help, update_thread_local_hook, - validate_new_val_hook) + s = _StateContextManager[Any]( + name, help, update_thread_local_hook, validate_new_val_hook) + setattr(Config, name, property(lambda _: s.value)) + return s def _trace_context(self): """Returns a tuple of configuration values that affect tracing. @@ -505,7 +510,8 @@ def _trace_context(self): class NoDefault: pass no_default = NoDefault() -class _StateContextManager: + +class _StateContextManager(Generic[_T]): def __init__(self, name, help, update_thread_local_hook, validate_new_val_hook: Optional[Callable[[Any], None]] = None, extra_description: str = "", default_value: Any = no_default): @@ -517,6 +523,11 @@ def __init__(self, name, help, update_thread_local_hook, self._validate_new_val_hook = validate_new_val_hook self._default_value = default_value + @property + def value(self) -> _T: + val = _thread_local_state.__dict__.get(self._name, unset) + return val if val is not unset else config._read(self._name) + @contextlib.contextmanager def __call__(self, new_val: Any = no_default): if new_val is no_default: