[pure bf16 training] w/ AnyPrecisionAdamW
and Kahan summation
#21312
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR was prompted by this discussion with @lessw2020.
The PR works, just keeping it as Draft for now as I haven't polished it to be ready for merging.
How to perform pure bf16 training (not mixed) running with
AnyPrecisionAdamW
also in bf16 w/ Kahan summationI think it should require x8 bytes per param, instead of x18 for mixed precision training - i.e. 1/2 memory usage for everything but activations memory.
(also included a hack into loading
load_from_disk
to get saved datasets, but it's unrelated to the actual feature - will remove at the end)To test checkout this branch:
getting
AnyPrecisionAdamW
You can try to install the bleed edge
torchdistx
but it's very difficult to do. Since the optimizer is just python code, we just hack-install it doing just this:you will just need to update your destination path if you're not using CONDA or have a different python version. To be more specific adjust the location of your python's
site-packages
directory.Training
If you have an 80GB A100, you can do
opt-1.3b
setup below, otherwise for smaller cards choose one of the smaller setups.You can of course do this for any model, this PR is model invariant.
And you can do either finetuning or training from scratch
opt-1.3b / bf16-pure training from scratch
First, prep an initialized opt-1.3 model:
Train from scratch:
Let's check that I got the math right for opt-1.3B
Theoretical memory allocation for optim states, weights, grads
Real memory allocation: (got by adding
--skip_memory_metrics 0
flag to get memory usage reports)So the theoretical and actual numbers check out memory wise.
opt-125m / bf16-pure training from scratch
If you want to fit into a smaller card, let's do opt-125m
Then prep an empty opt-125m model:
Train from scratch in pure bf16:
opt-125m / fp16-amp training from scratch
Same for mixed precision fp16 (we want bf16 to give us a similar loss curve when everything else is the same):