Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kv cache breaks generation #219

Closed
ad8e opened this issue Dec 30, 2023 · 5 comments
Closed

kv cache breaks generation #219

ad8e opened this issue Dec 30, 2023 · 5 comments

Comments

@ad8e
Copy link

ad8e commented Dec 30, 2023

If I turn the kv cache on, autoregressive generation produces junk.

Here's a diagnostic: replace autoregressive_wrapper.py with this one: autoregressive_wrapper.txt (rename to .py)

This is the diff: Screenshot_20231229_225001

Running this on a testcase, such as your enwiki8 example, the output is that the logits are indeed scrambled between kv cache on vs off.

When testing generation on a small dataset, like Karpathy's Tiny Shakespeare, we see that KV cache breaks the generation:

KV cache on:

BIANCA:
Pardon, like an unwioul.

one nobistasele me, for me.

Were'e tondove ture ture.see tove t, gureforec tat ittheeft,
Tatt, deeeFooond
't!'tcitcit!
We tufn ture-p ton tovit thon tathon tap,
A'e
't!
'e
'e tctc,
'eSattc, tc,
'e
'eeatvaS:
'e t!s tcap tLeee t,
'ee t,
De t!
A tovitoreeee ton tctc, t,
Wee tctc,
aC,
We dee
A thove
TeeTy:
Tt,
'eree t,
T,
Ty
I
'ee ton tc,
at!
'ee tononc,
'eeeee:
C tc ton,
'eat
'e ttchon tce t ttc,
'eree tc
tceeere t!
'ee ttceeee tttoreee:
ItcI
ttt tt, the,
hceTaC thore tttthore
Tu

KV cache off (by forcing self.can_cache_kv = 0):

BIANCA:
Pardon, my lord.

CORIOLANUS:
Why, having I can say 'tis the people?

First Servant:
Ha! what comes he ha'?

MENENIUS:
Now, I'll not budge thee with abard; not too much.
I have a ta'en for these sequesters must as
contrary'd to Rome, not Marcius, Lord Romenture;
My backetked Pravent; and, but at noblemant
More than any shall fine be a rought ,
Give up the to take and victor's bone
Have grace meet by my brother's life.

AUTOLYCUS:
Wear willow have we'll stout the coronatorsy.

First Gentleman:
Whom, how

If you want, you can fix it, or you can also just wait and I'll figure out the cause and submit a PR in a week.

As an aside, I'm running some changes to the enwik8 example: LR schedule, removing gradient accumulation, increasing batch size, and turning on flash attention. This significantly improves the performance and is codesize-neutral. Would you welcome these changes?

@lucidrains
Copy link
Owner

lucidrains commented Dec 30, 2023

@ad8e hey Kevin

thanks for reporting

i quickly checked on a test script and it seems to be fine

import torch

from x_transformers import (
    TransformerWrapper,
    Decoder,
    AutoregressiveWrapper
)

model = TransformerWrapper(
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 8,
        depth = 1,
        heads = 4
    )
)

model = AutoregressiveWrapper(model)

prompts = torch.zeros((1, 1))

generated = model.generate(
    prompts,
    seq_len = 100,
    temperature = 0.,
    cache_kv = False
)

kv_cache_generated = model.generate(
    prompts,
    seq_len = 100,
    temperature = 0.,
    cache_kv = True
)

assert torch.allclose(generated, kv_cache_generated)

could you modify the script so that it breaks? perhaps you are using some hyperparameter that is incompatible with kv cache (would be good to put in a patch if so into can_cache_kv logic)

logits and logits2 in your code have different shapes. you need to compare logits and logits2[:, -1:]

@lucidrains
Copy link
Owner

@ad8e as for your following offer, it is ok, as the library is model architecture specific. what you mention is all training related

lucidrains added a commit that referenced this issue Dec 30, 2023
… the context window during decoding, addressing #219
@lucidrains
Copy link
Owner

lucidrains commented Dec 30, 2023

@ad8e ah, got to the bottom of it Kevin

so it turns out the default (absolute positional embedding) is not kv cache friendly once you exceed the maximum sequence length (context window). however, it should still work when decoding from 1st token to the max context window size

i added an assert to prevent this, but also defaulted the enwik8 training script to use rotary positions, which is the preferred positional embeddings these days (llama), and kv cache friendly when exceeding context length.

ok, back to the holidays; have a great new years Kevin

lucidrains added a commit that referenced this issue Dec 30, 2023
… the context window during decoding, addressing #219
@ad8e
Copy link
Author

ad8e commented Dec 30, 2023

Thanks, this fixes the issue and is better than what I would have PR'd. Sorry that I dumped two bad testcases and then went to sleep.

@ad8e as for your following offer, it is ok

To clarify, do you mean it is ok to do it, or "it is ok" as in it is not necessary?

have a great new years Kevin

You too!

@lucidrains
Copy link
Owner

lucidrains commented Dec 30, 2023

Thanks, this fixes the issue and is better than what I would have PR'd. Sorry that I dumped two bad testcases and then went to sleep.

@ad8e as for your following offer, it is ok

To clarify, do you mean it is ok to do it, or "it is ok" as in it is not necessary?

have a great new years Kevin

You too!

it isn't necessary, not for this lib

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants