-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] memory leak under zero.Init
#2637
Comments
I'm pretty sure this other leak was unrelated. As the one I fixed was only a temp leak, on
On every iteration the memory was growing, and it only happened with |
Thanks @stas00 . I am attempting to repro this locally. Could you pass along the ds_config being used and the |
Thank you for trying to reproduce this, Joe If I had a small repro script I probably would have found the problem, but, alas, it appeared to be hidden somewhere in the ensemble of things. ds_config is just the staple defaults of zero3 + I think the issue was coming from nested As we overcame the leakage issue with rewriting several major parts of the code base to completely remove |
Sounds good. If anything similar happens again feel free reopen the issue and I will help investigate. |
it happened again. T_T |
Hello @stas00 , I'm glad I found this issue as I've been struggling to debug memory issues while loading models using huggingface transformers. args = {
'pretrained_model_name_or_path': model_path,
'config': config,
'torch_dtype': "auto",
}
model = model_class.from_pretrained(**args)
gpu_memory_plot_helper(device, "after initializing model") by using
I used below deepspeed configuration with huggingface accelerate with
However the memory profiler gave me following stats.
It return similar results for both 1 gpu and 8 gpu setting. However it does not match 931MB from pytorch summary and 1.17GB from deepspeed memory estimator for GPU RAM memory. +) And i checked it does not allocate any gpu memory if i did not use zero3 init flag for accelerate. but to my best knowledge, zero init is for CPU memory optimization during cpu offloading... and Note that, i didnt even use Following python function is what i used for profile and i used this right after loading model parameter with from_pretrained() . def gpu_memory_plot_helper(
rank,
device,
message: str,
):
gc.collect()
torch.cuda.synchronize(device)
allocated = torch.cuda.memory_allocated(device) / (1024**2)
max_allocated = torch.cuda.max_memory_allocated(device) / (1024**2)
reserved = torch.cuda.memory_reserved(device) / (1024**2)
max_reserved = torch.cuda. max_memory_reserved(device) / (1024**2)
vm_stats = psutil.virtual_memory()
used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
logger.info(
f'''
{message}
rank: {rank} / device: {device}
CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%
Allocated / Reserved: {allocated:.2f}MB / {reserved:.2f}MB
Max Allocated / Max Reserved: {max_allocated:.2f}MB / {max_reserved:.2f}MB
summary
{torch.cuda.memory_summary(device, abbreviated=True)}
'''
)
torch.cuda.reset_peak_memory_stats(device) I think it is related to the nested usage of I would very much appreciate if you answer me! Best regards. |
I have no idea how to give a confirmation from what you have shared, it's a lot of information but a lot of it is irrelevant. I think it's also more difficult to do any such evals when you offload to CPU as your config shows. Much simpler to use GPUs straight and then you get a single measurement. Perhaps instead of using the profiler try to pass The memory estimators only give a suggestion of params and grads and optim states. Other allocations could be quite significant. But if it's doing the right thing and the model is large enough you should definitely see a significant difference in memory usage between 1 and 8 gpus, regardless if you use zero.Init or not - the usage of the latter only is important if you can't load the whole model on a single GPU. The model will get sharded regardless by the time deepspeed got initialized. Here is a practical example. Let's take t5-3b and do a translation example from HF transformers
now edit
now run for 1 gpu:
now change to gpu memory usage stats via
8 gpus (the stats is per each gpu):
(I filtered out other irrelevant stats) you can clearly see that when 8 gpus are used ~1/8 of memory is used for the first two stats. In this little program with a gazillion of args we are using a tiny batch size and a tiny seqlen, so pretty much all memory is non-activation-related. note: I'm using a 8x A100 80GB node - you might need to use t5-large or t5-base if your gpu memory is smaller. So now you should be able to repeat this for your use case. Once proven that the code works (or leaks) without offload only then try with offload. Hope this helps. |
Thank you for your kind reply @stas00 . First of all, I apologize for my English and my lack of explanation of the problem. I should have provided the information more clearly. I didnt mean that using more GPUs does not bring memory improvement. My question was: And I finally realized that this is not due to memory leaks in Deep Speed or Acceleration. This is due to this behavior of GPT-2 class. You know, if i partition the model with deepspeed, paramters should be partitioned to cpu or each devices (according to the deepspeed configuration) and all That's why I asked you For now, I've decided to manage this attention mask globally, or hacked the source code with xformers (it does not require attention mask but LowerTriangularMask). Thank you so much for your response. |
I'm glad to hear you have tracked down the source of gpu memory consumption and that it's not the framework, @SeunghyunSEO! Thank you for the detailed notes on what you have discovered. |
Describe the bug
only when activating
zero.Init
the code leaks a lot per training iteration.To Reproduce
I'm yet to be able to reduce this to a simple test. I shared with Tunji how to reproduce it in the large framework.
However we found the source of the leak and rewrote the module that was leaking and the leak was gone.
So unless you can see something that points to a bug in deepspeed, this is a post for posterity and can be closed.
Here is the original module that was leaking:
I think it's the concatenations of the 2 parts of the linear somehow lead to the leak. The 2 parts were needed in order to make part of the linear layer frozen.
I rewrote it as following to deal with each part of the linear separately and the leak disappeared:
as you can see I reworked it to remove
torch.cat
calls and the leak is gone. It's also a much more efficient code.I still don't understand why would
zero.Init
be the trigger for such leak, as zero3 w/ozero.Init
or zero2 did not leak.While I have a more efficient solution and which doesn't leak I thought I would report this here for awareness and perhaps you can also see something obvious here that I'm missing.
It's also possible the the leak trigger was somehow in how this module was used.
@tjruwase
The text was updated successfully, but these errors were encountered: