Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing should have no functional impact #26221

Closed
2 of 4 tasks
marianokamp opened this issue Sep 18, 2023 · 10 comments
Closed
2 of 4 tasks

Gradient checkpointing should have no functional impact #26221

marianokamp opened this issue Sep 18, 2023 · 10 comments
Labels

Comments

@marianokamp
Copy link

System Info

Latest released and py3.10.

accelerate-0.21.0 aiohttp-3.8.5 aiosignal-1.3.1 async-timeout-4.0.3 bitsandbytes-0.41.0 datasets-2.14.5 evaluate-0.4.0 frozenlist-1.4.0 huggingface-hub-0.17.1 multidict-6.0.4 peft-0.4.0 pynvml-11.5.0 regex-2023.8.8 responses-0.18.0 safetensors-0.3.3 sagemaker-inference-1.10.0 tensorboardX-2.6.2.2 tokenizers-0.13.3 transformers-4.33.2 xxhash-3.3.0 yarl-1.9.2

Who can help?

@pacman100, @muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi @pacman100, @muellerzr.

I was wondering about the memory use of LoRA. Specifically what happens if I adapt modules that are

  • (top) closer to the head of the network than to the inputs, as opposed to
  • (bottom) the other way around.

Given that the number of parameters to train remains the same in both cases, the memory usage should be the same, except that to calculate the gradients for (bottom) we would need to keep more activations around from the forward pass. If that were the case, then turning on gradient checkpointing should make (top) and (bottom) use the same memory, as we are discarding the activations and recalculating them on the backward pass. That is correct, no (@younesbelkada)?

Trying this out, I can see that behavior as expected. However, the accuracy also changed.
My understanding would be that with gradient checkpointing we would now need less memory, more time, but the functional aspects, here model performance, should be unchanged. Hence the issue.

Details

Below you can see on the x-axis on which layer of a 12 layer RoBERTa Base the adapters were applied. As you can see the memory for (bottom - lower layer numbers, closer to the embeddings) are higher than for (top - higher layer numbers, closer to the head), when not using gradient checkpointing, and they are same when using gradient checkpointing.

image

However, when looking at the model performance we can see that we have a difference of 0.1 between using and not using checkpointing.

image

Not that it matters, but this is using the glue/sst-2 dataset. I am not changing anything, but passing 0 or 1 as an argument to Trainer's gradient_checkpointing attribute (and 0 and 1 to empty-cuda-cache every 30 seconds).

Expected behavior

No functional change when using gradient_checkpointing.

@marianokamp
Copy link
Author

No answer or re-action yet, but not stale either.

@huggingface huggingface deleted a comment from github-actions bot Oct 19, 2023
@huggingface huggingface deleted a comment from github-actions bot Nov 13, 2023
@amyeroberts
Copy link
Collaborator

Gentle ping @muellerzr @pacman100

@huggingface huggingface deleted a comment from github-actions bot Dec 8, 2023
@marianokamp
Copy link
Author

@pacman100, @muellerz
Just re-ran with transformers 4.36.0, same result:

image

@huggingface huggingface deleted a comment from github-actions bot Jan 8, 2024
@marianokamp
Copy link
Author

@pacman100, @muellerzr, @younesbelkada. Anything I can do here to help you acknowledge the ticket? If I am hearing nothing I will let it auto-close.

@pacman100
Copy link
Contributor

pacman100 commented Jan 10, 2024

Hello @marianokamp, Thank you for your patience. As I don't have a clear minimal reproducer here, I ran the below experiments and don't see a diff in performance with and without gradient checkpointing.

  1. Code: https://github.com/huggingface/peft/blob/main/examples/sequence_classification/LoRA.ipynb
  2. Use the set_seed for deterministic runs:
import argparse
import os

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    LoraConfig,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
)

import evaluate
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from tqdm import tqdm

+ set_seed(100)
  1. In gradient ckpt run, add the model.gradient_checkpointing_enable command:
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
  1. Run the notebooks with and without gradient ckpt.
  2. mem usage:
    Screenshot 2024-01-10 at 11 32 35 AM
  3. Without gradient ckpt output logs:
0%|                                                                                                                                                | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:27<00:00,  4.18it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.53it/s]
epoch 0: {'accuracy': 0.7083333333333334, 'f1': 0.8210526315789474}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.52it/s]
epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.31it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.53it/s]
epoch 2: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.52it/s]
epoch 3: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.52it/s]
epoch 4: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.54it/s]
epoch 5: {'accuracy': 0.8186274509803921, 'f1': 0.8766666666666666}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.54it/s]
epoch 6: {'accuracy': 0.8333333333333334, 'f1': 0.8885245901639344}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.50it/s]
epoch 7: {'accuracy': 0.875, 'f1': 0.9109947643979057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.52it/s]
epoch 8: {'accuracy': 0.8872549019607843, 'f1': 0.9184397163120569}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.51it/s]
epoch 9: {'accuracy': 0.8872549019607843, 'f1': 0.9201388888888888}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.49it/s]
epoch 10: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.49it/s]
epoch 11: {'accuracy': 0.8897058823529411, 'f1': 0.9220103986135182}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.49it/s]
epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.9241622574955909}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.46it/s]
epoch 13: {'accuracy': 0.8970588235294118, 'f1': 0.926056338028169}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.43it/s]
epoch 14: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.43it/s]
epoch 15: {'accuracy': 0.8872549019607843, 'f1': 0.9181494661921709}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.49it/s]
epoch 16: {'accuracy': 0.8897058823529411, 'f1': 0.9211908931698775}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.27it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.48it/s]
epoch 17: {'accuracy': 0.8897058823529411, 'f1': 0.9203539823008849}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.44it/s]
epoch 18: {'accuracy': 0.8872549019607843, 'f1': 0.9195804195804195}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.26it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.43it/s]
epoch 19: {'accuracy': 0.8921568627450981, 'f1': 0.923076923076923}
  1. with gradient checkpointing output logs:
0%|                                                                                                                                                | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:41<00:00,  2.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.37it/s]
epoch 0: {'accuracy': 0.7083333333333334, 'f1': 0.8210526315789474}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.37it/s]
epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.39it/s]
epoch 2: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.41it/s]
epoch 3: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.40it/s]
epoch 4: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.40it/s]
epoch 5: {'accuracy': 0.8186274509803921, 'f1': 0.8766666666666666}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:41<00:00,  2.79it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.44it/s]
epoch 6: {'accuracy': 0.8333333333333334, 'f1': 0.8885245901639344}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.42it/s]
epoch 7: {'accuracy': 0.875, 'f1': 0.9109947643979057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.40it/s]
epoch 8: {'accuracy': 0.8872549019607843, 'f1': 0.9184397163120569}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.84it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.39it/s]
epoch 9: {'accuracy': 0.8872549019607843, 'f1': 0.9201388888888888}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.40it/s]
epoch 10: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.39it/s]
epoch 11: {'accuracy': 0.8897058823529411, 'f1': 0.9220103986135182}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.40it/s]
epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.9241622574955909}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.39it/s]
epoch 13: {'accuracy': 0.8970588235294118, 'f1': 0.926056338028169}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.34it/s]
epoch 14: {'accuracy': 0.8921568627450981, 'f1': 0.9225352112676057}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.33it/s]
epoch 15: {'accuracy': 0.8872549019607843, 'f1': 0.9181494661921709}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.41it/s]
epoch 16: {'accuracy': 0.8897058823529411, 'f1': 0.9211908931698775}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.82it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.42it/s]
epoch 17: {'accuracy': 0.8897058823529411, 'f1': 0.9203539823008849}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.81it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.35it/s]
epoch 18: {'accuracy': 0.8872549019607843, 'f1': 0.9195804195804195}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.83it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.39it/s]
epoch 19: {'accuracy': 0.8921568627450981, 'f1': 0.923076923076923}

Observations: No performance gap between runs with gradient checkpointing and without gradient checkpointing.

@marianokamp
Copy link
Author

Thanks @pacman100. I got it now - a minimalist example is needed. I will try to create one over the weekend.

@marianokamp
Copy link
Author

@pacman100. Hi Sourab, thanks for investing the time!

You didn't say otherwise, so it's confirmed that using gradient checkpointing should not change the functional impact of the model, correct?

I now have a minimal implementation sample notebook that shows the issue.

Background: The original code is from an article that illustrates for educational purposes how a simple LoRA implementation looks like. It's just Python code and worked fine, until I tried gradient checkpointing in the 2nd article.

I am not aware of specific expectations that the transformers lib has on code. But there are two things I do in my example that may be worth pointing out as not being in the middle of the road. (a) Freezing modules and (b) overwriting the forward function in the module to be adapted to point it to the adapter implementation in the forward pass. Both work fine without gradient checkpointing, but maybe they are problematic with gradient checkpointing? The code is in the example I linked above, but for easier consumption I reproduce this method here:

def adapt_model(model):

    class MinimalLoRAAdapter(nn.Module): 
        def __init__(self, 
                     adaptee):
            super().__init__()

            self.adaptee = adaptee

            self.orig_forward = adaptee.forward
            adaptee.forward = self.forward # <-----------------
            
            r = 1
            adaptee.lora_A = nn.Parameter(
                torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
            )
            adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))

        def forward(self, x, *args, **kwargs):
            return (
                self.orig_forward(x, *args, **kwargs) # <-----------------
                + F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
            )
   
    # freeze all layers, incl. embeddings, except for the classifier
    for m in model.roberta.modules():    
        m.requires_grad_(False) # <-----------------

    # Adapt linear modules in transformer layers
    for m in model.roberta.encoder.modules():    
        if isinstance(m, nn.Linear):
            MinimalLoRAAdapter(m)

Here is an excerpt from the output. Full output in the linked notebook (check eval_accuracy):

---- without gradient checkpointing ----

[..]
model.is_gradient_checkpointing=False
[..]
{'train_runtime': 457.1886, 'train_samples_per_second': 489.951, 'train_steps_per_second': 2.187, 'train_loss': 0.38296363830566404, 'epoch': 3.32}
{'eval_loss': 0.23593959212303162, 'eval_accuracy': 0.908256880733945, 'eval_runtime': 1.6902, 'eval_samples_per_second': 515.919, 'eval_steps_per_second': 64.49, 'epoch': 3.32}

---- with gradient checkpointing ----

[..]
model.is_gradient_checkpointing=True
[..]
{'train_runtime': 227.8506, 'train_samples_per_second': 983.101, 'train_steps_per_second': 4.389, 'train_loss': 0.6675097045898437, 'epoch': 3.32}
{'eval_loss': 0.6635248064994812, 'eval_accuracy': 0.5194954128440367, 'eval_runtime': 1.6397, 'eval_samples_per_second': 531.808, 'eval_steps_per_second': 66.476, 'epoch': 3.32}
[..]

I tried the above with both GPU and CPU and I can observe the same behavior. Hope that helps to narrow it down.

@huggingface huggingface deleted a comment from github-actions bot Feb 12, 2024
@huggingface huggingface deleted a comment from github-actions bot Mar 8, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @pacman100

@huggingface huggingface deleted a comment from github-actions bot Apr 2, 2024
@pacman100
Copy link
Contributor

pacman100 commented Apr 2, 2024

Hello @marianokamp,

Thank you for the minimal reproducer via the notebook. I ran it using the latest versions with the below changes:

+ gradient_checkpointing_kwargs = None
    if cp_enabled:
-         model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False})
+        gradient_checkpointing_kwargs = {"use_reentrant":False}
    
    training_args = TrainingArguments(
        gradient_checkpointing=cp_enabled,
+        gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
...

The issue you are facing with gradient checkpointing with LoRA is as follows:

  1. Let's see the behaviour for use_reentrant=True as mentioned in https://pytorch.org/docs/stable/checkpoint.html:

At least one input and output must have requires_grad=True for the reentrant variant. If this condition is unmet, the checkpointed part of the model will not have gradients. The non-reentrant version does not have this requirement.

  1. You were correctly setting model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False}) but the Trainer was resetting it because of trainingArgument gradient_checkpointing and as you didn't pass gradient_checkpointing_kwargs, the default value of use_reentrant=True was used. This is clear from the warning in your notebook output:
    Screenshot 2024-04-02 at 3 35 53 PM
  2. Now, as the embedding layer is frozen, neither the input nor the output has requires_grad=True which is required when using use_reentrant=True. As such, no gradients are computed and no learning happens leading to very low model accuracy.
  3. The above changes rectify this to use the recommended use_reentrant=False.
  4. Another alternative if you still want to use use_reentrant=True is to make the outputs of the embedding layer require grads even though you won't be needing it as this fulfils the condition of least one input and output must have requires_grad=True for the reentrant variant. You can see this being done in the PEFT codebase at https://github.com/huggingface/peft/blob/02b5aeddf9c1ea11451f10a8a26da7e5df8cca4a/src/peft/utils/other.py#L112-L122

Output with the above changes:
Screenshot 2024-04-02 at 3 43 36 PM

Library versions:
Screenshot 2024-04-02 at 3 44 05 PM

Code:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, set_seed
from torch import nn
from torch.nn import functional as F
import math

hf_ckp = 'roberta-base'
set_seed(100)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {f"accuracy": (predictions == labels).mean()}

def count_parameters(m, verbose=True):
    total_count = 0
    learnable_count = 0
    if verbose:
        print("Parameters (name, tunable, count):")

    output_width = max([len(n) for n, _ in m.named_parameters()])
    for n, p in m.named_parameters():
        count = p.data.numel()
        if verbose:
            print(f" {n:{output_width}} {p.requires_grad:5b} {count:>11d}")
        total_count += count
        if p.requires_grad:
            learnable_count += count

    print(
        f"Total parameters: {total_count:,}, "
        f"thereof learnable: {learnable_count:,} "
        f"({learnable_count/total_count*100.:5.4f}%)"
    )

    return total_count, learnable_count

def adapt_model(model):
    
    # Minimalized example in place of the original LoRA-from-Scratch 
    # implementation from the article: 
    # https://towardsdatascience.com/dive-into-lora-adapters-38f4da488ede
    class MinimalLoRAAdapter(nn.Module): 
        def __init__(self, 
                     adaptee):
            super().__init__()

            self.adaptee = adaptee

            self.orig_forward = adaptee.forward
            adaptee.forward = self.forward
            
            r = 1
            adaptee.lora_A = nn.Parameter(
                torch.randn(adaptee.in_features, r) / math.sqrt(adaptee.in_features)
            )
            adaptee.lora_B = nn.Parameter(torch.zeros(r, adaptee.out_features))

        def forward(self, x, *args, **kwargs):
            return (
                self.orig_forward(x, *args, **kwargs)
                + F.dropout(x, 0.1) @ self.adaptee.lora_A @ self.adaptee.lora_B
            )
   
    # freeze all layers, incl. embeddings, except for the classifier
    for m in model.roberta.modules():    
        m.requires_grad_(False)

    # Adapt linear modules in transformer layers
    for m in model.roberta.encoder.modules():    
        if isinstance(m, nn.Linear):
            MinimalLoRAAdapter(m)
%%time

tokenizer = AutoTokenizer.from_pretrained(hf_ckp)
collator = DataCollatorWithPadding(tokenizer=tokenizer)

datasets.logging.disable_progress_bar()
dataset = datasets.load_dataset("glue", "sst2")
train = dataset["train"]
valid = dataset["validation"]

def preprocess_function(examples):
        return tokenizer(examples['sentence'], padding=False, truncation=True)

tokenized_train = train.map(preprocess_function, batched=False)
tokenized_valid = valid.map(preprocess_function, batched=False)

def train(cp_enabled, model):
    gradient_checkpointing_kwargs = None
    if cp_enabled:
        gradient_c_heckpointing_kwargs = {"use_reentrant":False}
    
    training_args = TrainingArguments(
        gradient_checkpointing=cp_enabled,
        gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
        output_dir="out",    
        per_device_train_batch_size=224,
        learning_rate=3e-5,
        save_steps=10_000,
        eval_steps=   250,
        max_steps = 1_000,
        evaluation_strategy="steps",
        save_strategy="steps",
        save_total_limit=1,
        disable_tqdm=True,
        metric_for_best_model='eval_accuracy',
        report_to="none", # Disable wandb, tensorboard
    )

    trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_valid,
            tokenizer=tokenizer,
            data_collator=collator,
            compute_metrics=compute_metrics,
    )
    print(f'{model.is_gradient_checkpointing=}')
    total, learnable = count_parameters(model, verbose=False)
    
    trainer.train()
    trainer.evaluate()
    
    
print('\n---- without gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)   
adapt_model(model)
train(False, model)

del(model) # essential!

print('\n---- with gradient checkpointing ----\n')
model = AutoModelForSequenceClassification.from_pretrained(hf_ckp, num_labels=2)
adapt_model(model)

train(True, model)

@marianokamp
Copy link
Author

@pacman100, thanks for your help and walking me through the solution in detail. I am still a bit confused by the API, but I understand the steps you showed me and following them fixed my issue in my original, non-minimal, code. All clear for me now. Much appreciated, Sourab!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants