Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] memory leak under zero.Init #2637

Closed
stas00 opened this issue Dec 21, 2022 · 10 comments
Closed

[BUG] memory leak under zero.Init #2637

stas00 opened this issue Dec 21, 2022 · 10 comments
Assignees
Labels
bug Something isn't working training

Comments

@stas00
Copy link
Contributor

stas00 commented Dec 21, 2022

Describe the bug

only when activating zero.Init the code leaks a lot per training iteration.

To Reproduce

I'm yet to be able to reduce this to a simple test. I shared with Tunji how to reproduce it in the large framework.

However we found the source of the leak and rewrote the module that was leaking and the leak was gone.

So unless you can see something that points to a bug in deepspeed, this is a post for posterity and can be closed.

Here is the original module that was leaking:

class DecoupledLinear(nn.Linear):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
    If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        out_additional_features: int = 0,
        bias: bool = True,
        partially_freeze: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        """
        out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
        partially_freeze:
        """
        super().__init__(in_features, out_features, bias, device, dtype)
        self.out_additional_features = out_additional_features
        self.partially_freeze = partially_freeze

        if partially_freeze:
            self.weight.requires_grad_(False)
            if bias:
                self.bias.requires_grad_(False)

        if out_additional_features > 0:
            self.additional_weight = Parameter(
                torch.empty((out_additional_features, in_features), device=device, dtype=dtype)
            )
            if bias:
                self.additional_bias = Parameter(torch.empty(out_additional_features, device=device, dtype=dtype))
            else:
                self.register_parameter("additional_bias", None)
            self.reset_additional_parameters()

    def reset_additional_parameters(self) -> None:
        """Equivalent of the `nn.Linear.reset_parameters` but only for the additional parameters."""
        init.kaiming_uniform_(self.additional_weight, a=math.sqrt(5))
        if self.additional_bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.additional_weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.additional_bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.out_additional_features > 0:
            all_weight = torch.cat((self.weight, self.additional_weight), 0)
            if self.bias is not None:
                all_bias = torch.cat((self.bias, self.additional_bias), 0)
            else:
                all_bias = None
        else:
            all_weight = self.weight
            if self.bias is not None:
                all_bias = self.bias
            else:
                all_bias = None
        return F.linear(input, all_weight, all_bias)

I think it's the concatenations of the 2 parts of the linear somehow lead to the leak. The 2 parts were needed in order to make part of the linear layer frozen.

I rewrote it as following to deal with each part of the linear separately and the leak disappeared:

class DecoupledLinearNew(nn.Linear):
    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
    """
    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters.
    In practise, the regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, then it will create `out_additional_features * in_features` additional parameters that are always trained.
    If `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        out_additional_features: int = 0,
        bias: bool = True,
        partially_freeze: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        """
        out_additional_features: int. Number of additional trainable dimensions. Only makes sense when `partially_freeze=True`.
        partially_freeze: bool. If True, the regular `weight` will be frozen and extra parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
        """
        super().__init__(in_features, out_features, bias, device, dtype)
        self.out_additional_features = out_additional_features
        self.partially_freeze = partially_freeze

        self.in_features = in_features
        self.out_features = out_features

        if partially_freeze:
            self.weight.requires_grad_(False)
            if bias:
                self.bias.requires_grad_(False)

        if out_additional_features > 0:
            self.additional_fc = nn.Linear(
                in_features=in_features,
                out_features=out_additional_features,
                bias=bias,
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = F.linear(input, self.weight, self.bias)

        if self.out_additional_features > 0:
            output_2 = F.linear(input, self.additional_fc.weight, self.additional_fc.bias)
            output = torch.cat((output, output_2), -1)

        return output

as you can see I reworked it to remove torch.cat calls and the leak is gone. It's also a much more efficient code.

I still don't understand why would zero.Init be the trigger for such leak, as zero3 w/o zero.Init or zero2 did not leak.

While I have a more efficient solution and which doesn't leak I thought I would report this here for awareness and perhaps you can also see something obvious here that I'm missing.

It's also possible the the leak trigger was somehow in how this module was used.

@tjruwase

@jomayeri
Copy link
Contributor

Hi @stas00, two questions:

  1. Do believe this PR sufficiently fixes the issue?
  2. How were you able to identify that memory was being leak?

Thanks,
Joe

@stas00
Copy link
Contributor Author

stas00 commented Jan 15, 2023

Hi @stas00, two questions:

1. Do believe this [PR](https://github.com/microsoft/DeepSpeed/pull/2665) sufficiently fixes the issue?

I'm pretty sure this other leak was unrelated. As the one I fixed was only a temp leak, on forward it'd get partitioned again. This one kept on increasing the memory usage with every iteration.

2. How were you able to identify that memory was being leak?

On every iteration the memory was growing, and it only happened with zero.Init enabled.

@jomayeri
Copy link
Contributor

Thanks @stas00 . I am attempting to repro this locally. Could you pass along the ds_config being used and the zero.Init call (a small repro script would work as well).

@stas00
Copy link
Contributor Author

stas00 commented Jan 18, 2023

Thank you for trying to reproduce this, Joe

If I had a small repro script I probably would have found the problem, but, alas, it appeared to be hidden somewhere in the ensemble of things.

ds_config is just the staple defaults of zero3 + zero.Init as integrated into transformer's from_pretrained - no offload or anything.

I think the issue was coming from nested from_pretrained calls which triggered nested zero.Init calls.

As we overcame the leakage issue with rewriting several major parts of the code base to completely remove from_pretrained nesting and me not being able to narrow it down to a simple script, perhaps we just close this issue for now?

@jomayeri
Copy link
Contributor

Sounds good. If anything similar happens again feel free reopen the issue and I will help investigate.

@dumpmemory
Copy link

it happened again. T_T

@SeunghyunSEO
Copy link

SeunghyunSEO commented Nov 24, 2023

Hello @stas00 , I'm glad I found this issue as I've been struggling to debug memory issues while loading models using huggingface transformers.
i've tested to load model using from_pretrained in both 1 gpu and 8 gpu HW setting.

        args = {
            'pretrained_model_name_or_path': model_path,
            'config': config,
            'torch_dtype': "auto",
        }
        model = model_class.from_pretrained(**args)
        gpu_memory_plot_helper(device, "after initializing model")

by using estimate_zero3_model_states_mem_needs_all_cold function, i expected below memory usage for each 1 and 8 gpus setting.

HW: Setup with 1 node, 1 GPU per node.
SW: Model with 6897M total params, 314M largest layer params.
  per CPU  |  per GPU |   Options
  173.43GB |   1.17GB | offload_param=cpu , offload_optimizer=cpu , zero_init=1
  173.43GB |   1.17GB | offload_param=cpu , offload_optimizer=cpu , zero_init=0
  154.16GB |  14.02GB | offload_param=none, offload_optimizer=cpu , zero_init=1
  154.16GB |  14.02GB | offload_param=none, offload_optimizer=cpu , zero_init=0
    1.75GB | 116.79GB | offload_param=none, offload_optimizer=none, zero_init=1
   38.54GB | 116.79GB | offload_param=none, offload_optimizer=none, zero_init=0
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 6897M total params, 314M largest layer params.
  per CPU  |  per GPU |   Options
  173.43GB |   1.17GB | offload_param=cpu , offload_optimizer=cpu , zero_init=1
  308.32GB |   1.17GB | offload_param=cpu , offload_optimizer=cpu , zero_init=0
  154.16GB |   2.78GB | offload_param=none, offload_optimizer=cpu , zero_init=1
  308.32GB |   2.78GB | offload_param=none, offload_optimizer=cpu , zero_init=0
   14.04GB |  15.62GB | offload_param=none, offload_optimizer=none, zero_init=1
  308.32GB |  15.62GB | offload_param=none, offload_optimizer=none, zero_init=0

I used below deepspeed configuration with huggingface accelerate with --zero3_init_flag true option.
to my best knowledge, accelerate call deepspeed internally, so it should consume above expected memory.

{
  "bf16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 3,
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "stage3_gather_16bit_weights_on_model_save": true,
    "memory_efficient_linear": true,
    "allgather_bucket_size": 2e8,
    "reduce_bucket_size": 2e8,

    "stage3_max_live_parameters": 3e7,
    "stage3_prefetch_bucket_size": 3e7,
    "stage3_param_persistence_threshold": 1e4,
    "stage3_max_reuse_distance": 5e8
  },
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  
  "communication_data_type": null,

  "gradient_clipping": 1.0,
  "prescale_gradients": false
}

However the memory profiler gave me following stats.

        after initializing model and tokenizer
        rank: 0 / device: cuda:0
        CPU Virtual Memory:  used = 54.74 GB, percent = 2.7%
        Allocated / Reserved: 931.07MB / 1204.00MB
        Max Allocated / Max Reserved: 931.07MB / 1204.00MB
        summary
        |===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    931 MiB |    931 MiB |  45573 MiB |  44642 MiB |
|---------------------------------------------------------------------------|
| Active memory         |    931 MiB |    931 MiB |  45573 MiB |  44642 MiB |
|---------------------------------------------------------------------------|
| Requested memory      |    931 MiB |    931 MiB |  45573 MiB |  44642 MiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   1204 MiB |   1204 MiB |   3004 MiB |   1800 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory | 279479 KiB | 279479 KiB |  53150 MiB |  52877 MiB |
|---------------------------------------------------------------------------|
| Allocations           |     467    |     468    |    2259    |    1792    |
|---------------------------------------------------------------------------|
| Active allocs         |     468    |     468    |    2259    |    1791    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       4    |       4    |       7    |       3    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       5    |       6    |     911    |     906    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

It return similar results for both 1 gpu and 8 gpu setting.
However the memory estimated by deepspeed function consider both optimizer and model parameter, while in my case i just only initialized and loaded model parameter, so i think it should comsume only the number of largest_layer_params*4 for gathering parameter i guess.
(idk if accelerate reserved expected buffer for forward, backward tensors)

However it does not match 931MB from pytorch summary and 1.17GB from deepspeed memory estimator for GPU RAM memory.
And with the larger model (13B or 70B llama), allocated and reserved GPU memory is getting larger and larger.
For 70B, largest parameter near 1B so, only 4GB is need for parameter (not including optimizer) but nearly 10GB is allocated after loading paramter.

+) And i checked it does not allocate any gpu memory if i did not use zero3 init flag for accelerate. but to my best knowledge, zero init is for CPU memory optimization during cpu offloading...

and Note that, i didnt even use deepspeed.intialize(model), just after loading parameter with zero.init. memory consumption blows up.

Following python function is what i used for profile and i used this right after loading model parameter with from_pretrained() .

def gpu_memory_plot_helper(
    rank,
    device, 
    message: str,
):
    gc.collect()

    torch.cuda.synchronize(device)
    
    allocated = torch.cuda.memory_allocated(device) / (1024**2)
    max_allocated = torch.cuda.max_memory_allocated(device) / (1024**2)
    reserved = torch.cuda.memory_reserved(device) / (1024**2)
    max_reserved = torch.cuda. max_memory_reserved(device) / (1024**2)

    vm_stats = psutil.virtual_memory()
    used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)

    logger.info(
        f'''
        {message}
        rank: {rank} / device: {device}
        CPU Virtual Memory:  used = {used_GB} GB, percent = {vm_stats.percent}%
        Allocated / Reserved: {allocated:.2f}MB / {reserved:.2f}MB
        Max Allocated / Max Reserved: {max_allocated:.2f}MB / {max_reserved:.2f}MB
        summary
        {torch.cuda.memory_summary(device, abbreviated=True)}
        '''
    )

    torch.cuda.reset_peak_memory_stats(device)

I think it is related to the nested usage of zero.init(gathering param or something) and from_pretrained like you said, but i couldnt find whats wrong...

I would very much appreciate if you answer me!

Best regards.

@stas00
Copy link
Contributor Author

stas00 commented Nov 28, 2023

I have no idea how to give a confirmation from what you have shared, it's a lot of information but a lot of it is irrelevant. I think it's also more difficult to do any such evals when you offload to CPU as your config shows. Much simpler to use GPUs straight and then you get a single measurement.

Perhaps instead of using the profiler try to pass --skip_memory_metrics 0 as explained here: https://github.com/huggingface/transformers/blob/0864dd3beb238b7bec3528a3d1d6c17a28f51a51/src/transformers/training_args.py#L532
and let it do the reporting of the memory usage for you.

The memory estimators only give a suggestion of params and grads and optim states. Other allocations could be quite significant.

But if it's doing the right thing and the model is large enough you should definitely see a significant difference in memory usage between 1 and 8 gpus, regardless if you use zero.Init or not - the usage of the latter only is important if you can't load the whole model on a single GPU. The model will get sharded regardless by the time deepspeed got initialized.

Here is a practical example. Let's take t5-3b and do a translation example from HF transformers

git clone https://github.com/huggingface/transformers
cd transformers

now edit tests/deepspeed/ds_config_zero3.json to turn off cpu offload, that is:

[...]
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "none",
            "pin_memory": true
        },
        "offload_param": {
            "device": "none",
            "pin_memory": true
        },
[...]

now run for 1 gpu:

# 1 gpu
export BS=1; rm -r output_dir; \
PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0,1 deepspeed --num_gpus=1 \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-3b \
--output_dir output_dir --adam_eps 1e-06 --do_train --label_smoothing 0.1 \
--learning_rate 3e-5 --logging_first_step --logging_steps 500 \
--max_source_length 128 --max_target_length 128 --num_train_epochs 1 \
--overwrite_output_dir --per_device_train_batch_size --predict_with_generate \
--sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 \
--dataset_config ro-en --source_prefix 'translate English to Romanian: ' \
--val_max_target_length 128 --warmup_steps 50 --max_train_samples 10 \
--deepspeed tests/deepspeed/ds_config_zero3.json --bf16 --skip_memory_metrics 0

now change to --num_gpus=8 and repeat the above, so now you will have 2 memory usage dumps:

gpu memory usage stats via --skip_memory_metrics 0 for 1 gpu:

***** train metrics *****
  before_init_mem_gpu        =     5438MB
  train_mem_gpu_alloc_delta  =    38092MB
  train_mem_gpu_peaked_delta =     5736MB

8 gpus (the stats is per each gpu):

  before_init_mem_gpu        =      734MB
  train_mem_gpu_alloc_delta  =     4723MB
  train_mem_gpu_peaked_delta =     2213MB

(I filtered out other irrelevant stats)

you can clearly see that when 8 gpus are used ~1/8 of memory is used for the first two stats. In this little program with a gazillion of args we are using a tiny batch size and a tiny seqlen, so pretty much all memory is non-activation-related.

note: I'm using a 8x A100 80GB node - you might need to use t5-large or t5-base if your gpu memory is smaller.

So now you should be able to repeat this for your use case. Once proven that the code works (or leaks) without offload only then try with offload.

Hope this helps.

@SeunghyunSEO
Copy link

SeunghyunSEO commented Nov 30, 2023

Thank you for your kind reply @stas00 .

First of all, I apologize for my English and my lack of explanation of the problem. I should have provided the information more clearly.

I didnt mean that using more GPUs does not bring memory improvement.
I have already verified that using zero 3 gives significant memory savings and when enabling CPU offload it is very memory saving.

My question was: Why does initializing a model with Deep Speed (acceleration) require about 1-2 GB of GPU memory for a large GPT-2 model, even with CPU offload enabled?.

And I finally realized that this is not due to memory leaks in Deep Speed or Acceleration.

This is due to this behavior of GPT-2 class.
Whenever i intialized GPT-2 model, it creates upper triangle mask with full precision for all layers which consumes (4*((seq_len**2)+1))*num_layers / (1024**2) MB
e.g.) if seq_len=8192 and num_layers=36, this would take 9216 MB for each device, even with CPU offload.
(torch.cuda.max_allocated() should return 0MB if cpu offload is activated when initalizing model)

You know, if i partition the model with deepspeed, paramters should be partitioned to cpu or each devices (according to the deepspeed configuration) and all param.data should be released by allocating empty tensor (this line).
But attention masks are not released in this way.

That's why I asked you are there any memory leaks? and where is this GPU memory going?.

For now, I've decided to manage this attention mask globally, or hacked the source code with xformers (it does not require attention mask but LowerTriangularMask).

Thank you so much for your response.
It helps a lot :)

@stas00
Copy link
Contributor Author

stas00 commented Dec 4, 2023

I'm glad to hear you have tracked down the source of gpu memory consumption and that it's not the framework, @SeunghyunSEO!

Thank you for the detailed notes on what you have discovered.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

5 participants