Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-hexagon/ggml-hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2228,7 +2228,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess

int mode = op_params[2];

if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
return false;
}
if (mode & 1) {
Expand Down
87 changes: 80 additions & 7 deletions ggml/src/ggml-hexagon/htp/rope-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#include "hvx-utils.h"
#include "ops-utils.h"

// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
#define HTP_ROPE_TYPE_NORMAL 0
#define HTP_ROPE_TYPE_NEOX 2

#define htp_rope_preamble \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
Expand Down Expand Up @@ -146,6 +150,57 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
}

static void hvx_calc_rope_neox_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];

//const float x0 = src[0];
//const float x1 = src[num_elems/2];

//dst[0] = x0*cos_theta - x1*sin_theta;
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;

//src += 1;
//dst += 1;
// }

const uint8_t * restrict src0_curr = (const uint8_t *) src0;
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
uint8_t * restrict dst_curr = (uint8_t *) dst;

int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
int half_size = (sizeof(float) * (num_elems / 2));

for (int i = 0; i < step_of_1; i++) {
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);

HVX_Vector v2 = *(HVX_Vector *) theta_curr;
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);

HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta

HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin));
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin));
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin));
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin));

HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);

*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);

src0_curr += VLEN;
theta_curr += 2 * VLEN;
dst_curr += VLEN;
}
}

static void hvx_calc_rope_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
Expand Down Expand Up @@ -212,6 +267,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;

const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;

htp_rope_preamble;

const int32_t * pos = (const int32_t *) src1->data;
Expand Down Expand Up @@ -247,20 +305,35 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
float * dst_data_loc = dst_data;

if (1 == opt_path) {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
if (is_neox) {
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
} else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
}
} else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0];
const float sin_theta = wp0[i0 + 1];

const float x0 = src_loc[0];
const float x1 = src_loc[1];
if (is_neox) {
const float x0 = src_loc[0];
const float x1 = src_loc[rope_ctx->n_dims/2];

dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;

src_loc += 1;
dst_data_loc += 1;
} else {
const float x0 = src_loc[0];
const float x1 = src_loc[1];

dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;

src_loc += 2;
dst_data_loc += 2;
src_loc += 2;
dst_data_loc += 2;
}
}
}

Expand Down
Loading