Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: isSquare - constant-time Jacobi/Kronecker/Legendre symbol using fast GCD #199

Closed
mratsim opened this issue Aug 6, 2022 · 0 comments

Comments

@mratsim
Copy link
Owner

mratsim commented Aug 6, 2022

According to Pornin we can expect a up to 7.5x speedup (with assembly) over a naive exponentiation by (p-1)/2

See https://github.com/pornin/x25519-cm0/blob/75a53f2/src/x25519-cm0.S#L89-L155

The gf_legendre_inner() function computes the Legendre symbol for a field
element. This is not actually needed for X25519, and is included here
only because it could be helpful in other operations adjacent to
X25519, e.g. the use of the Elligator2 map for encoding/hashing values
into curve points in a constant-time way. The Legendre symbol of x is:
1 if x is a non-zero quadratic residue in the field
-1 if x is not a quadratic residue in the field
0 if x is zero
The traditional method is again Fermat's Little Theorem: for a prime
p, the Legendre symbol of x is equal to x^((p-1)/2) mod p. This would
again require about 270000 cycles for p = 2^255-19.

The algorithm implemented here is roughly the same as the binary GCD
used for inversion. It internally computes the GCD of x and p with the
exact same steps (hence, it always converges with the same number of
iterations); it does not keep track of the Bezout coefficients, since
these are not needed for a Legendre symbol; however, it follows value
updates to compute the symbol. What is actually computed is the
Kronecker symbol (x|p), with the following properties:

(x|n) is equal to the Legendre symbol of x modulo n when n is a
nonnegative odd prime.

(x|n) == (y|n) if x == y mod n and either n > 0, or x and y have
the same sign.

If n and m are not both negative, then (n|m) == (m|n), unless
both n == 3 mod 4 and m == 3 mod 4, in which case (n|m) == -(m|n).
(This is the law of quadratic reciprocity.)

(2|n) == 1 if n == 1 or 7 mod 8, or -1 if n == 3 or 5 mod 8.

In the course of the binary GCD algorithm, we work over two values a
and b, such that they both converge toward 0 and 1. b is always odd.
Each iteration consists in three successive steps:

  1. If a and b are odd and a < b, then a and b are exchanged.
  2. If a is odd, then a is replaced with a-b.
  3. a <- a/2

When adapted to the Legendre symbol computation, we use the same steps,
but also maintain the expected Kronecker symbol in a variable j which
is initially 1, and is negated when approriate:

  • Step 1 exercises the law of quadratic reciprocity; j is negated if
    both a and b are equal to 3 modulo 4 at the time of the swap.

  • Step 2 does not change the Kronecker symbol; a critical observation
    here is that throughout the optimized binary GCD algorithm, it can
    never happen that a and b are both negative.

  • Step 3 negates j if and only if b == 3 or 5 mod 8 at that point.

These updates to j only need to look at the low bits of a and b (up to
three bits) and is thus largely compatible with the intermediate values
maintained by the optimized binary GCD in its inner loop. This implies
a relatively low overhead for the inner loop iterations. Combined with
the savings obtained by not keeping track of the Bezout coefficients,
we finally achieve the Legendre symbol computation in 43726 cycles, i.e.
even faster than inversions. This implementation is fully constant-time.

See https://github.com/bitcoin-core/secp256k1/blob/7e1bbef/doc/safegcd_implementation.md#8-from-gcds-to-jacobi-symbol

8. From GCDs to Jacobi symbol

We can also use a similar approach to calculate Jacobi symbol (x | M) by keeping track of an extra variable j, for which at every step (x | M) = j (g | f). As we update f and g, we make corresponding updates to j using properties of the Jacobi symbol. In particular, we update j whenever we divide g by 2 or swap f and g; these updates depend only on the values of f and g modulo 4 or 8, and can thus be applied very quickly. Overall, this calculation is slightly simpler than the one for modular inverse because we no longer need to keep track of d and e.

However, one difficulty of this approach is that the Jacobi symbol (a | n) is only defined for positive odd integers n, whereas in the original safegcd algorithm, f, g can take negative values. We resolve this by using the following modified steps:

        # Before
        if delta > 0 and g & 1:
            delta, f, g = 1 - delta, g, (g - f) // 2
        # After
        if delta > 0 and g & 1:
            delta, f, g = 1 - delta, g, (g + f) // 2

