Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Fixes unexpected behaviour for Llava / LLama & AWQ Fused modules + revert #30070 at the same time #30317

Merged
merged 5 commits into from Apr 18, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Apr 18, 2024

What does this PR do?

Fixes a silent behaviour introduced by a recent PR, passing a None attention mask results in unexpected behaviour for awq fused modules. The fix is simply to force set a dummy _attn_implementation on the config objects of the modules that contain fused modules

I can confirm the failing slow tests now pass with these changes

cc @ArthurZucker @fxmarty

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@younesbelkada younesbelkada changed the title FIX: Fixes unexpected behaviour for Llava / LLama & AWQ Fused modules FIX: Fixes unexpected behaviour for Llava / LLama & AWQ Fused modules + revert #30070 at the same time Apr 18, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure to run the slow Llama tests at least before merging to iterate fast if they fail

Comment on lines +253 to +255
# For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass
# `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt
# by the `AttentionMaskConverter` module.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass
# `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt
# by the `AttentionMaskConverter` module.
# For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behaviors. We loop over the layers to make sure the vision/text config are
# modified only if some of their modules were fused

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but more understandable

Comment on lines 269 to 270
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please apply the above fix and LGTM

If you can run

from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
import torch
from datasets import load_dataset

processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")

assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(
    sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features

predicted_ids = model.generate(input_features, assistant_model=assistant_model)

# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(transcription)

& llava AWQ test & llama slow test to be sure as well

@younesbelkada
Copy link
Contributor Author

Thanks @fxmarty @ArthurZucker ! just tested everything seems to all pass !

@LysandreJik LysandreJik merged commit 5728b5a into main Apr 18, 2024
16 of 21 checks passed
@LysandreJik LysandreJik deleted the younesbelkada-patch-1 branch April 18, 2024 13:51
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Apr 18, 2024
… + revert huggingface#30070 at the same time (huggingface#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
… + revert #30070 at the same time (#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
… + revert #30070 at the same time (#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
itazap pushed a commit that referenced this pull request May 14, 2024
… + revert #30070 at the same time (#30317)

* Update awq.py

* style

* revert felix PR

* fix

* add felix comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants