Skip to content

Commit

Permalink
Make typing_extensions a dev-dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Mar 13, 2024
1 parent caaac77 commit 621d9c1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
9 changes: 6 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
from functools import partial
import inspect
import math
from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast, overload
from typing import (Any, Callable, Literal, NamedTuple, TypeVar, cast,
overload, TYPE_CHECKING)
import weakref

import numpy as np
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from typing_extensions import ParamSpec

from jax._src import linear_util as lu
from jax._src import stages
Expand Down Expand Up @@ -96,7 +98,8 @@
T = TypeVar("T")
U = TypeVar("U")
V_co = TypeVar("V_co", covariant=True)
P = ParamSpec("P")
if TYPE_CHECKING:
P = ParamSpec("P")


map, unsafe_map = safe_map, map
Expand Down
11 changes: 8 additions & 3 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Generic, NamedTuple, Protocol, TypeVar, Union
from typing import (Any, Generic, NamedTuple, Protocol, TypeVar, Union,
TYPE_CHECKING)
import warnings

from typing_extensions import ParamSpec
if TYPE_CHECKING:
from typing_extensions import ParamSpec

import jax

Expand Down Expand Up @@ -710,7 +712,10 @@ def cost_analysis(self) -> Any | None:


V_co = TypeVar("V_co", covariant=True)
P = ParamSpec("P")
if TYPE_CHECKING:
P = ParamSpec("P")
else:
P = TypeVar("P")


class Wrapped(Protocol, Generic[P, V_co]):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def generate_proto(source):
# Python versions < 3.10. Can be dropped when 3.10 is the minimum
# required Python version.
'importlib_metadata>=4.6;python_version<"3.10"',
'typing_extensions>=4.5.0',
],
extras_require={
'dev': ['typing_extensions>=4.8.0'],
# Minimum jaxlib version; used in testing.
'minimum-jaxlib': [f'jaxlib=={_minimum_jaxlib_version}'],

Expand Down

0 comments on commit 621d9c1

Please sign in to comment.