Skip to content

Commit

Permalink
fix(jax-frontend): Adds use of _get_seed where needed
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTz committed Oct 2, 2023
1 parent 33b2d96 commit a1e8cc9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ivy/functional/frontends/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def beta(key, a, b, shape=None, dtype=None):
"jax",
)
def categorical(key, logits, axis, shape=None):
_get_seed(key)
logits_arr = ivy.asarray(logits)

if axis >= 0:
Expand Down Expand Up @@ -330,18 +329,20 @@ def multivariate_normal(key, mean, cov, shape=None, dtype="float64", method="cho
@handle_jax_dtype
@to_ivy_arrays_and_back
def normal(key, shape=(), dtype=None):
return ivy.random_normal(shape=shape, dtype=dtype, seed=ivy.to_scalar(key[1]))
seed = _get_seed(key)
return ivy.random_normal(shape=shape, dtype=dtype, seed=seed)


@handle_jax_dtype
@to_ivy_arrays_and_back
def orthogonal(key, n, shape=(), dtype=None):
seed = _get_seed(key)
flat_shape = (n, n)
if shape:
flat_shape = shape + flat_shape

# Generate a random matrix with the given shape and dtype
random_matrix = ivy.random_uniform(key, shape=flat_shape, dtype=dtype)
random_matrix = ivy.random_uniform(seed=seed, shape=flat_shape, dtype=dtype)

# Compute the QR decomposition of the random matrix
q, _ = ivy.linalg.qr(random_matrix)
Expand Down Expand Up @@ -445,8 +446,9 @@ def t(key, df, shape=(), dtype="float64"):
@handle_jax_dtype
@to_ivy_arrays_and_back
def uniform(key, shape=(), dtype=None, minval=0.0, maxval=1.0):
seed = _get_seed(key)
return ivy.random_uniform(
low=minval, high=maxval, shape=shape, dtype=dtype, seed=ivy.to_scalar(key[1])
low=minval, high=maxval, shape=shape, dtype=dtype, seed=seed
)


Expand Down

0 comments on commit a1e8cc9

Please sign in to comment.