diff --git a/demo/test.c b/demo/test.c index 0ebcfd07..3ea16bf6 100644 --- a/demo/test.c +++ b/demo/test.c @@ -1828,6 +1828,51 @@ static int test_mp_root_n(void) return EXIT_FAILURE; } +/* Less error-prone than -1 + 2^n with mp_2expt */ +static mp_err s_fill_with_ones(mp_int *a, int size) +{ + int i; + mp_err err = MP_OKAY; + + mp_zero(a); + + if ((err = mp_grow(a, size)) != MP_OKAY) goto LTM_ERR; + for (i = 0; i < size; i++) { + a->dp[i] = (mp_digit)MP_MASK; + a->used++; + } + +LTM_ERR: + return err; +} + +static int test_s_mp_sqr_comba(void) +{ + mp_int a, r1, r2; + int i, j; + + DOR(mp_init_multi(&a, &r1, &r2, NULL)); + + for (i = 1; i <= MP_MAX_COMBA; i++) { + DO(s_fill_with_ones(&a, i)); + DO(s_mp_sqr_comba(&a, &r1)); + DO(s_mp_sqr(&a, &r2)); + EXPECT(mp_cmp(&r1, &r2) == MP_EQ); + for (j = 0; j < 20; j++) { + DO(mp_rand(&a, i)); + DO(s_mp_sqr_comba(&a, &r1)); + DO(s_mp_sqr(&a, &r2)); + EXPECT(mp_cmp(&r1, &r2) == MP_EQ); + } + } + + mp_clear_multi(&a, &r1, &r2, NULL); + return EXIT_SUCCESS; +LBL_ERR: + mp_clear_multi(&a, &r1, &r2, NULL); + return EXIT_FAILURE; +} + static int test_s_mp_mul_balance(void) { mp_int a, b, c; @@ -2328,6 +2373,8 @@ static int unit_tests(int argc, char **argv) T1(mp_xor, MP_XOR), T3(s_mp_div_recursive, ONLY_PUBLIC_API, S_MP_DIV_RECURSIVE, S_MP_DIV_SCHOOL), T3(s_mp_div_small, ONLY_PUBLIC_API, S_MP_DIV_SMALL, S_MP_DIV_SCHOOL), + /* s_mp_mul_comba not (yet) testable because s_mp_mul branches to s_mp_mul_comba automatically */ + T2(s_mp_sqr_comba, ONLY_PUBLIC_API, S_MP_SQR_COMBA), T2(s_mp_mul_balance, ONLY_PUBLIC_API, S_MP_MUL_BALANCE), T2(s_mp_mul_karatsuba, ONLY_PUBLIC_API, S_MP_MUL_KARATSUBA), T2(s_mp_sqr_karatsuba, ONLY_PUBLIC_API, S_MP_SQR_KARATSUBA), diff --git a/mp_mul.c b/mp_mul.c index d35fa8ef..81807406 100644 --- a/mp_mul.c +++ b/mp_mul.c @@ -23,7 +23,7 @@ mp_err mp_mul(const mp_int *a, const mp_int *b, mp_int *c) } else if ((a == b) && MP_HAS(S_MP_SQR_COMBA) && /* can we use the fast comba multiplier? */ (((a->used * 2) + 1) < MP_WARRAY) && - (a->used < (MP_MAX_COMBA / 2))) { + (a->used <= MP_MAX_COMBA)) { err = s_mp_sqr_comba(a, c); } else if ((a == b) && MP_HAS(S_MP_SQR)) {