diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ab94bc3d78f68..1bdfd14d5809c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -466,6 +466,14 @@ static constexpr std::initializer_list> rope_view_set_rows_ed { 2, 0, 1 }, // set_rows->src[0] == view }; +static constexpr std::initializer_list> rms_norm_mul_rope_view_set_rows_edges { + { 1, 0, 0 }, // mul->src[0] == rms + { 2, 0, 1 }, // rope->src[0] == mul + { 3, 0, 2 }, // view->src[0] == rope + { 4, 0, 3 }, // set_rows->src[0] == view +}; + + struct vk_device_struct { std::recursive_mutex mutex; @@ -617,6 +625,8 @@ struct vk_device_struct { vk_pipeline pipeline_rms_norm_mul_f32; vk_pipeline pipeline_rms_norm_partials_f32; vk_pipeline pipeline_rms_norm_mul_partials_f32; + vk_pipeline pipeline_rms_norm_mul_rope_f32_f32; + vk_pipeline pipeline_rms_norm_mul_rope_f32_f16; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; @@ -1060,6 +1070,7 @@ struct vk_op_diag_mask_push_constants { }; struct vk_op_rope_push_constants { + uint32_t rope_mode; uint32_t ncols; uint32_t n_dims; float freq_scale; @@ -1079,6 +1090,12 @@ struct vk_op_rope_push_constants { uint32_t set_rows_stride; }; +// For fused rms_norm+mul+rope(+view+set_rows) +struct vk_op_rms_norm_mul_rope_push_constants { + vk_op_binary_push_constants bin; + vk_op_rope_push_constants rope; +}; + struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; @@ -3557,6 +3574,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + if (device->float_controls_rte_fp16 && + sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + } + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -9908,21 +9931,149 @@ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const g return num_bytes; } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params) { +static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) { + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // const int n_ctx = ((const int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + const float freq_base = ((const float *) dst->op_params)[5]; + const float freq_scale = ((const float *) dst->op_params)[6]; + const float ext_factor = ((const float *) dst->op_params)[7]; + const float attn_factor = ((const float *) dst->op_params)[8]; + const float beta_fast = ((const float *) dst->op_params)[9]; + const float beta_slow = ((const float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4); + } + + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + + vk_op_rope_push_constants rope { + (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + has_ff, (uint32_t)src0->ne[2], nb01, nb02, + { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + }; + + return rope; +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) { + ggml_tensor * dst; + const ggml_tensor * src0; + const ggml_tensor * src1; + + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0]; + dst = mul; + src0 = cgraph->nodes[node_idx]->src[0]; + src1 = other_src; + } else { + dst = cgraph->nodes[node_idx]; + src0 = src1 = dst->src[0]; + } + const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + vk_op_binary_push_constants bin { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, op_params[0], 0.0f, (int32_t)param3, - }); + }; + + // more than one fused op means rms_norm+mul+rope + if (ctx->num_additional_fused_ops > 1) { + static constexpr uint32_t max_tensors = 7; + const ggml_tensor *tensors[max_tensors] {}; + + ggml_tensor *rms = cgraph->nodes[node_idx + 0]; + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *rope = cgraph->nodes[node_idx + 2]; + + ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; + + bool do_set_rows = ctx->num_additional_fused_ops == 4; + + tensors[0] = rms->src[0]; + tensors[1] = other_src; + tensors[2] = mul; + tensors[3] = rope->src[1]; // pos + tensors[4] = rope->src[2]; // ff + tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst + tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr; + const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0; + + vk_op_rms_norm_mul_rope_push_constants pc; + pc.bin = bin; + pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride); + + vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + ggml_backend_vk_buffer_context * buf_ctx[max_tensors]; + vk_buffer buf[max_tensors]; + size_t offset[max_tensors]; + bool uma[max_tensors]; + + for (uint32_t i = 0; i < max_tensors; ++i) { + if (!tensors[i]) { + // If any remaining descriptors are unused, just point them at src[0] + buf[i] = buf[0]; + offset[i] = 0; + continue; + } + buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; + buf[i] = nullptr; + offset[i] = 0; + uma[i] = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]); + uma[i] = buf[i] != nullptr; + } + if (!uma[i]) { + buf[i] = buf_ctx[i]->dev_buffer; + offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs; + } + GGML_ASSERT(buf[i] != nullptr); + } + + std::array elements; + elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; + + static_assert(max_tensors == 7); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + }, pc, elements); + } else { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin)); + } if (ctx->do_add_rms_partials_offset_calculation) { ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); @@ -10117,9 +10268,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons // const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const float freq_base = ((float *) dst->op_params)[5]; - const float freq_scale = ((float *) dst->op_params)[6]; - const float ext_factor = ((float *) dst->op_params)[7]; - const float attn_factor = ((float *) dst->op_params)[8]; const float beta_fast = ((float *) dst->op_params)[9]; const float beta_slow = ((float *) dst->op_params)[10]; int sections[4] {}; @@ -10127,16 +10275,9 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); } - const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; - float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); - uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); - uint32_t set_rows_stride = 0; // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride // and overrides the dst and sets src3=row_indices @@ -10146,12 +10287,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons dst = cgraph->nodes[node_idx + 2]; } - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, { - (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, - { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, - }); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, + ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride)); } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11666,6 +11803,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (n->op == GGML_OP_GLU) { std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; } + if (n->op == GGML_OP_ROPE) { + const int mode = ((const int32_t *) n->op_params)[2]; + std::cerr << " rope mode: " << mode; + } std::cerr << std::endl; } #endif @@ -11773,14 +11914,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_RMS_NORM: - if (ctx->num_additional_fused_ops > 0) { - // fused rms_norm + mul - ggml_tensor *mul = cgraph->nodes[node_idx + 1]; - ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; - ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params); - } else { - ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params); - } + ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params); break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node); @@ -12766,6 +12900,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Check whether the tensors overlap in memory but are not equal. +// Fusions can potenitally overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them +// to overlap if they are exactly equal. +// XXX TODO this check is probably missing from several fusion optimizations. +static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; + vk_buffer b_buf = b_buf_ctx->dev_buffer; + if (a_buf == b_buf) { + auto a_base = vk_tensor_offset(a) + a->view_offs; + auto a_size = ggml_nbytes(a); + auto b_base = vk_tensor_offset(b) + b->view_offs; + auto b_size = ggml_nbytes(b); + + if (a_base == b_base && a_size == b_size) { + return false; + } + + if ((b_base <= a_base && a_base < b_base + b_size) || + (a_base <= b_base && b_base < a_base + a_size)) { + return true; + } + } + return false; +} + +static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx) { + GGML_UNUSED(ctx); + const ggml_tensor *rms = cgraph->nodes[node_idx + 0]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + const ggml_tensor *rope = cgraph->nodes[node_idx + 2]; + + const int mode = ((const int32_t *) rope->op_params)[2]; + + // noncontig tensors aren't tested, and don't seem common in practice + if (!ggml_is_contiguous(rms) || + !ggml_is_contiguous(mul) || + !ggml_is_contiguous(rope)) { + return false; + } + + // only norm/neox are handled in the shader + if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) { + return false; + } + + // shared memory size for passing data from mul->rope + if (mul->ne[0] > 1024) { + return false; + } + + // must not overwrite srcs in a way that's not elementwise + ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; + if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || + ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { + return false; + } + + return true; +} + static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { const ggml_tensor *first_node = cgraph->nodes[node_idx]; @@ -12911,12 +13109,20 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; - } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && + ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && + ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && + ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { + ctx->num_additional_fused_ops = 4; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& + ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 2; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -13149,14 +13355,34 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } if (ok) { current_set.push_back(j); + + int rope_idx = j; + + // When we've found RMS_NORM + MUL, try to find a ROPE that uses it + if (j > 0 && + graph->nodes[j]->op == GGML_OP_MUL && + graph->nodes[j-1]->op == GGML_OP_RMS_NORM) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_ROPE && + graph->nodes[k]->src[0] == graph->nodes[j] && + // Check that other srcs are already valid + graph->nodes[k]->src[1]->op == GGML_OP_NONE && + (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) { + rope_idx = k; + current_set.push_back(rope_idx); + used[rope_idx] = true; + break; + } + } + } // Look for ROPE + VIEW + SET_ROWS and make them consecutive - if (graph->nodes[j]->op == GGML_OP_ROPE) { + if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) { int view_idx = -1; int set_rows_idx = -1; - for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) { + for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) { if (view_idx == -1 && graph->nodes[k]->op == GGML_OP_VIEW && - graph->nodes[k]->src[0] == graph->nodes[j]) { + graph->nodes[k]->src[0] == graph->nodes[rope_idx]) { view_idx = k; continue; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index 99595fc688c08..c1ad5172562d4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -3,6 +3,9 @@ #include "rte.glsl" #include "utils.glsl" +#if RMS_NORM_ROPE_FUSION +#include "rope_params.glsl" +#endif layout (push_constant) uniform parameter { @@ -12,11 +15,16 @@ layout (push_constant) uniform parameter uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; uint misalign_offsets; float param1; float param2; int param3; +#if RMS_NORM_ROPE_FUSION + rope_params rope; +#endif } p; +#if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#endif // true if src0/src1 are the same shape and the indices can be reused without additional modulus layout(constant_id = 0) const bool norepeat = false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index d5b211ffaa7bb..3a47949d5a657 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -3,6 +3,32 @@ #include "generic_binary_head.glsl" #include "types.glsl" +#if RMS_NORM_ROPE_FUSION + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; + +// data is passed from rms_norm -> rope through shared memory. +// rms_norm calls this data_d, rope calls this rope_data_a. +// Binding 2 is not used +shared FLOAT_TYPE rope_data_a[1024]; +#define data_d rope_data_a + +layout (binding = 3) readonly buffer R_Y {int rope_data_pos[];}; +layout (binding = 4) readonly buffer R_Z {float rope_data_ff[];}; +layout (binding = 5) writeonly buffer R_D {ROPE_D_TYPE rope_data_d[];}; +layout (binding = 6) readonly buffer R_I {uvec2 rope_data_i[];}; // indices for set_rows + +#include "rope_params.glsl" +#include "rope_funcs.glsl" + +#define GGML_ROPE_TYPE_NORMAL 0 +#define GGML_ROPE_TYPE_NEOX 2 +#define GGML_ROPE_TYPE_MROPE 8 +#define GGML_ROPE_TYPE_VISION 24 + +#endif + #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 @@ -28,8 +54,12 @@ void rms_norm(uint num_iters) { uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); +#if RMS_NORM_ROPE_FUSION + // Per-row offset in shared memory + uint32_t d_offset = 0; +#else uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); - +#endif FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { @@ -79,6 +109,18 @@ void rms_norm(uint num_iters) { data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } +#if RMS_NORM_ROPE_FUSION + barrier(); + rope_params rp = p.rope; + uint rope_row = (samp*nchannels + channel)*nrows + row; + for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { + if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { + rope_neox(t, rope_row, rp); + } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { + rope_norm(t, rope_row, rp); + } + } +#endif } void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl new file mode 100644 index 0000000000000..9726b722d1e46 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -0,0 +1,227 @@ + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +#if RMS_NORM_ROPE_FUSION + // Per-row offset in shared memory + const uint ix = i0; +#else + const uint ix = i02*p.nb02 + i01*p.nb01 + i0; +#endif + return ix; +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta, rope_params p) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} + +void rope_norm(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + uint idst = i1*ne0 + i0; + const uint ix = rope_a_coord(i0, i01, i02, p); + + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + + if (i0 >= p.n_dims) { + rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); + rope_data_d[idst + 1] = ROPE_D_TYPE(rope_data_a[ix + 1]); + + return; + } + + const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + 1]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + +void rope_neox(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0/2; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + + if (i0 >= p.n_dims) { + rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); + rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); + + return; + } + + const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims/2]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + + +void rope_multi(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + const uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + if (i0 >= p.n_dims) { + rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); + rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); + + return; + } + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (p.is_imrope != 0) { + if (sector % 3 == 1 && sector < 3 * p.sections[1]) { + theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { + theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + } else { + theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } else { + if (sector < p.sections[0]) { + theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + } + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims/2]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + +void rope_vision(const uint i0, const uint i1, rope_params p) { + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint i01 = i1 % ne1; + const uint i02 = i1 / ne1; + + const uint idst = i1*ne0 + i0/2; + const uint ix = rope_a_coord(i0/2, i01, i02, p); + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta, p); + + const float x0 = float(rope_data_a[ix + 0]); + const float x1 = float(rope_data_a[ix + p.n_dims]); + + rope_data_d[idst + 0] = ROPE_D_TYPE(x0*cos_theta - x1*sin_theta); + rope_data_d[idst + p.n_dims] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index fa2bb33394cb2..d9b4d4c03f34f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -3,56 +3,18 @@ #extension GL_EXT_shader_16bit_storage : require #include "rte.glsl" +#include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {int data_pos[];}; -layout (binding = 2) readonly buffer Z {float data_ff[];}; -layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; -layout (binding = 4) readonly buffer I {uvec2 data_i[];}; // indices for set_rows +layout (binding = 0) readonly buffer X {A_TYPE rope_data_a[];}; +layout (binding = 1) readonly buffer Y {int rope_data_pos[];}; +layout (binding = 2) readonly buffer Z {float rope_data_ff[];}; +layout (binding = 3) writeonly buffer D {ROPE_D_TYPE rope_data_d[];}; +layout (binding = 4) readonly buffer I {uvec2 rope_data_i[];}; // indices for set_rows -layout (push_constant) uniform parameter { - uint ncols; - uint n_dims; - float freq_scale; - uint p_delta_rows; - float freq_base; - float ext_factor; - float attn_factor; - float corr_dims[2]; - float theta_scale; - uint has_ff; - uint ne02; - uint s1; - uint s2; - int sections[4]; - uint is_imrope; - uint is_back; - uint set_rows_stride; -} p; - -float rope_yarn_ramp(const float low, const float high, const uint i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} -void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { - float mscale = p.attn_factor; - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = p.freq_scale * theta_extrap; - float theta = theta_interp; - if (p.ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; +layout (push_constant) uniform parameter { + rope_params pc; +}; - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); - } - // Backprogagation uses inverted rotation - if (p.is_back != 0) { - theta = -theta; - } - cos_theta = cos(theta) * mscale; - sin_theta = sin(theta) * mscale; -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 54aabcf222838..7c1fb1cd22440 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -1,70 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - const uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; - data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; - - return; - } - - const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; - const int sec_w = p.sections[1] + p.sections[0]; - const uint sector = (i0 / 2) % sect_dims; - - float theta_base = 0.0; - if (p.is_imrope != 0) { - if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } else { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } - } else { - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } - } - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims/2]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_multi(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 9f4538155a05c..68f00c180bb9f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -1,48 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. - if (p.set_rows_stride != 0) { - idst = row_x*ne0 + i0/2; - idst += data_i[channel_x].x * p.set_rows_stride; - } - - if (i0 >= p.n_dims) { - data_d[idst + i0/2 + 0] = D_TYPE(data_a[ix + i0/2 + 0]); - data_d[idst + i0/2 + 1] = D_TYPE(data_a[ix + i0/2 + 1]); - - return; - } - - const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims/2]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_neox(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index f4209ed9582aa..28a939ec6ad39 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -1,48 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - uint idst = row_dst*ne0 + i0; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; - - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. - if (p.set_rows_stride != 0) { - idst = row_x*ne0 + i0; - idst += data_i[channel_x].x * p.set_rows_stride; - } - - if (i0 >= p.n_dims) { - data_d[idst + 0] = D_TYPE(data_a[ix + 0]); - data_d[idst + 1] = D_TYPE(data_a[ix + 1]); - - return; - } - - const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + 1]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_norm(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl new file mode 100644 index 0000000000000..82f39cee349d8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -0,0 +1,27 @@ +#if !defined(GGML_ROPE_PARAMS) +#define GGML_ROPE_PARAMS + +#include "rte.glsl" + +struct rope_params { + uint rope_mode; + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; + uint ne02; + uint nb01; + uint nb02; + int sections[4]; + uint is_imrope; + uint is_back; + uint set_rows_stride; +}; + +#endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d37d1c1043f8a..ea1e0fdb41688 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -1,47 +1,11 @@ #version 450 #include "rope_head.glsl" +#include "rope_funcs.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { - return; - } - - const uint row_dst = gl_GlobalInvocationID.x; - - const uint row_x = row_dst % ne1; - const uint channel_x = row_dst / ne1; - - const uint idst = row_dst*ne0 + i0/2; - const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; - - const int sect_dims = p.sections[0] + p.sections[1]; - const int sec_w = p.sections[1] + p.sections[0]; - const uint sector = (i0 / 2) % sect_dims; - - float theta_base = 0.0; - if (sector < p.sections[0]) { - const uint p0 = sector; - theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); - } - else if (sector >= p.sections[0] && sector < sec_w) { - const uint p0 = sector - p.sections[0]; - theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); - } - - const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; - - float cos_theta, sin_theta; - rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); - - const float x0 = float(data_a[ix + 0]); - const float x1 = float(data_a[ix + p.n_dims]); - - data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); - data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); + // i1 is actually i2*nb2+i1, but the rows are contiguous + const uint i1 = gl_GlobalInvocationID.x; + rope_vision(i0, i1, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bd178875d55f6..c2e42cf006e96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -695,6 +695,8 @@ void process_shaders() { string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -840,25 +842,25 @@ void process_shaders() { string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - - string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - - string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); - - string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 967a53c63d86d..011e2be134e35 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case { } }; +// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS) +struct test_rms_norm_mul_rope : public test_case { + const std::array ne; + const float eps; + const bool multi_add; // test a sequence of adds feeding into rms_norm + const bool set_rows; + int mode; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "RMS_NORM_MUL_ROPE"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode); + } + + test_rms_norm_mul_rope(std::array ne, float eps = 1e-6f, bool multi_add = false, + bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL) + : ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); + ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); + + if (multi_add) { + a = ggml_add(ctx, ggml_add(ctx, a, b), c); + } + + a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b); + + ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); + + ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode); + + ggml_tensor * out; + + if (set_rows) { + ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0); + + ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1); + ggml_set_name(dst, "dst"); + + ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1); + ggml_set_name(row_idxs, "row_idxs"); + + out = ggml_set_rows(ctx, dst, view, row_idxs); + ggml_set_name(out, "out"); + } else { + out = rope; + } + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { + continue; + } + + init_set_rows_row_ids(t, ne[2]); + } else { + init_tensor_uniform(t); + } + } + } +}; + // GGML_OP_ARGMAX struct test_argmax : public test_case { const ggml_type type; @@ -6743,6 +6816,22 @@ static std::vector> make_test_cases_eval() { } } + for (auto multi_add : {false, true}) { + for (auto set_rows : {false, true}) { + for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) { + test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope)); + } + } + } + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); for (int64_t d_conv : {3, 4}) {