-
Notifications
You must be signed in to change notification settings - Fork 778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Processing text prompts in batches for LLMs #2108
Comments
Sounds like a good thing to add (as mentioned on discord), but rather than doing a |
Okay, that is a cleaner choice. After changing the hashmap to have a (usize, usize) key ... masks: HashMap<(usize, usize), Tensor>, ... this implementation of mask works as you'd want, I think: fn mask(&mut self, t: usize, u: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&(t, u)) {
Ok(mask.clone())
} else {
let mask: Vec<_> =
(0..t).flat_map(|i|
(0..u).map(move |j|
u8::from(j > i+(u-t))))
.collect();
let mask = Tensor::from_slice(&mask, (t, u), &self.device)?;
self.masks.insert((t, u), mask.clone());
Ok(mask)
}
} Then, in
Tested it on batch sizes of 1, 64, 73, 128, 256, 1024 with a prompt of 892 tokens and seems to work well, though batch size of one takes ages, as expected. |
Cool, happy to get a PR for this if you want to make one, one nitpick is that I would suggest using |
Nothing left to ask, so I'm closing the issue. |
When prompts get longer than trivial sizes, the memory usage spikes as the prompt is thrown into one Tensor and sent off to a forward pass in the model at whatever length it comes in as. These spikes can be reduced by processing the batch in chunks.
The implementation of CausalSelfAttention for Llama inside candle-transformers only handles two cases, a special case of seq_len 1, which occurs while generating text, and a seq_len that matches the whole prompt size with an index_pos of 0 which occurs once when processing the whole prompt. If you attempt to chunk the prompt up into blocks of tokens, when processing the second chunk there will be a broadcasting error because the mask that gets generated is sized to the prompt chunk, but the kv cache has altered the size of k & v to include the previous data and the shapes won't match. (E.g. for a chunk size of 128, the mask will be [128, 128] but on second chunk, the att Tensor ends up as [1, 32, 128, 256])
This can be fixed by creating the mask in a different way:
I'm not sure how efficient that is, but it produces the same results when processing the prompt in batches vs sending it all in at once. This also is only drop in for the Llama model since it has a kv cache; quantized_llama.rs's
ModelWeights
doesn't have a kv cache to modify...I have a modified llama example with the batch processing and this change in the model struct that I could submit for a PR if you'd like to see it, but the above code is all that's needed in candle-transformers.
The text was updated successfully, but these errors were encountered: