Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pure bf16 training] w/ AnyPrecisionAdamW and Kahan summation #21312

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Jan 26, 2023

This PR was prompted by this discussion with @lessw2020.

The PR works, just keeping it as Draft for now as I haven't polished it to be ready for merging.

How to perform pure bf16 training (not mixed) running with AnyPrecisionAdamW also in bf16 w/ Kahan summation

I think it should require x8 bytes per param, instead of x18 for mixed precision training - i.e. 1/2 memory usage for everything but activations memory.

(also included a hack into loading load_from_disk to get saved datasets, but it's unrelated to the actual feature - will remove at the end)

To test checkout this branch:

git clone https://github.com/huggingface/transformers transformers-bf16
cd transformers-bf16
git checkout full-bf16-train

getting AnyPrecisionAdamW

You can try to install the bleed edge torchdistx but it's very difficult to do. Since the optimizer is just python code, we just hack-install it doing just this:

mkdir -p $CONDA_PREFIX/lib/python3.8/site-packages/torchdistx/optimizers
wget https://raw.githubusercontent.com/pytorch/torchdistx/main/src/python/torchdistx/optimizers/anyprecision_optimizer.py \
-O $CONDA_PREFIX/lib/python3.8/site-packages/torchdistx/optimizers/__init__.py

you will just need to update your destination path if you're not using CONDA or have a different python version. To be more specific adjust the location of your python's site-packages directory.

Training

If you have an 80GB A100, you can do opt-1.3b setup below, otherwise for smaller cards choose one of the smaller setups.

You can of course do this for any model, this PR is model invariant.

And you can do either finetuning or training from scratch

opt-1.3b / bf16-pure training from scratch

First, prep an initialized opt-1.3 model:


cat << EOT > prep-bf16.py
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch

mname = "facebook/opt-1.3b"

config = AutoConfig.from_pretrained(mname)
model = AutoModel.from_config(config, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(mname)

path = "opt-1.3b-bf16"

model.save_pretrained(path)
tokenizer.save_pretrained(path)
EOT

python prep-bf16.py

Train from scratch:

rm -rf save_dir;  PYTHONPATH="src" python -m torch.distributed.run \
--nproc_per_node=1 --nnode=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=9901 \
examples/pytorch/language-modeling/run_clm.py --bf16 \
--half_precision_backend no_amp --seed 42 --model_name_or_path opt-1.3b-bf16 \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 --optim \
adamw_anyprecision --optim_args \
'use_kahan_summation=true, momentum_dtype=bfloat16, variance_dtype=bfloat16, compensation_buffer_dtype=bfloat16' \
--per_device_train_batch_size 12 --per_device_eval_batch_size 12 \
--gradient_accumulation_steps 1 --do_train --do_eval --logging_steps 10 \
--save_steps 1000 --eval_steps 100 --weight_decay 0.1 --num_train_epochs 1 \
--adam_beta1 0.9 --adam_beta2 0.95 --learning_rate 0.0002 --lr_scheduler_type \
linear --warmup_steps 500 --report_to tensorboard --output_dir save_dir

Let's check that I got the math right for opt-1.3B

Theoretical memory allocation for optim states, weights, grads

breakdown:       n_params*(optim + grad + weights)
bf16 mixed precision: 1.3*(8     +   2  +  4+2   ) = 1.3*16 = 20.8GB
bf16 pure:            1.3*(4+2   +   2  +    2   ) = 1.3*10 = 13.0GB
-----------------------------------------------------
diff:                                                          7.8GB

Real memory allocation: (got by adding --skip_memory_metrics 0 flag to get memory usage reports)

a. bf16 mixed precision:
  before_init_mem_gpu        =        0MB
  init_mem_gpu_alloc_delta   =     5019MB
  init_mem_gpu_peaked_delta  =        0MB
  train_mem_gpu_alloc_delta  =    20076MB
  train_mem_gpu_peaked_delta =      123MB
-----------------------------------------
  total                      =    25218MB             

b. bf16 pure:
  before_init_mem_gpu        =        0MB
  init_mem_gpu_alloc_delta   =     5019MB
  init_mem_gpu_peaked_delta  =        0MB
  train_mem_gpu_alloc_delta  =    12548MB
  train_mem_gpu_peaked_delta =      124MB
-----------------------------------------
  total                      =    17691MB             


diff: 7.53GB

So the theoretical and actual numbers check out memory wise.

opt-125m / bf16-pure training from scratch

If you want to fit into a smaller card, let's do opt-125m

Then prep an empty opt-125m model:


cat << EOT > prep-bf16.py
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch

mname = "facebook/opt-125m"

config = AutoConfig.from_pretrained(mname)
model = AutoModel.from_config(config, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(mname)

path = "opt-125m-bf16"

model.save_pretrained(path)
tokenizer.save_pretrained(path)
EOT

python prep-bf16.py

Train from scratch in pure bf16:

rm -rf save_dir;  PYTHONPATH="src" python -m torch.distributed.run \
--nproc_per_node=1 --nnode=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=9901 \
examples/pytorch/language-modeling/run_clm.py --bf16 \
--half_precision_backend no_amp --seed 42 --model_name_or_path opt-125m-bf16 \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 --optim \
adamw_anyprecision --optim_args \
'use_kahan_summation=true, momentum_dtype=bfloat16, variance_dtype=bfloat16, compensation_buffer_dtype=bfloat16' \
--per_device_train_batch_size 12 --per_device_eval_batch_size 12 \
--gradient_accumulation_steps 1 --do_train --do_eval --logging_steps 10 \
--save_steps 1000 --eval_steps 100 --weight_decay 0.1 --num_train_epochs 1 \
--adam_beta1 0.9 --adam_beta2 0.95 --learning_rate 0.0002 --lr_scheduler_type \
linear --warmup_steps 500 --report_to tensorboard --output_dir save_dir

opt-125m / fp16-amp training from scratch

Same for mixed precision fp16 (we want bf16 to give us a similar loss curve when everything else is the same):


cat << EOT > prep-fp16.py
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch

mname = "facebook/opt-125m"

config = AutoConfig.from_pretrained(mname)
model = AutoModel.from_config(config, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(mname)

path = "opt-125m-fp16"

model.save_pretrained(path)
tokenizer.save_pretrained(path)
EOT

python prep-fp16.py
rm -rf save_dir;  PYTHONPATH="src" python -m torch.distributed.run \
--nproc_per_node=1 --nnode=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=9901 \
examples/pytorch/language-modeling/run_clm.py --ff16 \
--seed 42 --model_name_or_path opt-125m-fp16 \
--dataset_name wikitext --dataset_config_name wikitext-103-raw-v1 \
--per_device_train_batch_size 12 --per_device_eval_batch_size 12 \
--gradient_accumulation_steps 1 --do_train --do_eval --logging_steps 10 \
--save_steps 1000 --eval_steps 100 --weight_decay 0.1 --num_train_epochs 1 \
--adam_beta1 0.9 --adam_beta2 0.95 --learning_rate 0.0002 --lr_scheduler_type \
linear --warmup_steps 500 --report_to tensorboard --output_dir save_dir

@stas00 stas00 changed the title [pure bf16 training] AnyPrecisionAdamW [pure bf16 training] w/ AnyPrecisionAdamW Jan 26, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@stas00 stas00 changed the title [pure bf16 training] w/ AnyPrecisionAdamW [pure bf16 training] w/ AnyPrecisionAdamW and Kahan summation Feb 2, 2023
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Feb 27, 2023
@huggingface huggingface deleted a comment from github-actions bot Feb 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants