-
Notifications
You must be signed in to change notification settings - Fork 69
Training speed ups: +10-15% tokens/second/gpu #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
maybe attach a training curve / eval perplexity for some small scale experiments (e.g., 87m params, 1.7B tokens), just to make sure there is no degradation in performance? |
Some results on the 160m param scale (this is with the earlier initialization from mosaic and bias in the mlps). At least low precision layer norm and fused cross-entropy have basically no effect on train loss / val ppl, while giving us a good speedup. Unclear if gelu and float16 reduce are affecting convergence, we will test at 1b next. |
This PR provides 3 changes to improve the tokens/second/gpu for open_lm. Each change is guarded by a flag that is off by default, since these changes likely affect model convergence and downstream accuracy. We can enable them once we train large enough models with these changes.
Below, I'm quoting rough improvements in tokens/second/gpu for a 7b model, batch size 16, 2048 token length on a machine with 8 80gb A100s.
These tests were done with
xformers==0.0.19
andtorch==2.0.0+cu117
. Further upgrading toxformers==0.0.22
(+200t/s/g) andtorch==2.0.1+cu118
(+100t/s/g) gets us to about 4000 tokens/second/gpu on one node.Next steps: Multi-node. With the above changes, our 2 node performance is at ~3700 t/s/g, but this goes up to ~3840 t/s/g when we double the batch size to 32. Can we close the gap to 4k? Might be hard since we don't have the same EFA as mosaic reports here: https://github.com/mosaicml/llm-foundry/tree/main/scripts/train/benchmarking#a100-80gb-with-1600-gbps-node-node-interconnect-roce
Closes #22 #27