Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inconsistent logits for identical inputs in batch (metal) #1941

Closed
josharian opened this issue Mar 8, 2024 · 4 comments · Fixed by #1947
Closed

inconsistent logits for identical inputs in batch (metal) #1941

josharian opened this issue Mar 8, 2024 · 4 comments · Fixed by #1947

Comments

@josharian
Copy link
Contributor

Summary

Given a beam search that contains two decoders with identical inputs (encoded state, past tokens), the same token is getting slightly different logits across those two beam searches.

This matters because (a) it hints that there's a bug somewhere and (b) the beam search de-dup logic assumes that logprobs will be identical for identical sequences.

It happens with metal, but not with cpu-only, suggesting a bug in the metal graph evaluation.

Reproduce

Run:

./main -m models/ggml-large-v2.bin sid5s.wav

Result:

<snip>
whisper_full_with_state: prompt[0] = [_SOT_]
whisper_full_with_state: prompt[1] = [_LANG_en]
whisper_full_with_state: prompt[2] = [_TRANSCRIBE_]

whisper_full_with_state: beam search: decoder 0: from decoder 0: token =    [_BEG_], plog = -0.11073, sum_logprobs = -0.11073
whisper_full_with_state: beam search: decoder 1: from decoder 0: token =    [_BEG_], plog = -0.11073, sum_logprobs = -0.11073
whisper_full_with_state: beam search: decoder 2: from decoder 0: token =    [_BEG_], plog = -0.11073, sum_logprobs = -0.11073
whisper_full_with_state: beam search: decoder 3: from decoder 0: token =    [_BEG_], plog = -0.11073, sum_logprobs = -0.11073
whisper_full_with_state: beam search: decoder 4: from decoder 0: token =    [_BEG_], plog = -0.11073, sum_logprobs = -0.11073
whisper_full_with_state: id =   0, decoder = 0, token =  50364, p =  0.895, ts =    [_BEG_],  0.895, result_len =    0 '[_BEG_]'
whisper_full_with_state: id =   0, decoder = 1, token =  50364, p =  0.895, ts =    [_BEG_],  0.895, result_len =    0 '[_BEG_]'
whisper_full_with_state: id =   0, decoder = 2, token =  50364, p =  0.895, ts =    [_BEG_],  0.895, result_len =    0 '[_BEG_]'
whisper_full_with_state: id =   0, decoder = 3, token =  50364, p =  0.895, ts =    [_BEG_],  0.895, result_len =    0 '[_BEG_]'
whisper_full_with_state: id =   0, decoder = 4, token =  50364, p =  0.895, ts =    [_BEG_],  0.895, result_len =    0 '[_BEG_]'
whisper_full_with_state: beam search: decoder 0: from decoder 0: token =      Great, plog = -0.66754, sum_logprobs = -0.77828
whisper_full_with_state: beam search: decoder 1: from decoder 4: token =      Great, plog = -0.66756, sum_logprobs = -0.77830
whisper_full_with_state: beam search: decoder 2: from decoder 4: token =      Thank, plog = -4.03217, sum_logprobs = -4.14290
whisper_full_with_state: beam search: decoder 3: from decoder 0: token =      Thank, plog = -4.03220, sum_logprobs = -4.14294
whisper_full_with_state: beam search: decoder 4: from decoder 2: token =          -, plog = -4.22957, sum_logprobs = -4.34030
<snip>

The lines of interest are

whisper_full_with_state: beam search: decoder 0: from decoder 0: token =      Great, plog = -0.66754, sum_logprobs = -0.77828
whisper_full_with_state: beam search: decoder 1: from decoder 4: token =      Great, plog = -0.66756, sum_logprobs = -0.77830

Observe that plog and sum_logprobs are slightly different. They should be identical; the sequences leading up to them are identical ([_SOT_], [_LANG_en], [_TRANSCRIBE_], [_BEG_]).

It does not reproduce when running with cpu only:

./main -m models/ggml-large-v2.bin -ng sid5s.wav

Relevant lines are:

whisper_full_with_state: beam search: decoder 0: from decoder 0: token =      Great, plog = -0.66892, sum_logprobs = -0.77950
whisper_full_with_state: beam search: decoder 4: from decoder 0: token =      Great, plog = -0.66892, sum_logprobs = -0.77950

The absolute values also differ across metal and cpu; I assume that this is expected?

josharian added a commit to josharian/whisper.cpp that referenced this issue Mar 8, 2024
All else being otherwise equal, this encourages the beam candidate
selection to re-use the same decoder, which slightly
reduces the cache size.

I wouldn't expect it to make much of a performance difference,
but it helps when debug printing the cache and beam.

Added as part of understanding ggerganov#1941.
@josharian
Copy link
Contributor Author

Oh, and this reproduces with -t 1, so I don't think it is a data race.

ggerganov pushed a commit that referenced this issue Mar 9, 2024
All else being otherwise equal, this encourages the beam candidate
selection to re-use the same decoder, which slightly
reduces the cache size.

I wouldn't expect it to make much of a performance difference,
but it helps when debug printing the cache and beam.

Added as part of understanding #1941.
@ggerganov
Copy link
Owner

In whisper.cpp (and llama.cpp) we use what I call a "unified KV cache":

ggerganov/llama.cpp#4130 (comment)

In short, we put the tokens from all sequences in the same KV cache buffer and construct a suitable KQ-mask that is used to discard cross-sequence attention. This is in contrast to a more straightforward approach in which each sequence has it's own separate KV cache buffer and the attention for each sequence is computed independently from the rest - the KQ mask in this case is simply a causal mask.

A drawback of the unified KV cache is that the results for different sequences are now also a function of where their tokens end up in the buffer. The reason is that the V*QK multiplication does a reduction over the KV buffer length and since these are floating point numbers and there are rounding errors when accumulating them, the final sum in each row will be slightly different depending on the order of the data in the buffer. I.e. the output is not invariant to how the input is placed in the KV cache

It does not reproduce when running with cpu only

I expect similar effect to be observed on the CPU - maybe it's harder to encounter

The absolute values also differ across metal and cpu; I assume that this is expected?

Yes, these are normal to be different

@josharian
Copy link
Contributor Author

Got it. Thanks for the details.

What do you think about using a mechanism other than floating point equality in the beam search code I linked to above, since floating point equality is no longer expected?

The easiest and simplest approach is checking whether the sequence token ids are identical. That's O(n), but small n and the tests are cheap. More complicated would be to keep a running hash of token ids. I implemented the O(n) approach while investigating this; it was straightforward.

@ggerganov
Copy link
Owner

The easiest and simplest approach is checking whether the sequence token ids are identical.

Yes, sounds like a good idea

josharian added a commit to josharian/whisper.cpp that referenced this issue Mar 10, 2024
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking.
As a result, depending on their location in the batch,
identical sequences in a batch can have slightly different outputs
due to floating point rounding errors during reduction.
See the discussion in ggerganov#1941 for more details.

The beam search code used "has identical sum of log probabilities"
as a shorthand for "is an identical token sequence". However, per above,
identical tokens do not necessarily result in identical probabilities.

Instead, explicitly compare on sequences.
This is linear in cost when they are identical,
but the lengths are always small and the comparisons are cheap.

This increases diversity during beam search.

This improves output quality for some short samples I've been working
with, at no detectable performance cost.
I haven't checked against larger corpuses.

Fixes ggerganov#1941
ggerganov pushed a commit that referenced this issue Mar 10, 2024
As of #1486, whisper.cpp uses a unified KV cache with KQ masking.
As a result, depending on their location in the batch,
identical sequences in a batch can have slightly different outputs
due to floating point rounding errors during reduction.
See the discussion in #1941 for more details.

The beam search code used "has identical sum of log probabilities"
as a shorthand for "is an identical token sequence". However, per above,
identical tokens do not necessarily result in identical probabilities.

Instead, explicitly compare on sequences.
This is linear in cost when they are identical,
but the lengths are always small and the comparisons are cheap.

This increases diversity during beam search.

This improves output quality for some short samples I've been working
with, at no detectable performance cost.
I haven't checked against larger corpuses.

Fixes #1941
jiahansu pushed a commit to OOPRY/whisper.cpp that referenced this issue Apr 17, 2024
All else being otherwise equal, this encourages the beam candidate
selection to re-use the same decoder, which slightly
reduces the cache size.

I wouldn't expect it to make much of a performance difference,
but it helps when debug printing the cache and beam.

Added as part of understanding ggerganov#1941.
jiahansu pushed a commit to OOPRY/whisper.cpp that referenced this issue Apr 17, 2024
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking.
As a result, depending on their location in the batch,
identical sequences in a batch can have slightly different outputs
due to floating point rounding errors during reduction.
See the discussion in ggerganov#1941 for more details.

The beam search code used "has identical sum of log probabilities"
as a shorthand for "is an identical token sequence". However, per above,
identical tokens do not necessarily result in identical probabilities.

Instead, explicitly compare on sequences.
This is linear in cost when they are identical,
but the lengths are always small and the comparisons are cheap.

This increases diversity during beam search.

This improves output quality for some short samples I've been working
with, at no detectable performance cost.
I haven't checked against larger corpuses.

Fixes ggerganov#1941
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants