Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
f7cbee7 to
79c4729
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Great work!
Next is llm eval with paged vs no paged ? (model perf testing)
src/transformers/generation/continuous_batching/cache_manager.py
Outdated
Show resolved
Hide resolved
| read_index: List of tensors indicating which cache positions to read from, one per attention group. | ||
| logits_indices: Tensor indicating which positions in the output should be used for next-token prediction. | ||
| cache: The [`PagedAttentionCache`] instance managing the KV cache. | ||
| block_table: Block table for paged KV cache. If provided, uses `flash_attn_with_kvcache` for fused attention + cache update. |
There was a problem hiding this comment.
would help a LOT if you can show what it should look like etc. I remember trying to figure it out from code only and it can be super annoying!
The better the doc the happier myself
There was a problem hiding this comment.
I added documentation in the docstring, but no drawing / concrete example just there. The reason is we are adding support for compile soon, and that will involve a dedicated wrapper. I think the cleanest would be to add a concrete example / ASCII drawing there. Leaving a TODO for this, and the current documentation should make things understandable at least/
| elif isinstance(self.eos_token_id, int): | ||
| if self.eos_token_id >= 0: | ||
| self._eos_token_ids.add(self.eos_token_id) | ||
| # If there are multiple EOS token IDs, add them to the set only if they are valid, ie. non-negative |
There was a problem hiding this comment.
never saw that happenign
There was a problem hiding this comment.
this handle the case where we set eos_token_id=-1 for infinite generation
vasqu
left a comment
There was a problem hiding this comment.
Looks fine on my side, I think we can refactor the paged fa into the general fa modeling as we duplicate a lot of effort on processing kwargs.
| attn_output = attn_output[0] | ||
| # Reshape output from [batch_size, 1, num_heads, head_dim] to [batch_size, num_heads, head_dim] | ||
| attn_output = attn_output.squeeze(1) | ||
| return attn_output, None |
There was a problem hiding this comment.
General comment on maybe refactoring this. Not a must for this PR and maybe I'll tackle this in a different PR. Atm, we manually do a lot of the kwargs processing which can already be done in the base fa forward here
We should enter the right branch automatically due to the fa kwargs being available. This would leverage
which you should not need to handle hereIt would probably require some additional logic to add the kv cache branch and some kwargs processing (force int dtype etc). Wdyt?
There was a problem hiding this comment.
Agreed, it would make things easier in general. Though since it's a pretty different behavior when changing from one function to the other, we would need to document this quite carefully!
There was a problem hiding this comment.
out of scope for this PR as discussed. yea agreed it might be behaving differently so needs some docs then
| # This function is not implemented but should never be called because block table is not used on NPU | ||
| def npu_flash_attn_with_kvcache(): | ||
| raise NotImplementedError("npu_flash_attn_with_kvcache is not implemented") |
Summary
! This PR is in draft, waiting for #44227 to be merged
This PR adds support for the
flash_attention_with_kvcachekernel in continuoys batching. This is very efficient for decode-only batches, hence especially useful for long generations.Right now, this needs to be explicitly turned on by the user, because there are some divergence in generation. Once this is fixed / confirmed to be expected, it will be turned on by default.
This PR also fixes a few things:
Performance
Performance on long generations:
Tests
All tests, including the ones added for this feature, pass.