Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Processing text prompts in batches for LLMs #2108

Closed
tbogdala opened this issue Apr 22, 2024 · 4 comments
Closed

Processing text prompts in batches for LLMs #2108

tbogdala opened this issue Apr 22, 2024 · 4 comments

Comments

@tbogdala
Copy link

When prompts get longer than trivial sizes, the memory usage spikes as the prompt is thrown into one Tensor and sent off to a forward pass in the model at whatever length it comes in as. These spikes can be reduced by processing the batch in chunks.

The implementation of CausalSelfAttention for Llama inside candle-transformers only handles two cases, a special case of seq_len 1, which occurs while generating text, and a seq_len that matches the whole prompt size with an index_pos of 0 which occurs once when processing the whole prompt. If you attempt to chunk the prompt up into blocks of tokens, when processing the second chunk there will be a broadcasting error because the mask that gets generated is sized to the prompt chunk, but the kv cache has altered the size of k & v to include the previous data and the shapes won't match. (E.g. for a chunk size of 128, the mask will be [128, 128] but on second chunk, the att Tensor ends up as [1, 32, 128, 256])

This can be fixed by creating the mask in a different way:

let mut mask = cache.mask(seq_len)?;
if index_pos != 0 {
    let zero_history = Tensor::zeros((seq_len, (index_pos / seq_len) * seq_len), mask.dtype(), mask.device())?;
    mask = Tensor::cat(&[zero_history, mask], 1)?;
}
mask = mask.broadcast_as(att.shape())?;

I'm not sure how efficient that is, but it produces the same results when processing the prompt in batches vs sending it all in at once. This also is only drop in for the Llama model since it has a kv cache; quantized_llama.rs's ModelWeights doesn't have a kv cache to modify...

I have a modified llama example with the batch processing and this change in the model struct that I could submit for a PR if you'd like to see it, but the above code is all that's needed in candle-transformers.

@LaurentMazare
Copy link
Collaborator

Sounds like a good thing to add (as mentioned on discord), but rather than doing a cat whatabout just modifying the way mask is defined so that it creates the appropriate maks from the beginning - obviously you will have to pass index_pos to the mask function for that but then it should be pretty straightforward.

@tbogdala
Copy link
Author

Okay, that is a cleaner choice. After changing the hashmap to have a (usize, usize) key ...

masks: HashMap<(usize, usize), Tensor>,

... this implementation of mask works as you'd want, I think:

fn mask(&mut self, t: usize, u: usize) -> Result<Tensor> {
        if let Some(mask) = self.masks.get(&(t, u)) {
            Ok(mask.clone())
        } else {
            let mask: Vec<_> = 
                (0..t).flat_map(|i| 
                    (0..u).map(move |j| 
                        u8::from(j > i+(u-t))))
                    .collect();
            let mask = Tensor::from_slice(&mask, (t, u), &self.device)?;
            self.masks.insert((t, u), mask.clone());
            Ok(mask)
        }
    }

Then, in forward the mask can get created more like it originally was:

let mask = cache.mask(seq_len, index_pos + seq_len)?.broadcast_as(att.shape())?;

Tested it on batch sizes of 1, 64, 73, 128, 256, 1024 with a prompt of 892 tokens and seems to work well, though batch size of one takes ages, as expected.

@LaurentMazare
Copy link
Collaborator

Cool, happy to get a PR for this if you want to make one, one nitpick is that I would suggest using j + t > i + u as a condition so as to avoid having to think about why u-t has to be positive (which is the case seeing how you call the function but would result in an underflow if it wasn't called properly).

@tbogdala
Copy link
Author

Nothing left to ask, so I'm closing the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants