Skip to content

Commit

Permalink
libmbedtls: bignum: restore mbedtls_mpi_exp_mod() from v3.5.2
Browse files Browse the repository at this point in the history
The implementation of mbedtls_mpi_exp_mod() in Mbed TLS v3.6.0
introduces a large performance regression in "xtest 4011" on QEMUv7
(32-bit). One iteration of the test used to take 1.4 second on my
machine but the newer implementation now needs 23 seconds. To make
matters worse, xtest 4011 runs ten iterations so in reality it is
14 seconds vs. almost 4 minutes for the whole test.
Revert mbedtls_mpi_exp_mod() to the v3.5.2 implementation to gain the
performance back. The upstream commit that changed the algorithm is
[2]. Note that some mpi_* static functions have been made non static
and renamed mbedtls_mpi_* in the current Mbed TLS so the code from
v3.5.2 is modified accordingly.

Link: https://optee.readthedocs.io/en/latest/building/devices/qemu.html#qemu-v7 [1]
Link: Mbed-TLS/mbedtls@1ba4058 [2]
Signed-off-by: Jerome Forissier <jerome.forissier@linaro.org>
  • Loading branch information
jforissier committed Jun 5, 2024
1 parent afd7e60 commit f147fc1
Showing 1 changed file with 254 additions and 46 deletions.
300 changes: 254 additions & 46 deletions lib/libmbedtls/mbedtls/library/bignum.c
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,33 @@ void mbedtls_mpi_montred(mbedtls_mpi *A, const mbedtls_mpi *N,
mbedtls_mpi_montmul(A, &U, N, mm, T);
}

/**
* Select an MPI from a table without leaking the index.
*
* This is functionally equivalent to mbedtls_mpi_copy(R, T[idx]) except it
* reads the entire table in order to avoid leaking the value of idx to an
* attacker able to observe memory access patterns.
*
* \param[out] R Where to write the selected MPI.
* \param[in] T The table to read from.
* \param[in] T_size The number of elements in the table.
* \param[in] idx The index of the element to select;
* this must satisfy 0 <= idx < T_size.
*
* \return \c 0 on success, or a negative error code.
*/
static int mpi_select(mbedtls_mpi *R, const mbedtls_mpi *T, size_t T_size, size_t idx)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;

for (size_t i = 0; i < T_size; i++) {
MBEDTLS_MPI_CHK(mbedtls_mpi_safe_cond_assign(R, &T[i],
(unsigned char) mbedtls_ct_uint_eq(i, idx)));
}
cleanup:
return ret;
}

/*
* Sliding-window exponentiation: X = A^E mod N (HAC 14.85)
*/
Expand All @@ -1728,6 +1755,13 @@ int mbedtls_mpi_exp_mod(mbedtls_mpi *X, const mbedtls_mpi *A,
mbedtls_mpi *prec_RR)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
size_t window_bitsize;
size_t i, j, nblimbs;
size_t bufsize, nbits;
size_t exponent_bits_in_window = 0;
mbedtls_mpi_uint ei, mm, state;
mbedtls_mpi RR, T, W[(size_t) 1 << MBEDTLS_MPI_WINDOW_SIZE], WW, Apos;
int neg;

