-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Benchmark] HF Trainer on A100 #15026
Comments
precision: fp16 vs bf16 vs tf32 vs fp32Main interest: benchmarking the new --bf16 and --tf32 on Ampere/RTX-3090, comparatively to fp16 and fp32 modes.
BenchmarkThe benchmark uses 3 different t5 models, and at the end of the section also gpt2. For t5 the main script is:
and now adding one of:
But we are going to use a special benchmarking tool that will do all the work for us. #14934 Important notes:
*** Setup:
Benchmark 1: t5-small
Conclusions:
Benchmark 2: t5-base
Conclusions:
Benchmark 3: t5-large
Conclusions:
If I use a higher bs=16 instead of 8, bf16 does deliver a better than fp32 performance, but still not on par with fp16:
It'd be great to know why CUDA doesn't activate some optimization since not everybody is going to run benchmarks, but if you do run benchmarks and find yourself in this situation Eddie Yan proposed that adding So let's try
Both bf16 and tf32 show a much better performance here. Benchmark 4: gpt2
Conclusions:
Benchmark 5: gpt2-medium
Conclusions:
|
gradient accumulation stepsLet's choose Let's measure
Let's filter out just one subset so that it's easier to compare the gradient accumulation differences alone, so re-running with just bf16 enabled (
Conclusions:
|
batch size
Conclusions:
|
gradient checkpointing
Conclusions:
Let's look at memory:
We can clearly see that peak GPU memory is ~2/3 less. note: I had to half BS in the 2nd benchmark as I was getting OOM.
|
optimizersLet's do fp32 first:
Observations:
fp16:
bf16:
Observations:
|
combining winning strategiesNow let's combine the winning strategies from each individual benchmark above and compare with the baseline:
Getting an almost 3x improvement in speed!
|
RTX-3090 vs A100In all the benchmarks above I was making the batch size bigger and run more samples comparative to the same RTX-3090 benchmarks as A100 40GB card can handle more than RTX-3090 24GB, but let's compare now the 2 using the same config. So we will have RTX-3090 fully loaded, but A100 will be only partially loaded. Also each card is running on a different machines so there is a bit of hardware difference as well. A100
RTX-3090
Observations:
Same software was used for both setups:
I thought that perhaps this had to do with bf16, so I re-did the same with A100
RTX-3090
Still not good for A100. Let's try w/o tf32: A100
RTX-3090
This is better for A100. So tf32 was making things worse here for some reason. Eddie Yan explained the reason for RTX-3090 being faster:
|
@stas00 Another interesting issue that I found regarding batch size is that it is an important parameter when the model is mostly in fp32 and relies on autocast to dispatch to fp16 or bf16. I believe this is because of the overhead of casting back-and-forth can dominate the total runtime compared to the actual kernel/operator. Consider the following microbenchmark:
I get the following times on A6000 (similar architecture to 3090):
|
Thank you, @eqy. update: edited out the original note on casting back, since the explicit casting is not being measured Added a nicely formatted table output so it's much easier to analyze. Updated script attached: bench.txt On RTX-3090 I get: Autocast: True
Speedup:
Autocast: False
Speedup:
|
I believe there are "speed-of-light" cases where the cast-back wouldn't be necessary, though this may not be possible for the architectures we're interested in. Here, I think the big picture is that once the batch-size falls below a certain amount, the "building-block" operations like GEMMs will be slower in reduced precision vs. fp32 when casts are needed. |
why do you think bs=32 is an oddball relative to other bs for speedup? in both cases w/ and w/o amp its relatively faster for bf16 and fp16 then bs=64, and much more significantly for fp16. One would expect 8 < 16 < 32 < 64, but here it is 8 < 16 < 64< 32. so actual results are proportionally in line, but the speed ups aren't. |
That's interesting, I didn't see quite so dramatic results on an A100 (80GB), 2 runs: Autocast: TrueResults:
Speedup:
Autocast: FalseResults:
Speedup:
Autocast: TrueResults:
Speedup:
Autocast: FalseResults:
Speedup:
|
Is all benchmarking done on A100 ("NVIDIA_TESLA_A100") single GPU? Can you also include CUDA memory required Vs Data points for training Vs No. of GPU's. |
Yes.
I don't understand your question. |
On how many data points and epochs is it benchmarked on with Single GPU? |
I get error with 4 GPU's, 20 epochs on A100 with 700000 data points
|
It's defined by To your last OOM comment - please let's not derail this Benchmark Issue. If you want to discuss an unrelated question please open a new issue. Best to delete it from here and post in another Issue. Thank you. |
🖥 Benchmarking
transformers
w/ HF Trainer on a single A100 40GBWe are going to use a special benchmarking tool that will do all the work for us. #14934
This is the index post and specific benchmarks are in their own posts below:
Note that each benchmark was run only once, so multiple runs and averaging is probably going to give slightly different results. The purpose here though is to see relative differences roughly and not try to give an exact number.
See also the same benchmarks for RTX-3090
The text was updated successfully, but these errors were encountered: