Skip to content

Commit

Permalink
Remove references to deprecated jax.ShapedArray
Browse files Browse the repository at this point in the history
This is deprecated as of google/jax#15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.

PiperOrigin-RevId: 520659571
  • Loading branch information
Jake VanderPlas authored and romanngg committed Apr 19, 2023
1 parent 59ccf11 commit 31bc793
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion neural_tangents/_src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import warnings

import jax
from jax import core
from jax import random
import jax.numpy as np
from jax.tree_util import tree_all, tree_map
Expand Down Expand Up @@ -470,7 +471,7 @@ def mask(


def size_at(
x: Union[_ArrayOrShape, jax.ShapedArray],
x: Union[_ArrayOrShape, core.ShapedArray],
axes: Optional[Iterable[int]] = None
) -> int:
if hasattr(x, 'shape'):
Expand Down

0 comments on commit 31bc793

Please sign in to comment.