Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
/src/llama-model-loader.* @slaren
/src/llama-model.* @CISC
/src/llama-vocab.* @CISC
/src/models/ @CISC
/tests/ @ggerganov
/tests/test-backend-ops.cpp @slaren
/tests/test-thread-safety.cpp @slaren
Expand Down
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.system_prompt.pop_back();
}
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
add_opt(common_arg(
{"--in-file"}, "FNAME",
"an input file (repeat to specify multiple files)",
Expand Down
14 changes: 8 additions & 6 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,28 @@ static __global__ void mul_mat_vec_q(

const uint32_t channel_bias = ids ? channel_x : channel_dst;

float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) {
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
// 2. load only on threads that won't die after partial sum calculation
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
}
}
}
if (use_gate_bias) {
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
}
}
}
Expand Down Expand Up @@ -299,12 +301,12 @@ static __global__ void mul_mat_vec_q(
float result = tmp[j][threadIdx.x];
if constexpr (has_fusion) {
if (use_bias) {
result += x_biases[j][threadIdx.x];
result += x_biases[j];
}
if (use_gate) {
float gate_value = tmp_gate[j][threadIdx.x];
if (use_gate_bias) {
gate_value += gate_biases[j][threadIdx.x];
gate_value += gate_biases[j];
}
switch (active_glu) {
case GGML_GLU_OP_SWIGLU:
Expand Down
517 changes: 393 additions & 124 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

34 changes: 32 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif

layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};

layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};

#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
layout (binding = 4) readonly buffer IDS {int data_ids[];};
#endif

#include "dequant_funcs.glsl"
Expand All @@ -45,6 +48,8 @@ layout (push_constant) uniform parameter
uint batch_stride_b;
uint batch_stride_d;

uint enable_bias;

#ifdef MUL_MAT_ID
uint nei0;
uint ne11;
Expand All @@ -56,6 +61,10 @@ layout (push_constant) uniform parameter
#endif
} p;

#ifdef MUL_MAT_ID
uint expert_id;
#endif

void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.y;
Expand All @@ -75,7 +84,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
batch_idx_a = i03 * p.ne02 + i02;
}
#else
const uint expert_id = data_ids[expert_idx];
expert_id = data_ids[expert_idx];
#endif

a_offset =
Expand Down Expand Up @@ -113,6 +122,13 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
}
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
Expand Down Expand Up @@ -148,6 +164,13 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
}
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
Expand All @@ -173,6 +196,13 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
#else
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
}
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
}
}
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};

layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};

layout (push_constant) uniform parameter
{
uint ncols_x;
Expand All @@ -29,6 +31,7 @@ layout (push_constant) uniform parameter
uint nb03;
uint nb13;
uint nb23;
uint enable_bias;
} p;

shared FLOAT_TYPE tmp[BLOCK_SIZE];
Expand Down Expand Up @@ -117,6 +120,9 @@ void main() {
}

if (tid == 0) {
if (p.enable_bias != 0) {
tmp[0] += FLOAT_TYPE(data_bias[idst]);
}
dst[idst] = tmp[0];
}
}
6 changes: 6 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};

layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};

layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
layout(constant_id = 1) const uint gqa_ratio = 1;
Expand All @@ -29,6 +31,7 @@ layout (push_constant) uniform parameter
uint nchannels_y;
uint b_offset;
uint d_offset;
uint enable_bias;
} p;

#if !USE_SUBGROUP_ADD
Expand Down Expand Up @@ -148,6 +151,9 @@ void main() {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
if (p.enable_bias != 0) {
temp[c] += FLOAT_TYPE(data_bias[idst]);
}
dst[idst] = temp[c];
}
}
Expand Down
104 changes: 94 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
uint rms_partials;
} p;

// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];

layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
layout (binding = 0) buffer A0 {A_TYPE data_a[];} a0;
layout (binding = 1) buffer A1 {A_TYPE data_a[];} a1;
layout (binding = 2) buffer A2 {A_TYPE data_a[];} a2;
layout (binding = 3) buffer A3 {A_TYPE data_a[];} a3;
layout (binding = 4) buffer A4 {A_TYPE data_a[];} a4;
layout (binding = 5) buffer A5 {A_TYPE data_a[];} a5;
layout (binding = 6) buffer A6 {A_TYPE data_a[];} a6;
layout (binding = 7) buffer A7 {A_TYPE data_a[];} a7;
layout (binding = 8) buffer A8 {A_TYPE data_a[];} a8;
layout (binding = 9) buffer A9 {A_TYPE data_a[];} a9;
layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
layout (binding = 0) buffer D0 {D_TYPE data_d[];} d0;
layout (binding = 1) buffer D1 {D_TYPE data_d[];} d1;
layout (binding = 2) buffer D2 {D_TYPE data_d[];} d2;
layout (binding = 3) buffer D3 {D_TYPE data_d[];} d3;
layout (binding = 4) buffer D4 {D_TYPE data_d[];} d4;
layout (binding = 5) buffer D5 {D_TYPE data_d[];} d5;
layout (binding = 6) buffer D6 {D_TYPE data_d[];} d6;
layout (binding = 7) buffer D7 {D_TYPE data_d[];} d7;
layout (binding = 8) buffer D8 {D_TYPE data_d[];} d8;
layout (binding = 9) buffer D9 {D_TYPE data_d[];} d9;
layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
layout (binding = 0, std430) buffer PartialBuf0 {float partial_sums[];} partials0;
layout (binding = 1, std430) buffer PartialBuf1 {float partial_sums[];} partials1;
layout (binding = 2, std430) buffer PartialBuf2 {float partial_sums[];} partials2;
layout (binding = 3, std430) buffer PartialBuf3 {float partial_sums[];} partials3;
layout (binding = 4, std430) buffer PartialBuf4 {float partial_sums[];} partials4;
layout (binding = 5, std430) buffer PartialBuf5 {float partial_sums[];} partials5;
layout (binding = 6, std430) buffer PartialBuf6 {float partial_sums[];} partials6;
layout (binding = 7, std430) buffer PartialBuf7 {float partial_sums[];} partials7;
layout (binding = 8, std430) buffer PartialBuf8 {float partial_sums[];} partials8;
layout (binding = 9, std430) buffer PartialBuf9 {float partial_sums[];} partials9;
layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;

layout(constant_id = 0) const uint num_srcs = 2;

FLOAT_TYPE load_a(uint b, uint i) {
switch (b) {
case 0: return FLOAT_TYPE(a0.data_a[i]);
case 1: return FLOAT_TYPE(a1.data_a[i]);
case 2: return FLOAT_TYPE(a2.data_a[i]);
case 3: return FLOAT_TYPE(a3.data_a[i]);
case 4: return FLOAT_TYPE(a4.data_a[i]);
case 5: return FLOAT_TYPE(a5.data_a[i]);
case 6: return FLOAT_TYPE(a6.data_a[i]);
case 7: return FLOAT_TYPE(a7.data_a[i]);
case 8: return FLOAT_TYPE(a8.data_a[i]);
case 9: return FLOAT_TYPE(a9.data_a[i]);
case 10: return FLOAT_TYPE(a10.data_a[i]);
case 11: return FLOAT_TYPE(a11.data_a[i]);
default: return FLOAT_TYPE(0);
}
}

void store_d(uint b, uint i, FLOAT_TYPE v) {
switch (b) {
case 0: d0.data_d[i] = D_TYPE(v); break;
case 1: d1.data_d[i] = D_TYPE(v); break;
case 2: d2.data_d[i] = D_TYPE(v); break;
case 3: d3.data_d[i] = D_TYPE(v); break;
case 4: d4.data_d[i] = D_TYPE(v); break;
case 5: d5.data_d[i] = D_TYPE(v); break;
case 6: d6.data_d[i] = D_TYPE(v); break;
case 7: d7.data_d[i] = D_TYPE(v); break;
case 8: d8.data_d[i] = D_TYPE(v); break;
case 9: d9.data_d[i] = D_TYPE(v); break;
case 10: d10.data_d[i] = D_TYPE(v); break;
case 11: d11.data_d[i] = D_TYPE(v); break;
default: break;
}
}

void store_partial(uint b, uint i, float v) {
switch (b) {
case 0: partials0.partial_sums[i] = v; break;
case 1: partials1.partial_sums[i] = v; break;
case 2: partials2.partial_sums[i] = v; break;
case 3: partials3.partial_sums[i] = v; break;
case 4: partials4.partial_sums[i] = v; break;
case 5: partials5.partial_sums[i] = v; break;
case 6: partials6.partial_sums[i] = v; break;
case 7: partials7.partial_sums[i] = v; break;
case 8: partials8.partial_sums[i] = v; break;
case 9: partials9.partial_sums[i] = v; break;
case 10: partials10.partial_sums[i] = v; break;
case 11: partials11.partial_sums[i] = v; break;
default: break;
}
}

uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
}
Expand Down Expand Up @@ -78,10 +162,10 @@ void main() {

FLOAT_TYPE sum = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
sum += load_a(s, src_idx(s, i00, i01, i02, i03));
}
sum_sq += sum*sum;
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);

idx += num_threads;
}
Expand All @@ -104,7 +188,7 @@ void main() {
}

if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
}
}
#endif
Expand Down
Loading