Skip to content

Commit

Permalink
Reverts b506fee
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608319964
  • Loading branch information
tkoeppe authored and jax authors committed Feb 19, 2024
1 parent 2101725 commit dcc65e6
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Expand Up @@ -140,7 +140,7 @@ jobs:
PY_COLORS: 1
run: |
pytest -n auto --tb=short docs
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/interpreters/mlir.py --ignore=jax/_src/iree.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas
documentation_render:
Expand Down
8 changes: 0 additions & 8 deletions CHANGELOG.md
Expand Up @@ -25,14 +25,6 @@ Remember to align the itemized text with the first line of an item within a list
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless
of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
* The previously deprecated configuration APIs have been removed
following a standard 3 months deprecation cycle (see {ref}`api-compatibility`).
These include
* the `jax.config.config` object and
* the `define_*_state` and `DEFINE_*` methods of {data}`jax.config`.
* Importing the `jax.config` submodule via `import jax.config` is deprecated.
To configure JAX use `import jax` and then reference the config object
via `jax.config`.

## jaxlib 0.4.25

Expand Down
6 changes: 6 additions & 0 deletions jax/__init__.py
Expand Up @@ -33,6 +33,12 @@
del _warn
del _cloud_tpu_init

# Confusingly there are two things named "config": the module and the class.
# We want the exported object to be the class, so we first import the module
# to make sure a later import doesn't overwrite the class.
from jax import config as _config_module
del _config_module

# Force early import, allowing use of `jax.core` after importing `jax`.
import jax.core as _core
del _core
Expand Down
32 changes: 32 additions & 0 deletions jax/_src/config.py
Expand Up @@ -23,6 +23,7 @@
import sys
import threading
from typing import Any, Callable, Generic, NamedTuple, NoReturn, TypeVar, cast
import warnings

from jax._src import lib
from jax._src.lib import jax_jit
Expand Down Expand Up @@ -69,6 +70,23 @@ def int_env(varname: str, default: int) -> int:
UPGRADE_BOOL_EXTRA_DESC = " (transient)"


_CONFIG_DEPRECATIONS = {
# Added October 26, 2023:
"check_exists",
"DEFINE_bool",
"DEFINE_integer",
"DEFINE_float",
"DEFINE_string",
"DEFINE_enum",
"define_bool_state",
"define_enum_state",
"define_int_state",
"define_float_state",
"define_string_state",
"define_string_or_object_state",
}


class Config:
_HAS_DYNAMIC_ATTRIBUTES = True

Expand All @@ -82,6 +100,20 @@ def __init__(self):
self.use_absl = False
self._contextmanager_flags = set()

def __getattr__(self, name):
fn = None
if name in _CONFIG_DEPRECATIONS:
fn = globals().get(name, None)
if fn is None:
raise AttributeError(
f"'{type(self).__name__!r} object has no attribute {name!r}")
message = (
f"jax.config.{name} is deprecated. Please use other libraries "
"for configuration instead."
)
warnings.warn(message, DeprecationWarning, stacklevel=2)
return fn

def update(self, name, val):
if name not in self._value_holders:
raise AttributeError(f"Unrecognized config option: {name}")
Expand Down
29 changes: 19 additions & 10 deletions jax/config.py
Expand Up @@ -12,14 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from jax._src.config import config as _deprecated_config # noqa: F401

# Added February 16, 2024.
warnings.warn(
"Importing the jax.config submodule via `import jax.config` is deprecated."
" To configure JAX use `import jax` and then reference the config object"
" via `jax.config`.",
DeprecationWarning,
stacklevel=2,
)
del warnings
# Deprecations

_deprecations = {
# Added October 27, 2023
"config": (
"Accessing jax.config via the jax.config submodule is deprecated.",
_deprecated_config),
}

import typing
if typing.TYPE_CHECKING:
config = _deprecated_config
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _deprecated_config
8 changes: 8 additions & 0 deletions tests/BUILD
Expand Up @@ -73,6 +73,14 @@ jax_test(
},
)

py_test(
name = "config_test",
srcs = ["config_test.py"],
deps = [
"//jax",
],
)

jax_test(
name = "core_test",
srcs = ["core_test.py"],
Expand Down
36 changes: 36 additions & 0 deletions tests/config_test.py
@@ -0,0 +1,36 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from jax import config

class ConfigTest(unittest.TestCase):
def test_deprecations(self):
for name in ["DEFINE_bool", "define_bool_state"]:
with (
self.subTest(name),
self.assertWarnsRegex(
DeprecationWarning,
"other libraries for configuration"),
):
getattr(config, name)

def test_missing_attribute(self):
with self.assertRaises(AttributeError):
config.missing_attribute


if __name__ == '__main__':
unittest.main()

0 comments on commit dcc65e6

Please sign in to comment.