@@ -7507,19 +7507,20 @@ static void ggml_compute_forward_rope_f32(
75077507 // row index used to determine which thread to use
75087508 int ir = 0 ;
75097509
7510+ const float theta_scale = powf (10000.0 , ((float )-2 )/n_dims );
7511+
75107512 for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
75117513 for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
75127514 const int p = (mode == 0 ? n_past + i2 : i2 );
75137515 for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
75147516 if (ir ++ < ir0 ) continue ;
75157517 if (ir > ir1 ) break ;
7516-
7518+ float theta = ( float ) p ;
75177519 for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7518- const float theta = powf (10000.0 , ((float )- i0 )/n_dims );
7519-
7520- const float cos_theta = cosf (p * theta );
7521- const float sin_theta = sinf (p * theta );
7520+ const float cos_theta = cosf (theta );
7521+ const float sin_theta = sinf (theta );
75227522
7523+ theta *= theta_scale ;
75237524 const float * const src = (float * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
75247525 float * dst_data = (float * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
75257526
@@ -7580,19 +7581,20 @@ static void ggml_compute_forward_rope_f16(
75807581 // row index used to determine which thread to use
75817582 int ir = 0 ;
75827583
7584+ const float theta_scale = powf (10000.0 , ((float )-2 )/n_dims );
7585+
75837586 for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
75847587 for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
75857588 const int p = (mode == 0 ? n_past + i2 : i2 );
75867589 for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
75877590 if (ir ++ < ir0 ) continue ;
75887591 if (ir > ir1 ) break ;
7589-
7592+ float theta = ( float ) p ;
75907593 for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7591- const float theta = powf (10000.0 , ((float )- i0 )/n_dims );
7592-
7593- const float cos_theta = cosf (p * theta );
7594- const float sin_theta = sinf (p * theta );
7594+ const float cos_theta = cosf (theta );
7595+ const float sin_theta = sinf (theta );
75957596
7597+ theta *= theta_scale ;
75967598 const ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
75977599 ggml_fp16_t * dst_data = (ggml_fp16_t * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
75987600
0 commit comments