Skip to content

Commit

Permalink
Migrate jax.ShapedArray -> jax.core.ShapedArray
Browse files Browse the repository at this point in the history
jax.ShapedArray is deprecated as of google/jax#15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.

PiperOrigin-RevId: 521880270
  • Loading branch information
Jake VanderPlas authored and romanngg committed Apr 19, 2023
1 parent 31bc793 commit 396ad1d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from jax import numpy as np
from jax import ops
from jax import random
from jax import ShapeDtypeStruct, ShapedArray, eval_shape, vmap
from jax import ShapeDtypeStruct, eval_shape, vmap
from jax.core import ShapedArray
import jax.example_libraries.stax as ostax
import numpy as onp
from .requirements import Bool, Diagonal, get_diagonal_outer_prods, layer, mean_and_var, requires, supports_masking
Expand Down
3 changes: 2 additions & 1 deletion neural_tangents/_src/stax/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import jax
from jax import lax
from jax import numpy as np
from jax import eval_shape, ShapedArray
from jax import eval_shape
from jax.core import ShapedArray
from jax.tree_util import tree_map, tree_all
from ..utils import utils
import dataclasses
Expand Down

0 comments on commit 396ad1d

Please sign in to comment.