From 9e5c5960e25bb9dfdcfb6459c474b37948ea4f05 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Tue, 23 Sep 2025 18:20:36 -0700 Subject: [PATCH 1/7] Add mul_mm_f16_f32_kq_kqv kernel --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 145 ++++++++++ .../kernels/mul_mm_f16_f32_kq_kqv.cl | 271 ++++++++++++++++++ 3 files changed, 417 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index d3d97f375e8f3..681c81b88a18b 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -119,6 +119,7 @@ set(GGML_OPENCL_KERNELS pad repeat mul_mat_f16_f32 + mul_mm_f16_f32_kq_kqv conv2d conv2d_f16_f32 flash_attn_f32_f16 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3dc4d03550931..2bed54cc009bb 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -376,6 +376,8 @@ struct ggml_backend_opencl_context { cl_program program_mul_mv_f32_f32; cl_program program_mul; cl_program program_mul_mat_f16_f32_tiled; + cl_program program_mul_mm_f16_f32_kqv; + cl_program program_mul_mm_f16_f32_kq; cl_program program_div; cl_program program_sub; cl_program program_norm; @@ -450,6 +452,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_f16_f32_tiled; + cl_kernel kernel_mul_mm_f16_f32_kqv; + cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; @@ -1204,6 +1208,25 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_f16_f32_kq_kqv + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_f16_f32_kq_kqv.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl"); +#endif + backend_ctx->program_mul_mm_f16_f32_kqv = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV "); + backend_ctx->program_mul_mm_f16_f32_kq = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err)); + GGML_LOG_CONT("."); + } + // mul { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -6694,6 +6717,128 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_context context = backend_ctx->context; + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d; + cl_mem A_sub_buffer; + cl_mem B_sub_buffer; + cl_mem D_image1d; + cl_mem D_sub_buffer; + + int M = ne01; + int N = ne1; + int K = ne00; + + if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ + + if (M >= 64 && N >= 32 && K >= 16 && (ne12 % ne02) == 0){ + if (nb01 > nb02) { + // KQ + kernel = backend_ctx->kernel_mul_mm_f16_f32_kq; + } else { + // KQV + kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv; + } + // create sub-buffer for A + // <--------------------------------------------> // + extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra; + + region.origin = (extra0->offset); + if (nb01 > nb02) { + // KQ + region.size = nb01 * ne01; + } else { + // KQV + region.size = nb02 * ne02; + } + + A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // <--------------------------------------------> // + + // create sub-buffer for B + // <--------------------------------------------> // + region.origin = (extra1->offset); + region.size = nb10 * ne10 * ne11 * ne12; + B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + // <--------------------------------------------> // + + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + if (nb01 > nb02) { + img_desc_1d.image_width = (nb01 * ne01 / 4)/4; + } + else { + img_desc_1d.image_width = (nb02 * ne02 / 4)/4; + } + img_desc_1d.buffer = A_sub_buffer; + A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + + // create sub-buffer for output C + // <--------------------------------------------> // + region.origin = (extrad->offset); + region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes + D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + // <--------------------------------------------> // + + // create image for C output + // <--------------------------------------------> // + img_fmt_1d = {CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4; + img_desc_1d.buffer = D_sub_buffer; + D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + // <--------------------------------------------> // + + // offsets = 0 when using image + int offset0 = 0; + int offset1 = 0; + + // set kernel args + // <--------------------------------------------> // + cl_uint k_arg = 0; + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01)); + + size_t global_work_size[3] = {64, static_cast(((M+63)/64)), static_cast(((N+31)/32)*ne12)}; + size_t local_work_size[3] = {64, 1, 2}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + // deallocate sub buffers and images + // <--------------------------------------------> // + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(D_image1d)); + CL_CHECK(clReleaseMemObject(A_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(D_sub_buffer)); + // <--------------------------------------------> // + + return; + + } + } + if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { // init CL objects diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl new file mode 100644 index 0000000000000..6112e8d89804d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl @@ -0,0 +1,271 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#define LM_FIRST_256B 0 +#define LM_SECOND_256B 64 +#define LM_THIRD_256B 128 +#define LM_FOURTH_256B 192 + + +inline float16 mm_load_a(image1d_buffer_t matrix_A, uint subMatrixAStartInElements, int nb01, int line_stride_matrix_A_in_bytes) +{ + __private float8 regA; + size_t sub_block_id_m = get_local_id(0); + +#ifdef KQV + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4); +#else // KQ + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4); +#endif + + regA.s0123 = read_imagef(matrix_A, a_texCoord/4); + regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4); + + return convert_float16(as_half16(regA)); +} + +inline void alu_32(float* regC, float16 regA, __local float* matrix_B_local, int wave_offset) { + + __private float4 rC = 0; + int i = wave_offset; + + rC += regA.s0 * ((__local float4*)matrix_B_local)[i]; + rC += regA.s1 * ((__local float4*)matrix_B_local)[i + 16]; + rC += regA.s4 * ((__local float4*)matrix_B_local)[i + 1]; + rC += regA.s5 * ((__local float4*)matrix_B_local)[i + 17]; + rC += regA.s8 * ((__local float4*)matrix_B_local)[i + 2]; + rC += regA.s9 * ((__local float4*)matrix_B_local)[i + 18]; + rC += regA.sc * ((__local float4*)matrix_B_local)[i + 3]; + rC += regA.sd * ((__local float4*)matrix_B_local)[i + 19]; + + i += 32; + + rC += regA.s2 * ((__local float4*)matrix_B_local)[i]; + rC += regA.s3 * ((__local float4*)matrix_B_local)[i + 16]; + rC += regA.s6 * ((__local float4*)matrix_B_local)[i + 1]; + rC += regA.s7 * ((__local float4*)matrix_B_local)[i + 17]; + rC += regA.sa * ((__local float4*)matrix_B_local)[i + 2]; + rC += regA.sb * ((__local float4*)matrix_B_local)[i + 18]; + rC += regA.se * ((__local float4*)matrix_B_local)[i + 3]; + rC += regA.sf * ((__local float4*)matrix_B_local)[i + 19]; + + float4* regC_vec = (float4*)regC; + *regC_vec += rC; +} + +inline void mm_mad(__local float* matrix_B_local, float* regC1_ptr, float* regC2_ptr, float* regC3_ptr, float* regC4_ptr, float* regC5_ptr, float* regC6_ptr, float* regC7_ptr, float* regC8_ptr, float16 regA, float8 regB, uint b_localOffsetInWords) +{ + + short linearIndex = get_local_id(0); + + int wave_offset = get_sub_group_id() * 64; + int offset = b_localOffsetInWords + get_sub_group_id() * 256; + + matrix_B_local[offset + LM_FIRST_256B] = regB.s0; + matrix_B_local[offset + LM_SECOND_256B] = regB.s1; + + matrix_B_local[offset + LM_THIRD_256B] = regB.s2; + matrix_B_local[offset + LM_FOURTH_256B] = regB.s3; + + alu_32(regC1_ptr, regA, matrix_B_local + 0, wave_offset); + + alu_32(regC2_ptr, regA, matrix_B_local + 16, wave_offset); + + alu_32(regC3_ptr, regA, matrix_B_local + 32, wave_offset); + + alu_32(regC4_ptr, regA, matrix_B_local + 48, wave_offset); + + + matrix_B_local[offset + LM_FIRST_256B] = regB.s4; + matrix_B_local[offset + LM_SECOND_256B] = regB.s5; + + matrix_B_local[offset + LM_THIRD_256B] = regB.s6; + matrix_B_local[offset + LM_FOURTH_256B] = regB.s7; + + alu_32(regC5_ptr, regA, matrix_B_local + 0, wave_offset); + + alu_32(regC6_ptr, regA, matrix_B_local + 16, wave_offset); + + alu_32(regC7_ptr, regA, matrix_B_local + 32, wave_offset); + + alu_32(regC8_ptr, regA, matrix_B_local + 48, wave_offset); + +} + +inline void mm_store_c_N(__write_only image1d_buffer_t matrix_C, float4 regC_1, float4 regC_2, float4 regC_3, float4 regC_4, +float4 regC_5, float4 regC_6, float4 regC_7, float4 regC_8, uint subMatrixCStartInElements, int line_stride_matrix_C_in_bytes, int mask) +{ + size_t sub_block_id_m = get_local_id(0); + short linearIndex = get_local_id(0); + + uint strideInWords = line_stride_matrix_C_in_bytes/4; + uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m); + + uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords; + uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords; + uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords; + uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords; + uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords; + uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords; + uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords; + uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords; + uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords; + uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords; + uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords; + uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords; + uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords; + uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords; + uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords; + uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords; + uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords; + uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords; + uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords; + uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords; + uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords; + uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords; + uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords; + uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords; + uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords; + uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords; + uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords; + uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords; + uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords; + uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords; + uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords; + + if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC_1.s0); } + if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC_1.s1); } + if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC_1.s2); } + if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC_1.s3); } + if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC_2.s0); } + if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC_2.s1); } + if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC_2.s2); } + if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC_2.s3); } + if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC_3.s0); } + if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC_3.s1); } + if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC_3.s2); } + if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC_3.s3); } + if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC_4.s0); } + if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC_4.s1); } + if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC_4.s2); } + if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC_4.s3); } + if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC_5.s0); } + if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC_5.s1); } + if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC_5.s2); } + if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC_5.s3); } + if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC_6.s0); } + if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC_6.s1); } + if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC_6.s2); } + if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC_6.s3); } + if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC_7.s0); } + if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC_7.s1); } + if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC_7.s2); } + if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC_7.s3); } + if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC_8.s0); } + if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC_8.s1); } + if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC_8.s2); } + if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC_8.s3); } + +} + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#ifdef KQV +__kernel void mul_mm_f16_f32_kqv( +#else +__kernel void mul_mm_f16_f32_kq( +#endif + __read_only image1d_buffer_t matrix_A, + int offset0, + __global float* matrix_B, + int offset1, + __write_only image1d_buffer_t matrix_C, + int offsetd, + int M, int K, int N, + int D_A, + int D_B, + int nb01 +) +{ + + uint block_id_m = get_global_id(1); + uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N); + uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N); + + __private float16 regA; + __private float8 regB; + __private float4 regC_1; + __private float4 regC_2; + __private float4 regC_3; + __private float4 regC_4; + __private float4 regC_5; + __private float4 regC_6; + __private float4 regC_7; + __private float4 regC_8; + + const uint col = block_id_m * TILESIZE_M; + const uint row = block_id_n * TILESIZE_N; + const uint depth_A = block_id_d / (D_B/D_A); + const uint depth_B = block_id_d; + +#ifdef KQV + int line_stride_matrix_A_in_bytes = nb01 * M; + int line_stride_matrix_B_in_bytes = K * N * 4; +#else + int line_stride_matrix_A_in_bytes = K * D_A * 2; + int line_stride_matrix_B_in_bytes = K * D_B * 4; +#endif + + int line_stride_matrix_C_in_bytes = M * 4; + + const uint strideAinElements = line_stride_matrix_A_in_bytes / 2; + const uint strideBinElements = line_stride_matrix_B_in_bytes / 4; + + size_t sub_block_id_m = get_local_id(0); + + uint b_localOffsetInWords = (sub_block_id_m/16)*16 + + ((((sub_block_id_m)>>0)&1)<<2) + + ((((sub_block_id_m)>>1)&1)<<3) + + ((((sub_block_id_m)>>2)&1)<<0) + + ((((sub_block_id_m)>>3)&1)<<1); + + uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)}; + uint b_globalOffsetInWords00, b_globalOffsetInWords16; + #ifdef KQV + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K; + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K); + uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2; + uint subMatrixBStartInElements = depth_B * strideBinElements + row * K; + #else + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4; + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4); + uint subMatrixAStartInElements = col * strideAinElements + depth_A * K; + uint subMatrixBStartInElements = row * strideBinElements + depth_B * K; + #endif + + __local float matrix_B_local[1024]; + + for (uint step=0; step < K; step+=TILESIZE_K) + { + size_t sub_block_id_m = get_local_id(0); + regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes); + + + uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00; + uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16; + + regB.s0123 = vload4(b_coordInWords00/4, matrix_B); + regB.s4567 = vload4(b_coordInWords16/4, matrix_B); + + mm_mad(matrix_B_local, (float *)®C_1, (float *)®C_2, (float *)®C_3, (float *)®C_4, (float *)®C_5, (float *)®C_6, (float *)®C_7, (float *)®C_8, regA, regB, b_localOffsetInWords); + + subMatrixAStartInElements += TILESIZE_K; + subMatrixBStartInElements += TILESIZE_K; + + } + + uint subMatrixCStartInElements = depth_B * N * M + row * M + col; + mm_store_c_N(matrix_C, regC_1, regC_2, regC_3, regC_4,regC_5, regC_6, regC_7,regC_8, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); + +} \ No newline at end of file From 24f32df423352a7fac3c39898bbeffceb372afb9 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Mon, 27 Oct 2025 11:55:54 -0700 Subject: [PATCH 2/7] Add ggml_cl_mul_mat_kq_kqv_adreno func --- ggml/src/ggml-opencl/ggml-opencl.cpp | 277 ++++++++++++++++----------- 1 file changed, 160 insertions(+), 117 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2bed54cc009bb..ec7cc3e3736c3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6651,6 +6651,164 @@ static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } +static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + + int r2 = ne12/ne02; + int r3 = ne13/ne03; + + GGML_ASSERT(ne00 == ne10); + + cl_kernel kernel; + cl_context context = backend_ctx->context; + + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d; + cl_mem A_sub_buffer; + cl_mem B_sub_buffer; + cl_mem D_image1d; + cl_mem D_sub_buffer; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (nb01 > nb02) { + // KQ + kernel = backend_ctx->kernel_mul_mm_f16_f32_kq; + } else { + // KQV + kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv; + } + // create sub-buffer for A + // <--------------------------------------------> // + extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra; + + region.origin = (extra0->offset); + if (nb01 > nb02) { + // KQ + region.size = nb01 * ne01; + } else { + // KQV + region.size = nb02 * ne02; + } + + A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // <--------------------------------------------> // + + // create sub-buffer for B + // <--------------------------------------------> // + region.origin = (extra1->offset); + region.size = nb10 * ne10 * ne11 * ne12; + B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + // <--------------------------------------------> // + + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + if (nb01 > nb02) { + img_desc_1d.image_width = (nb01 * ne01 / 4)/4; + } + else { + img_desc_1d.image_width = (nb02 * ne02 / 4)/4; + } + img_desc_1d.buffer = A_sub_buffer; + A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // create sub-buffer for output C + // <--------------------------------------------> // + region.origin = (extrad->offset); + region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes + D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + // <--------------------------------------------> // + + // create image for C output + // <--------------------------------------------> // + img_fmt_1d = {CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4; + img_desc_1d.buffer = D_sub_buffer; + D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + // <--------------------------------------------> // + + uint offset_src0 = 0; + uint offset_src1 = 0; + + // set kernel args + // <--------------------------------------------> // + cl_uint k_arg = 0; + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01)); + + size_t global_work_size[3] = {64, static_cast(((M+63)/64)), static_cast(((N+31)/32)*ne12)}; + size_t local_work_size[3] = {64, 1, 2}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + // deallocate sub buffers and images + // <--------------------------------------------> // + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(D_image1d)); + CL_CHECK(clReleaseMemObject(A_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(D_sub_buffer)); + // <--------------------------------------------> // + + return; + +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -6717,125 +6875,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_context context = backend_ctx->context; - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - cl_buffer_region region; - cl_mem A_image1d; - cl_mem A_sub_buffer; - cl_mem B_sub_buffer; - cl_mem D_image1d; - cl_mem D_sub_buffer; - - int M = ne01; - int N = ne1; - int K = ne00; - if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - - if (M >= 64 && N >= 32 && K >= 16 && (ne12 % ne02) == 0){ - if (nb01 > nb02) { - // KQ - kernel = backend_ctx->kernel_mul_mm_f16_f32_kq; - } else { - // KQV - kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv; - } - // create sub-buffer for A - // <--------------------------------------------> // - extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra; - - region.origin = (extra0->offset); - if (nb01 > nb02) { - // KQ - region.size = nb01 * ne01; - } else { - // KQV - region.size = nb02 * ne02; - } - - A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // <--------------------------------------------> // - - // create sub-buffer for B - // <--------------------------------------------> // - region.origin = (extra1->offset); - region.size = nb10 * ne10 * ne11 * ne12; - B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - // <--------------------------------------------> // - - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - if (nb01 > nb02) { - img_desc_1d.image_width = (nb01 * ne01 / 4)/4; - } - else { - img_desc_1d.image_width = (nb02 * ne02 / 4)/4; - } - img_desc_1d.buffer = A_sub_buffer; - A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - - // create sub-buffer for output C - // <--------------------------------------------> // - region.origin = (extrad->offset); - region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes - D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - // <--------------------------------------------> // - - // create image for C output - // <--------------------------------------------> // - img_fmt_1d = {CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4; - img_desc_1d.buffer = D_sub_buffer; - D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - // <--------------------------------------------> // - - // offsets = 0 when using image - int offset0 = 0; - int offset1 = 0; - - // set kernel args - // <--------------------------------------------> // - cl_uint k_arg = 0; - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01)); - - size_t global_work_size[3] = {64, static_cast(((M+63)/64)), static_cast(((N+31)/32)*ne12)}; - size_t local_work_size[3] = {64, 1, 2}; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - - // deallocate sub buffers and images - // <--------------------------------------------> // - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(D_image1d)); - CL_CHECK(clReleaseMemObject(A_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(D_sub_buffer)); - // <--------------------------------------------> // - + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0){ + ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); return; - } } From dada5171fe2934ee98a59e789717e656a0b818c5 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Mon, 27 Oct 2025 12:27:16 -0700 Subject: [PATCH 3/7] fix whitespace --- ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl index 6112e8d89804d..798f773a8d3f7 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl @@ -268,4 +268,4 @@ __kernel void mul_mm_f16_f32_kq( uint subMatrixCStartInElements = depth_B * N * M + row * M + col; mm_store_c_N(matrix_C, regC_1, regC_2, regC_3, regC_4,regC_5, regC_6, regC_7,regC_8, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); -} \ No newline at end of file +} From 0fc4b8bd96b12e7afe5d8fd8b5c065d1a40c40c3 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Fri, 7 Nov 2025 15:01:17 -0800 Subject: [PATCH 4/7] remove unused variable --- ggml/src/ggml-opencl/ggml-opencl.cpp | 40 +++++++++------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ec7cc3e3736c3..e6d248a8f1aa3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6658,35 +6658,21 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; - - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; - const int ne0 = dst ? dst->ne[0] : 0; - const int ne1 = dst ? dst->ne[1] : 0; + const cl_ulong nb10 = src1->nb[0]; - int r2 = ne12/ne02; - int r3 = ne13/ne03; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; GGML_ASSERT(ne00 == ne10); @@ -6772,8 +6758,8 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(status); // <--------------------------------------------> // - uint offset_src0 = 0; - uint offset_src1 = 0; + int offset_src0 = 0; + int offset_src1 = 0; // set kernel args // <--------------------------------------------> // From 301662b29ce6f3a8578e03dd391d1196350a6a71 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Fri, 7 Nov 2025 15:30:47 -0800 Subject: [PATCH 5/7] remove redundant --- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e6d248a8f1aa3..a5713d3910280 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6789,10 +6789,6 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(A_sub_buffer)); CL_CHECK(clReleaseMemObject(B_sub_buffer)); CL_CHECK(clReleaseMemObject(D_sub_buffer)); - // <--------------------------------------------> // - - return; - } static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From 41bf54f89678821b30fde3470fc7b17cb109077b Mon Sep 17 00:00:00 2001 From: shaoqi Date: Tue, 11 Nov 2025 13:52:48 -0800 Subject: [PATCH 6/7] refactor and clean up --- .../kernels/mul_mm_f16_f32_kq_kqv.cl | 305 +++++++++--------- 1 file changed, 153 insertions(+), 152 deletions(-) diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl index 798f773a8d3f7..b7c4054b3d4dd 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl @@ -7,96 +7,109 @@ #define LM_FOURTH_256B 192 -inline float16 mm_load_a(image1d_buffer_t matrix_A, uint subMatrixAStartInElements, int nb01, int line_stride_matrix_A_in_bytes) -{ +inline float16 mm_load_a( + image1d_buffer_t matrix_A, + uint subMatrixAStartInElements, + int nb01, + int line_stride_matrix_A_in_bytes +) { __private float8 regA; size_t sub_block_id_m = get_local_id(0); #ifdef KQV - uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4); + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4); #else // KQ - uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4); + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4); #endif - regA.s0123 = read_imagef(matrix_A, a_texCoord/4); - regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4); + regA.s0123 = read_imagef(matrix_A, a_texCoord/4); + regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4); return convert_float16(as_half16(regA)); } -inline void alu_32(float* regC, float16 regA, __local float* matrix_B_local, int wave_offset) { - - __private float4 rC = 0; - int i = wave_offset; - - rC += regA.s0 * ((__local float4*)matrix_B_local)[i]; - rC += regA.s1 * ((__local float4*)matrix_B_local)[i + 16]; - rC += regA.s4 * ((__local float4*)matrix_B_local)[i + 1]; - rC += regA.s5 * ((__local float4*)matrix_B_local)[i + 17]; - rC += regA.s8 * ((__local float4*)matrix_B_local)[i + 2]; - rC += regA.s9 * ((__local float4*)matrix_B_local)[i + 18]; - rC += regA.sc * ((__local float4*)matrix_B_local)[i + 3]; - rC += regA.sd * ((__local float4*)matrix_B_local)[i + 19]; - - i += 32; - - rC += regA.s2 * ((__local float4*)matrix_B_local)[i]; - rC += regA.s3 * ((__local float4*)matrix_B_local)[i + 16]; - rC += regA.s6 * ((__local float4*)matrix_B_local)[i + 1]; - rC += regA.s7 * ((__local float4*)matrix_B_local)[i + 17]; - rC += regA.sa * ((__local float4*)matrix_B_local)[i + 2]; - rC += regA.sb * ((__local float4*)matrix_B_local)[i + 18]; - rC += regA.se * ((__local float4*)matrix_B_local)[i + 3]; - rC += regA.sf * ((__local float4*)matrix_B_local)[i + 19]; +inline float4 alu_32( + float16 regA, + __local float4* matrix_B_vec +) { + + __private float4 rC = 0; + int i = get_sub_group_id() * 64; + + rC += regA.s0 * matrix_B_vec[i]; + rC += regA.s1 * matrix_B_vec[i + 16]; + rC += regA.s4 * matrix_B_vec[i + 1]; + rC += regA.s5 * matrix_B_vec[i + 17]; + rC += regA.s8 * matrix_B_vec[i + 2]; + rC += regA.s9 * matrix_B_vec[i + 18]; + rC += regA.sc * matrix_B_vec[i + 3]; + rC += regA.sd * matrix_B_vec[i + 19]; + + i += 32; + + rC += regA.s2 * matrix_B_vec[i]; + rC += regA.s3 * matrix_B_vec[i + 16]; + rC += regA.s6 * matrix_B_vec[i + 1]; + rC += regA.s7 * matrix_B_vec[i + 17]; + rC += regA.sa * matrix_B_vec[i + 2]; + rC += regA.sb * matrix_B_vec[i + 18]; + rC += regA.se * matrix_B_vec[i + 3]; + rC += regA.sf * matrix_B_vec[i + 19]; - float4* regC_vec = (float4*)regC; - *regC_vec += rC; + return rC; } -inline void mm_mad(__local float* matrix_B_local, float* regC1_ptr, float* regC2_ptr, float* regC3_ptr, float* regC4_ptr, float* regC5_ptr, float* regC6_ptr, float* regC7_ptr, float* regC8_ptr, float16 regA, float8 regB, uint b_localOffsetInWords) -{ +inline float16 alu_16( + float16 regA, + __local float* matrix_B_local +) { + float16 out; + __local float4* matrix_B_vec = (__local float4*)matrix_B_local; - short linearIndex = get_local_id(0); + out.s0123 = alu_32(regA, matrix_B_vec); + out.s4567 = alu_32(regA, matrix_B_vec + 4); + out.s89ab = alu_32(regA, matrix_B_vec + 8); + out.scdef = alu_32(regA, matrix_B_vec + 12); - int wave_offset = get_sub_group_id() * 64; + return out; +} + +inline void mm_mad( + __local float* matrix_B_local, + float16 regA, + float8 regB, + uint b_localOffsetInWords, + float16* regC0_ptr, + float16* regC1_ptr +) { int offset = b_localOffsetInWords + get_sub_group_id() * 256; - + matrix_B_local[offset + LM_FIRST_256B] = regB.s0; matrix_B_local[offset + LM_SECOND_256B] = regB.s1; - matrix_B_local[offset + LM_THIRD_256B] = regB.s2; matrix_B_local[offset + LM_FOURTH_256B] = regB.s3; - alu_32(regC1_ptr, regA, matrix_B_local + 0, wave_offset); + float16 add0 = alu_16(regA, matrix_B_local); + *regC0_ptr += add0; - alu_32(regC2_ptr, regA, matrix_B_local + 16, wave_offset); - - alu_32(regC3_ptr, regA, matrix_B_local + 32, wave_offset); - - alu_32(regC4_ptr, regA, matrix_B_local + 48, wave_offset); - - matrix_B_local[offset + LM_FIRST_256B] = regB.s4; matrix_B_local[offset + LM_SECOND_256B] = regB.s5; - matrix_B_local[offset + LM_THIRD_256B] = regB.s6; matrix_B_local[offset + LM_FOURTH_256B] = regB.s7; - alu_32(regC5_ptr, regA, matrix_B_local + 0, wave_offset); - - alu_32(regC6_ptr, regA, matrix_B_local + 16, wave_offset); - - alu_32(regC7_ptr, regA, matrix_B_local + 32, wave_offset); - - alu_32(regC8_ptr, regA, matrix_B_local + 48, wave_offset); - + float16 add1 = alu_16(regA, matrix_B_local); + *regC1_ptr += add1; } -inline void mm_store_c_N(__write_only image1d_buffer_t matrix_C, float4 regC_1, float4 regC_2, float4 regC_3, float4 regC_4, -float4 regC_5, float4 regC_6, float4 regC_7, float4 regC_8, uint subMatrixCStartInElements, int line_stride_matrix_C_in_bytes, int mask) -{ +inline void mm_store_c_N( + __write_only image1d_buffer_t matrix_C, + float16 regC0, + float16 regC1, + uint subMatrixCStartInElements, + int line_stride_matrix_C_in_bytes, + int mask +) { size_t sub_block_id_m = get_local_id(0); - short linearIndex = get_local_id(0); uint strideInWords = line_stride_matrix_C_in_bytes/4; uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m); @@ -133,39 +146,38 @@ float4 regC_5, float4 regC_6, float4 regC_7, float4 regC_8, uint subMatrixCStar uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords; uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords; - if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC_1.s0); } - if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC_1.s1); } - if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC_1.s2); } - if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC_1.s3); } - if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC_2.s0); } - if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC_2.s1); } - if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC_2.s2); } - if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC_2.s3); } - if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC_3.s0); } - if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC_3.s1); } - if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC_3.s2); } - if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC_3.s3); } - if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC_4.s0); } - if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC_4.s1); } - if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC_4.s2); } - if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC_4.s3); } - if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC_5.s0); } - if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC_5.s1); } - if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC_5.s2); } - if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC_5.s3); } - if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC_6.s0); } - if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC_6.s1); } - if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC_6.s2); } - if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC_6.s3); } - if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC_7.s0); } - if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC_7.s1); } - if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC_7.s2); } - if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC_7.s3); } - if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC_8.s0); } - if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC_8.s1); } - if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC_8.s2); } - if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC_8.s3); } - + if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); } + if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); } + if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); } + if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); } + if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); } + if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); } + if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); } + if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); } + if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); } + if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); } + if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); } + if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); } + if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); } + if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); } + if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); } + if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); } + if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); } + if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); } + if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); } + if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); } + if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); } + if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); } + if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); } + if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); } + if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); } + if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); } + if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); } + if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); } + if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); } + if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); } + if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); } + if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); } } #define TILESIZE_K 16 @@ -186,86 +198,75 @@ __kernel void mul_mm_f16_f32_kq( int D_A, int D_B, int nb01 -) -{ +) { - uint block_id_m = get_global_id(1); - uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N); - uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N); + uint block_id_m = get_global_id(1); + uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N); + uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N); - __private float16 regA; - __private float8 regB; - __private float4 regC_1; - __private float4 regC_2; - __private float4 regC_3; - __private float4 regC_4; - __private float4 regC_5; - __private float4 regC_6; - __private float4 regC_7; - __private float4 regC_8; - - const uint col = block_id_m * TILESIZE_M; - const uint row = block_id_n * TILESIZE_N; - const uint depth_A = block_id_d / (D_B/D_A); - const uint depth_B = block_id_d; + __private float16 regA; + __private float8 regB; + __private float16 regC0; + __private float16 regC1; + + const uint col = block_id_m * TILESIZE_M; + const uint row = block_id_n * TILESIZE_N; + const uint depth_A = block_id_d / (D_B/D_A); + const uint depth_B = block_id_d; #ifdef KQV - int line_stride_matrix_A_in_bytes = nb01 * M; - int line_stride_matrix_B_in_bytes = K * N * 4; + int line_stride_matrix_A_in_bytes = nb01 * M; + int line_stride_matrix_B_in_bytes = K * N * 4; #else - int line_stride_matrix_A_in_bytes = K * D_A * 2; - int line_stride_matrix_B_in_bytes = K * D_B * 4; + int line_stride_matrix_A_in_bytes = K * D_A * 2; + int line_stride_matrix_B_in_bytes = K * D_B * 4; #endif - int line_stride_matrix_C_in_bytes = M * 4; + int line_stride_matrix_C_in_bytes = M * 4; - const uint strideAinElements = line_stride_matrix_A_in_bytes / 2; - const uint strideBinElements = line_stride_matrix_B_in_bytes / 4; + const uint strideAinElements = line_stride_matrix_A_in_bytes / 2; + const uint strideBinElements = line_stride_matrix_B_in_bytes / 4; - size_t sub_block_id_m = get_local_id(0); + size_t sub_block_id_m = get_local_id(0); - uint b_localOffsetInWords = (sub_block_id_m/16)*16 + uint b_localOffsetInWords = (sub_block_id_m/16)*16 + ((((sub_block_id_m)>>0)&1)<<2) + ((((sub_block_id_m)>>1)&1)<<3) + ((((sub_block_id_m)>>2)&1)<<0) + ((((sub_block_id_m)>>3)&1)<<1); - uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)}; - uint b_globalOffsetInWords00, b_globalOffsetInWords16; - #ifdef KQV - b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K; - b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K); - uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2; - uint subMatrixBStartInElements = depth_B * strideBinElements + row * K; - #else - b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4; - b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4); - uint subMatrixAStartInElements = col * strideAinElements + depth_A * K; - uint subMatrixBStartInElements = row * strideBinElements + depth_B * K; - #endif - - __local float matrix_B_local[1024]; - - for (uint step=0; step < K; step+=TILESIZE_K) - { - size_t sub_block_id_m = get_local_id(0); - regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes); - + uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)}; + uint b_globalOffsetInWords00, b_globalOffsetInWords16; +#ifdef KQV + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K; + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K); + uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2; + uint subMatrixBStartInElements = depth_B * strideBinElements + row * K; +#else + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4; + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4); + uint subMatrixAStartInElements = col * strideAinElements + depth_A * K; + uint subMatrixBStartInElements = row * strideBinElements + depth_B * K; +#endif - uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00; - uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16; + __local float matrix_B_local[1024]; - regB.s0123 = vload4(b_coordInWords00/4, matrix_B); - regB.s4567 = vload4(b_coordInWords16/4, matrix_B); + for (uint step=0; step < K; step+=TILESIZE_K) { + size_t sub_block_id_m = get_local_id(0); + regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes); + + uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00; + uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16; - mm_mad(matrix_B_local, (float *)®C_1, (float *)®C_2, (float *)®C_3, (float *)®C_4, (float *)®C_5, (float *)®C_6, (float *)®C_7, (float *)®C_8, regA, regB, b_localOffsetInWords); + regB.s0123 = vload4(b_coordInWords00/4, matrix_B); + regB.s4567 = vload4(b_coordInWords16/4, matrix_B); - subMatrixAStartInElements += TILESIZE_K; - subMatrixBStartInElements += TILESIZE_K; + mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1); + subMatrixAStartInElements += TILESIZE_K; + subMatrixBStartInElements += TILESIZE_K; } - uint subMatrixCStartInElements = depth_B * N * M + row * M + col; - mm_store_c_N(matrix_C, regC_1, regC_2, regC_3, regC_4,regC_5, regC_6, regC_7,regC_8, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); - + uint subMatrixCStartInElements = depth_B * N * M + row * M + col; + mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); } From b3ee2ab0e3559723bd21b042ecde8b7917729db4 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Thu, 13 Nov 2025 11:00:25 -0800 Subject: [PATCH 7/7] remove trailing whitespace --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- .../kernels/mul_mm_f16_f32_kq_kqv.cl | 545 +++++++++--------- 2 files changed, 274 insertions(+), 273 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a5713d3910280..620516772e972 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6692,7 +6692,7 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten int M = ne01; int N = ne1; int K = ne00; - + if (nb01 > nb02) { // KQ kernel = backend_ctx->kernel_mul_mm_f16_f32_kq; diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl index b7c4054b3d4dd..ac0274b64fc0e 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl @@ -1,272 +1,273 @@ -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#pragma OPENCL EXTENSION cl_khr_subgroups : enable - -#define LM_FIRST_256B 0 -#define LM_SECOND_256B 64 -#define LM_THIRD_256B 128 -#define LM_FOURTH_256B 192 - - -inline float16 mm_load_a( - image1d_buffer_t matrix_A, - uint subMatrixAStartInElements, - int nb01, - int line_stride_matrix_A_in_bytes -) { - __private float8 regA; - size_t sub_block_id_m = get_local_id(0); - -#ifdef KQV - uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4); -#else // KQ - uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4); -#endif - - regA.s0123 = read_imagef(matrix_A, a_texCoord/4); - regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4); - - return convert_float16(as_half16(regA)); -} - -inline float4 alu_32( - float16 regA, - __local float4* matrix_B_vec -) { - - __private float4 rC = 0; - int i = get_sub_group_id() * 64; - - rC += regA.s0 * matrix_B_vec[i]; - rC += regA.s1 * matrix_B_vec[i + 16]; - rC += regA.s4 * matrix_B_vec[i + 1]; - rC += regA.s5 * matrix_B_vec[i + 17]; - rC += regA.s8 * matrix_B_vec[i + 2]; - rC += regA.s9 * matrix_B_vec[i + 18]; - rC += regA.sc * matrix_B_vec[i + 3]; - rC += regA.sd * matrix_B_vec[i + 19]; - - i += 32; - - rC += regA.s2 * matrix_B_vec[i]; - rC += regA.s3 * matrix_B_vec[i + 16]; - rC += regA.s6 * matrix_B_vec[i + 1]; - rC += regA.s7 * matrix_B_vec[i + 17]; - rC += regA.sa * matrix_B_vec[i + 2]; - rC += regA.sb * matrix_B_vec[i + 18]; - rC += regA.se * matrix_B_vec[i + 3]; - rC += regA.sf * matrix_B_vec[i + 19]; - - return rC; -} - -inline float16 alu_16( - float16 regA, - __local float* matrix_B_local -) { - float16 out; - __local float4* matrix_B_vec = (__local float4*)matrix_B_local; - - out.s0123 = alu_32(regA, matrix_B_vec); - out.s4567 = alu_32(regA, matrix_B_vec + 4); - out.s89ab = alu_32(regA, matrix_B_vec + 8); - out.scdef = alu_32(regA, matrix_B_vec + 12); - - return out; -} - -inline void mm_mad( - __local float* matrix_B_local, - float16 regA, - float8 regB, - uint b_localOffsetInWords, - float16* regC0_ptr, - float16* regC1_ptr -) { - int offset = b_localOffsetInWords + get_sub_group_id() * 256; - - matrix_B_local[offset + LM_FIRST_256B] = regB.s0; - matrix_B_local[offset + LM_SECOND_256B] = regB.s1; - matrix_B_local[offset + LM_THIRD_256B] = regB.s2; - matrix_B_local[offset + LM_FOURTH_256B] = regB.s3; - - float16 add0 = alu_16(regA, matrix_B_local); - *regC0_ptr += add0; - - matrix_B_local[offset + LM_FIRST_256B] = regB.s4; - matrix_B_local[offset + LM_SECOND_256B] = regB.s5; - matrix_B_local[offset + LM_THIRD_256B] = regB.s6; - matrix_B_local[offset + LM_FOURTH_256B] = regB.s7; - - float16 add1 = alu_16(regA, matrix_B_local); - *regC1_ptr += add1; -} - -inline void mm_store_c_N( - __write_only image1d_buffer_t matrix_C, - float16 regC0, - float16 regC1, - uint subMatrixCStartInElements, - int line_stride_matrix_C_in_bytes, - int mask -) { - size_t sub_block_id_m = get_local_id(0); - - uint strideInWords = line_stride_matrix_C_in_bytes/4; - uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m); - - uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords; - uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords; - uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords; - uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords; - uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords; - uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords; - uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords; - uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords; - uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords; - uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords; - uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords; - uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords; - uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords; - uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords; - uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords; - uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords; - uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords; - uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords; - uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords; - uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords; - uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords; - uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords; - uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords; - uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords; - uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords; - uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords; - uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords; - uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords; - uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords; - uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords; - uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords; - - if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); } - if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); } - if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); } - if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); } - if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); } - if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); } - if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); } - if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); } - if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); } - if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); } - if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); } - if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); } - if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); } - if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); } - if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); } - if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); } - if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); } - if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); } - if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); } - if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); } - if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); } - if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); } - if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); } - if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); } - if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); } - if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); } - if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); } - if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); } - if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); } - if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); } - if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); } - if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); } -} - -#define TILESIZE_K 16 -#define TILESIZE_M 64 -#define TILESIZE_N 32 -#ifdef KQV -__kernel void mul_mm_f16_f32_kqv( -#else -__kernel void mul_mm_f16_f32_kq( -#endif - __read_only image1d_buffer_t matrix_A, - int offset0, - __global float* matrix_B, - int offset1, - __write_only image1d_buffer_t matrix_C, - int offsetd, - int M, int K, int N, - int D_A, - int D_B, - int nb01 -) { - - uint block_id_m = get_global_id(1); - uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N); - uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N); - - __private float16 regA; - __private float8 regB; - __private float16 regC0; - __private float16 regC1; - - const uint col = block_id_m * TILESIZE_M; - const uint row = block_id_n * TILESIZE_N; - const uint depth_A = block_id_d / (D_B/D_A); - const uint depth_B = block_id_d; - -#ifdef KQV - int line_stride_matrix_A_in_bytes = nb01 * M; - int line_stride_matrix_B_in_bytes = K * N * 4; -#else - int line_stride_matrix_A_in_bytes = K * D_A * 2; - int line_stride_matrix_B_in_bytes = K * D_B * 4; -#endif - - int line_stride_matrix_C_in_bytes = M * 4; - - const uint strideAinElements = line_stride_matrix_A_in_bytes / 2; - const uint strideBinElements = line_stride_matrix_B_in_bytes / 4; - - size_t sub_block_id_m = get_local_id(0); - - uint b_localOffsetInWords = (sub_block_id_m/16)*16 - + ((((sub_block_id_m)>>0)&1)<<2) - + ((((sub_block_id_m)>>1)&1)<<3) - + ((((sub_block_id_m)>>2)&1)<<0) - + ((((sub_block_id_m)>>3)&1)<<1); - - uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)}; - uint b_globalOffsetInWords00, b_globalOffsetInWords16; -#ifdef KQV - b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K; - b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K); - uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2; - uint subMatrixBStartInElements = depth_B * strideBinElements + row * K; -#else - b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4; - b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4); - uint subMatrixAStartInElements = col * strideAinElements + depth_A * K; - uint subMatrixBStartInElements = row * strideBinElements + depth_B * K; -#endif - - __local float matrix_B_local[1024]; - - for (uint step=0; step < K; step+=TILESIZE_K) { - size_t sub_block_id_m = get_local_id(0); - regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes); - - uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00; - uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16; - - regB.s0123 = vload4(b_coordInWords00/4, matrix_B); - regB.s4567 = vload4(b_coordInWords16/4, matrix_B); - - mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1); - - subMatrixAStartInElements += TILESIZE_K; - subMatrixBStartInElements += TILESIZE_K; - } - - uint subMatrixCStartInElements = depth_B * N * M + row * M + col; - mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); -} +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#define LM_FIRST_256B 0 +#define LM_SECOND_256B 64 +#define LM_THIRD_256B 128 +#define LM_FOURTH_256B 192 + + +inline float16 mm_load_a( + image1d_buffer_t matrix_A, + uint subMatrixAStartInElements, + int nb01, + int line_stride_matrix_A_in_bytes +) { + __private float8 regA; + size_t sub_block_id_m = get_local_id(0); + +#ifdef KQV + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4); +#else // KQ + uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4); +#endif + + regA.s0123 = read_imagef(matrix_A, a_texCoord/4); + regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4); + + return convert_float16(as_half16(regA)); +} + +inline float4 alu_32( + float16 regA, + __local float4* matrix_B_vec +) { + + __private float4 rC = 0; + int i = get_sub_group_id() * 64; + + rC += regA.s0 * matrix_B_vec[i]; + rC += regA.s1 * matrix_B_vec[i + 16]; + rC += regA.s4 * matrix_B_vec[i + 1]; + rC += regA.s5 * matrix_B_vec[i + 17]; + rC += regA.s8 * matrix_B_vec[i + 2]; + rC += regA.s9 * matrix_B_vec[i + 18]; + rC += regA.sc * matrix_B_vec[i + 3]; + rC += regA.sd * matrix_B_vec[i + 19]; + + i += 32; + + rC += regA.s2 * matrix_B_vec[i]; + rC += regA.s3 * matrix_B_vec[i + 16]; + rC += regA.s6 * matrix_B_vec[i + 1]; + rC += regA.s7 * matrix_B_vec[i + 17]; + rC += regA.sa * matrix_B_vec[i + 2]; + rC += regA.sb * matrix_B_vec[i + 18]; + rC += regA.se * matrix_B_vec[i + 3]; + rC += regA.sf * matrix_B_vec[i + 19]; + + return rC; +} + +inline float16 alu_16( + float16 regA, + __local float* matrix_B_local +) { + float16 out; + __local float4* matrix_B_vec = (__local float4*)matrix_B_local; + + out.s0123 = alu_32(regA, matrix_B_vec); + out.s4567 = alu_32(regA, matrix_B_vec + 4); + out.s89ab = alu_32(regA, matrix_B_vec + 8); + out.scdef = alu_32(regA, matrix_B_vec + 12); + + return out; +} + +inline void mm_mad( + __local float* matrix_B_local, + float16 regA, + float8 regB, + uint b_localOffsetInWords, + float16* regC0_ptr, + float16* regC1_ptr +) { + int offset = b_localOffsetInWords + get_sub_group_id() * 256; + + matrix_B_local[offset + LM_FIRST_256B] = regB.s0; + matrix_B_local[offset + LM_SECOND_256B] = regB.s1; + matrix_B_local[offset + LM_THIRD_256B] = regB.s2; + matrix_B_local[offset + LM_FOURTH_256B] = regB.s3; + + float16 add0 = alu_16(regA, matrix_B_local); + *regC0_ptr += add0; + + matrix_B_local[offset + LM_FIRST_256B] = regB.s4; + matrix_B_local[offset + LM_SECOND_256B] = regB.s5; + matrix_B_local[offset + LM_THIRD_256B] = regB.s6; + matrix_B_local[offset + LM_FOURTH_256B] = regB.s7; + + float16 add1 = alu_16(regA, matrix_B_local); + *regC1_ptr += add1; +} + +inline void mm_store_c_N( + __write_only image1d_buffer_t matrix_C, + float16 regC0, + float16 regC1, + uint subMatrixCStartInElements, + int line_stride_matrix_C_in_bytes, + int mask +) { + size_t sub_block_id_m = get_local_id(0); + + uint strideInWords = line_stride_matrix_C_in_bytes/4; + uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m); + + uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords; + uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords; + uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords; + uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords; + uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords; + uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords; + uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords; + uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords; + uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords; + uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords; + uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords; + uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords; + uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords; + uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords; + uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords; + uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords; + uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords; + uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords; + uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords; + uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords; + uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords; + uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords; + uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords; + uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords; + uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords; + uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords; + uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords; + uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords; + uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords; + uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords; + uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords; + + if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); } + if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); } + if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); } + if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); } + if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); } + if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); } + if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); } + if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); } + if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); } + if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); } + if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); } + if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); } + if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); } + if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); } + if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); } + if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); } + if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); } + if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); } + if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); } + if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); } + if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); } + if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); } + if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); } + if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); } + if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); } + if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); } + if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); } + if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); } + if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); } + if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); } + if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); } + if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); } +} + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#ifdef KQV +__kernel void mul_mm_f16_f32_kqv( +#else +__kernel void mul_mm_f16_f32_kq( +#endif + __read_only image1d_buffer_t matrix_A, + int offset0, + __global float* matrix_B, + int offset1, + __write_only image1d_buffer_t matrix_C, + int offsetd, + int M, int K, int N, + int D_A, + int D_B, + int nb01 +) { + + uint block_id_m = get_global_id(1); + uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N); + uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N); + + __private float16 regA; + __private float8 regB; + __private float16 regC0; + __private float16 regC1; + + const uint col = block_id_m * TILESIZE_M; + const uint row = block_id_n * TILESIZE_N; + const uint depth_A = block_id_d / (D_B/D_A); + const uint depth_B = block_id_d; + +#ifdef KQV + int line_stride_matrix_A_in_bytes = nb01 * M; + int line_stride_matrix_B_in_bytes = K * N * 4; +#else + int line_stride_matrix_A_in_bytes = K * D_A * 2; + int line_stride_matrix_B_in_bytes = K * D_B * 4; +#endif + + int line_stride_matrix_C_in_bytes = M * 4; + + const uint strideAinElements = line_stride_matrix_A_in_bytes / 2; + const uint strideBinElements = line_stride_matrix_B_in_bytes / 4; + + size_t sub_block_id_m = get_local_id(0); + + uint b_localOffsetInWords = (sub_block_id_m/16)*16 + + ((((sub_block_id_m)>>0)&1)<<2) + + ((((sub_block_id_m)>>1)&1)<<3) + + ((((sub_block_id_m)>>2)&1)<<0) + + ((((sub_block_id_m)>>3)&1)<<1); + + uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)}; + uint b_globalOffsetInWords00, b_globalOffsetInWords16; +#ifdef KQV + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K; + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K); + uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2; + uint subMatrixBStartInElements = depth_B * strideBinElements + row * K; +#else + b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4; + b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4); + uint subMatrixAStartInElements = col * strideAinElements + depth_A * K; + uint subMatrixBStartInElements = row * strideBinElements + depth_B * K; +#endif + + __local float matrix_B_local[1024]; + + for (uint step=0; step < K; step+=TILESIZE_K) { + size_t sub_block_id_m = get_local_id(0); + regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes); + + uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00; + uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16; + + regB.s0123 = vload4(b_coordInWords00/4, matrix_B); + regB.s4567 = vload4(b_coordInWords16/4, matrix_B); + + mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1); + + subMatrixAStartInElements += TILESIZE_K; + subMatrixBStartInElements += TILESIZE_K; + } + + uint subMatrixCStartInElements = depth_B * N * M + row * M + col; + mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32)); +} +