-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equinox support? #1
Comments
@patrick-kidger It's supposed to work out of the box but actually I found a bug in my general case handling of import equinox as eqx
import jax
import jax.numpy as jnp
import optax
import lorax
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
def __call__(self, x):
return self.weight @ x + self.bias
@lorax.lora
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
return jax.numpy.mean((y - pred_y) ** 2)
@jax.jit
@jax.value_and_grad
def split_params_loss(tune_params, freeze_params, x, y):
return loss_fn((freeze_params, tune_params), x, y)
batch_size, in_size, out_size = 32, 128, 64
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
model = Linear(in_size, out_size, key=k1)
x = jax.random.normal(k2, (batch_size, in_size))
y = jax.random.normal(k3, (batch_size, out_size))
lora_spec = jax.tree_map(lambda x: 16 if len(x.shape) > 1 else lorax.LORA_FULL , model)
freeze_params, tune_params = lorax.init_lora(
param_tree=model,
spec=lora_spec,
rng=k4
)
lora_fn = jax.jit(jax.grad(lorax.lora(loss_fn)))
opt = optax.adam(learning_rate=1e-1)
opt_state = opt.init(tune_params)
for i in range(100):
loss, grad = split_params_loss(tune_params, freeze_params, x, y)
updates, opt_state = opt.update(grad, opt_state)
tune_params = optax.apply_updates(tune_params, updates)
print(f'{i}: {loss:.3e}') Well it works locally, but I'll be pushing the fixed version shortly. |
Oh nice -- that's really cool. Thank you for looking into this. For specifying paths to parameters: I've encountered this issue before. IMO the nicest way to do it is to use a lambda function, e.g. There is also the new keypath functionality, but this doesn't play super well with custom pytree nodes -- you have to say things like "the |
For now I just had it call I fixed the dot bug and the example above should work now, let me know if there are any problems with it. |
This looks neat! I'm just curious about supporting Equinox as a possible backend neural network library.
This is typically called as:
but this can still be thought of in an init/apply paradigm if you want it to:
c.f. also this example
So I'm guessing this should be straightforward/elegant to support.
(I'll own up to the fact that I'm discussing compatibility with one of my own projects here!)
The text was updated successfully, but these errors were encountered: