Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

broadcast_as error when processing multiple tokens at once in quantized example #2153

Open
EricLBuehler opened this issue May 2, 2024 · 9 comments

Comments

@EricLBuehler
Copy link
Member

EricLBuehler commented May 2, 2024

Hello all,

Thanks for your great work here. We are implementing speculative decoding at mistral.rs, and were in the final stages of testing when we discovered some incredibly strange behavior. Specifically, the following error results when sending multiple tokens at once during the completions steps:

Error: cannot broadcast [3, 3] to [1, 32, 3, 5]

Reproducing this error is simple:

In the quantized/main.rs:578:

-           let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
+           let input = Tensor::new(&[next_token, next_token, next_token], &device)?.unsqueeze(0)?;

Is this a bug?

@EricLBuehler EricLBuehler changed the title [Possible Bug] Error: cannot broadcast [3, 3] to [1, 32, 3, 5] in quantized example [Possible Bug] broadcast_as error when processing multiple tokens at once in quantized example May 2, 2024
@EricLBuehler
Copy link
Member Author

@LaurentMazare, is this a mistake on my part?

@LaurentMazare
Copy link
Collaborator

Not sure to understand, this model has been designed to be passed a prompt the one token at a time, so it fails if after the prompt you pass it multiple tokens at once which is somewhat expected. Do you mean that the error message should be more explicit about why this is failing?

@EricLBuehler
Copy link
Member Author

For speculative decoding, we need to run the target model with multiple tokens at once, once per step. If we need to run the target model with a full prompt, that would be a big performance hit, which is why I tried to do this. Is there some workaround, like disabling the attention mask?

@LaurentMazare
Copy link
Collaborator

I think disabling the attention mask would be incorrect, you want the tokens in the batch you're processing to be causal between them and to be able to attend to all tokens in the kv cache. So you would want a mask that is rectangular rather than square based on how many tokens are in the kv-caches at the moment, and it should look like the following for a batch of 4 tokens and a kv cache that already has 5 tokens processed.

00000111
00000011
00000001
00000000

@EricLBuehler
Copy link
Member Author

Ok. Would this be similar to #2111?

@EricLBuehler EricLBuehler changed the title [Possible Bug] broadcast_as error when processing multiple tokens at once in quantized example broadcast_as error when processing multiple tokens at once in quantized example May 3, 2024
@LaurentMazare
Copy link
Collaborator

Indeed looks like the mask part at the bottom. Would be great if you can make a fresh PR with that change for the model that you care about.

@EricLBuehler
Copy link
Member Author

Ok, so just to confirm: it is this part?

https://github.com/huggingface/candle/pull/2111/files#diff-ed262e4bc9a4a093e64842a2f61a85e1713c4efde0618ac7b31ad58dc5d171e3R137-R149

I can add a PR for this to some of the models if you think it is a good idea.

@LaurentMazare
Copy link
Collaborator

Yep exactly this part, probably good to support for at least llama and quantized-llama (and others too but they might need a bit more work as the mask generation is different).

@EricLBuehler
Copy link
Member Author

I was able to make a general causal masker implementation here:

https://github.com/EricLBuehler/mistral.rs/blob/cc2f60a0bc4acfde636464ac408722335e0be732/mistralrs-core/src/layers.rs#L253

It works for all models with a causal/causal+sliding window mask. Should I submit this as a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants