Skip to content

Commit

Permalink
Revert to using Union instead of | for union types to support pyt…
Browse files Browse the repository at this point in the history
…hon 3.9 (`|` only works since python 3.10). Update flax dependency.

PiperOrigin-RevId: 556877465
  • Loading branch information
romanngg committed Aug 14, 2023
1 parent e3e0f52 commit 152d0c5
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 98 deletions.
10 changes: 5 additions & 5 deletions neural_tangents/_src/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"""


from typing import Callable, Any, TypeVar, Iterable, Optional
from typing import Callable, Any, TypeVar, Iterable, Optional, Union
from functools import partial
import warnings
import jax
Expand Down Expand Up @@ -462,7 +462,7 @@ def col_fn(n1, n2):
return flatten(k, cov2_is_none)

@utils.wraps(kernel_fn)
def serial_fn(x1_or_kernel: NTTree[np.ndarray] | NTTree[Kernel],
def serial_fn(x1_or_kernel: Union[NTTree[np.ndarray], NTTree[Kernel]],
x2: Optional[NTTree[Optional[np.ndarray]]] = None,
*args,
**kwargs) -> NTTree[Kernel]:
Expand Down Expand Up @@ -613,8 +613,8 @@ def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs):
return parallel_fn_kernel(x1_or_kernel, *args, **kwargs)
raise NotImplementedError()

# Set function attributes so that `serial` can detect whether or not it is
# acting on a parallel function.
# Set function attributes so that `serial` can detect whether it is acting on
# a parallel function.
parallel_fn.device_count = device_count
return parallel_fn

Expand Down Expand Up @@ -708,7 +708,7 @@ def broadcast(arg: np.ndarray) -> np.ndarray:
return np.broadcast_to(arg, (device_count,) + arg.shape)

@utils.wraps(f)
def f_pmapped(x_or_kernel: np.ndarray | Kernel, *args, **kwargs):
def f_pmapped(x_or_kernel: Union[np.ndarray, Kernel], *args, **kwargs):
args_np, args_np_idxs = [], []
args_other = {}

Expand Down
50 changes: 25 additions & 25 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
import enum
import functools
import operator
from typing import Callable, KeysView, Optional, TypeVar, Iterable
from typing import Callable, KeysView, Optional, TypeVar, Iterable, Union
import warnings

import jax
Expand Down Expand Up @@ -683,8 +683,8 @@ def sum_and_contract(
fx1: np.ndarray,
fx2: np.ndarray,
fx_axis,
df_dys_1: list[np.ndarray | Zero],
df_dys_2: list[np.ndarray | Zero],
df_dys_1: list[Union[np.ndarray, Zero]],
df_dys_2: list[Union[np.ndarray, Zero]],
dy_dws_1: list[tuple[np.ndarray, rules.Structure]],
dy_dws_2: list[tuple[np.ndarray, rules.Structure]],
dtype: np.dtype
Expand Down Expand Up @@ -890,7 +890,7 @@ def empirical_ntk_fn(
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None,
implementation: NtkImplementation | int = DEFAULT_NTK_IMPLEMENTATION,
implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION,
_j_rules: bool = _DEFAULT_NTK_J_RULES,
_s_rules: bool = _DEFAULT_NTK_S_RULES,
_fwd: Optional[bool] = _DEFAULT_NTK_FWD,
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def empirical_kernel_fn(
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None,
implementation: NtkImplementation | int = DEFAULT_NTK_IMPLEMENTATION,
implementation: Union[NtkImplementation, int] = DEFAULT_NTK_IMPLEMENTATION,
_j_rules: bool = _DEFAULT_NTK_J_RULES,
_s_rules: bool = _DEFAULT_NTK_S_RULES,
_fwd: Optional[bool] = _DEFAULT_NTK_FWD,
Expand Down Expand Up @@ -1190,7 +1190,7 @@ def empirical_kernel_fn(
def kernel_fn(
x1: PyTree,
x2: Optional[PyTree],
get: None | str | tuple[str, ...],
get: Union[None, str, tuple[str, ...]],
params: PyTree,
**apply_fn_kwargs
) -> PyTree:
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def expand(x: np.ndarray) -> np.ndarray:


def _expand_dims(
x: Optional[PyTree] | UndefinedPrimal,
x: Union[None, PyTree, UndefinedPrimal],
axis: Optional[PyTree]
) -> Optional[PyTree]:
if axis is None or x is None or isinstance(x, UndefinedPrimal):
Expand Down Expand Up @@ -1545,7 +1545,7 @@ def _squeeze(x: PyTree, axis: Optional[PyTree]) -> PyTree:

def squeeze(
x: np.ndarray,
axis: None | int | tuple[int, ...]
axis: Union[None, int, tuple[int, ...]]
) -> np.ndarray:
"""`np.squeeze` analog working with 0-sized axes."""
if isinstance(axis, int):
Expand Down Expand Up @@ -1799,8 +1799,8 @@ def _backward_pass(
_j_rules: bool,
_s_rules: bool,
_fwd: Optional[bool]
) -> (list[list[np.ndarray | Zero]] |
list[list[tuple[np.ndarray, rules.Structure]]]):
) -> Union[list[list[Union[np.ndarray, Zero]]],
list[list[tuple[np.ndarray, rules.Structure]]]]:
"""Similar to and adapted from `jax.interpreters.ad.backward_pass`.
Traverses the computational graph in the same order as the above, but collects
Expand All @@ -1818,7 +1818,7 @@ def _backward_pass(
the NTK.
"""

def read_cotangent(v: Var) -> np.ndarray | Zero:
def read_cotangent(v: Var) -> Union[np.ndarray, Zero]:
return ct_env.pop(v, Zero(v.aval))

primal_env: dict[Var, np.ndarray] = {}
Expand Down Expand Up @@ -1988,9 +1988,9 @@ def _backprop_step(
eqn: JaxprEqn,
primal_env: dict[Var, np.ndarray],
ct_env: dict[Var, np.ndarray],
read_cotangent: Callable[[Var], np.ndarray | Zero],
read_cotangent: Callable[[Var], Union[np.ndarray, Zero]],
do_write_cotangents: bool = True
) -> tuple[np.ndarray | Zero, list[np.ndarray | UndefinedPrimal]]:
) -> tuple[Union[np.ndarray, Zero], list[Union[np.ndarray, UndefinedPrimal]]]:
"""Adapted from `jax.interpreters.ad`."""
invals = map(functools.partial(_read_primal, primal_env), eqn.invars)
cts_in = map(read_cotangent, eqn.outvars)
Expand Down Expand Up @@ -2024,9 +2024,9 @@ def _trim_cotangents(


def _trim_invals(
invals: list[np.ndarray | UndefinedPrimal],
invals: list[Union[np.ndarray, UndefinedPrimal]],
structure: rules.Structure,
) -> list[np.ndarray | UndefinedPrimal]:
) -> list[Union[np.ndarray, UndefinedPrimal]]:
trimmed_invals = list(invals)

for i in structure.in_trace_idxs:
Expand All @@ -2053,7 +2053,7 @@ def _trim_invals(
def _trim_eqn(
eqn: JaxprEqn,
idx: int,
trimmed_invals: list[np.ndarray | UndefinedPrimal],
trimmed_invals: list[Union[np.ndarray, UndefinedPrimal]],
trimmed_cts_in: ShapedArray
) -> JaxprEqn:
if eqn.primitive in rules.EQN_PARAMS_RULES:
Expand All @@ -2072,9 +2072,9 @@ def _trim_eqn(


def _trim_axis(
x: UndefinedPrimal | ShapedArray | np.ndarray,
axis: int | tuple[int, ...],
) -> UndefinedPrimal | ShapedArray:
x: Union[UndefinedPrimal, ShapedArray, np.ndarray],
axis: Union[int, tuple[int, ...]],
) -> Union[UndefinedPrimal, ShapedArray]:
"""Trim `axis` of `x` to be of length `1`. `x` is only used for shape."""
if isinstance(axis, int):
axis = (axis,)
Expand Down Expand Up @@ -2158,11 +2158,11 @@ def _eqn_vjp_fn(
def _get_jacobian(
eqn: Optional[JaxprEqn],
cts_in: ShapedArray,
invals: list[np.ndarray | UndefinedPrimal],
invals: list[Union[np.ndarray, UndefinedPrimal]],
idx: int,
_j_rules: bool,
_fwd: Optional[bool],
) -> np.ndarray | Zero:
) -> Union[np.ndarray, Zero]:
"""Get the (structured) `eqn` output Jacobian wrt `eqn.invars[idx]`."""
if eqn is None:
primitive = None
Expand Down Expand Up @@ -2214,7 +2214,7 @@ def _write_cotangent(
prim: core.Primitive,
ct_env: dict[Var, np.ndarray],
v: Var,
ct: np.ndarray | Zero
ct: Union[np.ndarray, Zero]
):
"""Adapted from `jax.interpreters.ad`."""
assert ct is not Zero, (prim, v.aval)
Expand All @@ -2235,8 +2235,8 @@ def _write_cotangent(

def _read_primal(
env: dict[Var, np.ndarray],
v: Var | Literal,
) -> np.ndarray | UndefinedPrimal:
v: Union[Var, Literal],
) -> Union[np.ndarray, UndefinedPrimal]:
if type(v) is Literal:
return v.val

Expand All @@ -2250,7 +2250,7 @@ def _read_primal(
def _write_primal(
env: dict[Var, np.ndarray],
v: Var,
val: np.ndarray | UndefinedPrimal
val: Union[np.ndarray, UndefinedPrimal]
):
if not ad.is_undefined_primal(val):
env[v] = val # pytype: disable=container-type-mismatch # jax-ndarray
Expand Down
12 changes: 6 additions & 6 deletions neural_tangents/_src/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from functools import partial
import operator
from typing import Generator, Iterable, Optional
from typing import Generator, Iterable, Optional, Union

from .batching import batch
from .empirical import empirical_kernel_fn, NtkImplementation, DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_S_RULES, _DEFAULT_NTK_J_RULES
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_sampled_kernel(
x2: np.ndarray,
get: Optional[Get] = None,
**apply_fn_kwargs
) -> Generator[np.ndarray | tuple[np.ndarray, ...], None, None]:
) -> Generator[Union[np.ndarray, tuple[np.ndarray, ...]], None, None]:
for n, sample in get_samples(x1, x2, get, **apply_fn_kwargs):
if n in n_samples:
yield normalize(sample, n)
Expand All @@ -106,7 +106,7 @@ def get_sampled_kernel(
x2: np.ndarray,
get: Optional[Get] = None,
**apply_fn_kwargs
) -> np.ndarray | tuple[np.ndarray, ...]:
) -> Union[np.ndarray, tuple[np.ndarray, ...]]:
for n, sample in get_samples(x1, x2, get, **apply_fn_kwargs):
pass
return normalize(sample, n)
Expand All @@ -118,14 +118,14 @@ def monte_carlo_kernel_fn(
init_fn: InitFn,
apply_fn: ApplyFn,
key: random.KeyArray,
n_samples: int | Iterable[int],
n_samples: Union[int, Iterable[int]],
batch_size: int = 0,
device_count: int = -1,
store_on_device: bool = True,
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: Optional[VMapAxes] = None,
implementation: int | NtkImplementation = DEFAULT_NTK_IMPLEMENTATION,
implementation: Union[int, NtkImplementation] = DEFAULT_NTK_IMPLEMENTATION,
_j_rules: bool = _DEFAULT_NTK_J_RULES,
_s_rules: bool = _DEFAULT_NTK_S_RULES,
_fwd: Optional[bool] = _DEFAULT_NTK_FWD,
Expand Down Expand Up @@ -325,7 +325,7 @@ def monte_carlo_kernel_fn(


def _canonicalize_n_samples(
n_samples: int | Iterable[int]) -> tuple[set[int], bool]:
n_samples: Union[int, Iterable[int]]) -> tuple[set[int], bool]:
get_generator = True
if isinstance(n_samples, int):
get_generator = False
Expand Down
22 changes: 11 additions & 11 deletions neural_tangents/_src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import collections
from functools import lru_cache
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Any
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Any, Union

import jax
from jax import grad
Expand All @@ -48,7 +48,7 @@
PyTree = Any


ArrayOrScalar = None | int | float | np.ndarray
ArrayOrScalar = Union[None, int, float, np.ndarray]
"""Alias for optional arrays or scalars."""


Expand All @@ -61,7 +61,7 @@ def __call__(
fx_train_0: ArrayOrScalar = 0.,
fx_test_0: Optional[ArrayOrScalar] = None,
k_test_train: Optional[np.ndarray] = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]:
...


Expand Down Expand Up @@ -230,7 +230,7 @@ def predict_fn(
fx_train_0: ArrayOrScalar = 0.,
fx_test_0: Optional[ArrayOrScalar] = None,
k_test_train: Optional[np.ndarray] = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]:
"""Return output predictions on train [and test] set[s] at time[s] `t`.
Args:
Expand Down Expand Up @@ -304,10 +304,10 @@ class PredictFnODE(Protocol):
def __call__(
self,
t: Optional[ArrayOrScalar] = None,
fx_train_or_state_0: ArrayOrScalar | ODEState = 0.,
fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
fx_test_0: Optional[ArrayOrScalar] = None,
k_test_train: Optional[np.ndarray] = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray] | ODEState:
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray], ODEState]:
...


