Skip to content

Commit

Permalink
alias chex.PRNGKey to jax.Array
Browse files Browse the repository at this point in the history
Starting with jax v0.4.16 and going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see [JEP 9263](google/jax#17297) for details)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations far more strict than they were previously.

PiperOrigin-RevId: 566933252
  • Loading branch information
Jake VanderPlas authored and ChexDev committed Sep 20, 2023
1 parent 533aeeb commit 902031d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion chex/_src/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
Scalar = Union[float, int]
Numeric = Union[Array, Scalar]
Shape = jax.core.Shape
PRNGKey = Union[jax.random.KeyArray, jax.Array]
PRNGKey = jax.Array
PyTreeDef = jax.tree_util.PyTreeDef
Device = jax.Device
ArrayDType = type(jnp.float32)
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
absl-py>=0.9.0
typing_extensions>=4.2.0
jax>=0.4.6
jax>=0.4.16
jaxlib>=0.1.37
numpy>=1.24.1
toolz>=0.9.0

0 comments on commit 902031d

Please sign in to comment.