Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}

// V /= S
const float S_inv = 1.0f/S;
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);

// dst indices
Expand Down
48 changes: 47 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,53 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);

char base[256];
char name[256];

snprintf(base, 256, "kernel_%s",
"flash_attn_ext_blk");

snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
base,
nqptg,
ncpsg);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

ggml_metal_cv_t cv = ggml_metal_cv_init();

//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);

//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
bool has_mask,
int32_t ncpsg);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
int32_t nqptg,
int32_t ncpsg);

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
Expand Down
29 changes: 24 additions & 5 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,19 @@

// function constants offsets
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT 200
#define FC_FLASH_ATTN_EXT_VEC 300
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
#define FC_MUL_MV 500
#define FC_MUL_MM 600
#define FC_FLASH_ATTN_EXT_BLK 200
#define FC_FLASH_ATTN_EXT 300
#define FC_FLASH_ATTN_EXT_VEC 400
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8
#define OP_FLASH_ATTN_EXT_NCPSG 64

#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32

// kernel argument structs
//
Expand Down Expand Up @@ -262,6 +270,17 @@ typedef struct {
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;

typedef struct {
int32_t ne01;
int32_t ne30;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_blk;

typedef struct {
int32_t ne01;
int32_t ne02;
Expand Down
113 changes: 99 additions & 14 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1904,19 +1904,19 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
const bool has_mask = op->src[3] != nullptr;

if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % 32 != 0;
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;

if (has_kvpad) {
res += 32*(
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % 64 != 0;
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;

if (has_kvpad) {
res += 64*(
res += OP_FLASH_ATTN_EXT_NCPSG*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
Expand All @@ -1926,6 +1926,44 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
return res;
}

size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);

GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);

size_t res = 0;

const bool has_mask = op->src[3] != nullptr;

if (!has_mask) {
return res;
}

const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);

// this optimization is not useful for the vector kernels
if (is_vec) {
return res;
}

const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;

const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;

res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);

return res;
}

size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);

Expand Down Expand Up @@ -2020,18 +2058,23 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_pad = bid_dst;
bid_pad.offs += ggml_nbytes(op);

ggml_metal_buffer_id bid_tmp = bid_pad;
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
ggml_metal_buffer_id bid_blk = bid_pad;
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);

ggml_metal_buffer_id bid_tmp = bid_blk;
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);

if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup

GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

bool need_sync = false;

const bool has_kvpad = ne11 % ncpsg != 0;

if (has_kvpad) {
Expand Down Expand Up @@ -2069,11 +2112,46 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {

ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);

ggml_metal_op_concurrency_reset(ctx);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}

if (has_mask) {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);

ggml_metal_kargs_flash_attn_ext_blk args0 = {
/*.ne01 =*/ ne01,
/*.ne30 =*/ ne30,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
};

ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);

ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);

const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);

ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);

need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
}

if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}

const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;

// 2*(2*ncpsg)
Expand Down Expand Up @@ -2150,22 +2228,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
const int64_t nkpsg = 1*ncpsg;
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
const int nkpsg = 1*ncpsg;

GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

bool need_sync = false;

const bool has_kvpad = ne11 % ncpsg != 0;

if (has_kvpad) {
Expand Down Expand Up @@ -2203,11 +2284,15 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {

ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);

ggml_metal_op_concurrency_reset(ctx);
need_sync = true;
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}

if (need_sync) {
ggml_metal_op_concurrency_reset(ctx);
}

// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);

size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);

int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
case GGML_OP_FLASH_ATTN_EXT:
{
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
default:
Expand Down
Loading
Loading