Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA OOM when loading large models #99

Closed
Tianwei-She opened this issue Nov 16, 2022 · 6 comments
Closed

CUDA OOM when loading large models #99

Tianwei-She opened this issue Nov 16, 2022 · 6 comments

Comments

@Tianwei-She
Copy link

I'm trying out deepspeed-mii on a local machine (8 GPU with 23GB VRAM each). Smaller models like bloom-560m and EleutherAI/gpt-neo-2.7B worked well. However, I got CUDA OOM errors when loading larger models, like bloom-7b1. For some even larger models like EleutherAI/gpt-neox-20b, the server just crashed without any specific error messages or logs.

I've tried deepspeed inference before, and it worked fine on these models.

I use this script to deploy models

import mii

mii_configs = {"tensor_parallel": 8, "dtype": "fp16"}
mii.deploy(task='text-generation',
           model="facebook/opt-6.7b",
           deployment_name="facebook/opt-6.7b",
           model_path="/home/ubuntu/.cache/huggingface/hub",
           mii_config=mii_configs)

Is there something I should change to my deployment script?

Thanks!

@mrwyattii
Copy link
Contributor

mrwyattii commented Nov 21, 2022

Hi @Tianwei-She thanks for using MII. It looks like you're seeing this problem because we try to load the model in fp32 onto each GPU before converting it to fp16 here.

In general, MII is not the most efficient with GPU memory when running multi-GPU, because:

  • we load the model onto each GPU before distributing it with DeepSpeed-Inference
  • we always load the model with the default dtype (which is typically fp32)

Here's a PR to address some of these inefficiencies by loading with the user-specified dtype and allowing the user to use system memory to load the model before distributing the model across GPUs. Please give #105 a try and let me know if that fixes your problem:
pip install git+https://github.com/microsoft/deepspeed-mii@mrwyattii/address-poor-vram-usage

The script you shared should work with these changes, but if it doesn't try adding "load_with_sys_mem": True to your mii_configs.

Note: Unfortunately, we will still need to load the entire model tensor_parallel times (either one copy on each GPU, or all on system memory). We are working on addressing this issue, but I don't have a fix right now.

@mrwyattii
Copy link
Contributor

Closing due to inactivity and #105 has been merged, please reopen if you are seeing the same error with the latest DeepSpeed-MII.

@wangshankun
Copy link

@mrwyattii Same error in V0.0.4
Config:
mii_config = {"tensor_parallel": 1, 'dtype': 'fp16', "load_with_sys_mem": True} name = "facebook/opt-30b" ds_config = { "fp16": { "enabled": True }, "bf16": { "enabled": False }, "zero_optimization": { "stage": 3, "offload_param": { "device": "cpu", "pin_memory": False, }, }, "train_micro_batch_size_per_gpu": 1, }
image

@mrwyattii
Copy link
Contributor

@wangshankun what kind of GPU are you trying to run on? The OPT-30b model is ~60GB in size. From the screenshot you shared, it looks like you only have 22GB of GPU memory available and will not be able to run a model this large:
image

@wangshankun
Copy link

@mrwyattii
Did I misunderstand the meaning of CPU offload in Zero3?

It is precisely because the GPU memory is insufficient that I want to place the model in the host, which is why I configured CPU offload and load_with_sys_mem.

@moussaba
Copy link

Same question...Does Zero offload actually work in mii, I have been having a lot of difficulties trying to get DeepSpeed-MII to do any soft of cpu or nvme offload.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants