Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp16 buffers for ADAM #289

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

fp16 buffers for ADAM #289

wants to merge 2 commits into from

Conversation

ngc92
Copy link
Contributor

@ngc92 ngc92 commented Apr 29, 2024

First proof-of-concept implementation

@ngc92
Copy link
Contributor Author

ngc92 commented Apr 29, 2024

Instead of having a single scale factor per tensor, we have scales for individual groups of 32. This is less about getting more accuracy (though it might help with that), and more to ensure that we don't need any form of cross-warp communication to handle the scales.
I'd expect the group size of 32 to increase once we switch to vertorized adam kernels anyway.

@karpathy
Copy link
Owner

I think I'm missing a bit of context on this PR. Is this following some paper / approach?

@ngc92
Copy link
Contributor Author

ngc92 commented Apr 29, 2024

It comes from the appendix of "Efficient Large Scale Language Modeling with Mixtures of Experts", which in turn cites "Jukebox: A Generative Model for Music".

However, this is not actually a 1:1 implementation of that. If you want to have one scaling factor per tensor, you need to know inside the adam kernel in which tensor you are (my other draft adam PR). It also requires synchronization, because you need to process the entire tensor, determine the max, scale things accordingly, and write to memory.

Having one scale factor per block requires more memory (though the amount should still be neglible, esp. since I assume the block size will increase when we use vector loads here.

@ngc92
Copy link
Contributor Author

ngc92 commented Apr 29, 2024

rebased on the lastest changes from master.
I used #288 to generate a gpt2-large model.
Without this patch, training at batch size 1 requires 12658MiB
with the fp16 buffers, this goes down to 9892MiB

Sadly, it's not enough to allow me to test the gpt2-xl on my 16GB card, even with batch size one.

@ngc92 ngc92 force-pushed the fp16-adam branch 2 times, most recently from 3252b88 to 1d96fe5 Compare April 29, 2024 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants