Skip to content

Commit

Permalink
chex: alias PRNGKey to jax.Array
Browse files Browse the repository at this point in the history
Going forward, `jax.Array` is the correct type annotation for both new-style and old-style PRNG keys in JAX (see google/jax#17297)

Note that `jax.random.KeyArray` has been aliased to `Any` under TYPE_CHECKING, and so this change will make existing annotations more strict, not less strict.

PiperOrigin-RevId: 565133147
  • Loading branch information
Jake VanderPlas authored and DistraxDev committed Sep 13, 2023
1 parent 079849a commit 3a832d3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion distrax/_src/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def convert_seed_and_sample_shape(
else: # key is of type PRNGKey
rng = seed

return rng, sample_shape
return rng, sample_shape # type: ignore[bad-return-type]


def to_batch_shape_index(
Expand Down

0 comments on commit 3a832d3

Please sign in to comment.