Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions demo/test.c
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,33 @@ static mp_err s_fill_with_ones(mp_int *a, int size)
return err;
}

static int test_s_mp_sqr(void)
{
mp_int a, b, c;
int i;

DOR(mp_init_multi(&a, &b, &c, NULL));

/* s_mp_mul() has a hardcoded branch to s_mul_comba if s_mul_comba is available,
so test another 10 just in case. */
for (i = 1; i < MP_MAX_COMBA + 10; i++) {
DO(s_fill_with_ones(&a, i));
DO(s_mp_sqr(&a, &b));
DO(s_mp_mul(&a, &a, &c, 2*i + 1));
EXPECT(mp_cmp(&b, &c) == MP_EQ);
DO(mp_rand(&a, i));
DO(s_mp_sqr(&a, &b));
DO(s_mp_mul(&a, &a, &c, 2*i + 1));
EXPECT(mp_cmp(&b, &c) == MP_EQ);
}

mp_clear_multi(&a, &b, &c, NULL);
return EXIT_SUCCESS;
LBL_ERR:
mp_clear_multi(&a, &b, &c, NULL);
return EXIT_FAILURE;
}

static int test_s_mp_sqr_comba(void)
{
mp_int a, r1, r2;
Expand Down Expand Up @@ -2373,6 +2400,7 @@ static int unit_tests(int argc, char **argv)
T1(mp_xor, MP_XOR),
T3(s_mp_div_recursive, ONLY_PUBLIC_API, S_MP_DIV_RECURSIVE, S_MP_DIV_SCHOOL),
T3(s_mp_div_small, ONLY_PUBLIC_API, S_MP_DIV_SMALL, S_MP_DIV_SCHOOL),
T2(s_mp_sqr, ONLY_PUBLIC_API, S_MP_SQR),
/* s_mp_mul_comba not (yet) testable because s_mp_mul branches to s_mp_mul_comba automatically */
T2(s_mp_sqr_comba, ONLY_PUBLIC_API, S_MP_SQR_COMBA),
T2(s_mp_mul_balance, ONLY_PUBLIC_API, S_MP_MUL_BALANCE),
Expand Down
23 changes: 18 additions & 5 deletions s_mp_sqr.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
r = (mp_word)a->dp[ix] * (mp_word)a->dp[iy];

/* now calculate the double precision result, note we use
* addition instead of *2 since it's easier to optimize
* addition instead of *2 since it's easier to optimize.
*/
r = (mp_word)t.dp[ix + iy] + r + r + (mp_word)u;
/* Some architectures and/or compilers seem to prefer a bit-shift nowadays */
r = (mp_word)t.dp[ix + iy] + (r<<1) + (mp_word)u;

/* store lower part */
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
Expand All @@ -50,9 +51,21 @@ mp_err s_mp_sqr(const mp_int *a, mp_int *b)
}
/* propagate upwards */
while (u != 0uL) {
r = (mp_word)t.dp[ix + iy] + (mp_word)u;
t.dp[ix + iy] = (mp_digit)(r & (mp_word)MP_MASK);
u = (mp_digit)(r >> (mp_word)MP_DIGIT_BIT);
mp_digit tmp;
/*
"u" can get bigger than MP_DIGIT_MAX and would need a bigger type
for the sum (mp_word). That is costly if mp_word is not a native
integer but a bigint from the compiler library. We do a manual
multiword addition instead.
*/
/* t.dp[ix + iy] has been masked off by MP_MASK and is hence of the correct size
and we can just add the lower part of "u". Carry is guaranteed to fit into
the type used for mp_digit, too, so we can extract it later. */
tmp = t.dp[ix + iy] + (u & MP_MASK);
/* t.dp[ix + iy] is set to the result minus the carry, carry is still in "tmp" */
t.dp[ix + iy] = tmp & MP_MASK;
/* Add high part of "u" and the carry from "tmp" to get the next "u" */
u = (u >> MP_DIGIT_BIT) + (tmp >> MP_DIGIT_BIT);
++iy;
}
}
Expand Down