if (mbedtls_mpi_cmp_int(N, 0) <= 0 || (N->p[0] & 1) == 0) {
return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
Expand All @@ -1743,88 +1777,262 @@ int mbedtls_mpi_exp_mod(mbedtls_mpi *X, const mbedtls_mpi *A,
}

/*
* Ensure that the exponent that we are passing to the core is not NULL.
* Init temps and window size
*/
if (E->n == 0) {
ret = mbedtls_mpi_lset(X, 1);
return ret;
mbedtls_mpi_montg_init(&mm, N);
mbedtls_mpi_init(&RR); mbedtls_mpi_init(&T);
mbedtls_mpi_init(&Apos);
mbedtls_mpi_init(&WW);
memset(W, 0, sizeof(W));

i = mbedtls_mpi_bitlen(E);

window_bitsize = (i > 671) ? 6 : (i > 239) ? 5 :
(i > 79) ? 4 : (i > 23) ? 3 : 1;

#if (MBEDTLS_MPI_WINDOW_SIZE < 6)
if (window_bitsize > MBEDTLS_MPI_WINDOW_SIZE) {
window_bitsize = MBEDTLS_MPI_WINDOW_SIZE;
}
#endif

const size_t w_table_used_size = (size_t) 1 << window_bitsize;

/*
* Allocate working memory for mbedtls_mpi_core_exp_mod()
* This function is not constant-trace: its memory accesses depend on the
* exponent value. To defend against timing attacks, callers (such as RSA
* and DHM) should use exponent blinding. However this is not enough if the
* adversary can find the exponent in a single trace, so this function
* takes extra precautions against adversaries who can observe memory
* access patterns.
*
* This function performs a series of multiplications by table elements and
* squarings, and we want the prevent the adversary from finding out which
* table element was used, and from distinguishing between multiplications
* and squarings. Firstly, when multiplying by an element of the window
* W[i], we do a constant-trace table lookup to obfuscate i. This leaves
* squarings as having a different memory access patterns from other
* multiplications. So secondly, we put the accumulator in the table as
* well, and also do a constant-trace table lookup to multiply by the
* accumulator which is W[x_index].
*
* This way, all multiplications take the form of a lookup-and-multiply.
* The number of lookup-and-multiply operations inside each iteration of
* the main loop still depends on the bits of the exponent, but since the
* other operations in the loop don't have an easily recognizable memory
* trace, an adversary is unlikely to be able to observe the exact
* patterns.
*
* An adversary may still be able to recover the exponent if they can
* observe both memory accesses and branches. However, branch prediction
* exploitation typically requires many traces of execution over the same
* data, which is defeated by randomized blinding.
*/
size_t T_limbs = mbedtls_mpi_core_exp_mod_working_limbs(N->n, E->n);
mbedtls_mpi_uint *T = (mbedtls_mpi_uint *) mbedtls_calloc(T_limbs, sizeof(mbedtls_mpi_uint));
if (T == NULL) {
return MBEDTLS_ERR_MPI_ALLOC_FAILED;
}
const size_t x_index = 0;
mbedtls_mpi_init(&W[x_index]);

j = N->n + 1;
/* All W[i] including the accumulator must have at least N->n limbs for
* the mbedtls_mpi_montmul() and mbedtls_mpi_montred() calls later.
* Here we ensure that
* W[1] and the accumulator W[x_index] are large enough. later we'll grow
* other W[i] to the same length. They must not be shrunk midway through
* this function!
*/
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&W[x_index], j));
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&W[1], j));
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&T, j * 2));

mbedtls_mpi RR;
mbedtls_mpi_init_mempool(&RR);
/*
* Compensate for negative A (and correct at the end)
*/
neg = (A->s == -1);
if (neg) {
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&Apos, A));
Apos.s = 1;
A = &Apos;
}

/*
* If 1st call, pre-compute R^2 mod N
*/
if (prec_RR == NULL || prec_RR->p == NULL) {
MBEDTLS_MPI_CHK(mbedtls_mpi_core_get_mont_r2_unsafe(&RR, N));
MBEDTLS_MPI_CHK(mbedtls_mpi_lset(&RR, 1));
MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(&RR, N->n * 2 * biL));
MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(&RR, &RR, N));

if (prec_RR != NULL) {
*prec_RR = RR;
memcpy(prec_RR, &RR, sizeof(mbedtls_mpi));
}
} else {
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(prec_RR, N->n));
RR = *prec_RR;
memcpy(&RR, prec_RR, sizeof(mbedtls_mpi));
}

/*
* To preserve constness we need to make a copy of A. Using X for this to
* save memory.
* W[1] = A * R^2 * R^-1 mod N = A * R mod N
*/
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(X, A));
if (mbedtls_mpi_cmp_mpi(A, N) >= 0) {
MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(&W[1], A, N));
/* This should be a no-op because W[1] is already that large before
* mbedtls_mpi_mod_mpi(), but it's necessary to avoid an overflow
* in mbedtls_mpi_montmul() below, so let's make sure. */
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&W[1], N->n + 1));
} else {
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&W[1], A));
}

/*
* Compensate for negative A (and correct at the end).
*/
X->s = 1;
/* Note that this is safe because W[1] always has at least N->n limbs
* (it grew above and was preserved by mbedtls_mpi_copy()). */
mbedtls_mpi_montmul(&W[1], &RR, N, mm, &T);

/*
* Make sure that X is in a form that is safe for consumption by
* the core functions.
*
* - The core functions will not touch the limbs of X above N->n. The
* result will be correct if those limbs are 0, which the mod call
* ensures.
* - Also, X must have at least as many limbs as N for the calls to the
* core functions.
* W[x_index] = R^2 * R^-1 mod N = R mod N
*/
if (mbedtls_mpi_cmp_mpi(X, N) >= 0) {
MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(X, X, N));
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&W[x_index], &RR));
mbedtls_mpi_montred(&W[x_index], N, mm, &T);


if (window_bitsize > 1) {
/*
* W[i] = W[1] ^ i
*
* The first bit of the sliding window is always 1 and therefore we
* only need to store the second half of the table.
*
* (There are two special elements in the table: W[0] for the
* accumulator/result and W[1] for A in Montgomery form. Both of these
* are already set at this point.)
*/
j = w_table_used_size / 2;

MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&W[j], N->n + 1));
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&W[j], &W[1]));

