Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equinox support? #1

Closed
patrick-kidger opened this issue May 3, 2023 · 3 comments
Closed

Equinox support? #1

patrick-kidger opened this issue May 3, 2023 · 3 comments

Comments

@patrick-kidger
Copy link

This looks neat! I'm just curious about supporting Equinox as a possible backend neural network library.

This is typically called as:

model = eqx.nn.MLP(...)
model(data)

but this can still be thought of in an init/apply paradigm if you want it to:

init = eqx.nn.MLP
apply = eqx.nn.MLP.__call__

params = init(...)
apply(params, data)

c.f. also this example

So I'm guessing this should be straightforward/elegant to support.

(I'll own up to the fact that I'm discussing compatibility with one of my own projects here!)

@davisyoshida
Copy link
Owner

@patrick-kidger It's supposed to work out of the box but actually I found a bug in my general case handling of dot_general. So I fixed that. The only thing breaking that I haven't decided how to fix is the formatting of pytree paths to parameters when constructing a lora spec. You can just manually do it though, and then the following works:

import equinox as eqx
import jax
import jax.numpy as jnp
import optax

import lorax


class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

@lorax.lora
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

@jax.jit
@jax.value_and_grad
def split_params_loss(tune_params, freeze_params, x, y):
    return loss_fn((freeze_params, tune_params), x, y)

batch_size, in_size, out_size = 32, 128, 64
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
model = Linear(in_size, out_size, key=k1)
x = jax.random.normal(k2, (batch_size, in_size))
y = jax.random.normal(k3, (batch_size, out_size))


lora_spec = jax.tree_map(lambda x: 16 if len(x.shape) > 1 else lorax.LORA_FULL , model)
freeze_params, tune_params = lorax.init_lora(
    param_tree=model, 
    spec=lora_spec,
    rng=k4
)

lora_fn = jax.jit(jax.grad(lorax.lora(loss_fn)))

opt = optax.adam(learning_rate=1e-1)
opt_state = opt.init(tune_params)

for i in range(100):
    loss, grad = split_params_loss(tune_params, freeze_params, x, y)
    updates, opt_state = opt.update(grad, opt_state)
    tune_params = optax.apply_updates(tune_params, updates)
    print(f'{i}: {loss:.3e}')

Well it works locally, but I'll be pushing the fixed version shortly.

@patrick-kidger
Copy link
Author

patrick-kidger commented May 3, 2023

Oh nice -- that's really cool. Thank you for looking into this.

For specifying paths to parameters: I've encountered this issue before. IMO the nicest way to do it is to use a lambda function, e.g. lambda model: model.layers[-1].weight. (This is exactly what's done in equinox.tree_at, e.g. new_model = tree_at(lambda m: m.layers[-1].weight, model, new_weight) modifies the weight at this position. Of course this idea can be used independently of Equinox.)

There is also the new keypath functionality, but this doesn't play super well with custom pytree nodes -- you have to say things like "the ith leaf" which isn't really that nice.

@davisyoshida
Copy link
Owner

For now I just had it call str with unknown node types (I'm using jax.tree_util.tree_map_with_path), but I agree it's not super pretty, so I may do something else in the future.

I fixed the dot bug and the example above should work now, let me know if there are any problems with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants