Skip to content

Commit

Permalink
Remove jax.api.
Browse files Browse the repository at this point in the history
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
  • Loading branch information
hawkinsp committed Sep 16, 2021
1 parent 75f941b commit 6a1b626
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 74 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.20...main).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.lax.partial` was an accidental export that has now been removed. Use
`functools.partial` instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/api_benchmark.py
Expand Up @@ -175,7 +175,7 @@ def jit_dispatch_without_transfer(state):
imgs = np.ones((128, 224, 224), np.float32)
imgs = jax.device_put(imgs)

f = jax.api.jit(lambda x: x+1)
f = jax.jit(lambda x: x+1)
f(imgs)

while state:
Expand All @@ -186,7 +186,7 @@ def jit_dispatch_without_transfer(state):
def jit_dispatch_with_transfer(state):
imgs = np.ones((128, 224, 224), np.float32)

f = jax.api.jit(lambda x: x+1)
f = jax.jit(lambda x: x+1)
f(imgs).block_until_ready()

while state:
Expand Down
1 change: 0 additions & 1 deletion jax/__init__.py
Expand Up @@ -119,7 +119,6 @@
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from . import abstract_arrays as abstract_arrays
from . import api as api
from . import api_util as api_util
from . import dtypes as dtypes
from . import errors as errors
Expand Down
56 changes: 0 additions & 56 deletions jax/api.py

This file was deleted.

4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -1783,7 +1783,7 @@ def linear_solve(a, b, solve, transpose_solve=None, symmetric=False):
return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric)

def explicit_jacobian_solve(matvec, b):
return lax.stop_gradient(jnp.linalg.solve(jax.api.jacobian(matvec)(b), b))
return lax.stop_gradient(jnp.linalg.solve(jax.jacobian(matvec)(b), b))

def _make_harness(name,
*,
Expand Down Expand Up @@ -2079,7 +2079,7 @@ def _make_select_and_scatter_add_harness(name,
padding=((0, 0), (0, 0), (0, 0)),
nb_inactive_dims=0):
ones = (1,) * len(shape)
cotangent_shape = jax.api.eval_shape(
cotangent_shape = jax.eval_shape(
lambda x: lax._select_and_gather_add(x, x, lax.ge_p, window_dimensions,
window_strides, padding, ones, ones),
np.ones(shape, dtype)).shape
Expand Down
22 changes: 11 additions & 11 deletions jax/experimental/sparse/ops.py
Expand Up @@ -34,7 +34,7 @@

from typing import Any, Sequence, Tuple

from jax import api
import jax
from jax import core
from jax import dtypes
from jax import jit
Expand Down Expand Up @@ -1436,15 +1436,15 @@ def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
nse = (mat != 0).sum()
return cls(csr_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)

@api.jit
@jax.jit
def todense(self):
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape)

@api.jit
@jax.jit
def matvec(self, v):
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape)

@api.jit
@jax.jit
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape)

Expand Down Expand Up @@ -1475,15 +1475,15 @@ def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
nse = (mat != 0).sum()
return cls(csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype), shape=mat.shape)

@api.jit
@jax.jit
def todense(self):
return csr_todense(self.data, self.indices, self.indptr, shape=self.shape[::-1]).T

@api.jit
@jax.jit
def matvec(self, v):
return csr_matvec(self.data, self.indices, self.indptr, v, shape=self.shape[::-1], transpose=True)

@api.jit
@jax.jit
def matmat(self, B):
return csr_matmat(self.data, self.indices, self.indptr, B, shape=self.shape[::-1], transpose=True)

Expand Down Expand Up @@ -1514,15 +1514,15 @@ def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
nse = (mat != 0).sum()
return cls(coo_fromdense(mat, nse=nse, index_dtype=index_dtype), shape=mat.shape)

@api.jit
@jax.jit
def todense(self):
return coo_todense(self.data, self.row, self.col, shape=self.shape)

@api.jit
@jax.jit
def matvec(self, v):
return coo_matvec(self.data, self.row, self.col, v, shape=self.shape)

@api.jit
@jax.jit
def matmat(self, B):
return coo_matmat(self.data, self.row, self.col, B, shape=self.shape)

Expand Down Expand Up @@ -1635,7 +1635,7 @@ def _dedupe(self):
"""Return a de-duplicated representation of the BCOO matrix."""
return BCOO(_dedupe_bcoo(self.data, self.indices, self.shape), shape=self.shape)

@api.jit
@jax.jit
def todense(self):
"""Create a dense version of the array."""
return bcoo_todense(self.data, self.indices, shape=self.shape)
Expand Down
4 changes: 2 additions & 2 deletions jax/tools/jax_to_hlo.py
Expand Up @@ -70,7 +70,7 @@ def fn(x, y, z):

from absl import app
from absl import flags
import jax.api
import jax
import jax.numpy as jnp
from jax.lib import xla_client

Expand Down Expand Up @@ -130,7 +130,7 @@ def ordered_wrapper(*args):
arg_names = [arg_name for arg_name, _ in input_shapes]
return fn_curried(**dict(zip(arg_names, args)))

comp = jax.api.xla_computation(ordered_wrapper)(*args)
comp = jax.xla_computation(ordered_wrapper)(*args)
return (comp.as_serialized_hlo_module_proto(), comp.as_hlo_text())


Expand Down

0 comments on commit 6a1b626

Please sign in to comment.