Skip to content

Commit

Permalink
Replace uses of jax.partial with functools.partial, in preparation fo…
Browse files Browse the repository at this point in the history
…r removing jax.partial.

jax.partial is an alias for functools.partial, and functools.partial is a Python standard library API. There's no need for jax to export this function.

PiperOrigin-RevId: 396370975
  • Loading branch information
hawkinsp authored and jax authors committed Sep 13, 2021
1 parent c9915fe commit 80599c0
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
7 changes: 4 additions & 3 deletions jax/_src/lax/polar.py
Expand Up @@ -20,6 +20,7 @@
is described in the docstring to `polarU`. This file covers the serial
case.
"""
import functools
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -89,7 +90,7 @@ def polar(a, side='right', method='qdwh', eps=None, maxiter=50):
return _polar(a, side, method, eps, maxiter)


@jax.partial(jax.jit, static_argnums=(1, 2, 4))
@functools.partial(jax.jit, static_argnums=(1, 2, 4))
def _polar(a, side, method, eps, maxiter):
if side not in ("left", "right"):
raise ValueError(f"side={side} was invalid.")
Expand All @@ -110,7 +111,7 @@ def polar_unitary(a, method="qdwh", eps=None, maxiter=50):
return _polar_unitary(a, method, eps, maxiter)


@jax.partial(jax.jit, static_argnums=(1, 3))
@functools.partial(jax.jit, static_argnums=(1, 3))
def _polar_unitary(a, method, eps, maxiter):
if method not in ("svd", "qdwh"):
raise ValueError(f"method={method} is unsupported.")
Expand All @@ -127,7 +128,7 @@ def _polar_unitary(a, method, eps, maxiter):
return unitary, info


@jax.partial(jax.jit, static_argnums=(2,))
@functools.partial(jax.jit, static_argnums=(2,))
def _qdwh(matrix, eps, maxiter):
""" Computes the unitary factor in the polar decomposition of A using
the QDWH method. QDWH implements a 3rd order Pade approximation to the
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/scipy/eigh.py
Expand Up @@ -15,6 +15,7 @@

"""Serial algorithm for eigh."""

import functools
import jax
import jax.numpy as jnp
import jax.scipy as jsp
Expand Down Expand Up @@ -87,7 +88,7 @@ def body_f(args):
return V1, V2


@jax.partial(jax.jit, static_argnums=(3, 4))
@functools.partial(jax.jit, static_argnums=(3, 4))
def _split_spectrum_jittable(P, H, V0, rank, precision):
""" The jittable portion of `split_spectrum`. At this point the sizes of the
relevant matrix blocks have been concretized.
Expand Down
6 changes: 4 additions & 2 deletions jax/experimental/sparse/__init__.py
Expand Up @@ -141,6 +141,7 @@
As an example of a more complicated sparse workflow, let's consider a simple logistic regression
implemented in JAX. Notice that the following implementation has no reference to sparsity:
>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
Expand All @@ -156,7 +157,8 @@
...
>>> def fit_logreg(X, y):
... params = jnp.zeros(X.shape[1] + 1)
... result = optimize.minimize(jax.partial(loss, X=X, y=y), x0=params, method='BFGS')
... result = optimize.minimize(functools.partial(loss, X=X, y=y),
... x0=params, method='BFGS')
... return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
Expand Down Expand Up @@ -221,4 +223,4 @@
BCOO,
)

from .transform import sparsify
from .transform import sparsify
2 changes: 1 addition & 1 deletion tests/x64_context_test.py
Expand Up @@ -14,14 +14,14 @@


import concurrent.futures
from functools import partial
import time

from absl.testing import absltest
from absl.testing import parameterized

from jax._src import api
from jax import lax
from jax import partial
from jax import random
from jax.config import config
from jax.experimental import enable_x64, disable_x64
Expand Down

0 comments on commit 80599c0

Please sign in to comment.