Skip to content

Commit

Permalink
[typing] regularize types of jax.random API
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 20, 2023
1 parent 347b8a5 commit 8f82f2e
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 88 deletions.
4 changes: 2 additions & 2 deletions jax/_src/prng.py
Expand Up @@ -423,7 +423,7 @@ def prngkeyarrayimpl_unflatten(impl, children):


# TODO(frostig): remove, rerouting callers directly to random_seed
def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl:
def seed_with_impl(impl: PRNGImpl, seed: int | typing.ArrayLike) -> PRNGKeyArrayImpl:
return random_seed(seed, impl=impl)


Expand Down Expand Up @@ -694,7 +694,7 @@ def iterated_vmap_binary_bcast(shape1, shape2, f):
return f


def random_seed(seeds, impl):
def random_seed(seeds: int | typing.ArrayLike, impl: PRNGImpl) -> PRNGKeyArrayImpl:
# Avoid overflow error in X32 mode by first converting ints to int64.
# This breaks JIT invariance for large ints, but supports the common
# use-case of instantiating with Python hashes in X32 mode.
Expand Down

0 comments on commit 8f82f2e

Please sign in to comment.