Skip to content

Commit

Permalink
Replaces references to jax.numpy.DeviceArray with jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 515924016
  • Loading branch information
hawkinsp authored and OptaxDev committed Mar 11, 2023
1 parent 451b006 commit 166e12f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion optax/_src/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def join_schedules(schedules: Sequence[base.Schedule],
Returns:
schedule: A function that maps step counts to values.
"""
def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray:
def schedule(step: jax.Array) -> jax.Array:
output = schedules[0](step)
for boundary, schedule in zip(boundaries, schedules[1:]):
output = jnp.where(step < boundary, output, schedule(step - boundary))
Expand Down
16 changes: 8 additions & 8 deletions optax/_src/second_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def ravel(p: Any) -> Array:

def hvp(
loss: LossFun,
v: jnp.DeviceArray,
v: jax.Array,
params: Any,
inputs: jnp.DeviceArray,
targets: jnp.DeviceArray,
) -> jnp.DeviceArray:
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Performs an efficient vector-Hessian (of `loss`) product.
Args:
Expand All @@ -69,9 +69,9 @@ def hvp(
def hessian_diag(
loss: LossFun,
params: Any,
inputs: jnp.DeviceArray,
targets: jnp.DeviceArray,
) -> jnp.DeviceArray:
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Computes the diagonal hessian of `loss` at (`inputs`, `targets`).
Args:
Expand All @@ -94,7 +94,7 @@ def fisher_diag(
params: Any,
inputs: jnp.ndarray,
targets: jnp.ndarray,
) -> jnp.DeviceArray:
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.
Args:
Expand Down

0 comments on commit 166e12f

Please sign in to comment.