@@ -2287,29 +2287,31 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
22872287 // Step1.1: prepare theta_scale exponent. if this this exponent updated, should update theta_scale_tensor.
22882288 acl_tensor_ptr acl_theta_scale_tensor;
22892289 bool theta_scale_updated = false ;
2290- if (ctx.rope_cache .theta_scale_length != theta_scale_length || ctx.rope_cache .indep_sects != indep_sects) {
2290+ if (ctx.rope_cache .theta_scale_length != theta_scale_length || ctx.rope_cache .theta_scale != theta_scale || ctx. rope_cache . indep_sects != indep_sects) {
22912291 theta_scale_updated = true ;
22922292 if (ctx.rope_cache .theta_scale_exp_host != nullptr ) {
22932293 free (ctx.rope_cache .theta_scale_exp_host );
22942294 }
22952295 ctx.rope_cache .theta_scale_exp_host = (float *) malloc (theta_scale_length * sizeof (float ));
22962296
22972297 if (!indep_sects) {
2298- for (int i = 0 ; i < theta_scale_length; i++) {
2299- ctx.rope_cache .theta_scale_exp_host [i] = i;
2298+ ctx.rope_cache .theta_scale_exp_host [0 ] = 1 ;
2299+ for (int i = 1 ; i < theta_scale_length; i++) {
2300+ ctx.rope_cache .theta_scale_exp_host [i] = ctx.rope_cache .theta_scale_exp_host [i-1 ] * theta_scale;
23002301 }
23012302 } else {
23022303 int sect_dims = sections[0 ] + sections[1 ] + sections[2 ] + sections[3 ];
23032304 int sec_w = sections[1 ] + sections[0 ];
23042305 int sec_e = sections[2 ] + sec_w;
2305- int exp = 0 ;
2306- for (int i = 0 ; i < theta_scale_length; i++) {
2306+
2307+ ctx.rope_cache .theta_scale_exp_host [0 ] = 1 ;
2308+ for (int i = 1 ; i < theta_scale_length; i++) {
23072309 int sector = i % sect_dims;
23082310 if (sector == 0 || sector == sections[0 ] || sector == sec_w || sector == sec_e) {
2311+ ctx.rope_cache .theta_scale_exp_host [i] = 1 ;
2312+ continue ;
23092313 }
2310- exp = 0 ;
2311- ctx.rope_cache .theta_scale_exp_host [i] = exp;
2312- exp++;
2314+ ctx.rope_cache .theta_scale_exp_host [i] = ctx.rope_cache .theta_scale_exp_host [i-1 ] * theta_scale;
23132315 }
23142316 }
23152317
@@ -2329,15 +2331,7 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
23292331
23302332 // Step1.2: prepare theta_scale_tensor, if both theta_scale or theta_scale_tensor's exponent updated,
23312333 // theta_scale_tensor should update.
2332- if (ctx.rope_cache .theta_scale != theta_scale || theta_scale_updated) {
2333- theta_scale_updated = true ;
2334- acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar (&theta_scale, aclDataType::ACL_FLOAT);
2335- GGML_CANN_CALL_ACLNN_OP (ctx, PowScalarTensor, acl_theta_scale.get (), acl_theta_scale_tensor.get (),
2336- acl_theta_scale_tensor.get ());
2337- float res[64 ];
2338- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
2339- ACL_CHECK (aclrtMemcpy (res, 64 *4 , ctx.rope_cache .theta_scale_cache , 64 *4 , ACL_MEMCPY_DEVICE_TO_HOST));
2340- }
2334+
23412335
23422336 // Step1.3: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
23432337 bool yarn_ramp_tensor_updated = false ;
@@ -2410,7 +2404,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
24102404 if (mrope_used) {
24112405 if (ctx.rope_cache .sections [0 ] != sections[0 ] || ctx.rope_cache .sections [1 ] != sections[1 ] ||
24122406 ctx.rope_cache .sections [2 ] != sections[2 ] || ctx.rope_cache .sections [3 ] != sections[3 ] ||
2413- ctx.rope_cache .theta_scale_length != theta_scale_length) {
2407+ ctx.rope_cache .theta_scale_length != theta_scale_length ||
2408+ ctx.rope_cache .is_imrope != is_imrope) {
24142409 if (ctx.rope_cache .position_select_index_host != nullptr ) {
24152410 free (ctx.rope_cache .position_select_index_host );
24162411 }
@@ -2461,17 +2456,6 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
24612456 sizeof (int ), theta_scale_ne, theta_scale_nb, 1 );
24622457 }
24632458
2464- ctx.rope_cache .indep_sects = indep_sects;
2465- ctx.rope_cache .theta_scale_length = theta_scale_length;
2466- ctx.rope_cache .freq_scale = freq_scale;
2467- ctx.rope_cache .theta_scale = theta_scale;
2468- ctx.rope_cache .ext_factor = ext_factor;
2469- ctx.rope_cache .sections [0 ] = sections[0 ];
2470- ctx.rope_cache .sections [1 ] = sections[1 ];
2471- ctx.rope_cache .sections [2 ] = sections[2 ];
2472- ctx.rope_cache .sections [3 ] = sections[3 ];
2473-
2474-
24752459 ggml_cann_pool_alloc freq_fac_res_allocator (ctx.pool ());
24762460 // Step2: divide by freq_factors
24772461 if (src2) {
@@ -2502,13 +2486,6 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
25022486 ggml_cann_create_tensor (src1->data , ggml_cann_type_mapping (src1->type ), ggml_type_size (src1->type ), mrope_position_ne,
25032487 mrope_position_nb, 2 );
25042488
2505- int res[128 ];
2506- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
2507- ACL_CHECK (aclrtMemcpy (res, 8 *4 , src1->data , 8 *4 , ACL_MEMCPY_DEVICE_TO_HOST));
2508-
2509- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
2510- ACL_CHECK (aclrtMemcpy (res, 64 *4 , ctx.rope_cache .position_select_index , 64 *4 , ACL_MEMCPY_DEVICE_TO_HOST));
2511-
25122489 // selected position tensor's shape is a transpose of cache tensor.
25132490 int64_t selected_position_ne[] = {position_length, theta_scale_length};
25142491 size_t selected_position_nb[] = {sizeof (float ), position_length * sizeof (float )};
@@ -2518,10 +2495,6 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
25182495 selected_position_nb, 2 );
25192496 GGML_CANN_CALL_ACLNN_OP (ctx, IndexSelect, mrope_position.get (), 0 , position_select_index_tensor.get (), acl_position_tensor.get ());
25202497
2521-
2522- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
2523- ACL_CHECK (aclrtMemcpy (res, 128 *4 , mrope_position_buffer, 128 *4 , ACL_MEMCPY_DEVICE_TO_HOST));
2524-
25252498 // transpose
25262499 int64_t transposed_ne[] = {position_length, 1 , theta_scale_length, 1 };
25272500 size_t transposed_nb[GGML_MAX_DIMS];
@@ -2551,10 +2524,6 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
25512524 ggml_cann_create_tensor (theta_buffer, ACL_FLOAT, sizeof (float ), cache_ne, cache_nb, GGML_MAX_DIMS);
25522525 aclnn_mul (ctx, acl_position_tensor.get (), acl_theta_scale_tensor.get (), acl_theta_tensor.get ());
25532526
2554- float res[128 ];
2555- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
2556- ACL_CHECK (aclrtMemcpy (res, 128 *4 , theta_buffer, 128 *4 , ACL_MEMCPY_DEVICE_TO_HOST));
2557-
25582527 // Step4: calculate sin cos.
25592528 // init sin_repeat && cos_repeat, only to accelerate first layer on each device
25602529 if (position_length > ctx.rope_cache .position_length ) {
@@ -2621,9 +2590,16 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
26212590
26222591 // Other layers use cache except first layer.
26232592 ctx.rope_cache .cached = true ;
2624- ctx.rope_cache .ext_factor = ext_factor;
2593+ ctx.rope_cache .indep_sects = indep_sects;
2594+ ctx.rope_cache .theta_scale_length = theta_scale_length;
2595+ ctx.rope_cache .freq_scale = freq_scale;
26252596 ctx.rope_cache .theta_scale = theta_scale;
2626- ctx.rope_cache .freq_scale = freq_scale;
2597+ ctx.rope_cache .ext_factor = ext_factor;
2598+ ctx.rope_cache .is_imrope = is_imrope;
2599+ ctx.rope_cache .sections [0 ] = sections[0 ];
2600+ ctx.rope_cache .sections [1 ] = sections[1 ];
2601+ ctx.rope_cache .sections [2 ] = sections[2 ];
2602+ ctx.rope_cache .sections [3 ] = sections[3 ];
26272603 ctx.rope_cache .attn_factor = attn_factor;
26282604 ctx.rope_cache .is_neox = is_neox;
26292605}
@@ -2677,11 +2653,15 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26772653 float corr_dims[2 ];
26782654 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
26792655
2680- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2656+ bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
26812657 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
26822658 const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
26832659 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
26842660
2661+ if (is_imrope || mrope_used) {
2662+ is_neox = true ;
2663+ }
2664+
26852665 // init ctx.rope_cos/rope_sin cache
26862666 aclnn_rope_cache_init (ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
26872667
0 commit comments