Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

usage of past_key_values produces different output than the whole sequence at once #26344

Closed
2 of 4 tasks
IvanSedykh opened this issue Sep 22, 2023 · 5 comments
Closed
2 of 4 tasks

Comments

@IvanSedykh
Copy link
Contributor

System Info

transformers 4.33.1

Who can help?

@ArthurZucker @younesbelkada @gan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

when I use past_key_values the model produces not the same logits as when I input the whole sequence at once.

Please, follow the code snippet below for more details.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


model_name = "codellama/CodeLlama-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto"
)


prompt = """
import json

fname = 'some_file.json'
with open(fname) as f:
    data = json."""

all_input_ids = tokenizer([prompt], return_tensors='pt').input_ids

# process the whole sequence
with torch.no_grad():
    all_outputs = model(all_input_ids)
# get logits for the last token
last_token_logits = all_outputs.logits[0][-1:]

with torch.no_grad():
    # process the sequence except the last token
    kv = model(all_input_ids[:, :-1]).past_key_values
    # input only the last token with previous kv_cache
    new_output = model(all_input_ids[:, -1:], past_key_values=kv)
# extract the last token logits
new_last_token_logits = new_output.logits[0][-1:]

# theese two distributions should be equal, but they are not.
print(torch.dist(last_token_logits, new_last_token_logits))
# tensor(0.4462)
assert torch.allclose(last_token_logits, new_last_token_logits)  #fails

Expected behavior

If I've got the idea of kv_caching correctly the outputs should be exactly the same. This is important because the generate method heavily relies on past_key_values. So if there is a bug somewhere, it affects a lot of applications.

@ArthurZucker
Copy link
Collaborator

Hey! Thanks for opening and issue. This is pretty much a duplicate of #25420, where we deep dive into this!

@BCreativeS

This comment was marked as spam.

@gante
Copy link
Member

gante commented Oct 23, 2023

Hey @IvanSedykh 👋

As Arthur wrote, this is a duplicate of #25420 -- you can find a detailed answer here

@IvanSedykh
Copy link
Contributor Author

Hi @gante !
Than you for this investigation, it's much more clear now. 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants