Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EsmForSequenceClassification does not support gradient checkpointing #606

Open
Amelie-Schreiber opened this issue Aug 24, 2023 · 1 comment

Comments

@Amelie-Schreiber
Copy link

Amelie-Schreiber commented Aug 24, 2023

NOTE: if this is not a bug report, please use the GitHub Discussions for support questions (How do I do X?), feature requests, ideas, showcasing new applications, etc.

Bug description
ESM-2 models do not seem to be compatible with QLoRA due to not being compatible with gradient checkpointing.

Reproduction steps
Code to reproduce:

!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git 
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig

model_id = "facebook/esm2_t6_8M_UR50D"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

this next part produces the error:

from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

Expected behavior
The script should simply prepare the model for training with a QLoRA (Quantized Low Rank Adaptation). See here for example which is linked to in this article.

Logs
Please paste the command line output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-2-d6b5f42e99b2>](https://localhost:8080/#) in <cell line: 3>()
      1 from peft import prepare_model_for_kbit_training
      2 
----> 3 model.gradient_checkpointing_enable()
      4 model = prepare_model_for_kbit_training(model)

[/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in gradient_checkpointing_enable(self)
   1719         """
   1720         if not self.supports_gradient_checkpointing:
-> 1721             raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
   1722         self.apply(partial(self._set_gradient_checkpointing, value=True))
   1723 

ValueError: EsmForSequenceClassification does not support gradient checkpointing.

Additional context
This is a basic attempt at training a QLoRA for ESM-2 models such as facebook/esm2_t6_8M_UR50D for a sequence classification task. The error is not task dependent though, and I have the same error when trying to train a token classifier. Any assistance on making ESM-2 models compatible with QLoRA would be greatly appreciated.

@sirius777coder
Copy link

I use TrainingArguments checkpointing and it works well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants