-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
integration with Flax? #11
Comments
Integration with Flax would be fantastic, but neither I nor @edwardjhu are familiar with it. If someone from the Flax team can work with us, we can definitely advise the integration process. |
@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku here. If you're not attached to FLAX in particular you could use this. You could also probably adapt this design to FLAX if you wanted, since FLAX/Haiku are more similar than FLAX/torch. Edit: @thegregyang By the way, can you take a look at the plots in the README there? The optimal learning rate stabilizes with width, but it does look like I see better training loss for SP sometimes. Is that indicative of a bug? My coord checks look good, nothing grows with width, output norm (at init) decays with width. |
Hey @davisyoshida your repo looks great so far! For your plot, you'd get better results if you tune the input, output, and hidden learning rates for your small model and scale up from there, sweeping a global lr multiplier on the x-axis (ideally, you tune (lr, init) for all parameter tensors, but these 3 learning rates should be a good practical approximation). In particular, for a fair comparison, the curves for your small model in both SP and muP plots should be the same. Your current plots are just looking at a slice of the HP space (of (lr, init) for all parameter tensors) away from the true optimum. |
Ah that makes perfect sense, I'll generate new versions of the figures. Thanks! |
Is there any interest in integrating this work with Flax?
They already have a init function, decoupling parameters initialization from model definition which could make introducing mup fairly plug-and-play.
Plus they relie on optax for their optimizers. As that library has a focus on composability, you might be able to introduce a transformation that takes an optimizer and makes it mup compatible.
Overall, I believe the Flax ecosystem could make mup more easily accessible to people.
The text was updated successfully, but these errors were encountered: