diff --git a/bn_mp_n_root.c b/bn_mp_n_root.c index 3f959f1d4..b25291874 100644 --- a/bn_mp_n_root.c +++ b/bn_mp_n_root.c @@ -3,7 +3,8 @@ /* LibTomMath, multiple-precision integer library -- Tom St Denis */ /* SPDX-License-Identifier: Unlicense */ -/* find the n'th root of an integer +/* + * Find the n'th root of an integer. * * Result found such that (c)**b <= a and (c+1)**b > a * @@ -12,11 +13,13 @@ * which will find the root in log(N) time where * each step involves a fair bit. */ + +#ifdef LTM_USE_SMALLER_NTH_ROOT mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c) { mp_int t1, t2, t3, a_; mp_ord cmp; - int ilog2; + int ilog2; mp_err err; /* input must be positive if b is even */ @@ -75,6 +78,7 @@ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c) if ((err = mp_2expt(&t2,ilog2)) != MP_OKAY) { goto LBL_ERR; } + do { /* t1 = t2 */ if ((err = mp_copy(&t2, &t1)) != MP_OKAY) { @@ -167,4 +171,361 @@ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c) return err; } +#else /* LTM_USE_SMALLER_NTH_ROOT */ + +/* + On a system with Gnu LibC > 4 you can use + __builtin_clz() or the assembler command BSR + (Intel) but let me assure you: the function + s_floor_log2() will not be the bottleneck here. +*/ +static int s_floor_log2(mp_digit value) +{ + int r = 0; + while ((value >>= 1) != 0) { + r++; + } + return r; +} +/* + Extra version for int needed because mp_digit is + a) unsigned + b) can be any size between 8 and 64 bits + Two version with the same code, just different + input types seems silly but all the ways known to + me than can work around that are either complicated + or dependent on compiler specifics or are ugly or + all of the above. + An example for "all of the above": + #define FLOOR_ILOG2(T) \ + int s_floor_ilog2_##T(T value) { \ + int r = 0; \ + while ((value >>= 1) != 0) { \ + r++; \ + } \ + return r; \ + } + FLOOR_ILOG2(int) + FLOOR_ILOG2(mp_digit) +*/ + +/* + Here, "value" will not be negative, so it is, in theory, + possible to use the function above by casting "int" to + "mp_digit" but "mp_digit" can be smaller than "int", much + smaller. + */ +static int s_floor_ilog2(int value) +{ + int r = 0; + while ((value >>= 1) != 0) { + r++; + } + return r; +} +/* + The cut-off between Newton's method and bisection is at + about ln(x)/(ln ln (x)) * 1.2. + Floating point methods are not available, so a rough + approximation must do. + By taking the bitcount of the number as floor(log_2(x)) + and, together with ln(x) ~ floor(log2(x)) * log(2) + implemented as 69/100 * floor(log2(x)), we can get + a sufficiently good approximation. + This snippet assumes "int" is at least 16 bit wide. + TODO: check if it is possible to use mp_word instead + which is guaranteed to be at least 16 bit wide +*/ +#include +static int s_recurrence_bisection_cutoff(int value) +{ + int lnx, lnlnx; + + /* + such small values should have been handled by a nth-root + implementation with native integers + */ + if (value < 8) { + return 1; + } + + /* ln(x) ~ floor(log2(x)) * log(2) */ + if (value > ((INT_MAX / 69))) { + /* + if "value" is so big that a multiplication + with 69 overflows we can safely spend + two digits of accuracy for a better sleep. + */ + lnx = (value / 100) * 69; + } else { + lnx = ((69 * value) / 100); + } + /* ln ln x */ + lnlnx = s_floor_ilog2(lnx); + /* cannot overflow anymore here */ + lnlnx = ((69 * lnlnx) / 100); + + lnx = lnx / lnlnx; + /* floor(ln(x)/(ln ln (x))) < floor(fln2(x)/(fln2 fln2 (x))) + 1 for x >= 8 */ + lnx += 1; + /* apply twiddle factor */ + /* cannot overflow */ + lnx = ((12 * lnx) / 10); + return lnx; +} + +/* + Compute log_2(b) bits of a^(1/b) or all of them with a binary search method +*/ +static mp_err s_bisection(mp_int *a, mp_digit b, mp_int *c, int cutoff, int rootsize) +{ + mp_int low, high, mid, midpow; + mp_err err; + int comp, i = 0; + + /* force at least one run */ + if (cutoff == 0) { + cutoff = 1; + } + + if ((err = mp_init_multi(&low, &high, &mid, &midpow, NULL)) != MP_OKAY) { + return err; + } + if ((err = mp_2expt(&high, rootsize)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_2expt(&low, rootsize - 2)) != MP_OKAY) { + goto LTM_ERR; + } + while (mp_cmp(&low, &high) == MP_LT) { + if (i++ == cutoff) { + mp_exch(&high, c); + goto LTM_ERR; + } + if ((err = mp_add(&low, &high, &mid)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_div_2(&mid, &mid)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_expt_d(&mid, b, &midpow)) != MP_OKAY) { + goto LTM_ERR; + } + comp = mp_cmp(&midpow, a); + if (mp_cmp(&low, &mid) == MP_LT && comp == MP_LT) { + mp_exch(&low, &mid); + } else if (mp_cmp(&high, &mid) == MP_GT && comp == MP_GT) { + mp_exch(&high, &mid); + } else { + mp_exch(&mid, c); + goto LTM_ERR; + } + } + if ((err = mp_add_d(&mid, 1, &mid)) != MP_OKAY) { + goto LTM_ERR; + } + mp_exch(&mid, c); +LTM_ERR: + mp_clear_multi(&low, &high, &mid, &midpow, NULL); + return err; +} + +static mp_err s_newton(mp_int *a, mp_digit b, mp_int *c, int cutoff, int rootsize) +{ + mp_int xi, t1, t2; + mp_err err = MP_OKAY; + + if ((err = mp_init_multi(&xi, &t1, &t2, NULL)) != MP_OKAY) { + return err; + } + if ((err = s_bisection(a, b, &t1, cutoff, rootsize)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_add_d(&t1, 1, &xi)) != MP_OKAY) { + goto LTM_ERR; + } + while (mp_cmp(&t1, &xi) == MP_LT) { + if ((rootsize--) == 0) { + break; + } + if ((err = mp_copy(&t1, &xi)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_mul_d(&xi, b - 1, &t2)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_expt_d(&xi, b - 1, &t1)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_div(a, &t1, &t1, NULL)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_add(&t1, &t2, &t1)) != MP_OKAY) { + goto LTM_ERR; + } + if ((err = mp_div_d(&t1, b, &t1, NULL)) != MP_OKAY) { + goto LTM_ERR; + } + } + mp_exch(&xi, c); +LTM_ERR: + mp_clear_multi(&xi, &t1, &t2, NULL); + return err; +} + +mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c) +{ + mp_int A; + mp_int t1; + int cmp; + mp_err err = MP_OKAY; + int ilog2, rootsize, cutoff, even_faster; + mp_sign neg; + + /* + * Checks, balances and shortcuts + * + * if b = 0 -> MP_VAL division by zero + * if b even and a neg. -> MP_VAL non-real result + * if a = 0 and b > 0 -> 0 + * if a = 0 and b < 0 -> n/a b is unsigned + * if a = 1 -> 1 + * if a > 0 and b < 0 -> n/a b is unsigned + * if b > log_2(a) -> 1 + */ + + if (b == 0) { + return MP_VAL; + } + + if (b == 1) { + if ((err = mp_copy(a, c)) != MP_OKAY) { + return err; + } + return MP_OKAY; + } + if (b == 2) { + return mp_sqrt(a, c); + } + + /* TODO: check if an exception for unity is sensible */ + if ((a->used == 1) && (a->dp[0] == 1)) { + mp_set(c, 1); + if (a->sign == MP_NEG && ((b & 1) == 0)) { + c->sign = MP_NEG; + } + return MP_OKAY; + } + + if ((a->sign == MP_NEG) && ((b & 1) == 0)) { + return MP_VAL; + } +#if ( !(defined MP_8BIT) && !(defined MP_16BIT) ) + /* The type "mp_digit" can be bigger than int */ + if (sizeof(mp_digit) > sizeof(int) && b > INT_MAX) { + /* In that case "b" is bigger than log_2(x), hence floor(x^(1/b)) = 1 */ + mp_set(c, 1); + c->sign = a->sign; + return MP_OKAY; + } +#endif + if (mp_iszero(a)) { + mp_zero(c); + return MP_OKAY; + } + +#ifdef LTM_USE_SMALL_NTH_ROOT + if (a->used == 1) { + ilog2 = s_small_nthroot(a->dp[0], b); + mp_set(c,ilog2); + return MP_OKAY; + } +#endif + if ((err = mp_init(&A)) != MP_OKAY) { + return err; + } + if ((err = mp_copy(a, &A)) != MP_OKAY) { + goto LTM_ERR_2; + } + neg = a->sign; + A.sign = MP_ZPOS; + + ilog2 = mp_count_bits(a); + + if (ilog2 < (int)(b)) { + mp_set(c, 1uL); + c->sign = neg; + goto LTM_ERR_2; + } + + rootsize = (ilog2/(int)(b)) + 1; + cutoff = s_floor_log2(b); + + even_faster = s_recurrence_bisection_cutoff(ilog2); + if (b < (mp_digit)even_faster) { + if ((err = s_newton(&A, b, c, cutoff, rootsize)) != MP_OKAY) { + goto LTM_ERR_2; + } + } else { + if ((err = s_bisection(&A, b, c, -1, rootsize)) != MP_OKAY) { + goto LTM_ERR_2; + } + } + + if ((err = mp_init(&t1)) != MP_OKAY) { + goto LTM_ERR_2; + } + if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) { + goto LTM_ERR_1; + } + cmp = mp_cmp(&t1, &A); + if (cmp == MP_GT) { + if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) { + goto LTM_ERR_1; + } + for (;;) { + if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) { + goto LTM_ERR_1; + } + cmp = mp_cmp(&t1, &A); + if (cmp != MP_GT) { + break; + } + if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) { + goto LTM_ERR_1; + } + } + } else if (cmp == MP_LT) { + if ((err = mp_add_d(c, 1u, c)) != MP_OKAY) { + goto LTM_ERR_1; + } + for (;;) { + if ((err = mp_expt_d(c, b, &t1)) != MP_OKAY) { + goto LTM_ERR_1; + } + cmp = mp_cmp(&t1, &A); + if (cmp != MP_LT) { + break; + } + if ((err = mp_add_d(c, 1u, c)) != MP_OKAY) { + goto LTM_ERR_1; + } + } + /* Does overshoot in contrast to the other branch above */ + if (cmp != MP_EQ) { + if ((err = mp_sub_d(c, 1u, c)) != MP_OKAY) { + goto LTM_ERR_1; + } + } + } + +LTM_ERR_1: + mp_clear(&t1); +LTM_ERR_2: + mp_clear(&A); + c->sign = a->sign; + return err; +} +#endif + #endif diff --git a/demo/test.c b/demo/test.c index 8627bf269..f53885ef3 100644 --- a/demo/test.c +++ b/demo/test.c @@ -803,7 +803,7 @@ static int test_mp_sqrt(void) } mp_n_root(&a, 2uL, &c); if (mp_cmp_mag(&b, &c) != MP_EQ) { - printf("mp_sqrt() bad result!\n"); + printf("mp_sqrt() or mp_n_root() bad result!\n"); goto LBL_ERR; } } @@ -1876,7 +1876,6 @@ static int test_mp_n_root(void) "16", "15", "15" } }; - if ((e = mp_init_multi(&a, &c, &r, NULL)) != MP_OKAY) { return EXIT_FAILURE; } diff --git a/doc/bn.tex b/doc/bn.tex index 36daf8ce4..be34bdb04 100644 --- a/doc/bn.tex +++ b/doc/bn.tex @@ -1768,11 +1768,12 @@ \section{Root Finding} \begin{alltt} int mp_n_root (mp_int * a, mp_digit b, mp_int * c) \end{alltt} -This computes $c = a^{1/b}$ such that $c^b \le a$ and $(c+1)^b > a$. Will return a positive root only for even roots and return -a root with the sign of the input for odd roots. For example, performing $4^{1/2}$ will return $2$ whereas $(-8)^{1/3}$ -will return $-2$. +This computes $c = a^{1/b}$ such that $c^b \le a$ and $(c+1)^b > a$. Will return a positive root only for even roots and return a root with the sign of the input for odd roots. For example, performing $4^{1/2}$ will return $2$ whereas $(-8)^{1/3}$ will return $-2$. This algorithm uses the ``Newton Approximation'' method and will converge on the correct root fairly quickly. +An alternative algorithm can be chosen by setting the macro \texttt{LTM\_USE\_SMALLER\_NTH\_ROOT}. The smaller version is much slower\footnote{The speed is still usable} but is about 3-4 times smaller compiled. + +Neither of those functions runs in constant time, that is neither function should be used for programs that need to be cryptographically secure. The square root $c = a^{1/2}$ (with the same conditions $c^2 \le a$ and $(c+1)^2 > a$) is implemented with a faster algorithm. diff --git a/tommath.h b/tommath.h index 9334efff4..8bba5b507 100644 --- a/tommath.h +++ b/tommath.h @@ -504,6 +504,7 @@ mp_err mp_lcm(const mp_int *a, const mp_int *b, mp_int *c) MP_WUR; * returns error if a < 0 and b is even */ mp_err mp_n_root(const mp_int *a, mp_digit b, mp_int *c) MP_WUR; + MP_DEPRECATED(mp_n_root_ex) mp_err mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast) MP_WUR; /* special sqrt algo */ diff --git a/tommath_class.h b/tommath_class.h index 90c27e8c6..d0e611e97 100644 --- a/tommath_class.h +++ b/tommath_class.h @@ -629,6 +629,14 @@ # define BN_MP_SUB_D_C # define BN_MP_EXCH_C # define BN_MP_CLEAR_MULTI_C +# define BN_MP_ADD_C +# define BN_MP_DIV_2_C +# define BN_MP_DIV_D_C +# define BN_MP_SQRT_C +# define BN_MP_ISZERO_C +# define BN_MP_ZERO_C +# define BN_MP_INIT_C +# define BN_MP_CLEAR_C #endif #if defined(BN_MP_NEG_C)