Skip to content

Commit

Permalink
Make typing_extensions a dev-dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Feb 26, 2023
1 parent 353bba1 commit c877b16
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
8 changes: 5 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import inspect
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union,
overload)
overload, TYPE_CHECKING)

import numpy as np
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from typing_extensions import ParamSpec

import jax
from jax._src import linear_util as lu
Expand Down Expand Up @@ -108,7 +109,8 @@
T = TypeVar("T")
U = TypeVar("U")
V_co = TypeVar("V_co", covariant=True)
P = ParamSpec("P")
if TYPE_CHECKING:
P = ParamSpec("P")


map, unsafe_map = safe_map, map
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@

from dataclasses import dataclass
from typing import (Any, Dict, Generic, List, NamedTuple, Optional, Protocol,
Sequence, Tuple, TypeVar)
from typing_extensions import ParamSpec
Sequence, Tuple, TypeVar, TYPE_CHECKING)
if TYPE_CHECKING:
from typing_extensions import ParamSpec

import jax
from jax import tree_util
Expand Down Expand Up @@ -620,7 +621,8 @@ def compiler_ir(self, dialect: Optional[str] = None) -> Optional[Any]:


V_co = TypeVar("V_co", covariant=True)
P = ParamSpec("P")
if TYPE_CHECKING:
P = ParamSpec("P")


class Wrapped(Protocol, Generic[P, V_co]):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def generate_proto(source):
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',
'typing_extensions>=4.5.0',
],
extras_require={
'dev': ['typing_extensions>=4.5.0'],
# Minimum jaxlib version; used in testing.
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],

Expand Down

0 comments on commit c877b16

Please sign in to comment.