From 4f41ee11d6a4ddb02e4644922bf0b56880ee6f55 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 17 May 2025 15:35:47 +0900 Subject: [PATCH 1/5] vulkan: use scalar FA rather than coopmat2 when N==1 (#13554) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0856a1122832d..fe3669b462c38 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline *pipelines; bool small_rows = N <= get_fa_num_small_rows(path); + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. if (small_rows && path == FA_COOPMAT1) { path = FA_SCALAR; } + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; switch (path) { From 2f5a4e1e09567e09be8b84a5f976334de9539d17 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 17 May 2025 16:14:55 +0900 Subject: [PATCH 2/5] vulkan: move common FA code to flash_attn_base.comp (#13556) * vulkan: move common FA code to flash_attn_base.comp * vulkan: move common FA index/stride setup code to flash_attn_base.comp * build fix --- .../vulkan-shaders/flash_attn.comp | 153 +---------------- .../vulkan-shaders/flash_attn_base.comp | 162 ++++++++++++++++++ .../vulkan-shaders/flash_attn_cm1.comp | 152 +--------------- .../vulkan-shaders/flash_attn_cm2.comp | 120 +------------ 4 files changed, 170 insertions(+), 417 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 1683557681439..ce230a8f7d910 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -9,60 +9,13 @@ #extension GL_KHR_shader_subgroup_shuffle : enable #include "types.comp" +#include "flash_attn_base.comp" -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 0) const uint32_t WorkGroupSize = 128; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; - -layout (constant_id = 5) const uint32_t D_split = 16; const uint32_t D_per_thread = D / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; @@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#if defined(A_TYPE_PACKED16) -#define BINDING_IDX_K 0 -#define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); -} -#endif - -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); -} -#endif - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. @@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - shared FLOAT_TYPE tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; @@ -146,58 +45,12 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t tid = gl_LocalInvocationIndex; - const uint32_t N = p.N; - const uint32_t KV = p.KV; + init_indices(); + const uint32_t tid = gl_LocalInvocationIndex; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; - - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp new file mode 100644 index 0000000000000..61d90e2d8ed21 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -0,0 +1,162 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = 0; +layout (constant_id = 5) const uint32_t D_split = 16; + + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 8b86b623bd94d..da478be24fb6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -11,14 +11,7 @@ #extension GL_KHR_cooperative_matrix : enable #include "types.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; - -layout (constant_id = 5) const uint32_t D_split = 16; +#include "flash_attn_base.comp" const uint32_t D_per_thread = D / D_split; const uint32_t row_split = 4; @@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; @@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#if defined(A_TYPE_PACKED16) -#define BINDING_IDX_K 0 -#define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); -} -#endif - -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); -} -#endif - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. @@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; const uint32_t MatBc = 16; @@ -162,9 +61,9 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif + init_indices(); + const uint32_t tid = gl_LocalInvocationIndex; - const uint32_t N = p.N; - const uint32_t KV = p.KV; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; @@ -173,51 +72,6 @@ void main() { #define tile_row(r) (row_tid * rows_per_thread + (r)) - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; - - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b926a578aded6..6acf67a03a463 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -18,62 +18,12 @@ #include "types.comp" #include "dequant_funcs_cm2.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 32; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; - -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; +#include "flash_attn_base.comp" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); @@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r % p.gqa_ratio); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t N = p.N; - const uint32_t KV = p.KV; - - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; + init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); @@ -195,17 +90,6 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { From 518329b2d4434d0683c30b741b4f4cf5734b5f99 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 17 May 2025 12:58:55 +0300 Subject: [PATCH 3/5] parallel : add option for non-shared and larger prompts (#13598) * parallel : add option for non-shared and larger prompts * parallel : update readme [no ci] * cont : add note about base models [no ci] * parallel : better var name ggml-ci --- common/arg.cpp | 4 +- examples/parallel/README.md | 11 ++++ examples/parallel/parallel.cpp | 100 ++++++++++++++++++++++++++++----- 3 files changed, 99 insertions(+), 16 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a1fd4c9651b9c..8aa72515d1042 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2585,7 +2585,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.n_junk = value; } - ).set_examples({LLAMA_EXAMPLE_PASSKEY})); + ).set_examples({LLAMA_EXAMPLE_PASSKEY, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"--pos"}, "N", string_format("position of the passkey in the junk text (default: %d)", params.i_pos), @@ -2648,7 +2648,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.is_pp_shared = true; } - ).set_examples({LLAMA_EXAMPLE_BENCH})); + ).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"-npp"}, "n0,n1,...", "number of prompt tokens", diff --git a/examples/parallel/README.md b/examples/parallel/README.md index df04567337b15..ece3a66416edd 100644 --- a/examples/parallel/README.md +++ b/examples/parallel/README.md @@ -1,3 +1,14 @@ # llama.cpp/example/parallel Simplified simulation of serving incoming requests in parallel + +## Example + +Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of 10 junk questions (`-j 10`) followed by the actual question. + +```bash +llama-parallel -m model.gguf -np 8 -ns 128 --top-k 1 -pps --junk 10 -c 16384 +``` + +> [!NOTE] +> It's recommended to use base models with this example. Instruction tuned models might not be able to properly follow the custom chat template specified here, so the results might not be as expected. diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 80698518e3102..b967731a2153c 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -34,11 +34,61 @@ static std::string k_system = R"(Transcript of a never ending dialog, where the User interacts with an Assistant. The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. -User: Recommend a nice restaurant in the area. -Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays. -User: Who is Richard Feynman? -Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?". -User:)"; +User: +Recommend a nice restaurant in the area. +Assistant: +I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays. +User: +Who is Richard Feynman? +Assistant: +Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?". +)"; + +static std::vector k_questions = { + "What is the tallest mountain in the world?", + "Who was the first person to win two Nobel Prizes?", + "Which country invented paper?", + "What organ is primarily responsible for pumping blood throughout the body?", + "Which planet is known for its prominent ring system?", + "Who directed the movie 'Inception'?", + "What is the freezing point of water in Fahrenheit?", + "Which animal is known to have the longest lifespan?", + "What language has the most native speakers worldwide?", + "What is the capital city of Canada?", + "Who is credited with inventing the World Wide Web?", + "Which metal is liquid at room temperature?", + "What is the term for an animal that eats both plants and meat?", + "Who painted 'The Starry Night'?", + "What gas do humans exhale that plants use for photosynthesis?", + "What year did World War II end?", + "Which continent has the most countries?", + "Who wrote the novel 'Frankenstein'?", + "What does DNA stand for?", + "What is the main ingredient in traditional Japanese miso soup?" +}; + +static std::vector k_answers = { + "The tallest mountain in the world is Mount Everest.", + "Marie Curie was the first person to win two Nobel Prizes.", + "Paper was invented in China.", + "The heart is the organ responsible for pumping blood.", + "Saturn is known for its prominent ring system.", + "Christopher Nolan directed the movie 'Inception'.", + "The freezing point of water in Fahrenheit is 32°F.", + "The bowhead whale is known to have the longest lifespan among mammals.", + "Mandarin Chinese has the most native speakers in the world.", + "The capital city of Canada is Ottawa.", + "Tim Berners-Lee is credited with inventing the World Wide Web.", + "Mercury is the metal that is liquid at room temperature.", + "An animal that eats both plants and meat is called an omnivore.", + "'The Starry Night' was painted by Vincent van Gogh.", + "Humans exhale carbon dioxide, which plants use in photosynthesis.", + "World War II ended in 1945.", + "Africa is the continent with the most countries.", + "The novel 'Frankenstein' was written by Mary Shelley.", + "DNA stands for Deoxyribonucleic Acid.", + "The main ingredient in traditional Japanese miso soup is fermented soybean paste." +}; static std::vector k_prompts = { "What is the meaning of life?", @@ -49,7 +99,7 @@ static std::vector k_prompts = { "What is the best way to learn a new language?", "How to get a job at Google?", "If you could have any superpower, what would it be?", - "I want to learn how to play the piano.", + "I want to learn how to play the piano. What would be the best way to do it?", }; struct client { @@ -68,6 +118,7 @@ struct client { int64_t t_start_prompt; int64_t t_start_gen; + int32_t n_past = 0; int32_t n_prompt = 0; int32_t n_decoded = 0; int32_t i_batch = -1; @@ -107,6 +158,7 @@ int main(int argc, char ** argv) { common_params params; params.n_predict = 128; + params.n_junk = 0; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; @@ -128,6 +180,12 @@ int main(int argc, char ** argv) { const bool dump_kv_cache = params.dump_kv_cache; + // is the system prompt shared in the cache + const bool is_sp_shared = params.is_pp_shared; + + // extra text to insert in each client's prompt in order to make it larger + const int32_t n_junk = params.n_junk; + // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -169,6 +227,7 @@ int main(int argc, char ** argv) { } std::vector tokens_system; + tokens_system = common_tokenize(ctx, k_system, true); const int32_t n_tokens_system = tokens_system.size(); @@ -190,7 +249,7 @@ int main(int argc, char ** argv) { LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); LOG_INF("\n"); - { + if (is_sp_shared) { LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { @@ -228,7 +287,7 @@ int main(int argc, char ** argv) { client.i_batch = batch.n_tokens; - common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); + common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true); client.n_decoded += 1; } @@ -254,9 +313,23 @@ int main(int argc, char ** argv) { client.t_start_gen = 0; client.input = k_prompts[rand() % k_prompts.size()]; - client.prompt = client.input + "\nAssistant:"; client.response = ""; + // construct the prompt: + // [system prompt] + [junk] + [user prompt] + client.n_past = 0; + client.prompt = ""; + if (is_sp_shared) { + client.n_past = n_tokens_system; + } else { + client.prompt += k_system; + } + for (int i = 0; i < n_junk; ++i) { + const int r = rand() % k_questions.size(); + client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n"; + } + client.prompt += "User:\n" + client.input + "\nAssistant:\n"; + common_sampler_reset(client.smpl); // do not prepend BOS because we have a system prompt! @@ -264,7 +337,7 @@ int main(int argc, char ** argv) { tokens_prompt = common_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); + common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false); } // extract the logits only for the last token @@ -363,10 +436,9 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (llama_vocab_is_eog(vocab, id) || - (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || - client.response.find("User:") != std::string::npos || - client.response.find('\n') != std::string::npos)) { + (llama_vocab_is_eog(vocab, id) || + (params.n_predict > 0 && client.n_decoded >= params.n_predict) || + client.response.find("User:") != std::string::npos)) { // basic reverse prompt const size_t pos = client.response.find("User:"); if (pos != std::string::npos) { From e3a7cf6c5bf6a0a24217f88607b06e4405a2b5d9 Mon Sep 17 00:00:00 2001 From: "Gilad S." <7817232+giladgd@users.noreply.github.com> Date: Sat, 17 May 2025 21:26:43 +0300 Subject: [PATCH 4/5] cmake: use the current build config for vulkan-shaders-gen (#13595) * fix: use the current build config for `vulkan-shaders-gen` * fix: only pass a valid build type to `--config` --- ggml/src/ggml-vulkan/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 16e10a9f399d5..662f137710716 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -54,6 +54,11 @@ if (Vulkan_FOUND) -DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY} ) + set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "") + if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo") + list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE}) + endif() + # Test all shader extensions test_shader_extension_support( "GL_KHR_cooperative_matrix" @@ -149,7 +154,7 @@ if (Vulkan_FOUND) vulkan-shaders-gen SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS} - BUILD_COMMAND ${CMAKE_COMMAND} --build . + BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS} INSTALL_COMMAND ${CMAKE_COMMAND} --install . INSTALL_DIR ${CMAKE_BINARY_DIR} ) From 6a2bc8bfb7cd502e5ebc72e36c97a6f848c21c2c Mon Sep 17 00:00:00 2001 From: Isaac McFadyen Date: Sat, 17 May 2025 17:59:48 -0400 Subject: [PATCH 5/5] server : added --no-prefill-assistant flag (#13608) * added no-prefill-assistant flag * reworded documentation comment * updated server README.md --- common/arg.cpp | 10 ++++++++++ common/common.h | 1 + tools/server/README.md | 2 ++ tools/server/server.cpp | 2 ++ tools/server/utils.hpp | 3 ++- 5 files changed, 17 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 8aa72515d1042..305168043c27c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2880,6 +2880,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.chat_template = read_file(value); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"--no-prefill-assistant"}, + string_format( + "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n" + "when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" + ), + [](common_params & params) { + params.prefill_assistant = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_PREFILL_ASSISTANT")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.h b/common/common.h index a99a36029a53c..da525dd420b7d 100644 --- a/common/common.h +++ b/common/common.h @@ -368,6 +368,7 @@ struct common_params { bool use_jinja = false; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response std::vector api_keys; diff --git a/tools/server/README.md b/tools/server/README.md index 17ad93df61f87..0b84966ae86d7 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -13,6 +13,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. * Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support * Monitoring endpoints * Schema-constrained JSON response format + * Prefilling of assistant messages similar to the Claude API * [Function calling](../../docs/function-calling.md) / tool use for ~any model * Speculative decoding * Easy-to-use web UI @@ -175,6 +176,7 @@ The project is under active development, and we are [looking for feedback and co | `--reasoning-format FORMAT` | reasoning format (default: deepseek; allowed values: deepseek, none)
controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).
only supported for non-streamed responses
(env: LLAMA_ARG_THINK) | | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled
(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX) | diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 129d013ac75f7..348588a2cb224 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -4348,6 +4348,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse( body, params.use_jinja, + params.prefill_assistant, params.reasoning_format, ctx_server.chat_templates.get(), ctx_server.mctx, @@ -4369,6 +4370,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse( body, params.use_jinja, + params.prefill_assistant, params.reasoning_format, ctx_server.chat_templates.get(), ctx_server.mctx, diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 232eef195437f..3e7733539fe0e 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -583,6 +583,7 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, + bool prefill_assistant, common_reasoning_format reasoning_format, const struct common_chat_templates * tmpls, bool allow_non_text, @@ -732,7 +733,7 @@ static json oaicompat_completion_params_parse( // if the assistant message appears at the end of list, we do not add end-of-turn token // for ex. this can be useful to modify the reasoning process in reasoning models - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant"; + bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant; common_chat_msg last_message; if (prefill_assistant_message) { last_message = inputs.messages.back();