Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/integration/fmha/fmha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) {
false,
kSeqLast,
false,
false,
false>;
using accum_t = typename fmha_forward_op_t::accum_t;

Expand Down Expand Up @@ -346,6 +347,8 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) {
kUseBias ? klen_pad32 * qlen : 0,
kUseBias ? 0 : 0, // broadcast on N (head num)
kUseBias ? klen_pad32 : 0,
nullptr,
nullptr,
softmax_scale,
0,
0,
Expand Down
142 changes: 113 additions & 29 deletions tests/integration/fmha/fmha_forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ Fused Multi-Head Attention Forward
This is an implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf)
*/
#include <sys/types.h>
#include <limits>
#include "fmha_forward_policy.h"
#include "fmha_utils.h"

namespace gpu::xetla {
namespace fmha {

template <
typename fmha_policy,
typename scalar_t,
Expand All @@ -21,7 +23,8 @@ template <
bool kIsCausal,
bool kSeqLast,
bool kIsTraining,
bool kIsDropout>
bool kIsDropout,
bool kVarlen>
class fmha_forward_t {
public:
using accum_t = float;
Expand All @@ -47,6 +50,9 @@ class fmha_forward_t {
uint32_t bias_strideB;
uint32_t bias_strideN;
uint32_t bias_strideF;
// Sequence length info
int32_t* cu_seqlen_q;
int32_t* cu_seqlen_k;
// Softmax scale is the reciprocal square root of head size by default
accum_t sm_scale;
// Dropout scale is computed from dropout prob
Expand Down Expand Up @@ -77,6 +83,8 @@ class fmha_forward_t {
uint32_t bias_strideB,
uint32_t bias_strideN,
uint32_t bias_strideF,
int32_t* cu_seqlen_q,
int32_t* cu_seqlen_k,
accum_t sm_scale,
accum_t dropout_prob,
uint32_t alibi_padded_block_size,
Expand All @@ -100,6 +108,8 @@ class fmha_forward_t {
bias_strideB(bias_strideB),
bias_strideN(bias_strideN),
bias_strideF(bias_strideF),
cu_seqlen_q(cu_seqlen_q),
cu_seqlen_k(cu_seqlen_k),
sm_scale(sm_scale),
dp_prob(dropout_prob),
dp_scale(1.f / (1.f - dropout_prob)),
Expand All @@ -115,31 +125,25 @@ class fmha_forward_t {
static constexpr uint32_t accum_step = fmha_policy::accum_step;
static constexpr uint32_t stages = fmha_policy::stages;
static constexpr uint32_t sync_freq = fmha_policy::sync_freq;
static constexpr uint32_t kBr = fmha_policy::kBr;
static constexpr uint32_t kBc = fmha_policy::kBc;
static constexpr uint32_t kHm = fmha_policy::kHm;
static constexpr uint32_t kSgBr = fmha_policy::kSgBr;
static constexpr uint32_t kSgBc = fmha_policy::kSgBc;
static constexpr uint32_t kSgHm = fmha_policy::kSgHm;

using comp_attr = std::conditional_t<
std::is_same_v<scalar_t, bf16> && (arch_tag < gpu_arch::XeHpc),
group::compute_attr_t<accum_t, accum_t, accum_t>,
group::compute_attr_t<scalar_t, scalar_t, accum_t>>;
using comp_attr = group::compute_attr_t<scalar_t, scalar_t, accum_t>;
using knobs = group::perf_tuning_knob_t<accum_step, stages, sync_freq>;

// use fpu when M==1 even if xmx is available
static constexpr bool _use_xmx = arch_tag >= gpu_arch::XeHpg && kSgBr != 1;
using compute_policy_BrBc = std::conditional_t<
_use_xmx,
(arch_tag >= gpu_arch::XeHpg),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// TODO(Yi): add k slicing?
// TODO: add k slicing
using compute_policy_BrBm = std::conditional_t<
_use_xmx,
(arch_tag >= gpu_arch::XeHpg),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// ---------------- // Tile shape and Threads // ---------------- //
static constexpr uint32_t kBr = fmha_policy::kBr;
static constexpr uint32_t kBc = fmha_policy::kBc;
static constexpr uint32_t kHm = fmha_policy::kHm;
static constexpr uint32_t kSgBr = fmha_policy::kSgBr;
static constexpr uint32_t kSgBc = fmha_policy::kSgBc;
static constexpr uint32_t kSgHm = fmha_policy::kSgHm;

using tile_shape_BrBc = group::tile_shape_t<kBc, kBr, kSgBc, kSgBr>;
using tile_shape_BrHm = group::tile_shape_t<kHm, kBr, kSgHm, kSgBr>;
Expand Down Expand Up @@ -268,6 +272,21 @@ class fmha_forward_t {
args.O_ptr,
{end_x, end_y, b_stride * args.uB},
{start_acc, start_y});
} else if constexpr (kVarlen) {
int32_t start_y = args.cu_seqlen_q[batch_id] + item.get_group(1) * kBr;
uint32_t end_y = start_y + kBr;
int32_t limit_y = args.cu_seqlen_q[batch_id + 1];
end_y = end_y < limit_y ? end_y : limit_y;

int32_t start_acc = head_id * args.uH;
uint32_t end_x = start_acc + args.uH;
const uint32_t ld_qo = args.uH * args.uN;

mem_desc_Qi.init(
args.Q_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});

mem_desc_Oi.init(
args.O_ptr, {end_x, end_y, ld_qo}, {start_acc, start_y});
} else { // 2d mem: [BxF, NxH]
// startF
int32_t start_y = batch_id * args.uF + item.get_group(1) * kBr;
Expand All @@ -277,16 +296,13 @@ class fmha_forward_t {
end_y = end_y > boundary_y ? boundary_y : end_y;

int32_t start_acc = head_id * args.uH;
uint32_t end_acc = start_acc + args.uH;
const uint32_t ld_qo = args.uH * args.uN;

mem_desc_Qi.init(
args.Q_ptr,
{args.uH * args.uN, end_y, ld_qo},
{start_acc, start_y});
args.Q_ptr, {end_acc, end_y, ld_qo}, {start_acc, start_y});
mem_desc_Oi.init(
args.O_ptr,
{args.uH * args.uN, end_y, ld_qo},
{start_acc, start_y});
args.O_ptr, {end_acc, end_y, ld_qo}, {start_acc, start_y});
}

int32_t start_x_ml = item.get_group(1) * kBr + sg_idy * kSgBr;
Expand Down Expand Up @@ -331,27 +347,46 @@ class fmha_forward_t {
args.V_ptr,
{end_y, end_x, b_stride * args.uB},
{start_acc, start_x});
} else if (kVarlen) {
int32_t start_x = startT + args.cu_seqlen_k[batch_id];
uint32_t end_x = start_x + kBc;
int32_t limit_x = args.cu_seqlen_k[batch_id + 1];
end_x = end_x < limit_x ? end_x : limit_x;

int32_t start_acc = head_id * args.uNkv / args.uN * args.uH;
uint32_t end_y = start_acc + args.uH;
mem_desc_Kj_T.init(
args.K_ptr,
{end_x, end_y, args.uNkv * args.uH},
{start_x, start_acc});

mem_desc_Vj.init(
args.V_ptr,
{end_y, end_x, args.uNkv * args.uH},
{start_acc, start_x});

} else {
int32_t start_x = batch_id * args.uT + startT;
uint32_t end_x = start_x + kBc;
uint32_t boundary_x = (batch_id + 1) * args.uT;
end_x = end_x > boundary_x ? boundary_x : end_x;

int32_t start_acc = head_id_kv * args.uH;
uint32_t end_acc = start_acc + args.uH;

mem_desc_Kj_T.init(
args.K_ptr,
{end_x, args.uH * args.uNkv, args.uH * args.uNkv},
{end_x, end_acc, args.uH * args.uNkv},
{start_x, start_acc});
mem_desc_Vj.init(
args.V_ptr,
{args.uH * args.uNkv, end_x, args.uH * args.uNkv},
{end_acc, end_x, args.uH * args.uNkv},
{start_acc, start_x});
}

// B, N, 1, T
// gid * T + startT
if constexpr (kUseAlibi) {
if constexpr (kUseAlibi && !kVarlen) {
int32_t batch_start = gid * args.uAT;
int32_t start_x = batch_start + startT;
uint32_t end_x = startT + kBc;
Expand All @@ -363,6 +398,15 @@ class fmha_forward_t {
args.A_ptr, {end_x, 1, args.uAT * args.uN * args.uB}, {start_x, 0});
}

// B, N or N
if constexpr (kUseAlibi && kVarlen) {
// assume uAt in varlen equals N or 0
int32_t start_x = batch_id * args.uAT + head_id;
uint32_t end_x = start_x + 1;
end_x = end_x >= args.uN ? end_x : args.uN;
mem_desc_Ai.init(args.A_ptr, {end_x, 1, 1}, {start_x, 0});
}

if constexpr (kUseBias && !kIsCausal) {
int32_t start_x = startT;
uint32_t end_x = start_x + kBc;
Expand Down Expand Up @@ -442,7 +486,7 @@ class fmha_forward_t {
matAccSij.reg *= args.sm_scale;

// + beta * alibi
if constexpr (kUseAlibi) {
if constexpr (kUseAlibi && !kVarlen) {
using alibi_op_t = bias_add_op_t<scalar_t, arch_tag>;
using alibi_args_t = typename alibi_op_t::arguments_t;

Expand All @@ -455,6 +499,16 @@ class fmha_forward_t {
alibi_op(matAccSij, ctx.mem_desc_Ai.coord, alibi_args);
}

if constexpr (kUseAlibi && kVarlen) {
using alibi_op_t =
bias_add_op_t<scalar_t, arch_tag, add_type::single_element>;
using alibi_args_t = typename alibi_op_t::arguments_t;

alibi_op_t alibi_op;
alibi_args_t alibi_args(ctx.mem_desc_Ai.base, ctx.mem_desc_Ai.shape);
alibi_op(matAccSij, ctx.mem_desc_Ai.coord, alibi_args);
}

// Add attn_mask if needed
if constexpr (kUseBias && !kIsCausal) {
if (args.is_bias_add) {
Expand Down Expand Up @@ -533,14 +587,22 @@ class fmha_forward_t {

/// @brief apply mask to matAccSij.
inline void apply_mask(
nd_item<3>& item,
matAccSij_t& matAccSij,
arguments_t& args,
uint32_t startF,
uint32_t startT) {
using tile_mask = tile_mask_t<matAccSij_t>;

uint32_t sg_startT = startT + ctx.sg_idx * kSgBc;
uint32_t remainT = std::max(int(args.uT) - int(sg_startT), 0);
uint32_t real_T;
if constexpr (kVarlen) {
int32_t batch_id = item.get_group(0) / args.uN;
real_T = args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id];
} else {
real_T = args.uT;
}
uint32_t remainT = std::max(int(real_T) - int(sg_startT), 0);
if (remainT < kSgBc) {
tile_mask::padding_mask(matAccSij, remainT);
}
Expand Down Expand Up @@ -867,6 +929,19 @@ class fmha_forward_t {

// initialize context for flash mha loops
ctx.init_context(item, args);
uint32_t gid = item.get_group(0);
uint32_t batch_id = gid / args.uN; // get batch idx
// Early exit when current thread access data exceed actual seqlen in varlen
// fwd
if constexpr (kVarlen) {
int32_t actual_seqlen_q =
args.cu_seqlen_q[batch_id + 1] - args.cu_seqlen_q[batch_id];
int32_t seqlen_q = item.get_group(1) * kBr;

if (seqlen_q >= actual_seqlen_q) {
return;
}
}
// preload Qi to local memory
preload_Qi(args);
// initialize matAccOi for accumulate the output
Expand All @@ -877,6 +952,15 @@ class fmha_forward_t {

// iterate through the keys
for (uint32_t startT = 0; startT < args.uT; startT += kBc) {
// Early leave for varlen_fwd if we found current seqlen exceed the actual
// seqlen.
if constexpr (kVarlen) {
int32_t actual_seqlen =
args.cu_seqlen_k[batch_id + 1] - args.cu_seqlen_k[batch_id];
if (startT >= actual_seqlen) {
break;
}
}
if constexpr (kIsCausal) {
if (startT >= endF)
break;
Expand All @@ -887,7 +971,7 @@ class fmha_forward_t {
matAccSij_t matAccSij(0);
gemm_Sij(matAccSij, args);
// apply mask
apply_mask(matAccSij, args, startF, startT);
apply_mask(item, matAccSij, args, startF, startT);
// softmax
dp_mask_tile_t mask_in;
softmax_fwd(matAccSij, matAccOi, mask_in, args);
Expand Down
Loading