diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 6ce6e9e95b..ef5c4c7451 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -61,13 +61,14 @@ struct Options { bool is_causal; bool varlen = false; std::string scheduler; + std::string layout; int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; float softmax_scale; Options() : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), - seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual"), layout("NHD") {} // Parses the command line void parse(int argc, char const **args) { @@ -87,6 +88,12 @@ struct Options { } cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); + cmd.get_cmd_line_argument("layout", layout, std::string("NHD")); + if (layout != "NHD" && layout != "HND") { + std::cerr << "Invalid --layout option: " << layout + << ". Supported values are NHD and HND." << std::endl; + return; + } cmd.get_cmd_line_argument("batch", batch, 32); cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); @@ -113,6 +120,7 @@ struct Options { << " --is_causal Apply Causal Mask to the output of first Matmul\n" << " --varlen Enable variable sequence length\n" << " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n" + << " --layout=\"Value\" NHD:(seq_len, num_heads, head_dim), HND:(num_heads, seq_len, head_dim), default layout is NHD\n" << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" << " --num_heads_q= Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n" << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" @@ -197,7 +205,7 @@ template struct ExampleRunner { // Methods // - bool verify(ProblemShapeType shape, bool is_causal) { + bool verify(ProblemShapeType shape, const Options &options) { auto batch = shape.batch; auto num_heads_q = shape.num_heads_q; @@ -210,6 +218,7 @@ template struct ExampleRunner { auto block_Q_ = in_memory(block_Q); auto block_K_ = in_memory(block_K); auto block_V_ = in_memory(block_V); + std::vector host_O(block_ref_O.size()); using ElementV_ = std::remove_pointer_t; @@ -220,143 +229,158 @@ template struct ExampleRunner { // loop over the batch dimension to compute the output // to avoid the risk of running out of device memory - int q_group_size = num_heads_q/num_heads_kv; + int q_group_size = num_heads_q / num_heads_kv; for (int b = 0; b < batch; b++) { - int kv_group_update=1; - for (int h = 0; h < num_heads_q; h++) { - cutlass::DeviceAllocation block_S; - block_S.reset(seq_len_qo * seq_len_kv); - - cutlass::TensorRef ref_Q(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); - cutlass::TensorRef ref_K(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); - cutlass::TensorRef ref_V(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); - cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); - - cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, - cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, - 0.f, ref_S, ref_S, ElementS(0), - 1, // batch_count - seq_len_qo * head_size_qk, // batch_stride_Q - seq_len_kv * head_size_qk, // batch_stride_K - seq_len_qo * seq_len_kv, // batch_stride_S - seq_len_qo * seq_len_kv // batch_stride_S - ); - - compat::wait(); - - std::vector host_S(block_S.size()); - compat::memcpy(host_S.data(), block_S.get(), host_S.size()); - - // delete this memory as it is no longer needed - block_S.reset(); - auto offset = cute::min(seq_len_qo, seq_len_kv); - auto discard_seq_coord = seq_len_qo - offset; - auto full_tile_offset = seq_len_kv - offset; - if (is_causal) { - // apply mask to S - for (int row = 0; row < seq_len_qo; row++) { - for (int col = 0; col < seq_len_kv; col++) { - if ((col - full_tile_offset) > (row - discard_seq_coord)) - host_S[col + row * seq_len_kv] = ElementS{-INFINITY}; + offset_q = b * seq_len_qo * num_heads_q * head_size_qk; + offset_k = b * seq_len_kv * num_heads_kv * head_size_qk; + offset_v = b * seq_len_kv * num_heads_kv * head_size_vo; + offset_o = b * seq_len_qo * num_heads_q * head_size_vo; + for (int q_group = 0; q_group < num_heads_q / q_group_size; q_group++) { + for (int q_head = 0; q_head < q_group_size; q_head++) { + cutlass::DeviceAllocation block_S; + block_S.reset(seq_len_qo * seq_len_kv); + cutlass::TensorRef ref_Q; + cutlass::TensorRef ref_K; + cutlass::TensorRef ref_V; + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + if (options.layout == "NHD") { + ref_Q = cutlass::TensorRef(block_Q_.get() + offset_q, LayoutQ(num_heads_q * head_size_qk)); + ref_K = cutlass::TensorRef(block_K_.get() + offset_k, LayoutK(num_heads_kv * head_size_qk)); + ref_V = cutlass::TensorRef(block_V_.get() + offset_v, LayoutV(num_heads_kv * head_size_vo)); + } else if (options.layout == "HND") { + ref_Q = cutlass::TensorRef(block_Q_.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); + ref_K = cutlass::TensorRef(block_K_.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); + ref_V = cutlass::TensorRef(block_V_.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); + } + + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, + cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, + 0.f, ref_S, ref_S, ElementS(0)); + + compat::wait(); + + std::vector host_S(block_S.size()); + compat::memcpy(host_S.data(), block_S.get(), host_S.size()); + + // delete this memory as it is no longer needed + block_S.reset(); + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + if (options.is_causal) { + // apply mask to S + for (int row = 0; row < seq_len_qo; row++) { + for (int col = 0; col < seq_len_kv; col++) { + if ((col - full_tile_offset) > (row - discard_seq_coord)) + host_S[col + row * seq_len_kv] = ElementS{-INFINITY}; + } } } - } - // compute max element per row of S - std::vector max_vec(seq_len_qo, ElementS{-INFINITY}); - for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; - int max_idx = row; - max_vec[max_idx] = host_S[idx++]; - for (int col = 1; col < seq_len_kv; col++, idx++) { - if (max_vec[max_idx] < host_S[idx]) - max_vec[max_idx] = host_S[idx]; + // compute max element per row of S + std::vector max_vec(seq_len_qo, ElementS{-INFINITY}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int max_idx = row; + max_vec[max_idx] = host_S[idx++]; + for (int col = 1; col < seq_len_kv; col++, idx++) { + if (max_vec[max_idx] < host_S[idx]) + max_vec[max_idx] = host_S[idx]; + } } - } - // compute exp of S - for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; - int max_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { - /* FIXME: use softmax_scale instead of assuming its value here */ - host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); + // compute exp of S + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int max_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + /* FIXME: use softmax_scale instead of assuming its value here */ + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); + } } - } - // compute sum per row of S - std::vector sum_vec(seq_len_qo, ElementS{0}); - for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; - int sum_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { - sum_vec[sum_idx] += host_S[idx]; - } + // compute sum per row of S + std::vector sum_vec(seq_len_qo, ElementS{0}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv; + int sum_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + sum_vec[sum_idx] += host_S[idx]; + } - // scale each row with the sum to compute softmax - idx = row * seq_len_kv; - sum_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { - if(is_causal && row < discard_seq_coord) { - host_S[idx] = 0; - } else { - host_S[idx] /= sum_vec[sum_idx]; + // scale each row with the sum to compute softmax + idx = row * seq_len_kv; + sum_idx = row; + for (int col = 0; col < seq_len_kv; col++, idx++) { + if (options.is_causal && row < discard_seq_coord) { + host_S[idx] = 0; + } else { + host_S[idx] /= sum_vec[sum_idx]; + } } } - } - - std::vector host_P(host_S.size()); - for (int p = 0; p < host_P.size(); p++) - host_P[p] = static_cast(host_S[p]); - cutlass::DeviceAllocation block_P; - block_P.reset(host_P.size()); + std::vector host_P(host_S.size()); + for (int p = 0; p < host_P.size(); p++) + host_P[p] = static_cast(host_S[p]); - compat::memcpy(block_P.get(), host_P.data(), host_P.size()); + cutlass::DeviceAllocation block_P; + block_P.reset(host_P.size()); - cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + compat::memcpy(block_P.get(), host_P.data(), host_P.size()); - cutlass::DeviceAllocation block_acc; - block_acc.reset(seq_len_qo * head_size_vo); - cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); - cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementS{1}, ref_P, - cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, - ElementS{0}, ref_acc, ref_acc, ElementS{0}, - 1, // batch_count - seq_len_qo * seq_len_kv, // batch_stride_P - seq_len_kv * head_size_vo, // batch_stride_V - seq_len_qo * head_size_vo, // batch_stride_O - seq_len_qo * head_size_vo // batch_stride_O - ); + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); - compat::wait(); - // delete this memory as it is no longer needed - block_P.reset(); + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, ElementS{1}, ref_P, + cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, + ElementS{0}, ref_acc, ref_acc, ElementS{0}); - std::vector vec_acc(block_acc.size()); - compat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + compat::wait(); + // delete this memory as it is no longer needed + block_P.reset(); - // delete this memory as it is no longer needed - block_acc.reset(); - std::vector vec_out(vec_acc.size()); - for(int i = 0; i < vec_out.size(); i++) { - vec_out[i] = static_cast(vec_acc[i]); - } - compat::memcpy(block_ref_O.get() + offset_o, vec_out.data(), vec_out.size()); + std::vector vec_acc(block_acc.size()); + compat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); - offset_q += seq_len_qo * head_size_qk; - if(kv_group_update % q_group_size==0) { - offset_k += seq_len_kv * head_size_qk; - offset_v += seq_len_kv * head_size_vo; + // delete this memory as it is no longer needed + block_acc.reset(); + // reorder host_O HND->NHD + for (int seq = 0; seq < seq_len_qo; seq++) { + for (int hvo = 0; hvo < head_size_vo; hvo++) { + if (options.layout == "NHD") { + int idx = offset_o + seq * num_heads_q * head_size_vo + (q_group * q_group_size + q_head) * head_size_vo + hvo; + host_O[idx] = static_cast(vec_acc[seq * head_size_vo + hvo]); + } else if (options.layout == "HND") { + int idx = offset_o + (q_group * q_group_size + q_head) * seq_len_qo * head_size_vo + seq * head_size_vo + hvo; + host_O[idx] = static_cast(vec_acc[seq * head_size_vo + hvo]); + } + } + } + + if (options.layout == "NHD") { + offset_q += head_size_qk; + } else if (options.layout == "HND") { + offset_q += head_size_qk * seq_len_qo; + } + } // end of q_group loop + { + if (options.layout == "NHD") { + offset_k += head_size_qk; + offset_v += head_size_vo; + } else if (options.layout == "HND") { + offset_k += head_size_qk * seq_len_kv; + offset_v += head_size_vo * seq_len_kv; + } } - kv_group_update++; - offset_o += seq_len_qo * head_size_vo; - } - } + } // end of q_head loop + } // end of batch loop + compat::memcpy(block_ref_O.get(), host_O.data(), host_O.size()); compat::wait(); - // Check if output from CUTLASS kernel and reference kernel are equal or not bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), block_O.size(), ElementO{0.005}, ElementO{0.005}); @@ -377,16 +401,31 @@ template struct ExampleRunner { // Set up strides. // These lines can be adjusted to support different data layouts, as needed. - stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch)); - stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch)); - stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch)); - stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch)); - - block_Q.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_qk); - block_K.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_qk); - block_V.reset(static_cast(batch) * num_heads_kv * seq_len_kv * head_size_vo); - block_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); - block_ref_O.reset(static_cast(batch) * num_heads_q * seq_len_qo * head_size_vo); + + // The shape order in the kernel is + // auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); + // auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch); + // auto shape_V = make_shape(s.head_size_vo, s.seq_len_kv, s.num_heads_kv, s.batch); + // auto shape_O = make_shape(s.seq_len_qo, s.head_size_vo, s.num_heads_kv, s.batch); + // The stride needs to match the shape. + if (options.layout == "NHD") { + // {batch, seq_len, num_heads, head_dim} + stride_Q = cutlass::make_stride(num_heads_q * head_size_qk, Int<1>{}, head_size_qk, head_size_qk * num_heads_q * seq_len_qo); + stride_K = cutlass::make_stride(num_heads_kv * head_size_qk, Int<1>{}, head_size_qk, head_size_qk * num_heads_kv * seq_len_kv); + stride_V = cutlass::make_stride(Int<1>{}, num_heads_kv * head_size_vo, head_size_vo, head_size_vo * num_heads_kv * seq_len_kv); + stride_O = cutlass::make_stride(num_heads_q * head_size_vo, Int<1>{}, head_size_vo, head_size_vo * num_heads_q * seq_len_qo); + } else if (options.layout == "HND") { + // {batch, num_heads, seq_len, head_dim} + stride_Q = cutlass::make_stride(head_size_qk, Int<1>{}, head_size_qk * seq_len_qo, head_size_qk * num_heads_q * seq_len_qo); + stride_K = cutlass::make_stride(head_size_qk, Int<1>{}, head_size_qk * seq_len_kv, head_size_qk * num_heads_kv * seq_len_kv); + stride_V = cutlass::make_stride(Int<1>{}, head_size_vo, head_size_vo * seq_len_kv, head_size_vo * num_heads_kv * seq_len_kv); + stride_O = cutlass::make_stride(head_size_vo, Int<1>{}, head_size_vo * seq_len_qo, head_size_vo * num_heads_q * seq_len_qo); + } + block_Q.reset(static_cast(batch) * seq_len_qo * num_heads_q * head_size_qk); + block_K.reset(static_cast(batch) * seq_len_kv * num_heads_kv * head_size_qk); + block_V.reset(static_cast(batch) * seq_len_kv * num_heads_kv * head_size_vo); + block_O.reset(static_cast(batch) * seq_len_qo * num_heads_q * head_size_vo); + block_ref_O.reset(static_cast(batch) * seq_len_qo * num_heads_q * head_size_vo); initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); @@ -447,7 +486,7 @@ template struct ExampleRunner { cutlass::device_memory::allocation workspace(workspace_size); if (!FMHAKernel::can_implement(arguments)) { - std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << + std::cerr << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << options.seq_len_qo << 'x' << options.seq_len_kv << 'x' << options.head_size_qk << 'x' << options.head_size_vo << (options.is_causal ? "xCausal" : "xNonCausal") << std::endl; return cutlass::Status::kErrorInvalidProblem; @@ -465,7 +504,7 @@ template struct ExampleRunner { compat::wait(); // Verify that the result is correct - bool passed = verify(shape, options.is_causal); + bool passed = verify(shape, options); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) { @@ -499,7 +538,7 @@ template struct ExampleRunner { std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") - << "\t Scheduler: " << options.scheduler; + << "\t Scheduler: " << options.scheduler << "\tLayout: " << options.layout; printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); }