Skip to content

Commit

Permalink
Improve pytype checking of XLA types inside JAX.
Browse files Browse the repository at this point in the history
Add an explicit `.pyi` file for jax/_src/lib/__init__.pyi, which works around a bug in pytype where the types of modules that are re-exported becomes `Any`.

[XLA:Python] Fix type declaration for sharding specs.

PiperOrigin-RevId: 404313123
  • Loading branch information
hawkinsp authored and jax authors committed Oct 19, 2021
1 parent e783cbc commit 8f0bfcb
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 17 deletions.
24 changes: 12 additions & 12 deletions jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
'xla_extension',
]

# First, before attempting to import jaxlib, warn about experimental machine
# First, before attempting to from jax import jaxlib, warn about experimental machine
# configurations.
if platform.system() == "Darwin" and platform.machine() == "arm64":
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
"Please see https://github.com/google/jax/issues/5501 in the "
"event of problems.")

try:
import jaxlib
import jaxlib as jaxlib
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
'jax requires jaxlib to be installed. See '
Expand All @@ -43,7 +43,7 @@

from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
try:
from jaxlib import version as jaxlib_version
import jaxlib.version as jaxlib_version
except Exception as err:
# jaxlib is too old to have version number.
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
Expand All @@ -68,40 +68,40 @@ def _check_jaxlib_version():

_check_jaxlib_version()

from jaxlib import cpu_feature_guard
import jaxlib.cpu_feature_guard as cpu_feature_guard
cpu_feature_guard.check_cpu_features()

from jaxlib import xla_client
from jaxlib import lapack
from jaxlib import pocketfft
import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack
import jaxlib.pocketfft as pocketfft

xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib

try:
from jaxlib import cusolver # pytype: disable=import-error
import jaxlib.cusolver as cusolver # pytype: disable=import-error
except ImportError:
cusolver = None

try:
from jaxlib import cusparse # pytype: disable=import-error
import jaxlib.cusparse as cusparse # pytype: disable=import-error
except ImportError:
cusparse = None

try:
from jaxlib import rocsolver # pytype: disable=import-error
import jaxlib.rocsolver as rocsolver # pytype: disable=import-error
except ImportError:
rocsolver = None

try:
from jaxlib import cuda_prng # pytype: disable=import-error
import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error
except ImportError:
cuda_prng = None

try:
from jaxlib import cuda_linalg # pytype: disable=import-error
import jaxlib.cuda_linalg as cuda_linalg # pytype: disable=import-error
except ImportError:
cuda_linalg = None

Expand Down
39 changes: 39 additions & 0 deletions jax/_src/lib/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This .pyi file exists primarily to help pytype infer the types of reexported
# modules from jaxlib. Without an explicit type stub, many types become Any.
# (Google pytype bug b/192059119).

from typing import Any, Optional, Tuple

import jaxlib.lapack as lapack
import jaxlib.pocketfft as pocketfft
import jaxlib.xla_client as xla_client
import jaxlib.xla_extension as xla_extension
import jaxlib.xla_extension.jax_jit as jax_jit
import jaxlib.xla_extension.pmap_lib as pmap_lib
import jaxlib.xla_extension.pytree as pytree

version: Tuple[int, ...]

cuda_path: Optional[str]
cuda_linalg: Optional[Any]
cuda_prng: Optional[Any]
cusolver: Optional[Any]
cusparse: Optional[Any]
rocsolver: Optional[Any]
tpu_driver_client: Optional[Any]

_xla_extension_version: int
10 changes: 7 additions & 3 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,13 @@ class WeakRefList(list):
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
MeshDimAssignment = Union[ShardedAxis, Replicated]

# https://mypy.readthedocs.io/en/stable/runtime_troubles.html#typing-type-checking
# mypy will consider this constant to be True at type check time.
MYPY = False

# TODO(jblespiau): Remove the version check when jaxlib 0.1.70 is the minimal
# version.
if TYPE_CHECKING or _xla_extension_version < 30:
if MYPY or (not TYPE_CHECKING and _xla_extension_version < 30):
class ShardingSpec:
"""Describes the sharding of an ndarray.
Expand Down Expand Up @@ -473,8 +477,8 @@ def make_sharded_device_array(
"""
if sharding_spec is None:
sharded_aval = aval.update(shape=aval.shape[1:])
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
1, None, sharded_aval, 0)
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], 1, None,
sharded_aval, 0)

if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
Expand Down
6 changes: 4 additions & 2 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,9 +1325,11 @@ def _tuple_output(*args, **kwargs):
yield (ans,)

def lower_fun(fun: Callable, *, multiple_results: bool, parallel: bool = False,
backend=None, new_style: bool = False):
backend=None, new_style: bool = False) -> Callable:
if new_style:
def f_new(ctx, avals_in, avals_out, *xla_args, **params):
def f_new(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue], *xla_args: xc.XlaOp,
**params) -> Sequence[xc.XlaOp]:
wrapped_fun = lu.wrap_init(fun, params)
if not multiple_results:
wrapped_fun = _tuple_output(wrapped_fun)
Expand Down

0 comments on commit 8f0bfcb

Please sign in to comment.