Skip to content
241 changes: 186 additions & 55 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

7 changes: 0 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@

#include "types.glsl"

#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ layout (push_constant) uniform parameter

#if !RMS_NORM_ROPE_FUSION
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
} p;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

uint get_idx() {
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require

#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

Expand Down
2 changes: 0 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

#include "mul_mat_vec_iface.glsl"

#include "dequant_funcs.glsl"

layout (push_constant) uniform parameter
{
uint ncols;
Expand Down
8 changes: 5 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8

#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif

layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
Expand Down
62 changes: 27 additions & 35 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,56 @@

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
#define K_PER_ITER 8

#include "mul_mmq_funcs.glsl"
#elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16
#else
#error unimplemented
#endif

uint a_offset, b_offset, d_offset;

int32_t cache_b_qs[2];
int32_t cache_b_qs[K_PER_ITER / 4];
vec2 cache_b_ds;

#include "mul_mat_vecq_funcs.glsl"

void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;

// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);

#if QUANT_R == 2
// Assumes K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
#if K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#elif K_PER_ITER == 16
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#else
#error unimplemented
#endif
#endif

uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
ibi += p.ncols;

int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif

#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
}
}
}
Expand All @@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;

get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
a_offset /= QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
Expand Down Expand Up @@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);

#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif

while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
Expand All @@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);

#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif

// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
Expand Down
Loading
Loading