Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is GPU throughput reasonable? #192

Open
Crispig opened this issue Aug 1, 2022 · 3 comments
Open

Is GPU throughput reasonable? #192

Crispig opened this issue Aug 1, 2022 · 3 comments

Comments

@Crispig
Copy link

Crispig commented Aug 1, 2022

I currently have some tests on Zero3 infinite and have had some problems and would like your help.

Machine configuration: two nodes, each node a piece of A100-PCIE-40GB, RAM 126G (actual operation available 60G), SSD 1TB (Samsung 980)

Benchmark Code:/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/

Model cases tested
HIDDEN_SIZE / NUM_ATTN_HEADS/ NUM_LAYERS/ BATCHSIZE = 4096/16/50/8 (Model size 10B)
GPU memory occupies 13395/40537MB
RAM occupancy 109/126G, (60G at idle)
80G of swap files stored in nvme file system
Effective Tera Flops per GPU is 1.5TFLPOS

Question
Whether the GPU throughput achieved under the current environment configuration is reasonable, and whether the throughput can be increased by increasing the batch size or other configurations
Effective Tera Flops per GPU calculated in flops_calculator of DeepSpeedExamples is about 1.5 TFLPOS. But deepspeed profile tested FLOPS per GPU is 2.32 GFLOPS.(deepspeed _profile.txt is generated by deepspeed profile and train.log is the information output during training)

deepspeed _profile.txt
train.log

I hope to get your help, thank you very much!

@tjruwase
Copy link
Contributor

tjruwase commented Aug 1, 2022

@Crispig, thanks for your question.

The TFLOPs on 16xA100-40GB is quite low. What is the batch size? 10B model is too small for zero-infinity with nvme offload, given the overheads of parameter partitioning and nvme offload. You should get much better performance with zero-offload.

Some factors to consider in order to understand and improve zero-infinity performance

  1. SSD is likely a bottleneck, you can profile the SSD using this guide.
  2. Only offload optimizer state to CPU/NVMe, but not parameters since there is sufficient GPU memory for that
  3. Increase batch size to improve compute load and efficiency
  4. Disable or reduce activation checkpointing frequency

@Crispig
Copy link
Author

Crispig commented Aug 4, 2022

Thank you very much for your reply!
The batch size I used in my previous test was 8.
I have done the following tests so far:

  1. For a 1.7B size model with a batchsize of 48, without offload or offload to the CPU, the throughput can reach 36TFLOPS. For a 17B model, with only the optimizer status infinite to the SSD, it can run up to 15TFLPOS.
  2. Infinite optimizer status to SSD, without modifying any other configuration options only modify the hidden_size to make the model smaller will result in the following error:
    Traceback (most recent call last): 192.168.189.10: File "pretrain_gpt2.py", line 134, in <module> 192.168.189.10: args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) 192.168.189.10: File "/home/lcy/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py", line 111, in pretrain 192.168.189.10: train_data_iterator, valid_data_iterator) 192.168.189.10: File "/home/lcy/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py", line 545, in train 192.168.189.10: lr_scheduler) 192.168.189.10: File "/home/lcy/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py", line 394, in train_step 192.168.189.10: model.step() 192.168.189.10: File "/home/lcy/DeepSpeed/deepspeed/runtime/engine.py", line 1911, in step 192.168.189.10: self._take_model_step(lr_kwargs) 192.168.189.10: File "/home/lcy/DeepSpeed/deepspeed/runtime/engine.py", line 1812, in _take_model_step 192.168.189.10: self.optimizer.step() 192.168.189.10: File "/home/lcy/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1932, in step 192.168.189.10: self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) 192.168.189.10: File "/home/lcy/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2007, in unscale_and_clip_grads 192.168.189.10: self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) 192.168.189.10: AttributeError: 'NoneType' object has no attribute 'mul_'
  3. In the training process constantly output the following warning, through this issue correction, do not know whether this will affect the performance
    [WARNING] [parameter_offload.py:48:_apply_to_tensors_only] A module has unknown inputs or outputs type (<class 'torch.nn.parameter.Parameter'>) and the tensors embedded in it cannot be detected. The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and output tensors and therefore may not get triggered properly.

@awan-10
Copy link
Contributor

awan-10 commented Aug 31, 2022

Maybe I am too late here but this old Megatron has been deprecated. Can you kindly try the latest code and recipes from here?

https://github.com/microsoft/Megatron-DeepSpeed/tree/main/examples/azure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants