-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory leak when using CLIPTextModel #31439
Comments
Hi @minsuk00 |
@younesbelkada -cc @muellerz Additionally, calling |
cc @muellerzr regarding the accelerate behaviour. Regarding |
Hey there @minsuk00, Hope you are doing well. (Redirected from #33345 to here). I primarily observe two concerns raised regarding the VRAM usages in here. Let's clear these one by one. Small VRAM Usage Spike on 2nd GPU/Non-primary GPUsWhen loading a model, PyTorch initializes a CUDA context for managing resources like device memory and kernel execution. This setup includes detecting which GPUs are available. By default, CUDA might allocate a small amount of memory on all available GPUs, not just the one we explicitly specified for model to be deployed on. This can cause a minor memory spike on non-primary GPUs. clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:1") The code above places the model on GPU-1. However, without specific settings, CUDA initializes contexts on all GPUs, leading to small memory footprints on each. To ensure that only GPU-1 is utilized, and to prevent CUDA from initializing on other GPUs, you can set the import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # Only GPU-1 is recognized This command restricts CUDA to see and use only the GPU specified by us and avoides unnecessary memory allocation on others. 🤗 Tip: Always define import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
# rest of the code Residual VRAM Usage After Freeing Up MemoryWhen we use PyTorch with CUDA to load a model, the framework initializes a CUDA context. The CUDA could be thought of as a data structure that manages crucial GPU resources such as device memory allocation, kernel execution, and synchronization mechanisms. So previously, to make an attempt to clear the VRAM entirely, we executed the following code: del clip_text_model
gc.collect()
torch.cuda.empty_cache() This frees up most of the VRAM used by the model, but running this does not dismantle the CUDA context or other minor memory allocations made during the model's initialization. As a result, we still observe some residual VRAM usage. The CUDA context, as a persistent entity, maintains the state and resources required for efficient GPU operations across the lifecycle of the program, ensuring the overhead of initializing these resources is not repeatedly incurred. The good news is that residual VRAM usage is generally consistent across different models, regardless of their type or size. For example: import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import gc
import torch
from transformers import CLIPTextModel
clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:1")
del clip_text_model
gc.collect()
torch.cuda.empty_cache() And similarly: import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import gc
import torch
from transformers import AutoModelForCausalLM
checkpoint = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(checkpoint,
device_map="auto",
torch_dtype=torch.float16,
token=hugging_face_token)
del model
gc.collect()
torch.cuda.empty_cache() 🌟 Important point to note: Therefore, loading multiple models does not cumulatively increase residual VRAM usage beyond what is necessary for the CUDA context and other minimal allocations. After using one model, you can delete it and load another as needed without concern for having extra VRAM overhead thought to be carried by another model load. If you are looking to mitigate even this small residual VRAM usage, I have written the next section discussing techniques to do the same. But if you want to continue with loading another model and continue with your work after loading the first model, i have briefly highlighted the guidelines in the TLDR & Ending Notes section at the end. Mitigating Residual VRAM Usage(1) Restart the IPy-Kernel:Restarting the Jupyter Notebook kernel clears all defined variables and the CUDA context, effectively freeing up all GPU memory used during the session. This is a straightforward method to reset the memory state of the GPU after the model operations: (2) Using Terminal to Run the Script:Running your scripts directly through a terminal session instead of a notebook can help manage memory more effectively. Each time the script completes, the Python process terminates, which clears all memory allocations including the CUDA context: Here’s how to set up and run your script: %%writefile model_load_script.py
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
from transformers import CLIPTextModel
clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:1")
# Operations are performed here
# Once completed, clean up the model and VRAM
del clip_text_model
gc.collect()
torch.cuda.empty_cache() To execute the script, run the following command in your terminal: python model_load_script.py (3) Using
|
Hey @minsuk00, does the answer above fixes your issue ? |
@SunMarc @nnilayy 1.For the first problem with the error of non-primary GPU VRAM spike, setting the clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda:0") Data is correctly allocated at GPU 1 as a result.
2.I confirmed that indeed, the residual VRAM after freeing the memory is shared among multiple models. Thanks for the help! Really appreciate it. |
Yes, you should do If your issue is fixed, feel free to close the PR. Thanks ! |
System Info
transformers
version: 4.26.1Who can help?
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I can't free GPU memory after I use CLIPTextModel
Also, memory is allocated in another device for some reason
problem should be reproduced by using the following code snippet
Expected behavior
I've also tried using garbage collection and explicitly moving model to cpu, but they don't work.
The text was updated successfully, but these errors were encountered: