diff --git a/jax/config.py b/jax/config.py index 947c5abf8230..9435308d157f 100644 --- a/jax/config.py +++ b/jax/config.py @@ -12,6 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO(phawkins): fix users of this alias and delete this file. +from jax._src.config import config as _deprecated_config # noqa: F401 -from jax._src.config import config # noqa: F401 +# Deprecations + +_deprecations = { + # Added October 27, 2023 + "config": ( + "Accessing jax.config via the jax.config submodule is deprecated.", + _deprecated_config), +} + +import typing +if typing.TYPE_CHECKING: + config = _deprecated_config +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _deprecated_config