Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] Use ParamSpec in JIT annotation #14688

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@

import collections
from collections.abc import Generator, Hashable, Iterable, Sequence
from contextlib import contextmanager, ExitStack
from functools import partial
import inspect
import math
import typing
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, overload,
cast)
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, cast,
overload, TYPE_CHECKING)
import weakref

import numpy as np
from contextlib import contextmanager, ExitStack
if TYPE_CHECKING:
from typing_extensions import ParamSpec

from jax._src import linear_util as lu
from jax._src import stages
Expand Down Expand Up @@ -96,6 +97,12 @@
F = TypeVar("F", bound=Callable)
T = TypeVar("T")
U = TypeVar("U")
V_co = TypeVar("V_co", covariant=True)
if TYPE_CHECKING:
P = ParamSpec("P")
else:
P = TypeVar("P")


map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -140,7 +147,7 @@ def _update_debug_special_thread_local(_):


def jit(
fun: Callable,
fun: Callable[P, V_co],
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
Expand All @@ -152,7 +159,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
) -> pjit.JitWrapped:
) -> pjit.JitWrapped[P, V_co]:
"""Sets up ``fun`` for just-in-time compilation with XLA.

Args:
Expand Down Expand Up @@ -1796,7 +1803,7 @@ def cache_miss(*args, **kwargs):
### Decide whether we can support the C++ fast path
use_fastpath = False
if execute is not None and isinstance(execute, pxla.ExecuteReplicated):
execute_replicated = typing.cast(pxla.ExecuteReplicated, execute)
execute_replicated = cast(pxla.ExecuteReplicated, execute)
use_fastpath = (
# TODO(sharadmv): Enable effects in replicated computation
not execute_replicated.has_unordered_effects
Expand All @@ -1806,7 +1813,7 @@ def cache_miss(*args, **kwargs):

### If we can use the fastpath, we return required info to the caller.
if use_fastpath:
execute_replicated = typing.cast(pxla.ExecuteReplicated, execute)
execute_replicated = cast(pxla.ExecuteReplicated, execute)
out_handler = execute_replicated.out_handler
in_handler = execute_replicated.in_handler

Expand Down
15 changes: 13 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import logging
import operator as op
import weakref
from typing import Callable, cast, NamedTuple, Any, Union, Optional
from typing import (cast, Callable, NamedTuple, Any, Generic, Optional,
TypeVar, TYPE_CHECKING, Union)
import threading
import warnings

import numpy as np
if TYPE_CHECKING:
from typing_extensions import ParamSpec

from jax._src import api
from jax._src import api_util
Expand Down Expand Up @@ -653,7 +656,15 @@ def ax_leaf(l):
return broadcast_prefix(abstracted_axes, args, ax_leaf)


class JitWrapped(stages.Wrapped):

V_co = TypeVar("V_co", covariant=True)
if TYPE_CHECKING:
P = ParamSpec("P")
else:
P = TypeVar("P")


class JitWrapped(stages.Wrapped[P, V_co], Generic[P, V_co]):

def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
Expand Down
17 changes: 14 additions & 3 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union
from typing import (Any, Generic, NamedTuple, Protocol, TypeVar, Union,
TYPE_CHECKING)
import warnings

if TYPE_CHECKING:
from typing_extensions import ParamSpec

import jax

from jax._src import core
Expand Down Expand Up @@ -707,7 +711,14 @@ def cost_analysis(self) -> Any | None:
return None


class Wrapped(Protocol):
V_co = TypeVar("V_co", covariant=True)
if TYPE_CHECKING:
P = ParamSpec("P")
else:
P = TypeVar("P")


class Wrapped(Protocol, Generic[P, V_co]):
"""A function ready to be specialized, lowered, and compiled.

This protocol reflects the output of functions such as
Expand All @@ -716,7 +727,7 @@ class Wrapped(Protocol):
to compilation, and the result compiled prior to execution.
"""

def __call__(self, *args, **kwargs):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> V_co:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def generate_proto(source):
'importlib_metadata>=4.6;python_version<"3.10"',
],
extras_require={
'dev': ['typing_extensions>=4.8.0'],
# Minimum jaxlib version; used in testing.
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],

Expand Down