You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is there a recommended way for tagging an array as not trainable? Specifically in the case where it may not be known beforehand that we do not wish to train the parameter (i.e. so stop_gradient is not coded into the model).
I am also aware that e.g. optax allows specifying which parameters are trainable, but in many cases it would be much simpler to tag the arrays in someway, rather than specifying the trainable parameters using optax.
Option 1: make use of duck typing with __jax_array__, (from #10065), this is experimental and seems to be at risk of being removed.
import typing
import jax.numpy as jnp
from jax.lax import stop_gradient
class Buffer(typing.NamedTuple):
array: jnp.ndarray
def __jax_array__(self):
return stop_gradient(self.array)
Option 2: add an attribute to the arrays, and use this to partition the model e.g. using equinox
import equinox as eqx
import jax
import jax.numpy as jnp
import equinox as eqx
def make_my_module():
array = jnp.array([3.])
frozen_array = jnp.array([3.])
frozen_array.is_frozen = True
return (array, frozen_array)
def f(diff, static):
module = eqx.combine(diff, static)
return jnp.sum(module[0]) + jnp.sum(module[1])
def partition_fn(leaf):
if eqx.is_inexact_array(leaf):
if hasattr(leaf, "is_frozen"):
return not leaf.is_frozen
return True
return False
my_module = make_my_module()
val, grad = jax.value_and_grad(f)(*eqx.partition(my_module, filter_spec=partition_fn))
expected_val = 6
expected_grad = (1, None)
assert val==expected_val
assert grad==expected_grad
# But this attribute might not be maintained e.g. under vmap
my_module_vmapped = jax.vmap(make_my_module, axis_size=1)()
val, grad = jax.value_and_grad(f)(
*eqx.partition(my_module_vmapped, filter_spec=partition_fn)
)
assert grad!=expected_grad
Option 2 seems like it will have some issues, e.g. losing the attributed when constructed with vmap.
Is there a recommended way to achieve this?
The text was updated successfully, but these errors were encountered:
JAX itself doesn't have any notion of model training, so the answer to your question would depend on what framework you're using. It looks like you're using equinox, so you may find more useful answers by asking at https://github.com/patrick-kidger/equinox.
No, not really. The only supported way to associate metadata with an array would be using a pytree (see https://jax.readthedocs.io/en/latest/pytrees.html), but there's no way to make a pytree duck-type as an array (note that __jax_array__, despite existing in some places, is not meant as a public API and is not fully supported throughout the package).
Is there a recommended way for tagging an array as not trainable? Specifically in the case where it may not be known beforehand that we do not wish to train the parameter (i.e. so
stop_gradient
is not coded into the model).I am also aware that e.g. optax allows specifying which parameters are trainable, but in many cases it would be much simpler to tag the arrays in someway, rather than specifying the trainable parameters using optax.
Option 1: make use of duck typing with
__jax_array__
, (from #10065), this is experimental and seems to be at risk of being removed.Option 2: add an attribute to the arrays, and use this to partition the model e.g. using equinox
Option 2 seems like it will have some issues, e.g. losing the attributed when constructed with vmap.
Is there a recommended way to achieve this?
The text was updated successfully, but these errors were encountered: