Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integration with Flax? #11

Open
nestordemeure opened this issue Apr 13, 2022 · 4 comments
Open

integration with Flax? #11

nestordemeure opened this issue Apr 13, 2022 · 4 comments

Comments

@nestordemeure
Copy link

Is there any interest in integrating this work with Flax?

They already have a init function, decoupling parameters initialization from model definition which could make introducing mup fairly plug-and-play.

Plus they relie on optax for their optimizers. As that library has a focus on composability, you might be able to introduce a transformation that takes an optimizer and makes it mup compatible.

Overall, I believe the Flax ecosystem could make mup more easily accessible to people.

@thegregyang
Copy link
Contributor

Integration with Flax would be fantastic, but neither I nor @edwardjhu are familiar with it. If someone from the Flax team can work with us, we can definitely advise the integration process.

@davisyoshida
Copy link

davisyoshida commented Apr 25, 2022

@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku here. If you're not attached to FLAX in particular you could use this. You could also probably adapt this design to FLAX if you wanted, since FLAX/Haiku are more similar than FLAX/torch.

Edit: @thegregyang By the way, can you take a look at the plots in the README there? The optimal learning rate stabilizes with width, but it does look like I see better training loss for SP sometimes. Is that indicative of a bug? My coord checks look good, nothing grows with width, output norm (at init) decays with width.

@thegregyang
Copy link
Contributor

Hey @davisyoshida your repo looks great so far!

For your plot, you'd get better results if you tune the input, output, and hidden learning rates for your small model and scale up from there, sweeping a global lr multiplier on the x-axis (ideally, you tune (lr, init) for all parameter tensors, but these 3 learning rates should be a good practical approximation). In particular, for a fair comparison, the curves for your small model in both SP and muP plots should be the same. Your current plots are just looking at a slice of the HP space (of (lr, init) for all parameter tensors) away from the true optimum.

@davisyoshida
Copy link

Ah that makes perfect sense, I'll generate new versions of the figures. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants