You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Wrapping a LlamaModel with FSDP results in the following error during a forward pass;
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
outputs = self.model(
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1016, in forward
layer_outputs = decoder_layer(
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
return self.checkpoint_fn( # type: ignore[misc]
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
ret = function(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 739, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'offload_to_cpu'
Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉
I will open a PR after cataloguing all the models that have this issue. Gptneox also has this issue. Reproducer is to wrap a model in FSDP and then do a forward on any data.
System Info
transformers
version: 4.40.1Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Wrapping a LlamaModel with FSDP results in the following error during a forward pass;
This occurs because we are passing **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608
If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77
Expected behavior
Remove line 749 or add **kwargs to forward().
The text was updated successfully, but these errors were encountered: