Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marking non-trainable / frozen parameters #20012

Closed
danielward27 opened this issue Feb 28, 2024 · 3 comments
Closed

Marking non-trainable / frozen parameters #20012

danielward27 opened this issue Feb 28, 2024 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@danielward27
Copy link

Is there a recommended way for tagging an array as not trainable? Specifically in the case where it may not be known beforehand that we do not wish to train the parameter (i.e. so stop_gradient is not coded into the model).

I am also aware that e.g. optax allows specifying which parameters are trainable, but in many cases it would be much simpler to tag the arrays in someway, rather than specifying the trainable parameters using optax.

Option 1: make use of duck typing with __jax_array__, (from #10065), this is experimental and seems to be at risk of being removed.

import typing
import jax.numpy as jnp
from jax.lax import stop_gradient

class Buffer(typing.NamedTuple):
    array: jnp.ndarray

    def __jax_array__(self):
        return stop_gradient(self.array)
    

Option 2: add an attribute to the arrays, and use this to partition the model e.g. using equinox

import equinox as eqx
import jax
import jax.numpy as jnp
import equinox as eqx

def make_my_module():
    array = jnp.array([3.])
    frozen_array = jnp.array([3.])
    frozen_array.is_frozen = True
    return (array, frozen_array)

def f(diff, static):
    module = eqx.combine(diff, static)
    return jnp.sum(module[0]) + jnp.sum(module[1])


def partition_fn(leaf):
    if eqx.is_inexact_array(leaf):
        if hasattr(leaf, "is_frozen"):
            return not leaf.is_frozen
        return True
    return False

my_module = make_my_module()

val, grad = jax.value_and_grad(f)(*eqx.partition(my_module, filter_spec=partition_fn))

expected_val = 6
expected_grad = (1, None)
assert val==expected_val
assert grad==expected_grad


# But this attribute might not be maintained e.g. under vmap
my_module_vmapped = jax.vmap(make_my_module, axis_size=1)()
val, grad = jax.value_and_grad(f)(
    *eqx.partition(my_module_vmapped, filter_spec=partition_fn)
    )
assert grad!=expected_grad

Option 2 seems like it will have some issues, e.g. losing the attributed when constructed with vmap.

Is there a recommended way to achieve this?

@danielward27 danielward27 added the enhancement New feature or request label Feb 28, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 28, 2024

JAX itself doesn't have any notion of model training, so the answer to your question would depend on what framework you're using. It looks like you're using equinox, so you may find more useful answers by asking at https://github.com/patrick-kidger/equinox.

@jakevdp jakevdp self-assigned this Feb 28, 2024
@danielward27
Copy link
Author

Ok thanks. I guess more broadly, is there a way to associate metadata with an array without duck typing?

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 28, 2024

No, not really. The only supported way to associate metadata with an array would be using a pytree (see https://jax.readthedocs.io/en/latest/pytrees.html), but there's no way to make a pytree duck-type as an array (note that __jax_array__, despite existing in some places, is not meant as a public API and is not fully supported throughout the package).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants