Skip to content

Commit

Permalink
Use python >= 3.9 generic type annotations. Add support for 3.11, dro…
Browse files Browse the repository at this point in the history
…p 3.8 following JAX / tf2jax. Upgrade dependencies and github actions.

PiperOrigin-RevId: 555585732
  • Loading branch information
romanngg committed Aug 11, 2023
1 parent 5afc14a commit ae36d40
Show file tree
Hide file tree
Showing 22 changed files with 265 additions and 270 deletions.
3 changes: 1 addition & 2 deletions examples/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"""

import time
from typing import Tuple

from absl import app
from jax import random
Expand Down Expand Up @@ -100,7 +99,7 @@ def main(*args, use_dummy_data: bool = False, **kwargs) -> None:


def _get_dummy_data(mask_constant: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Return dummy data for when downloading embeddings is not feasible."""
n_train, n_test = 6, 6

Expand Down
18 changes: 9 additions & 9 deletions neural_tangents/_src/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"""


from typing import Callable, Tuple, Union, Dict, Any, TypeVar, Iterable, Optional
from typing import Callable, Any, TypeVar, Iterable, Optional
from functools import partial
import warnings
import jax
Expand Down Expand Up @@ -131,9 +131,9 @@ def batch(kernel_fn: _KernelFn,
_Output = TypeVar('_Output')


def _scan(f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]],
def _scan(f: Callable[[_Carry, _Input], tuple[_Carry, _Output]],
init: _Carry,
xs: Iterable[_Input]) -> Tuple[_Carry, _Output]:
xs: Iterable[_Input]) -> tuple[_Carry, _Output]:
"""Implements an unrolled version of scan.
Based on :obj:`jax.lax.scan` and has a similar API.
Expand Down Expand Up @@ -179,9 +179,9 @@ def _flatten_batch_dimensions(k: np.ndarray,


@utils.nt_tree_fn(nargs=1)
def _flatten_kernel_dict(k: Dict[str, Any],
def _flatten_kernel_dict(k: dict[str, Any],
x2_is_none: bool,
is_parallel: bool) -> Dict[str, Any]:
is_parallel: bool) -> dict[str, Any]:
if 'nngp' in k:
# We only use `batch_size` to compute `shape1` and `shape2` for the batch.
# This only happens if k_dict came from a `Kernel` in which case it must
Expand Down Expand Up @@ -399,7 +399,7 @@ def col_fn(x1, x2):

def serial_fn_kernel(k: NTTree[Kernel], *args, **kwargs) -> NTTree[Kernel]:

def get_n1_n2(k: NTTree[Kernel]) -> Tuple[int, ...]:
def get_n1_n2(k: NTTree[Kernel]) -> tuple[int, ...]:
if utils.is_list_or_tuple(k):
# TODO(schsam): We might want to check for consistency here, but I can't
# imagine a case where we could get inconsistent kernels.
Expand Down Expand Up @@ -462,7 +462,7 @@ def col_fn(n1, n2):
return flatten(k, cov2_is_none)

@utils.wraps(kernel_fn)
def serial_fn(x1_or_kernel: Union[NTTree[np.ndarray], NTTree[Kernel]],
def serial_fn(x1_or_kernel: NTTree[np.ndarray] | NTTree[Kernel],
x2: Optional[NTTree[Optional[np.ndarray]]] = None,
*args,
**kwargs) -> NTTree[Kernel]:
Expand Down Expand Up @@ -623,7 +623,7 @@ def _get_n_batches_and_batch_sizes(n1: int,
n2: int,
batch_size: int,
device_count: int
) -> Tuple[int, int, int, int]:
) -> tuple[int, int, int, int]:
# TODO(romann): if dropout batching works for different batch sizes, relax.
max_serial_batch_size = onp.gcd(n1, n2) // device_count

Expand Down Expand Up @@ -708,7 +708,7 @@ def broadcast(arg: np.ndarray) -> np.ndarray:
return np.broadcast_to(arg, (device_count,) + arg.shape)

@utils.wraps(f)
def f_pmapped(x_or_kernel: Union[np.ndarray, Kernel], *args, **kwargs):
def f_pmapped(x_or_kernel: np.ndarray | Kernel, *args, **kwargs):
args_np, args_np_idxs = [], []
args_other = {}

Expand Down

0 comments on commit ae36d40

Please sign in to comment.