Expand Down Expand Up @@ -465,10 +465,10 @@ def dstate_dt(state_t: ODEState, unused_t) -> ODEState:

def predict_fn(
t: Optional[ArrayOrScalar] = None,
fx_train_or_state_0: ArrayOrScalar | ODEState = 0.,
fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
fx_test_0: Optional[ArrayOrScalar] = None,
k_test_train: Optional[np.ndarray] = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray] | ODEState:
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray], ODEState]:
"""Return output predictions on train [and test] set[s] at time[s] `t`.
Args:
Expand Down Expand Up @@ -512,7 +512,7 @@ def predict_fn(
t_shape = t.shape
t = t.reshape((-1,))

# ODE solver requires `t[0]` to be the time where `fx_train_0` [and
# ODE solver requires `t[0]` to be the time when `fx_train_0` [and
# `fx_test_0`] are evaluated, but also a strictly increasing sequence of
# timesteps, so we always temporarily append an [almost] `0` at the start.
t0 = np.where(t[0] == 0,
Expand Down Expand Up @@ -635,7 +635,7 @@ def k_inv_y(g: str):
def predict_fn(get: Optional[Get] = None,
k_test_train=None,
k_test_test=None
) -> dict[str, np.ndarray | Gaussian]:
) -> dict[str, Union[np.ndarray, Gaussian]]:
"""`test`-set posterior given respective covariance matrices.
Args:
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def _inv_expm1_fn(evals: np.ndarray, t: np.ndarray):
return _inv_expm1_fn