The algorithm is still correct, since the changed divstep, called a "posdivstep" (see section 8.4 and E.5 in the paper) preserves gcd(f, g). However, there's no proof that the modified algorithm will converge. The justification for posdivsteps is completely empirical: in practice, it appears that the vast majority of inputs converge to f=g=gcd(f0, g0) in a number of steps proportional to their logarithm.

Note that:

  • We require inputs to satisfy gcd(x, M) = 1.
  • We need to update the termination condition from g=0 to f=1.
  • We deal with the case where g=0 on input specially.

We account for the possibility of nonconvergence by only performing a bounded number of posdivsteps, and then falling back to square-root based Jacobi calculation if a solution has not yet been found.

The optimizations in sections 3-7 above are described in the context of the original divsteps, but in the C implementation we also adapt most of them (not including "avoiding modulus operations", since it's not necessary to track d, e, and "constant-time operation", since we never calculate Jacobi symbols for secret data) to the posdivsteps version.

Python: https://gist.github.com/robot-dreams/ceb00162b80384f2ae1913aaa2b35e75

N = 62

def count_trailing_zeros(x):
    assert x != 0
    ans = 0
    while x & 1 == 0:
        ans += 1
        x >>= 1
    return ans

def update_fg(f, g, t):
    u, v, q, r = t
    cf, cg = u * f + v * g, q * f + r * g
    assert cf % 2**N == 0
    assert cg % 2**N == 0
    return cf >> N, cg >> N

def divsteps_N_matrix(eta, f, g, jac):
    u, v, q, r = 1, 0, 0, 1
    i = N
    while True:
        z = min(i, count_trailing_zeros(g))
        eta, g, u, v = eta - z, g >> z, u << z, v << z
        i -= z
        if z & 1 and (f % 8 == 3 or f % 8 == 5):
            jac = -jac
        if i == 0:
            break
        assert (g & 1) == 1
        if eta < 0:
            eta, f, g, u, v, q, r = -eta, g, f, q, r, u, v
            if f % 4 == 3 and g % 4 == 3:
                jac = -jac
            limit = min(i, eta + 1, 6)
            assert limit > 0 and limit <= N
            m = (1 << limit) - 1
            w = (f * g * (f * f - 2)) & m
        else:
            limit = min(i, eta + 1, 4)
            assert limit > 0 and limit <= N
            m = (1 << limit) - 1
            w = f + (((f + 1) & 4) << 1)
            w = (-w * g) & m
        g, q, r = g + f * w, q + u * w, r + v * w
        assert g % (2**limit) == 0
    return eta, (u, v, q, r), jac

def fastjac_divsteps(x, M):
    assert x > 0 and M & 1
    jac = 1
    eta, f, g = -1, M, x
    while f != 1:
        assert f & 1
        eta, t, jac = divsteps_N_matrix(eta, f, g, jac)
        f, g = update_fg(f, g, t)
    return jac

def slowjac(a, n):
    jac = 1
    while True:
        assert n & 1 == 1
        if a == 1:
            return jac
        if a == n - 1:
            if n & 3 == 3:
                jac = -jac
            return jac
        while a & 1 == 0:
            if n & 7 in (3, 5):
                jac = -jac
            a >>= 1
        if a & 3 == 3 and n & 3 == 3:
            jac = -jac
        a, n = n % a, a
    return jac

if __name__ == '__main__':
    from math import gcd
    from random import randint
    for i in range(10000):
        while True:
            a = randint(1, 10**6)
            M = 2 * randint(1, 10**6) + 1
            if M > 1 and gcd(a, M) == 1:
                break
        assert fastjac_divsteps(a, M) == slowjac(a, M)

This will significantly accelerate:

  • BLS12-377 Fp2 deserialization
  • BN254_Snarks hash-to-curve (via SVDW)
    # TODO: faster Legendre symbol.
    # We can optimize the 2 legendre symbols + 3 sqrt to
    # - either 2 legendre and 1 sqrt
    # - or 3 fused legendre+sqrt
    let e1 = gx1.isSquare()
    let e2 = gx2.isSquare() and not e1

Paper: https://eprint.iacr.org/2021/1271.pdf

Reference code:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant