From 585ac35b4207e73a377810c60864da12a6be3650 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 13 Jul 2023 10:32:19 +0200 Subject: [PATCH 1/3] 3-5% faster Q4_0 on Metal --- ggml-metal.metal | 62 +++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 30d60fa58d686..8867dcfce5add 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32( // each thread in a SIMD group deals with 1 block. for (int column = 0; column < nb / N_SIMDWIDTH; column++) { + float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } + sumy *= (-8.f); for (int row = 0; row < N_DST; row++) { // prefetch next x block @@ -405,39 +408,50 @@ kernel void kernel_mul_mat_q4_0_f32( // calculate float d = qb_curr.d; - float2 acc = {0.0f, 0.0f}; + float acc = sumy; for (int i = 0; i < 16; i++) { - acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - acc[1] += yl[i] + yl[i+16]; + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); } - sumf[row] += d * (acc[0] - 8.f*acc[1]); + sumf[row] += d * acc; qb_curr = qb_next; } } - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float2 acc = {0.0f, 0.0f}; - for (int i = 0; i < 16; i++) { - acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - acc[1] += yl[i] + yl[i+16]; + if (nb % N_SIMDWIDTH == 0) { + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * (acc[0] - 8.f*acc[1]); + } else { + + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - qb_curr = qb_next; + sumy *= (-8.f); - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; + + // calculate + float d = qb_curr.d; + float acc = sumy; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + if (tiisg < nb % N_SIMDWIDTH) { + sumf[row] += d * acc; + } + qb_curr = qb_next; + + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } } } } From 0f7967089f3a80bb1be51eb6a9ca7da5c4466c01 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 13 Jul 2023 10:51:45 +0200 Subject: [PATCH 2/3] 7-25% faster Q4_1 on Metal --- ggml-metal.m | 8 +-- ggml-metal.metal | 181 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 136 insertions(+), 53 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 02dc9beb94c58..c795ee22784bd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -739,12 +739,8 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - if (src0t == GGML_TYPE_Q4_0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_1) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q3_K || diff --git a/ggml-metal.metal b/ggml-metal.metal index 8867dcfce5add..6251194605f37 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -463,68 +463,155 @@ kernel void kernel_mul_mat_q4_1_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - const int nb = ne00/QK4_1; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/QK4_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; device const float * y = (device const float *) src1 + r1*ne10; + block_q4_1 qb_curr, qb_next; + float4 y_curr[8]; // src1 vector cache + float sumf[N_DST]={0.f}, all_sum; + thread float * yl=(thread float *)y_curr; - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; - - const int ix = tpitg.y/4; // 0 or 1 - const int iy = tpitg.y - 4*ix; // 0...3 - - const int first = 4 * iy; - - float sumf = 0; - - for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { - - const float d = (float)x[i].d; - const float m = (float)x[i].m; + // bootstrap + qb_curr = x[tiisg]; + // each thread in a SIMD group deals with 1 block. + for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - device const uint8_t * xl = x[i].qs + first; - device const float * yl = y + i * QK4_1 + first; + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + } - float2 acc = {0.0f, 0.0f}; + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - for (int j = 0; j < 4; ++j) { + // calculate + const float d = qb_curr.d; + const float m = qb_curr.m; + float acc = 0.f; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + sumf[row] += d * acc + m * sumy; + qb_curr = qb_next; + } + } - acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); - acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); + if (nb % N_SIMDWIDTH == 0) { + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } + } + } else { + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumf += acc[0] + acc[1]; - } + for (int row = 0; row < N_DST; row++) { + // prefetch next x block + qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - sum[ith] = sumf; + // calculate + const float d = qb_curr.d; + const float m = qb_curr.m; + float acc = 0.f; + for (int i = 0; i < 16; i++) { + acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + } + if (tiisg < nb % N_SIMDWIDTH) { + sumf[row] += d * acc + m * sumy; + } + qb_curr = qb_next; - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } + } } } +//kernel void kernel_mul_mat_q4_1_f32( +// device const void * src0, +// device const float * src1, +// device float * dst, +// constant int64_t & ne00, +// constant int64_t & ne10, +// constant int64_t & ne0, +// threadgroup float * sum [[threadgroup(0)]], +// uint2 tgpig[[threadgroup_position_in_grid]], +// uint2 tpitg[[thread_position_in_threadgroup]], +// uint2 tptg[[threads_per_threadgroup]]) { +// const int nb = ne00/QK4_1; +// +// const int64_t r0 = tgpig.x; +// const int64_t r1 = tgpig.y; +// +// device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; +// device const float * y = (device const float *) src1 + r1*ne10; +// +// const uint nth = tptg.x*tptg.y; +// const uint ith = tptg.y*tpitg.x + tpitg.y; +// +// const int ix = tpitg.y/4; // 0 or 1 +// const int iy = tpitg.y - 4*ix; // 0...3 +// +// const int first = 4 * iy; +// +// float sumf = 0; +// +// for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { +// +// const float d = (float)x[i].d; +// const float m = (float)x[i].m; +// +// device const uint8_t * xl = x[i].qs + first; +// device const float * yl = y + i * QK4_1 + first; +// +// float2 acc = {0.0f, 0.0f}; +// +// for (int j = 0; j < 4; ++j) { +// +// acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); +// acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); +// +// } +// +// sumf += acc[0] + acc[1]; +// } +// +// sum[ith] = sumf; +// +// // +// // Accumulate the sum from all threads in the threadgroup +// // +// threadgroup_barrier(mem_flags::mem_threadgroup); +// if (ith%4 == 0) { +// sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; +// } +// threadgroup_barrier(mem_flags::mem_threadgroup); +// if (ith%16 == 0) { +// sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; +// } +// threadgroup_barrier(mem_flags::mem_threadgroup); +// if (ith == 0) { +// for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; +// dst[r1*ne0 + r0] = sum[0]; +// } +//} + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, From 08812fe2e655b1f7e1ff346c7eb5871c4c266993 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 13 Jul 2023 11:33:34 +0200 Subject: [PATCH 3/3] Oops, forgot to delete the original Q4_1 kernel --- ggml-metal.metal | 69 ------------------------------------------------ 1 file changed, 69 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 6251194605f37..f094a1d407f08 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -543,75 +543,6 @@ kernel void kernel_mul_mat_q4_1_f32( } } -//kernel void kernel_mul_mat_q4_1_f32( -// device const void * src0, -// device const float * src1, -// device float * dst, -// constant int64_t & ne00, -// constant int64_t & ne10, -// constant int64_t & ne0, -// threadgroup float * sum [[threadgroup(0)]], -// uint2 tgpig[[threadgroup_position_in_grid]], -// uint2 tpitg[[thread_position_in_threadgroup]], -// uint2 tptg[[threads_per_threadgroup]]) { -// const int nb = ne00/QK4_1; -// -// const int64_t r0 = tgpig.x; -// const int64_t r1 = tgpig.y; -// -// device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; -// device const float * y = (device const float *) src1 + r1*ne10; -// -// const uint nth = tptg.x*tptg.y; -// const uint ith = tptg.y*tpitg.x + tpitg.y; -// -// const int ix = tpitg.y/4; // 0 or 1 -// const int iy = tpitg.y - 4*ix; // 0...3 -// -// const int first = 4 * iy; -// -// float sumf = 0; -// -// for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { -// -// const float d = (float)x[i].d; -// const float m = (float)x[i].m; -// -// device const uint8_t * xl = x[i].qs + first; -// device const float * yl = y + i * QK4_1 + first; -// -// float2 acc = {0.0f, 0.0f}; -// -// for (int j = 0; j < 4; ++j) { -// -// acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); -// acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); -// -// } -// -// sumf += acc[0] + acc[1]; -// } -// -// sum[ith] = sumf; -// -// // -// // Accumulate the sum from all threads in the threadgroup -// // -// threadgroup_barrier(mem_flags::mem_threadgroup); -// if (ith%4 == 0) { -// sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; -// } -// threadgroup_barrier(mem_flags::mem_threadgroup); -// if (ith%16 == 0) { -// sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; -// } -// threadgroup_barrier(mem_flags::mem_threadgroup); -// if (ith == 0) { -// for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; -// dst[r1*ne0 + r0] = sum[0]; -// } -//} - kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1,