Skip to content

Commit

Permalink
Drop dependence on typing_extensions since typing Protocols are avail…
Browse files Browse the repository at this point in the history
…able in python >= 3.8

PiperOrigin-RevId: 568377349
  • Loading branch information
romanngg committed Sep 26, 2023
1 parent d839deb commit cca7385
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 5 deletions.
3 changes: 1 addition & 2 deletions neural_tangents/_src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import collections
from functools import lru_cache
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Any, Union
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Any, Union, Protocol

import jax
from jax import grad
Expand All @@ -40,7 +40,6 @@
from jax.tree_util import tree_all, tree_map
import numpy as np
import scipy as sp
from typing_extensions import Protocol
from .utils import dataclasses, utils
from .utils.typing import Axes, Get, KernelFn

Expand Down
3 changes: 1 addition & 2 deletions neural_tangents/_src/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

"""Common Type Definitions."""

from typing import Any, Generator, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Protocol
from typing import Any, Generator, Optional, Sequence, TYPE_CHECKING, TypeVar, Union, Protocol

from jax import random
import jax.numpy as jnp
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
INSTALL_REQUIRES = [
'jax>=0.4.14',
'frozendict>=2.3.8',
'typing_extensions>=4.7.1',
'tf2jax>=0.3.5',
]

Expand Down

0 comments on commit cca7385

Please sign in to comment.