def _check_inputs(fx_train_or_state_0: ArrayOrScalar | ODEState,
def _check_inputs(fx_train_or_state_0: Union[ArrayOrScalar, ODEState],
fx_test_0: ArrayOrScalar,
k_test_train: Optional[np.ndarray]):
if isinstance(fx_train_or_state_0, ODEState):
Expand Down
8 changes: 4 additions & 4 deletions neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import functools
import operator as op
import string
from typing import Callable, Iterable, Optional, Sequence
from typing import Callable, Iterable, Optional, Sequence, Union
import warnings

import jax
Expand Down Expand Up @@ -111,8 +111,8 @@ def Identity() -> InternalLayer:
@supports_masking(remask_kernel=False)
def DotGeneral(
*,
lhs: Optional[np.ndarray | float] = None,
rhs: Optional[np.ndarray | float] = None,
lhs: Optional[Union[np.ndarray, float]] = None,
rhs: Optional[Union[np.ndarray, float]] = None,
dimension_numbers: lax.DotDimensionNumbers = (((), ()), ((), ())),
precision: Optional[lax.Precision] = None,
batch_axis: int = 0,
Expand Down Expand Up @@ -2592,7 +2592,7 @@ def kernel_fn_train(k: Kernel, **kwargs):
@supports_masking(remask_kernel=True)
def ImageResize(
shape: Sequence[int],
method: str | jax.image.ResizeMethod,
method: Union[str, jax.image.ResizeMethod],
antialias: bool = True,
precision: lax.Precision = lax.Precision.HIGHEST,
batch_axis: int = 0,
Expand Down

0 comments on commit 152d0c5

Please sign in to comment.