Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] new in pytorch: torch.optim._multi_tensor faster optimizers #9965

Open
stas00 opened this issue Feb 2, 2021 · 7 comments
Open
Assignees
Labels
Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Feb 2, 2021

Back in September pytorch introduced torch.optim._multi_tensor pytorch/pytorch#43507 which should be much more efficient for situations with lots of small feature tensors (transformers) and thus should show an appreciable speed up in training. If someone is interested in the progress of this project here is the stack to track: pytorch/pytorch#48223

This feature is currently an alpha stage, so users can try to use it by simply replacing torch.optim with torch.optim._multi_tensor in HF Trainer or their own trainer.

Eventually it'll replace torch.optim so there is nothing that we need to do otherwise.

@blefaudeux who alerted me to this improvement suggested it should have good speed ups for the DDP/Sharded DDP training.

If resources allow it'd be good to run some benchmarks. Please feel free to beat me to it.

Thanks to @blefaudeux for the heads up, and @izdeby for working on this enhancement and clarifying where things are at.

heads up to: @sgugger, @patrickvonplaten - nothing else that needs to be done.

@stas00
Copy link
Contributor Author

stas00 commented Feb 2, 2021

I did a quick benchmark, with --sharded_ddp --fp16 and just --fp16 and there is no visible difference . Perhaps it is more visible in a different kind of training/model combination.

Testing HF AdamW vs. torch.optim._multi_tensor.AdamW

# benchmark with just  --fp16

# baseline HF `AdamW`

export BS=16; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_train --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS  --sortish_sampler --task translation_en_to_ro --warmup_steps 500 --n_train 20000  --fp16

{'train_runtime': 226.5618, 'train_samples_per_second': 2.759, 'epoch': 1.0}

# w/ torch.optim._multi_tensor.AdamW

export BS=16; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_train --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_train_batch_size $BS  --sortish_sampler --task translation_en_to_ro --warmup_steps 500 --n_train 20000  --fp16

{'train_runtime': 226.1715, 'train_samples_per_second': 2.763, 'epoch': 1.0}

The change I did was:

--- a/examples/seq2seq/seq2seq_trainer.py
+++ b/examples/seq2seq/seq2seq_trainer.py
@@ -24,7 +24,6 @@ from transformers.integrations import is_fairscale_available
 from transformers.models.fsmt.configuration_fsmt import FSMTConfig
 from transformers.optimization import (
     Adafactor,
-    AdamW,
     get_constant_schedule,
     get_constant_schedule_with_warmup,
     get_cosine_schedule_with_warmup,
@@ -32,6 +31,7 @@ from transformers.optimization import (
     get_linear_schedule_with_warmup,
     get_polynomial_decay_schedule_with_warmup,
 )
+from torch.optim._multi_tensor import AdamW
 from transformers.trainer_pt_utils import get_tpu_sampler
 from transformers.training_args import ParallelMode

and this is from pytorch-nightly from today.

@blefaudeux
Copy link

you must have a really strange bottleneck in that test, neither the latest fairscale nor these are changing anything ? These optimizers are measurably faster in isolation, and sure enough we see a difference in fairscale CI, even on a dummy job / small model (see for instance, two last jobs)

@blefaudeux
Copy link

testing with the same command, I see a vastly varying throughput depending on num_train_epochs, which seems a bit strange to me

@stas00
Copy link
Contributor Author

stas00 commented Feb 2, 2021

To share with others, @blefaudeux and his team made speed improvements in fairscale (master) recently, which should have been quite visible, but a few days ago we tested this same script with --sharded_ddp and saw no improvement whatsoever. So something odd is going on.

@stas00
Copy link
Contributor Author

stas00 commented Feb 4, 2021

I will leave this issue open for now as an incentive to profile this script and identify the bottleneck.

@stas00 stas00 added Feature request Request for a new feature WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress and removed Feature request Request for a new feature labels Mar 6, 2021
@jaketae
Copy link
Contributor

jaketae commented Jan 15, 2022

@stas00 Do you think this should be revisited given the discussion in upstream PyTorch?

@stas00
Copy link
Contributor Author

stas00 commented Jan 15, 2022

Yes, I was just about to revisit it.

edit: I thought you might have wanted to work on that, but the pytorch team asks to run a profiler on it and all, so I probably will look into testing it out again.

--- original comment ---

Do you want to take a lead on this experiment, @jaketae?

The new --optim HF Trainer just got merged, so you can quickly implement --optim adamw_torch_multi_tensor in the same way --optim adamw

You can use this tool for benchmarking #14934 if it helps. I think it's pretty stable now, I will propose to PR it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

3 participants