From 166e12fc7df7f269be0d8464cfe3a925e3db7271 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Sat, 11 Mar 2023 15:07:06 -0800 Subject: [PATCH] Replaces references to jax.numpy.DeviceArray with jax.Array. PiperOrigin-RevId: 515924016 --- optax/_src/schedule.py | 2 +- optax/_src/second_order.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/optax/_src/schedule.py b/optax/_src/schedule.py index 24025279..0f17f9e8 100644 --- a/optax/_src/schedule.py +++ b/optax/_src/schedule.py @@ -393,7 +393,7 @@ def join_schedules(schedules: Sequence[base.Schedule], Returns: schedule: A function that maps step counts to values. """ - def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray: + def schedule(step: jax.Array) -> jax.Array: output = schedules[0](step) for boundary, schedule in zip(boundaries, schedules[1:]): output = jnp.where(step < boundary, output, schedule(step - boundary)) diff --git a/optax/_src/second_order.py b/optax/_src/second_order.py index 6793dbc9..ea619ba7 100644 --- a/optax/_src/second_order.py +++ b/optax/_src/second_order.py @@ -43,11 +43,11 @@ def ravel(p: Any) -> Array: def hvp( loss: LossFun, - v: jnp.DeviceArray, + v: jax.Array, params: Any, - inputs: jnp.DeviceArray, - targets: jnp.DeviceArray, -) -> jnp.DeviceArray: + inputs: jax.Array, + targets: jax.Array, +) -> jax.Array: """Performs an efficient vector-Hessian (of `loss`) product. Args: @@ -69,9 +69,9 @@ def hvp( def hessian_diag( loss: LossFun, params: Any, - inputs: jnp.DeviceArray, - targets: jnp.DeviceArray, -) -> jnp.DeviceArray: + inputs: jax.Array, + targets: jax.Array, +) -> jax.Array: """Computes the diagonal hessian of `loss` at (`inputs`, `targets`). Args: @@ -94,7 +94,7 @@ def fisher_diag( params: Any, inputs: jnp.ndarray, targets: jnp.ndarray, -) -> jnp.DeviceArray: +) -> jax.Array: """Computes the diagonal of the (observed) Fisher information matrix. Args: