Skip to content

Sequence parallel prefill attention kernel#2

Merged
ivanium merged 67 commits intomainfrom
seq-parallel-kernel
Jul 31, 2024
Merged

Sequence parallel prefill attention kernel#2
ivanium merged 67 commits intomainfrom
seq-parallel-kernel

Conversation

@ivanium
Copy link
Copy Markdown
Owner

@ivanium ivanium commented Jul 21, 2024

This PR mainly implemented RadixAttention.seq_parallel_extend_forward_flashinfer() method. We adopted the parallelization strategy discussed before and overlapped communication and computation in each iteration within the ring attention algorithm. Specifically, each SP worker has:

  • q tensor: [batch_size, seq_len, q_head_num // SP_SIZE, head_dim] # partitioned along the q_head_num dimension. Assuming that q_head_num is divisible by SP_SIZE.
  • k tensor: [batch_size, seq_len // SP_SIZE, k_head_num, head_dim] # Here seq_len // SP_SIZE should be adjusted correspondingly when seq_len cannot be divisible by SP_SIZE.
  • v tensor: [batch_size, seq_len // SP_SIZE, v_head_num, head_dim] # Same as k tensor

NOTE: for now the kernel is only tested when seq_len can be divided by SP_SIZE. Will update the code later.

To balance the workload of all workers, we schedule the computation tasks in the following way:
At iteration i (starting from 0), each SP worker will compute the self-attention for its currently active shard (calling ragged attention kernel), and i number of cross-shard attention (calling paged attention kernel). SP workers have a perfectly balanced workload in each iteration, although the workload per iteration increases step by step. We will need to further investigate the performance impact of this design.

ivanium and others added 30 commits July 19, 2024 13:33
…ement a sequence parallel kernel. Verified with 2 sp workers
TODO: turn communication into async fashion and overlap it with computation
… kv cache management before testing because we haven't implemented kv cache management for seq parallel yet
…ODO: fix the bug that causes communication hang.
@ivanium ivanium mentioned this pull request Jul 28, 2024
@ivanium ivanium changed the title Sequence parallel attention kernel [WIP] Sequence parallel attention kernel Jul 28, 2024
Comment on lines +967 to +972
seq_lens_cpu[i] = seq_len // model_runner.sp_size + (
seq_len % model_runner.sp_size > model_runner.sp_rank
)
prefix_lens_cpu[i] = prefix_len // model_runner.sp_size + (
prefix_len % model_runner.sp_size > model_runner.sp_rank
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#3 is merged to main, let's resolve this one

return prev_rank


def get_actual_tensor_model_parallel_world_size():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend actual -> kv but it's just a naming issue

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. I will fix this.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

(existing_sid, sid) if existing_sid > sid else (sid, existing_sid)
)
q_data = qs[i]
kv_data = torch.stack(owned_shards[j], dim=1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a stack, can we do: 1. allocate kv_data together; 2. construct a view for k_data and v_data? (k_data = kv_data[:, :k_head, :], v_data = kv_data[:, k_head:, :]) This reduces a potentially large memcpy

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have used the ragged attention kernel here which avoids creating this redundant kv_data. Will push a fix to this soon.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines +995 to +997
kv_indices = torch.arange(
0, torch.sum(seq_lens), dtype=torch.int32, device="cuda"
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems pretty hacky and I don't understand why arange is correct...

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will get rid of this magic after switching to use ragged attention for non-causal attention parts.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Copy Markdown
Collaborator

@ZYHowell ZYHowell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm shall we merge this and do decoding in the next pr?

@ivanium
Copy link
Copy Markdown
Owner Author

ivanium commented Jul 31, 2024

Sounds good

@ivanium ivanium changed the title [WIP] Sequence parallel attention kernel Sequence parallel prefill attention kernel Jul 31, 2024
@ivanium ivanium merged commit 98c1154 into main Jul 31, 2024
@ivanium ivanium deleted the seq-parallel-kernel branch September 6, 2024 23:54
ivanium added a commit that referenced this pull request Sep 7, 2024
* test: test cases of combining multiple attention kernel calls to implement a sequence parallel kernel. Verified with 2 sp workers

* fix: simplify flashinfer kernel initialization (begin_forward() and end_forward())

* test: add logic for sp worker 1 which is basically the same but with different orders of kernel calls

* chore: format tweak

* feat: a general seq parallel attention kernel that achieves workload balance

* fix: minor tweak loop iteration within ring attention

* feat [radix_attention]: seq_parallel kernel with sync communication.

TODO: turn communication into async fashion and overlap it with computation

* test: update test cases for seq parallel attn kernel. Need to disable kv cache management before testing because we haven't implemented kv cache management for seq parallel yet

* chore [radix_attention]: format tweak

* feat: async communication within ring attention

* fix [parallel_utils]: add missed files

* fix [infer_batch]: set default values for newly added sp-related metadata

* fix [bench_latency]: minor fixes to input args

* feat [parallel_utils]: get actual tp rank and size when both TP and SP are enabled

* feat [linear]: add QKVParallelLinear

* feat [llama2]: update llama model to use our QKVParallelLinear

* feat [model_runner]: initialize model parallel with sequence parallel

* fix [infer_batch]: 1. a minor issue when calling get_prefill_indices; 2. flashinfer intialization args

* fix [bench_latency]: load model with sp_rank

* feat [radix_attention]: automatically dispatch to seq-parallel attn kernel when sp_size > 1

* debug: stash current debug changes

* fix [radix_attention]: reshape q tensor before running the kernel

* bug fix for sp layout types

* fix: adjust tensor layout. TODO: fix many dirty hacks and hardcoded values

* fix [wip]: disable p2p communication within ring attention for now. TODO: fix the bug that causes communication hang.

* chore [bench_latency]: disable decode for now since we haven't supported it

* upstream with correct prefill sp layout

* fix early exit on decode SP

* chore: tweak format

* update layout

* bug fix

* fix [linear, radix_attention]: fix q head indexes per SP worker to align with GQA setting.

* fix [infer_batch]: set up flashinfer kernels for the batch size > 1 case

* chore: tweak format

* fix [radix_attention]: revert commented-out kv cache store operations in normal attention

* fix: adjust k, v tensor shape to align with both TP and SP setting

* chore [llama2]: minor adjustment

* fix: update bench_latency to evenly distribute each sequence across all SP workers to avoid the layout issue

* test: update test cases to align with current kernel in args

* fix [model_runner]: initialize TokenToKVPool with correct num_heads and enable KV cache store in SP attention

* chore [radix_attention]: clean up comments

* fix [model_runner]: correct num_heads in memory profiling as well to avoid OOM

* fix [infer_batch]: adopt SP KV cache allocation

* feat [linear]: correctly partition q proj along the num_heads dimension with GQA

* chore [llama2]: clean up stable variables

* feat [infer_batch]: adjust positions to SP layout when preparing input_metadata

* feat [infer_batch]: use dedicate paged attn kernel for cross-SP-shard attn

* feat [parallel_state]: creat sequence parallel comm groups

* test [sp_comm_group]: simple test case with sp_size = 2

* doc [parallel_state]: doc string for our SP group organization

* fix [infer_batch]: add padding zeros to positions tensor and out_cache_loc to fix positional encoding and KV cache store

* feat [radix_attn, infer_batch]: create masks for padded sequences and now attn works for unevenly-distributed sequenses too

* chore [bench_latency]: revert original prompts

* fix [parallel_state]: rename "actual" to "kv"

* refactor [radix_attention]: unified two cases with differnt comm-comp tradeoffs

* chore: rename "actual_tp_[size|rank]" to "kv_tp_[size|rank]"

* fix [infer_batch]: ensure prefix_lens is not None in init_flashinfer_args

* fix [infer_batch]: only pad positions and out_cache_loc for prefill

* chore [linear]: clean up and revise comments

* chore [parallel_state]: revise comments

* chore [linear]: revise comments and class names

* chore [radix_attention]: add defensive checks

---------

Co-authored-by: ZYHowell <yhzhuang@cmu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants