Sequence parallel prefill attention kernel#2
Conversation
…ement a sequence parallel kernel. Verified with 2 sp workers
…different orders of kernel calls
TODO: turn communication into async fashion and overlap it with computation
… kv cache management before testing because we haven't implemented kv cache management for seq parallel yet
… 2. flashinfer intialization args
…ernel when sp_size > 1
…ODO: fix the bug that causes communication hang.
| seq_lens_cpu[i] = seq_len // model_runner.sp_size + ( | ||
| seq_len % model_runner.sp_size > model_runner.sp_rank | ||
| ) | ||
| prefix_lens_cpu[i] = prefix_len // model_runner.sp_size + ( | ||
| prefix_len % model_runner.sp_size > model_runner.sp_rank | ||
| ) |
| return prev_rank | ||
|
|
||
|
|
||
| def get_actual_tensor_model_parallel_world_size(): |
There was a problem hiding this comment.
I'd recommend actual -> kv but it's just a naming issue
There was a problem hiding this comment.
Make sense. I will fix this.
| (existing_sid, sid) if existing_sid > sid else (sid, existing_sid) | ||
| ) | ||
| q_data = qs[i] | ||
| kv_data = torch.stack(owned_shards[j], dim=1) |
There was a problem hiding this comment.
instead of a stack, can we do: 1. allocate kv_data together; 2. construct a view for k_data and v_data? (k_data = kv_data[:, :k_head, :], v_data = kv_data[:, k_head:, :]) This reduces a potentially large memcpy
There was a problem hiding this comment.
I should have used the ragged attention kernel here which avoids creating this redundant kv_data. Will push a fix to this soon.
| kv_indices = torch.arange( | ||
| 0, torch.sum(seq_lens), dtype=torch.int32, device="cuda" | ||
| ) |
There was a problem hiding this comment.
this seems pretty hacky and I don't understand why arange is correct...
There was a problem hiding this comment.
We will get rid of this magic after switching to use ragged attention for non-causal attention parts.
…e_loc to fix positional encoding and KV cache store
… now attn works for unevenly-distributed sequenses too
ZYHowell
left a comment
There was a problem hiding this comment.
lgtm shall we merge this and do decoding in the next pr?
|
Sounds good |
* test: test cases of combining multiple attention kernel calls to implement a sequence parallel kernel. Verified with 2 sp workers * fix: simplify flashinfer kernel initialization (begin_forward() and end_forward()) * test: add logic for sp worker 1 which is basically the same but with different orders of kernel calls * chore: format tweak * feat: a general seq parallel attention kernel that achieves workload balance * fix: minor tweak loop iteration within ring attention * feat [radix_attention]: seq_parallel kernel with sync communication. TODO: turn communication into async fashion and overlap it with computation * test: update test cases for seq parallel attn kernel. Need to disable kv cache management before testing because we haven't implemented kv cache management for seq parallel yet * chore [radix_attention]: format tweak * feat: async communication within ring attention * fix [parallel_utils]: add missed files * fix [infer_batch]: set default values for newly added sp-related metadata * fix [bench_latency]: minor fixes to input args * feat [parallel_utils]: get actual tp rank and size when both TP and SP are enabled * feat [linear]: add QKVParallelLinear * feat [llama2]: update llama model to use our QKVParallelLinear * feat [model_runner]: initialize model parallel with sequence parallel * fix [infer_batch]: 1. a minor issue when calling get_prefill_indices; 2. flashinfer intialization args * fix [bench_latency]: load model with sp_rank * feat [radix_attention]: automatically dispatch to seq-parallel attn kernel when sp_size > 1 * debug: stash current debug changes * fix [radix_attention]: reshape q tensor before running the kernel * bug fix for sp layout types * fix: adjust tensor layout. TODO: fix many dirty hacks and hardcoded values * fix [wip]: disable p2p communication within ring attention for now. TODO: fix the bug that causes communication hang. * chore [bench_latency]: disable decode for now since we haven't supported it * upstream with correct prefill sp layout * fix early exit on decode SP * chore: tweak format * update layout * bug fix * fix [linear, radix_attention]: fix q head indexes per SP worker to align with GQA setting. * fix [infer_batch]: set up flashinfer kernels for the batch size > 1 case * chore: tweak format * fix [radix_attention]: revert commented-out kv cache store operations in normal attention * fix: adjust k, v tensor shape to align with both TP and SP setting * chore [llama2]: minor adjustment * fix: update bench_latency to evenly distribute each sequence across all SP workers to avoid the layout issue * test: update test cases to align with current kernel in args * fix [model_runner]: initialize TokenToKVPool with correct num_heads and enable KV cache store in SP attention * chore [radix_attention]: clean up comments * fix [model_runner]: correct num_heads in memory profiling as well to avoid OOM * fix [infer_batch]: adopt SP KV cache allocation * feat [linear]: correctly partition q proj along the num_heads dimension with GQA * chore [llama2]: clean up stable variables * feat [infer_batch]: adjust positions to SP layout when preparing input_metadata * feat [infer_batch]: use dedicate paged attn kernel for cross-SP-shard attn * feat [parallel_state]: creat sequence parallel comm groups * test [sp_comm_group]: simple test case with sp_size = 2 * doc [parallel_state]: doc string for our SP group organization * fix [infer_batch]: add padding zeros to positions tensor and out_cache_loc to fix positional encoding and KV cache store * feat [radix_attn, infer_batch]: create masks for padded sequences and now attn works for unevenly-distributed sequenses too * chore [bench_latency]: revert original prompts * fix [parallel_state]: rename "actual" to "kv" * refactor [radix_attention]: unified two cases with differnt comm-comp tradeoffs * chore: rename "actual_tp_[size|rank]" to "kv_tp_[size|rank]" * fix [infer_batch]: ensure prefix_lens is not None in init_flashinfer_args * fix [infer_batch]: only pad positions and out_cache_loc for prefill * chore [linear]: clean up and revise comments * chore [parallel_state]: revise comments * chore [linear]: revise comments and class names * chore [radix_attention]: add defensive checks --------- Co-authored-by: ZYHowell <yhzhuang@cmu.edu>
This PR mainly implemented
RadixAttention.seq_parallel_extend_forward_flashinfer()method. We adopted the parallelization strategy discussed before and overlapped communication and computation in each iteration within the ring attention algorithm. Specifically, each SP worker has:q_head_numdimension. Assuming thatq_head_numis divisible bySP_SIZE.seq_len // SP_SIZEshould be adjusted correspondingly whenseq_lencannot be divisible bySP_SIZE.NOTE: for now the kernel is only tested when
seq_lencan be divided bySP_SIZE. Will update the code later.To balance the workload of all workers, we schedule the computation tasks in the following way:
At iteration
i(starting from 0), each SP worker will compute the self-attention for its currently active shard (calling ragged attention kernel), andinumber of cross-shard attention (calling paged attention kernel). SP workers have a perfectly balanced workload in each iteration, although the workload per iteration increases step by step. We will need to further investigate the performance impact of this design.