From 21f6736005d71c47391ab88e824e70b09f3bd553 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 11 Jul 2023 12:42:32 -0700 Subject: [PATCH] Remove several deprecated APIs --- CHANGELOG.md | 16 ++++++++---- jax/__init__.py | 48 +++-------------------------------- jax/_src/interpreters/pxla.py | 11 +------- jax/interpreters/pxla.py | 28 -------------------- jax/numpy/__init__.py | 6 ----- 5 files changed, 15 insertions(+), 94 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1220be0bd9c9..43ae992d7ea4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,6 @@ Remember to align the itemized text with the first line of an item within a list https://jax.readthedocs.io/en/latest/deprecation.html * JAX now requires NumPy 1.22 or newer as per https://jax.readthedocs.io/en/latest/deprecation.html - * `jax.interpreters.pxla.device_put` has been removed. This was deprecated in - JAX version 0.4.6: use `jax.device_put` instead. - * `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was - deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays` - instead. * Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` @@ -26,6 +21,17 @@ Remember to align the itemized text with the first line of an item within a list * `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead. * `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead. * `jax.Array.split`: use {func}`jax.numpy.split` instead. + * The following APIs have been removed after previous deprecation: + * `jax.ad`: use {mod}`jax.interpreters.ad`. + * `jax.curry`: use ``curry = lambda f: partial(partial, f)``. + * `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`. + * `jax.pxla`: use {mod}`jax.interpreters.pxla`. + * `jax.xla`: use {mod}`jax.interpreters.xla`. + * `jax.ShapedArray`: use {class}`jax.core.ShapedArray`. + * `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`. + * `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`. + * `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`. + * `jax.numpy.DeviceArray`: use {class}`jax.Array`. * Breaking changes * JAX now requires ml_dtypes version 0.2.0 or newer. diff --git a/jax/__init__.py b/jax/__init__.py index 08c1f402dc77..e22eb0fe765a 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -82,7 +82,6 @@ from jax._src.api import clear_backends as clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert -from jax._src.util import curry as _deprecated_curry from jax._src.custom_derivatives import custom_gradient as custom_gradient from jax._src.custom_derivatives import custom_jvp as custom_jvp from jax._src.custom_derivatives import custom_vjp as custom_vjp @@ -95,7 +94,6 @@ from jax._src.xla_bridge import devices as devices from jax._src.api import disable_jit as disable_jit from jax._src.api import eval_shape as eval_shape -from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs from jax._src.dtypes import float0 as float0 from jax._src.api import grad as grad from jax._src.api import hessian as hessian @@ -120,19 +118,15 @@ from jax._src.xla_bridge import process_index as process_index from jax._src.callback import pure_callback_api as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat -from jax._src.core import ShapedArray as _deprecated_ShapedArray from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap from jax._src.api import xla_computation as xla_computation -from jax.interpreters import ad as _deprecated_ad -import jax.interpreters.batching -import jax.interpreters.mlir -from jax.interpreters import partial_eval as _deprecated_partial_eval -from jax.interpreters import pxla as _deprecated_pxla -from jax.interpreters import xla as _deprecated_xla +# Force import, allowing jax.interpreters.* to be used after import jax. +from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla +del ad, batching, mlir, partial_eval, pxla, xla from jax._src.array import ( make_array_from_single_device_arrays as make_array_from_single_device_arrays, @@ -191,47 +185,11 @@ "jax.abstract_arrays is deprecated. Refer to jax.core.", _deprecated_abstract_arrays ), - # Added 28 March 2023 - "ShapedArray": ( - "jax.ShapedArray is deprecated. Use jax.core.ShapedArray", - _deprecated_ShapedArray, - ), - "ad": ( - "jax.ad is deprecated. Use jax.interpreters.ad", - _deprecated_ad, - ), - "partial_eval": ( - "jax.partial_eval is deprecated. Use jax.interpreters.partial_eval", - _deprecated_partial_eval, - ), - "pxla": ( - "jax.pxla is deprecated. Use jax.interpreters.pxla", - _deprecated_pxla, - ), - "xla": ( - "jax.xla is deprecated. Use jax.interpreters.xla", - _deprecated_xla, - ), - "curry": ( - "jax.curry is deprecated. Use curry = lambda f: partial(partial, f)", - _deprecated_curry, - ), - "flatten_fun_nokwargs": ( - "jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.", - _deprecated_flatten_fun_nokwargs, - ), } import typing as _typing if _typing.TYPE_CHECKING: from jax._src import abstract_arrays as abstract_arrays - from jax._src.core import ShapedArray as ShapedArray - from jax.interpreters import ad as ad - from jax.interpreters import partial_eval as partial_eval - from jax.interpreters import pxla as pxla - from jax.interpreters import xla as xla - from jax._src.util import curry as curry - from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f27890750012..782031830644 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -24,7 +24,7 @@ import logging import math from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union, - Iterable, TYPE_CHECKING, cast, TypeVar) + Iterable, cast, TypeVar) import numpy as np @@ -331,15 +331,6 @@ def make_sharded_device_array( aval.shape, sharding, device_buffers) # type: ignore -if TYPE_CHECKING: - ShardedDeviceArray = Any -else: - class ShardedDeviceArray(object): - def __init__(self): - raise RuntimeError("ShardedDeviceArray is a backward compatibility shim " - "and cannot be instantiated.") - - def _hashable_index(idx): return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index e23c743ad8c2..7d26831e93bd 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -110,31 +110,3 @@ sharding_spec_sharding_proto as sharding_spec_sharding_proto, spec_to_indices as spec_to_indices, ) - -# Deprecations - -from jax._src.interpreters.pxla import ( - ShardedDeviceArray as _deprecated_ShardedDeviceArray, -) - -_deprecations = { - # Added March 15, 2023: - "ShardedDeviceArray": ( - ( - "jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use " - "jax.Array." - ), - _deprecated_ShardedDeviceArray, - ), -} - -import typing -if typing.TYPE_CHECKING: - from jax._src.interpreters.pxla import ( - ShardedDeviceArray as ShardedDeviceArray, - ) -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index a5b449a2690c..922ed69642c8 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -425,11 +425,6 @@ # Deprecations _deprecations = { - # Added March 14, 2023: - "DeviceArray": ( - "jax.numpy.DeviceArray is deprecated. Use jax.Array.", - ndarray, - ), # Added June 2, 2023: "alltrue": ( "jax.numpy.alltrue is deprecated. Use jax.numpy.all", @@ -451,7 +446,6 @@ import typing if typing.TYPE_CHECKING: - from jax._src.basearray import Array as DeviceArray alltrue = all cumproduct = cumprod product = prod