Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Peft Training with Zero.Init() and Zero3 will increase GPU memory every forward step #3002

Closed
dumpmemory opened this issue Mar 13, 2023 · 11 comments
Assignees
Labels
bug Something isn't working training

Comments

@dumpmemory
Copy link

dumpmemory commented Mar 13, 2023

Describe the bug
when i using Peft LoRA to train a gpt2 model, the gpu memory increase with every forward step with Zero3 adn zero.init function. when i disable zero.init, it worked as normal.

To Reproduce

  1. the all details can be found in GPT2 Training GPU Memory Increase with LoRA and Zero 3 huggingface/peft#161

Expected behavior
run with no gpu memory increasing
ds_report output
Please run ds_report to give us details about your setup.

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/miniconda/lib/python3.8/site-packages/torch']
torch version .................... 1.12.1
deepspeed install path ........... ['/opt/miniconda/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.8.2, unknown, unknown
torch cuda version ............... 11.3
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.3

Screenshots
If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • GPU count and types 8x3080 ti
@dumpmemory dumpmemory added bug Something isn't working training labels Mar 13, 2023
@dumpmemory
Copy link
Author

I have also try the tohtana/nested_zero_init branch, which did not fix it.

@tohtana
Copy link
Contributor

tohtana commented Apr 22, 2023

@dumpmemory
I found that Zero3's all-gathered parameters are not freed for LoRA Linear modules.
The following fix prevented the memory leak in my environment. Can you try this?

$ git diff
diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py
index 1d1680d..97f0a4e 100644
--- a/src/peft/tuners/lora.py
+++ b/src/peft/tuners/lora.py
@@ -484,7 +484,7 @@ class Linear(nn.Linear, LoraLayer):
                 self.unmerge()
             result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
         elif self.r[self.active_adapter] > 0 and not self.merged:
-            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
+            result = torch.matmul(x, transpose(self.weight, not self.fan_in_fan_out)) + self.bias

             x = x.to(self.lora_A[self.active_adapter].weight.dtype)

Although Zero3 sets an empty tensor to self.weight.data, PyTorch does not free the memory in the original code.
The reference to the buffer for all-gathered parameters might be alive, but I couldn't write a simple repro using only PyTorch.

@dumpmemory
Copy link
Author

@dumpmemory I found that Zero3's all-gathered parameters are not freed for LoRA Linear modules. The following fix prevented the memory leak in my environment. Can you try this?

$ git diff
diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py
index 1d1680d..97f0a4e 100644
--- a/src/peft/tuners/lora.py
+++ b/src/peft/tuners/lora.py
@@ -484,7 +484,7 @@ class Linear(nn.Linear, LoraLayer):
                 self.unmerge()
             result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
         elif self.r[self.active_adapter] > 0 and not self.merged:
-            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
+            result = torch.matmul(x, transpose(self.weight, not self.fan_in_fan_out)) + self.bias

             x = x.to(self.lora_A[self.active_adapter].weight.dtype)

Although Zero3 sets an empty tensor to self.weight.data, PyTorch does not free the memory in the original code. The reference to the buffer for all-gathered parameters might be alive, but I couldn't write a simple repro using only PyTorch.

I will test this, thanks for your help. I will update result later

@dumpmemory
Copy link
Author

It wokred ! With peft commit 10a2a6db5dc9cabb63a36c0fb489aeb2b9a1e433 and modification above , deepspeed 0.9.1 and torch 2.0. Thanks for your help.

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/miniconda/lib/python3.8/site-packages/torch']
torch version .................... 2.0.0
deepspeed install path ........... ['/opt/miniconda/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.9.1, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.7

@dumpmemory
Copy link
Author

I will try to make a pr following your idea on peft . Thanks again.

@tohtana
Copy link
Contributor

tohtana commented Apr 30, 2023

@dumpmemory This memory leak can also be fixed by setting memory_efficient_linear to false in the configuration of zero.

By default, DeepSpeed replaces PyTorch's linear with a different implementation. This might cause the memory leak. I will investigate what the memory_efficient_linear does.

@dumpmemory
Copy link
Author

@dumpmemory This memory leak can also be fixed by setting memory_efficient_linear to false in the configuration of zero.

By default, DeepSpeed replaces PyTorch's linear with a different implementation. This might cause the memory leak. I will investigate what the memory_efficient_linear does.

Thanks for your work !

@tjruwase
Copy link
Contributor

tjruwase commented May 1, 2023

@dumpmemory, can you please try PR #3413 created by @tohtana? Thanks!

@dumpmemory
Copy link
Author

@dumpmemory, can you please try PR #3413 created by @tohtana? Thanks!

Yes i can. Can i test it after my holiday ? Thanks

@tjruwase
Copy link
Contributor

tjruwase commented May 2, 2023

@dumpmemory, of course! By the way, the PR is merged so you can use the master branch when you are ready.

Happy holidays to you! Thanks for your help.

@dumpmemory
Copy link
Author

It worked with peft(commit 10a2a6db5dc9cabb63a36c0fb489aeb2b9a1e433 ) and peft 3.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

3 participants