From dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20K=C3=B6ppe?= Date: Mon, 19 Feb 2024 06:20:27 -0800 Subject: [PATCH] Reverts b506fee9e389391efb1336bc7575dba913e75cdf PiperOrigin-RevId: 608319964 --- .github/workflows/ci-build.yaml | 2 +- CHANGELOG.md | 8 -------- jax/__init__.py | 6 ++++++ jax/_src/config.py | 32 +++++++++++++++++++++++++++++ jax/config.py | 29 +++++++++++++++++--------- tests/BUILD | 8 ++++++++ tests/config_test.py | 36 +++++++++++++++++++++++++++++++++ 7 files changed, 102 insertions(+), 19 deletions(-) create mode 100644 tests/config_test.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index bcdbcd7f91ec..e6c7455a2130 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -140,7 +140,7 @@ jobs: PY_COLORS: 1 run: | pytest -n auto --tb=short docs - pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas + pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas documentation_render: diff --git a/CHANGELOG.md b/CHANGELOG.md index bf3e5360b55a..8e9656db069a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,14 +25,6 @@ Remember to align the itemized text with the first line of an item within a list * Conversion of a non-scalar array to a Python scalar now raises an error, regardless of the size of the array. Previously a deprecation warning was raised in the case of non-scalar arrays of size 1. This follows a similar deprecation in NumPy. - * The previously deprecated configuration APIs have been removed - following a standard 3 months deprecation cycle (see {ref}`api-compatibility`). - These include - * the `jax.config.config` object and - * the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`. - * Importing the `jax.config` submodule via `import jax.config` is deprecated. - To configure JAX use `import jax` and then reference the config object - via `jax.config`. ## jaxlib 0.4.25 diff --git a/jax/__init__.py b/jax/__init__.py index 68279407fd52..ece7c7611249 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -33,6 +33,12 @@ del _warn del _cloud_tpu_init +# Confusingly there are two things named "config": the module and the class. +# We want the exported object to be the class, so we first import the module +# to make sure a later import doesn't overwrite the class. +from jax import config as _config_module +del _config_module + # Force early import, allowing use of `jax.core` after importing `jax`. import jax.core as _core del _core diff --git a/jax/_src/config.py b/jax/_src/config.py index 0b625c8d6f92..f3af0171f929 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -23,6 +23,7 @@ import sys import threading from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast +import warnings from jax._src import lib from jax._src.lib import jax_jit @@ -69,6 +70,23 @@ def int_env(varname: str, default: int) -> int: UPGRADE_BOOL_EXTRA_DESC = " (transient)" +_CONFIG_DEPRECATIONS = { + # Added October 26, 2023: + "check_exists", + "DEFINE_bool", + "DEFINE_integer", + "DEFINE_float", + "DEFINE_string", + "DEFINE_enum", + "define_bool_state", + "define_enum_state", + "define_int_state", + "define_float_state", + "define_string_state", + "define_string_or_object_state", +} + + class Config: _HAS_DYNAMIC_ATTRIBUTES = True @@ -82,6 +100,20 @@ def __init__(self): self.use_absl = False self._contextmanager_flags = set() + def __getattr__(self, name): + fn = None + if name in _CONFIG_DEPRECATIONS: + fn = globals().get(name, None) + if fn is None: + raise AttributeError( + f"'{type(self).__name__!r} object has no attribute {name!r}") + message = ( + f"jax.config.{name} is deprecated. Please use other libraries " + "for configuration instead." + ) + warnings.warn(message, DeprecationWarning, stacklevel=2) + return fn + def update(self, name, val): if name not in self._value_holders: raise AttributeError(f"Unrecognized config option: {name}") diff --git a/jax/config.py b/jax/config.py index 763fe6c0fde9..9435308d157f 100644 --- a/jax/config.py +++ b/jax/config.py @@ -12,14 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +from jax._src.config import config as _deprecated_config # noqa: F401 -# Added February 16, 2024. -warnings.warn( - "Importing the jax.config submodule via `import jax.config` is deprecated." - " To configure JAX use `import jax` and then reference the config object" - " via `jax.config`.", - DeprecationWarning, - stacklevel=2, -) -del warnings +# Deprecations + +_deprecations = { + # Added October 27, 2023 + "config": ( + "Accessing jax.config via the jax.config submodule is deprecated.", + _deprecated_config), +} + +import typing +if typing.TYPE_CHECKING: + config = _deprecated_config +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _deprecated_config diff --git a/tests/BUILD b/tests/BUILD index 783dfdf52c71..8765c831ae7f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -73,6 +73,14 @@ jax_test( }, ) +py_test( + name = "config_test", + srcs = ["config_test.py"], + deps = [ + "//jax", + ], +) + jax_test( name = "core_test", srcs = ["core_test.py"], diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 000000000000..b801f1d7c5fd --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,36 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from jax import config + +class ConfigTest(unittest.TestCase): + def test_deprecations(self): + for name in ["DEFINE_bool", "define_bool_state"]: + with ( + self.subTest(name), + self.assertWarnsRegex( + DeprecationWarning, + "other libraries for configuration"), + ): + getattr(config, name) + + def test_missing_attribute(self): + with self.assertRaises(AttributeError): + config.missing_attribute + + +if __name__ == '__main__': + unittest.main()