for (i = 0; i < window_bitsize - 1; i++) {
mbedtls_mpi_montmul(&W[j], &W[j], N, mm, &T);
}

/*
* W[i] = W[i - 1] * W[1]
*/
for (i = j + 1; i < w_table_used_size; i++) {
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&W[i], N->n + 1));
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&W[i], &W[i - 1]));

mbedtls_mpi_montmul(&W[i], &W[1], N, mm, &T);
}
}

nblimbs = E->n;
bufsize = 0;
nbits = 0;
state = 0;

while (1) {
if (bufsize == 0) {
if (nblimbs == 0) {
break;
}

nblimbs--;

bufsize = sizeof(mbedtls_mpi_uint) << 3;
}

bufsize--;

ei = (E->p[nblimbs] >> bufsize) & 1;

/*
* skip leading 0s
*/
if (ei == 0 && state == 0) {
continue;
}

if (ei == 0 && state == 1) {
/*
* out of window, square W[x_index]
*/
MBEDTLS_MPI_CHK(mpi_select(&WW, W, w_table_used_size, x_index));
mbedtls_mpi_montmul(&W[x_index], &WW, N, mm, &T);
continue;
}

/*
* add ei to current window
*/
state = 2;

nbits++;
exponent_bits_in_window |= (ei << (window_bitsize - nbits));

if (nbits == window_bitsize) {
/*
* W[x_index] = W[x_index]^window_bitsize R^-1 mod N
*/
for (i = 0; i < window_bitsize; i++) {
MBEDTLS_MPI_CHK(mpi_select(&WW, W, w_table_used_size,
x_index));
mbedtls_mpi_montmul(&W[x_index], &WW, N, mm, &T);
}

/*
* W[x_index] = W[x_index] * W[exponent_bits_in_window] R^-1 mod N
*/
MBEDTLS_MPI_CHK(mpi_select(&WW, W, w_table_used_size,
exponent_bits_in_window));
mbedtls_mpi_montmul(&W[x_index], &WW, N, mm, &T);

state--;
nbits = 0;
exponent_bits_in_window = 0;
}
}
MBEDTLS_MPI_CHK(mbedtls_mpi_grow(X, N->n));

/*
* Convert to and from Montgomery around mbedtls_mpi_core_exp_mod().
* process the remaining bits
*/
{
mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init(N->p);
mbedtls_mpi_core_to_mont_rep(X->p, X->p, N->p, N->n, mm, RR.p, T);
mbedtls_mpi_core_exp_mod(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
mbedtls_mpi_core_from_mont_rep(X->p, X->p, N->p, N->n, mm, T);
for (i = 0; i < nbits; i++) {
MBEDTLS_MPI_CHK(mpi_select(&WW, W, w_table_used_size, x_index));
mbedtls_mpi_montmul(&W[x_index], &WW, N, mm, &T);

exponent_bits_in_window <<= 1;

if ((exponent_bits_in_window & ((size_t) 1 << window_bitsize)) != 0) {
MBEDTLS_MPI_CHK(mpi_select(&WW, W, w_table_used_size, 1));
mbedtls_mpi_montmul(&W[x_index], &WW, N, mm, &T);
}
}

/*
* Correct for negative A.
* W[x_index] = A^E * R * R^-1 mod N = A^E mod N
*/
if (A->s == -1 && (E->p[0] & 1) != 0) {
mbedtls_ct_condition_t is_x_non_zero = mbedtls_mpi_core_check_zero_ct(X->p, X->n);
X->s = mbedtls_ct_mpi_sign_if(is_x_non_zero, -1, 1);
mbedtls_mpi_montred(&W[x_index], N, mm, &T);

MBEDTLS_MPI_CHK(mbedtls_mpi_add_mpi(X, N, X));
if (neg && E->n != 0 && (E->p[0] & 1) != 0) {
W[x_index].s = -1;
MBEDTLS_MPI_CHK(mbedtls_mpi_add_mpi(&W[x_index], N, &W[x_index]));
}

/*
* Load the result in the output variable.
*/
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(X, &W[x_index]));

cleanup:

mbedtls_mpi_zeroize_and_free(T, T_limbs);
/* The first bit of the sliding window is always 1 and therefore the first
* half of the table was unused. */
for (i = w_table_used_size/2; i < w_table_used_size; i++) {
mbedtls_mpi_free(&W[i]);
}

mbedtls_mpi_free(&W[x_index]);
mbedtls_mpi_free(&W[1]);
mbedtls_mpi_free(&T);
mbedtls_mpi_free(&Apos);
mbedtls_mpi_free(&WW);

if (prec_RR == NULL || prec_RR->p == NULL) {
mbedtls_mpi_free(&RR);
Expand Down

0 comments on commit f147fc1

Please sign in to comment.