diff --git a/demo/test.c b/demo/test.c index 3ea16bf6..0a9bfc31 100644 --- a/demo/test.c +++ b/demo/test.c @@ -1846,6 +1846,33 @@ static mp_err s_fill_with_ones(mp_int *a, int size) return err; } +static int test_s_mp_sqr(void) +{ + mp_int a, b, c; + int i; + + DOR(mp_init_multi(&a, &b, &c, NULL)); + + /* s_mp_mul() has a hardcoded branch to s_mul_comba if s_mul_comba is available, + so test another 10 just in case. */ + for (i = 1; i < MP_MAX_COMBA + 10; i++) { + DO(s_fill_with_ones(&a, i)); + DO(s_mp_sqr(&a, &b)); + DO(s_mp_mul(&a, &a, &c, 2*i + 1)); + EXPECT(mp_cmp(&b, &c) == MP_EQ); + DO(mp_rand(&a, i)); + DO(s_mp_sqr(&a, &b)); + DO(s_mp_mul(&a, &a, &c, 2*i + 1)); + EXPECT(mp_cmp(&b, &c) == MP_EQ); + } + + mp_clear_multi(&a, &b, &c, NULL); + return EXIT_SUCCESS; +LBL_ERR: + mp_clear_multi(&a, &b, &c, NULL); + return EXIT_FAILURE; +} + static int test_s_mp_sqr_comba(void) { mp_int a, r1, r2; @@ -2373,6 +2400,7 @@ static int unit_tests(int argc, char **argv) T1(mp_xor, MP_XOR), T3(s_mp_div_recursive, ONLY_PUBLIC_API, S_MP_DIV_RECURSIVE, S_MP_DIV_SCHOOL), T3(s_mp_div_small, ONLY_PUBLIC_API, S_MP_DIV_SMALL, S_MP_DIV_SCHOOL), + T2(s_mp_sqr, ONLY_PUBLIC_API, S_MP_SQR), /* s_mp_mul_comba not (yet) testable because s_mp_mul branches to s_mp_mul_comba automatically */ T2(s_mp_sqr_comba, ONLY_PUBLIC_API, S_MP_SQR_COMBA), T2(s_mp_mul_balance, ONLY_PUBLIC_API, S_MP_MUL_BALANCE), diff --git a/s_mp_sqr.c b/s_mp_sqr.c index 4a203063..da9aa69c 100644 --- a/s_mp_sqr.c +++ b/s_mp_sqr.c @@ -38,9 +38,10 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b) r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy]; /* now calculate the double precision result, note we use - * addition instead of *2 since it's easier to optimize + * addition instead of *2 since it's easier to optimize. */ - r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u; + /* Some architectures and/or compilers seem to prefer a bit-shift nowadays */ + r = (mp_word)t.dp[ix + iy] + (r<<1) + (mp_word)u; /* store lower part */ t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); @@ -50,9 +51,21 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b) } /* propagate upwards */ while (u != 0uL) { - r = (mp_word)t.dp[ix + iy] + (mp_word)u; - t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK); - u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT); + mp_digit tmp; + /* + "u" can get bigger than MP_DIGIT_MAX and would need a bigger type + for the sum (mp_word). That is costly if mp_word is not a native + integer but a bigint from the compiler library. We do a manual + multiword addition instead. + */ + /* t.dp[ix + iy] has been masked off by MP_MASK and is hence of the correct size + and we can just add the lower part of "u". Carry is guaranteed to fit into + the type used for mp_digit, too, so we can extract it later. */ + tmp = t.dp[ix + iy] + (u & MP_MASK); + /* t.dp[ix + iy] is set to the result minus the carry, carry is still in "tmp" */ + t.dp[ix + iy] = tmp & MP_MASK; + /* Add high part of "u" and the carry from "tmp" to get the next "u" */ + u = (u >> MP_DIGIT_BIT) + (tmp >> MP_DIGIT_BIT); ++iy; } }