Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama Attention Call should not pass **kwargs #30523

Open
2 of 4 tasks
kiddyboots216 opened this issue Apr 28, 2024 · 4 comments
Open
2 of 4 tasks

Llama Attention Call should not pass **kwargs #30523

kiddyboots216 opened this issue Apr 28, 2024 · 4 comments

Comments

@kiddyboots216
Copy link

System Info

  • transformers version: 4.40.1
  • Platform: Linux-4.18.0-513.24.1.el8_9.x86_64-x86_64-with-glibc2.28
  • Python version: 3.10.13
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.1
  • Accelerate version: 0.29.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: FSDP

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Wrapping a LlamaModel with FSDP results in the following error during a forward pass;

  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1016, in forward
    layer_outputs = decoder_layer(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    ret = function(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 739, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'offload_to_cpu'

This occurs because we are passing **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608

If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77

Expected behavior

Remove line 749 or add **kwargs to forward().

@ArthurZucker
Copy link
Collaborator

Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉

@kiddyboots216
Copy link
Author

Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉

I will open a PR after cataloguing all the models that have this issue. Gptneox also has this issue. Reproducer is to wrap a model in FSDP and then do a forward on any data.

@vikram71198
Copy link

vikram71198 commented May 9, 2024

Yep, can confirm I also see the same issue with LLaMA-3-8b-Instruct with FSDP + Gradient Checkpointing.

The Yi series of models also have this issue, I just checked. And it makes perfect sense since they follow the LLaMA architecture.

@ArthurZucker
Copy link
Collaborator

We'll remove the kwargs! cc @zhenglongjiepheonix who is working on something related!
We can open a separate PR for this and link this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants