Skip to content

Commit

Permalink
Remove several deprecated APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jul 11, 2023
1 parent a29d4bc commit 21f6736
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 94 deletions.
16 changes: 11 additions & 5 deletions CHANGELOG.md
Expand Up @@ -13,11 +13,6 @@ Remember to align the itemized text with the first line of an item within a list
https://jax.readthedocs.io/en/latest/deprecation.html
* JAX now requires NumPy 1.22 or newer as per
https://jax.readthedocs.io/en/latest/deprecation.html
* `jax.interpreters.pxla.device_put` has been removed. This was deprecated in
JAX version 0.4.6: use `jax.device_put` instead.
* `jax.interpreters.pxla.make_sharded_device_array` has been removed. This was
deprecated in JAX version 0.4.6: use `jax.make_array_from_single_device_arrays`
instead.
* Passing optional arguments to {func}`jax.numpy.ndarray.at` by position is
no longer supported, after being deprecated in JAX version 0.4.7.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
Expand All @@ -26,6 +21,17 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.
* The following APIs have been removed after previous deprecation:
* `jax.ad`: use {mod}`jax.interpreters.ad`.
* `jax.curry`: use ``curry = lambda f: partial(partial, f)``.
* `jax.partial_eval`: use {mod}`jax.interpreters.partial_eval`.
* `jax.pxla`: use {mod}`jax.interpreters.pxla`.
* `jax.xla`: use {mod}`jax.interpreters.xla`.
* `jax.ShapedArray`: use {class}`jax.core.ShapedArray`.
* `jax.interpreters.pxla.device_put`: use {func}`jax.device_put`.
* `jax.interpreters.pxla.make_sharded_device_array`: use {func}`jax.make_array_from_single_device_arrays`.
* `jax.interpreters.pxla.ShardedDeviceArray`: use {class}`jax.Array`.
* `jax.numpy.DeviceArray`: use {class}`jax.Array`.

* Breaking changes
* JAX now requires ml_dtypes version 0.2.0 or newer.
Expand Down
48 changes: 3 additions & 45 deletions jax/__init__.py
Expand Up @@ -82,7 +82,6 @@
from jax._src.api import clear_backends as clear_backends
from jax._src.api import clear_caches as clear_caches
from jax._src.custom_derivatives import closure_convert as closure_convert
from jax._src.util import curry as _deprecated_curry
from jax._src.custom_derivatives import custom_gradient as custom_gradient
from jax._src.custom_derivatives import custom_jvp as custom_jvp
from jax._src.custom_derivatives import custom_vjp as custom_vjp
Expand All @@ -95,7 +94,6 @@
from jax._src.xla_bridge import devices as devices
from jax._src.api import disable_jit as disable_jit
from jax._src.api import eval_shape as eval_shape
from jax._src.api_util import flatten_fun_nokwargs as _deprecated_flatten_fun_nokwargs
from jax._src.dtypes import float0 as float0
from jax._src.api import grad as grad
from jax._src.api import hessian as hessian
Expand All @@ -120,19 +118,15 @@
from jax._src.xla_bridge import process_index as process_index
from jax._src.callback import pure_callback_api as pure_callback
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
from jax._src.core import ShapedArray as _deprecated_ShapedArray
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.api import value_and_grad as value_and_grad
from jax._src.api import vjp as vjp
from jax._src.api import vmap as vmap
from jax._src.api import xla_computation as xla_computation

from jax.interpreters import ad as _deprecated_ad
import jax.interpreters.batching
import jax.interpreters.mlir
from jax.interpreters import partial_eval as _deprecated_partial_eval
from jax.interpreters import pxla as _deprecated_pxla
from jax.interpreters import xla as _deprecated_xla
# Force import, allowing jax.interpreters.* to be used after import jax.
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla
del ad, batching, mlir, partial_eval, pxla, xla

from jax._src.array import (
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
Expand Down Expand Up @@ -191,47 +185,11 @@
"jax.abstract_arrays is deprecated. Refer to jax.core.",
_deprecated_abstract_arrays
),
# Added 28 March 2023
"ShapedArray": (
"jax.ShapedArray is deprecated. Use jax.core.ShapedArray",
_deprecated_ShapedArray,
),
"ad": (
"jax.ad is deprecated. Use jax.interpreters.ad",
_deprecated_ad,
),
"partial_eval": (
"jax.partial_eval is deprecated. Use jax.interpreters.partial_eval",
_deprecated_partial_eval,
),
"pxla": (
"jax.pxla is deprecated. Use jax.interpreters.pxla",
_deprecated_pxla,
),
"xla": (
"jax.xla is deprecated. Use jax.interpreters.xla",
_deprecated_xla,
),
"curry": (
"jax.curry is deprecated. Use curry = lambda f: partial(partial, f)",
_deprecated_curry,
),
"flatten_fun_nokwargs": (
"jax.flatten_fun_nokwargs is deprecated. Use jax.api_util.flatten_fun_nokwargs.",
_deprecated_flatten_fun_nokwargs,
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
from jax._src import abstract_arrays as abstract_arrays
from jax._src.core import ShapedArray as ShapedArray
from jax.interpreters import ad as ad
from jax.interpreters import partial_eval as partial_eval
from jax.interpreters import pxla as pxla
from jax.interpreters import xla as xla
from jax._src.util import curry as curry
from jax._src.api_util import flatten_fun_nokwargs as flatten_fun_nokwargs
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
Expand Down
11 changes: 1 addition & 10 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -24,7 +24,7 @@
import logging
import math
from typing import (Any, Callable, NamedTuple, Optional, Sequence, Union,
Iterable, TYPE_CHECKING, cast, TypeVar)
Iterable, cast, TypeVar)

import numpy as np

Expand Down Expand Up @@ -331,15 +331,6 @@ def make_sharded_device_array(
aval.shape, sharding, device_buffers) # type: ignore


if TYPE_CHECKING:
ShardedDeviceArray = Any
else:
class ShardedDeviceArray(object):
def __init__(self):
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
"and cannot be instantiated.")


def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)

Expand Down
28 changes: 0 additions & 28 deletions jax/interpreters/pxla.py
Expand Up @@ -110,31 +110,3 @@
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
spec_to_indices as spec_to_indices,
)

# Deprecations

from jax._src.interpreters.pxla import (
ShardedDeviceArray as _deprecated_ShardedDeviceArray,
)

_deprecations = {
# Added March 15, 2023:
"ShardedDeviceArray": (
(
"jax.interpreters.pxla.ShardedDeviceArray is deprecated. Use "
"jax.Array."
),
_deprecated_ShardedDeviceArray,
),
}

import typing
if typing.TYPE_CHECKING:
from jax._src.interpreters.pxla import (
ShardedDeviceArray as ShardedDeviceArray,
)
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
6 changes: 0 additions & 6 deletions jax/numpy/__init__.py
Expand Up @@ -425,11 +425,6 @@
# Deprecations

_deprecations = {
# Added March 14, 2023:
"DeviceArray": (
"jax.numpy.DeviceArray is deprecated. Use jax.Array.",
ndarray,
),
# Added June 2, 2023:
"alltrue": (
"jax.numpy.alltrue is deprecated. Use jax.numpy.all",
Expand All @@ -451,7 +446,6 @@

import typing
if typing.TYPE_CHECKING:
from jax._src.basearray import Array as DeviceArray
alltrue = all
cumproduct = cumprod
product = prod
Expand Down

0 comments on commit 21f6736

Please sign in to comment.