Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An error occurred when using the model.gradient_checkpointing_enable() feature. #27596

Closed
4 tasks
CaC033 opened this issue Nov 20, 2023 · 4 comments · Fixed by #27610
Closed
4 tasks

An error occurred when using the model.gradient_checkpointing_enable() feature. #27596

CaC033 opened this issue Nov 20, 2023 · 4 comments · Fixed by #27610

Comments

@CaC033
Copy link

CaC033 commented Nov 20, 2023

System Info

  • transformers version: 4.35.2
  • Platform: Linux-4.19.91-014.kangaroo.alios7.x86_64-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 1.14.0a0+410ce96 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

model = AutoModelForCausalLM.from_pretrained(
    args.load,
    from_tf=False,
    config=config,
    revision='main',
    use_auth_token=None,
    low_cpu_mem_usage=False,
    ignore_mismatched_sizes=True,
    trust_remote_code=True,
    local_files_only=True
    
)

if args.enable_gradient_checkpointing:
    model.gradient_checkpointing_enable()

n_params = model.num_parameters()
logger.info(f"Training model with {n_params * 1e-9:.2f}B model")
embedding_size = model.get_input_embeddings().weight.shape[0]

if len(tokenizer) > embedding_size:
    model.resize_token_embeddings(len(tokenizer))

def tokenize_function(examples):
    sources = examples['instruction'] 
    targets = examples['content']
    data_dict = preprocess(sources, targets, tokenizer)
    return data_dict

with training_args.main_process_first(desc="dataset map tokenization"):
    lm_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=64
    )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    neftune_noise_alpha=0.1,
)
trainer.train()

error:

Traceback (most recent call last):
File "/mnt/workspace/peipao/jichunengli/test_qwen_hf/ds_train_huggingface_Ulama-py",line322,in
File "/mnt/workspace/peipao/jichunengli/test_qwen_h/ds_train_huggingface_llama-py",line288,inmain model.gradient_checkpointing_enable ()
File "/us/local/lib/python3.8/dist-packages/transformers/modeling_utils.py", line 1872, in gradient_checkpointing_enable self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func-gradient_checkpointing_func)
TypeError:
_set_gradient_checkpointing() got an unexpected kevword argument 'enable'

I checked the source code of _set_gradient_checkpointing and found that the input parameter includes "enable".

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
    """
    Activates gradient checkpointing for the current model.

    Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
    activations".

    We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
    the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2

    Args:
        gradient_checkpointing_kwargs (dict, *optional*):
            Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
    """
    if not self.supports_gradient_checkpointing:
        raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

    if gradient_checkpointing_kwargs is None:
        gradient_checkpointing_kwargs = {}

    gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)

    self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
    if getattr(self, "_hf_peft_config_loaded", False):
        # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
        # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
        # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
        # the gradients to make sure the gradient flows.
        self.enable_input_require_grads()

def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
    is_gradient_checkpointing_set = False

    # Apply it on the top-level module in case the top-level modules supports it
    # for example, LongT5Stack inherits from `PreTrainedModel`.
    if hasattr(self, "gradient_checkpointing"):
        self._gradient_checkpointing_func = gradient_checkpointing_func
        self.gradient_checkpointing = enable
        is_gradient_checkpointing_set = True

    for module in self.modules():
        if hasattr(module, "gradient_checkpointing"):
            module._gradient_checkpointing_func = gradient_checkpointing_func
            module.gradient_checkpointing = enable
            is_gradient_checkpointing_set = True

    if not is_gradient_checkpointing_set:
        raise ValueError(
            f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
            " `gradient_checkpointing` to modules of the model that uses checkpointing."
        )

Expected behavior

Please fix this bug.

@rangehow
Copy link

the same issues with transformers version: 4.35.2 when load baichuan2-13b-chat model .

@ArthurZucker
Copy link
Collaborator

cc @younesbelkada who worked on this recently 😉

@younesbelkada
Copy link
Contributor

Hi @CaC033 @rangehow

#27610 should fix the issue. However, note that with respect to the new refactor of gradient checkpointing, the models that use code on the Hub should not define a _set_gradient_checkpointing method (as it is done for baichuan models), as modules that support GC are automatically inferred thanks to the gradient_checkpointing attribute.
A long-term fix would be to remove that method as done in https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/discussions/27 as currently you cannot pass checkpointing arguments such as use_reentrant

@younesbelkada
Copy link
Contributor

Hi everyone, it should be now resolved on transformers main, again, bear in mind that you need to remove the _set_gradient_checkpointing method to avoid these issues in the future as the support for old GC will be removed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants