Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Fix multiple generations for new HF cache format #444

Merged
merged 1 commit into from May 1, 2024

Conversation

younesbelkada
Copy link
Collaborator

What does this PR do?

Currently on transformers main + latest autoawq, users will face issues with fused modules + multiple calls of generate:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig

model_id = "hf-internal-testing/Mixtral-tiny-AWQ"

quantization_config = AwqConfig(bits=4, fuse_max_seq_len=128, do_fuse=True)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config).to(0)

dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

_ = model.generate(dummy_input, use_cache=True)

# second generate fails
_ = model.generate(dummy_input, use_cache=True)

This PR addresses that by checking the correct attributes in the FusedAttention modules

cc @casper-hansen

@casper-hansen
Copy link
Owner

Hi @younesbelkada, these are some interesting edge cases. Would it help using the attention module the same way as it is used in AutoAWQ? Currently, transformers just resets the start_pos which seems to require multiple fixes and there are probably more bugs that we are not aware of.

@casper-hansen casper-hansen merged commit 33af761 into main May 1, 2024
@younesbelkada younesbelkada deleted the younesbelkada-patch-2 branch May 2, 2024 08:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants