diff --git a/benchmarks/bench_fields_template.nim b/benchmarks/bench_fields_template.nim index c8216227b..ef1690942 100644 --- a/benchmarks/bench_fields_template.nim +++ b/benchmarks/bench_fields_template.nim @@ -91,6 +91,20 @@ proc sqrBench*(T: typedesc, iters: int) = bench("Squaring", T, iters): r.square(x) +proc toBigBench*(T: typedesc, iters: int) = + var r: matchingBigInt(T.C) + let x = rng.random_unsafe(T) + preventOptimAway(r) + bench("BigInt <- field conversion", T, iters): + r.fromField(x) + +proc toFieldBench*(T: typedesc, iters: int) = + var r: T + let x = rng.random_unsafe(matchingBigInt(T.C)) + preventOptimAway(r) + bench("BigInt -> field conversion", T, iters): + r.fromBig(x) + proc invBench*(T: typedesc, iters: int) = var r: T let x = rng.random_unsafe(T) diff --git a/benchmarks/bench_fp.nim b/benchmarks/bench_fp.nim index 6d251b85f..9ca81ce9e 100644 --- a/benchmarks/bench_fp.nim +++ b/benchmarks/bench_fp.nim @@ -50,6 +50,9 @@ proc main() = mulBench(Fp[curve], Iters) sqrBench(Fp[curve], Iters) smallSeparator() + toBigBench(Fp[curve], Iters) + toFieldBench(Fp[curve], Iters) + smallSeparator() invBench(Fp[curve], ExponentIters) sqrtBench(Fp[curve], ExponentIters) sqrtRatioBench(Fp[curve], ExponentIters) diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim index f057c093e..9fe52a50c 100644 --- a/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_modular_x86.nim @@ -24,14 +24,111 @@ import # They are nice to let the compiler deals with mov # but too constraining so we move things ourselves. -static: doAssert UseASM_X86_64 +static: doAssert UseASM_X86_32 {.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers +proc finalSubNoCarryImpl*( + ctx: var Assembler_x86, + r: Operand or OperandArray, + a, M, scratch: OperandArray + ) = + ## Reduce `a` into `r` modulo `M` + ## + ## r, a, scratch, scratchReg are mutated + ## M is read-only + let N = M.len + ctx.comment "Final substraction (no carry)" + for i in 0 ..< N: + ctx.mov scratch[i], a[i] + if i == 0: + ctx.sub scratch[i], M[i] + else: + ctx.sbb scratch[i], M[i] + + # If we borrowed it means that we were smaller than + # the modulus and we don't need "scratch" + for i in 0 ..< N: + ctx.cmovnc a[i], scratch[i] + ctx.mov r[i], a[i] + +proc finalSubMayCarryImpl*( + ctx: var Assembler_x86, + r: Operand or OperandArray, + a, M, scratch: OperandArray, + scratchReg: Operand or Register or OperandReuse + ) = + ## Reduce `a` into `r` modulo `M` + ## To be used when the final substraction can + ## also depend on the carry flag + ## + ## r, a, scratch, scratchReg are mutated + ## M is read-only + + ctx.comment "Final substraction (may carry)" + + # Mask: scratchReg contains 0xFFFF or 0x0000 + ctx.sbb scratchReg, scratchReg + + # Now substract the modulus to test a < p + let N = M.len + for i in 0 ..< N: + ctx.mov scratch[i], a[i] + if i == 0: + ctx.sub scratch[i], M[i] + else: + ctx.sbb scratch[i], M[i] + + # If it overflows here, it means that it was + # smaller than the modulus and we don't need `scratch` + ctx.sbb scratchReg, 0 + + # If we borrowed it means that we were smaller than + # the modulus and we don't need "scratch" + for i in 0 ..< N: + ctx.cmovnc a[i], scratch[i] + ctx.mov r[i], a[i] + +macro finalSub_gen*[N: static int]( + r_PIR: var array[N, SecretWord], + a_EIR, M_PIR: array[N, SecretWord], + scratch_EIR: var array[N, SecretWord], + mayCarry: static bool): untyped = + ## Returns: + ## a-M if a > M + ## a otherwise + ## + ## - r_PIR is a pointer to the result array, mutated, + ## - a_EIR is an array of registers, mutated, + ## - M_PIR is a pointer to an array, read-only, + ## - scratch_EIR is an array of registers, mutated + ## - mayCarry is set to true when the carry flag also needs to be read + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + let + r = init(OperandArray, nimSymbol = r_PIR, N, PointerInReg, InputOutput) + # We reuse the reg used for b for overflow detection + a = init(OperandArray, nimSymbol = a_EIR, N, ElemsInReg, InputOutput) + # We could force m as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) + t = init(OperandArray, nimSymbol = scratch_EIR, N, ElemsInReg, Output_EarlyClobber) + + if mayCarry: + ctx.finalSubMayCarryImpl( + r, a, M, t, rax + ) + else: + ctx.finalSubNoCarryImpl( + r, a, M, t + ) + + result.add ctx.generate() + # Field addition # ------------------------------------------------------------ -macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped = +macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N], hasSpareBit: static bool): untyped = ## Generate an optimized modular addition kernel # Register pressure note: # We could generate a kernel per modulus m by hardcoding it as immediate @@ -68,33 +165,18 @@ macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped = # Interleaved copy in a second buffer as well ctx.mov v[i], u[i] - # Mask: overflowed contains 0xFFFF or 0x0000 - # TODO: unnecessary if MSB never set, i.e. "Field.getSpareBits >= 1" - let overflowed = b.reuseRegister() - ctx.sbb overflowed, overflowed + if hasSparebit: + ctx.finalSubNoCarryImpl(r, u, M, v) + else: + ctx.finalSubMayCarryImpl( + r, u, M, v, b.reuseRegister() + ) - # Now substract the modulus to test a < p - for i in 0 ..< N: - if i == 0: - ctx.sub v[0], M[0] - else: - ctx.sbb v[i], M[i] - - # If it overflows here, it means that it was - # smaller than the modulus and we don't need V - ctx.sbb overflowed, 0 - - # Conditional Mov and - # and store result - for i in 0 ..< N: - ctx.cmovnc u[i], v[i] - ctx.mov r[i], u[i] - - result.add ctx.generate + result.add ctx.generate() -func addmod_asm*(r: var Limbs, a, b, m: Limbs) = +func addmod_asm*(r: var Limbs, a, b, m: Limbs, hasSpareBit: static bool) = ## Constant-time modular addition - addmod_gen(r, a, b, m) + addmod_gen(r, a, b, m, hasSpareBit) # Field substraction # ------------------------------------------------------------ @@ -225,7 +307,7 @@ when isMainModule: debugecho " a: ", a.toHex() debugecho " b: ", b.toHex() debugecho " m: ", m.toHex() - addmod_asm(a, a, b, m) + addmod_asm(a, a, b, m, hasSpareBit = false) debugecho "after:" debugecho " a: ", a.toHex().tolower debugecho " s: ", s @@ -245,7 +327,7 @@ when isMainModule: debugecho " a: ", a.toHex() debugecho " b: ", b.toHex() debugecho " m: ", m.toHex() - addmod_asm(a, a, b, m) + addmod_asm(a, a, b, m, hasSpareBit = false) debugecho "after:" debugecho " a: ", a.toHex().tolower debugecho " s: ", s @@ -265,7 +347,7 @@ when isMainModule: debugecho " a: ", a.toHex() debugecho " b: ", b.toHex() debugecho " m: ", m.toHex() - addmod_asm(a, a, b, m) + addmod_asm(a, a, b, m, hasSpareBit = false) debugecho "after:" debugecho " a: ", a.toHex().tolower debugecho " s: ", s @@ -285,7 +367,7 @@ when isMainModule: debugecho " a: ", a.toHex() debugecho " b: ", b.toHex() debugecho " m: ", m.toHex() - addmod_asm(a, a, b, m) + addmod_asm(a, a, b, m, hasSpareBit = false) debugecho "after:" debugecho " a: ", a.toHex().tolower debugecho " s: ", s @@ -308,7 +390,7 @@ when isMainModule: debugecho " a: ", a.toHex() debugecho " b: ", b.toHex() debugecho " m: ", m.toHex() - submod_asm(a, a, b, m) + submod_asm(a, a, b, m, hasSpareBit = false) debugecho "after:" debugecho " a: ", a.toHex().tolower debugecho " s: ", s @@ -333,7 +415,7 @@ when isMainModule: debugecho " a: ", a.toHex() debugecho " b: ", b.toHex() debugecho " m: ", m.toHex() - submod_asm(r, a, b, m) + submod_asm(r, a, b, m, hasSpareBit = false) debugecho "after:" debugecho " r: ", r.toHex().tolower debugecho " s: ", s diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim deleted file mode 100644 index eb2cb18c5..000000000 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim +++ /dev/null @@ -1,160 +0,0 @@ -# Constantine -# Copyright (c) 2018-2019 Status Research & Development GmbH -# Copyright (c) 2020-Present Mamy André-Ratsimbazafy -# Licensed and distributed under either of -# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -import - # Standard library - std/macros, - # Internal - ../../config/common, - ../../primitives, - ./limbs_asm_montred_x86 - -# ############################################################ -# -# Assembly implementation of finite fields -# -# ############################################################ - -static: doAssert UseASM_X86_64 - -# MULX/ADCX/ADOX -{.localPassC:"-madx -mbmi2".} -# Necessary for the compiler to find enough registers (enabled at -O1) -{.localPassC:"-fomit-frame-pointer".} - -# No exceptions allowed -{.push raises: [].} - -# Montgomery reduction -# ------------------------------------------------------------ - -macro montyRedc2x_adx_gen*[N: static int]( - r_MR: var array[N, SecretWord], - a_MR: array[N*2, SecretWord], - M_MR: array[N, SecretWord], - m0ninv_MR: BaseType, - hasSpareBit: static bool - ) = - result = newStmtList() - - var ctx = init(Assembler_x86, BaseType) - let - # We could force M as immediate by specializing per moduli - M = init(OperandArray, nimSymbol = M_MR, N, PointerInReg, Input) - - let uSlots = N+1 - let vSlots = max(N-1, 5) - - var # Scratchspaces - u = init(OperandArray, nimSymbol = ident"U", uSlots, ElemsInReg, InputOutput_EnsureClobber) - v = init(OperandArray, nimSymbol = ident"V", vSlots, ElemsInReg, InputOutput_EnsureClobber) - - # Prologue - let usym = u.nimSymbol - let vsym = v.nimSymbol - result.add quote do: - static: doAssert: sizeof(SecretWord) == sizeof(ByteAddress) - var `usym`{.noinit.}: Limbs[`uSlots`] - var `vsym` {.noInit.}: Limbs[`vSlots`] - `vsym`[0] = cast[SecretWord](`r_MR`[0].unsafeAddr) - `vsym`[1] = cast[SecretWord](`a_MR`[0].unsafeAddr) - `vsym`[2] = SecretWord(`m0ninv_MR`) - - let r_temp = v[0].asArrayAddr(len = N) - let a = v[1].asArrayAddr(len = 2*N) - let m0ninv = v[2] - let lo = v[3] - let hi = v[4] - - # Algorithm - # --------------------------------------------------------- - # for i in 0 .. n-1: - # hi <- 0 - # m <- a[i] * m0ninv mod 2^w (i.e. simple multiplication) - # for j in 0 .. n-1: - # (hi, lo) <- a[i+j] + m * M[j] + hi - # a[i+j] <- lo - # a[i+n] += hi - # for i in 0 .. n-1: - # r[i] = a[i+n] - # if r >= M: - # r -= M - - # No register spilling handling - doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs." - - ctx.mov rdx, m0ninv - - for i in 0 ..< N: - ctx.mov u[i], a[i] - - for i in 0 ..< N: - # RDX contains m0ninv at the start of each loop - ctx.comment "" - ctx.imul rdx, u[0] # m <- a[i] * m0ninv mod 2^w - ctx.comment "---- Reduction " & $i - ctx.`xor` u[N], u[N] - - for j in 0 ..< N-1: - ctx.comment "" - ctx.mulx hi, lo, M[j], rdx - ctx.adcx u[j], lo - ctx.adox u[j+1], hi - - # Last limb - ctx.comment "" - ctx.mulx hi, lo, M[N-1], rdx - ctx.mov rdx, m0ninv # Reload m0ninv for next iter - ctx.adcx u[N-1], lo - ctx.adox hi, u[N] - ctx.adcx u[N], hi - - u.rotateLeft() - - ctx.mov rdx, r_temp - let r = rdx.asArrayAddr(len = N) - - # This does a[i+n] += hi - # but in a separate carry chain, fused with the - # copy "r[i] = a[i+n]" - for i in 0 ..< N: - if i == 0: - ctx.add u[i], a[i+N] - else: - ctx.adc u[i], a[i+N] - - let t = repackRegisters(v, u[N]) - - if hasSpareBit: - ctx.finalSubNoCarry(r, u, M, t) - else: - ctx.finalSubCanOverflow(r, u, M, t, hi) - - # Code generation - result.add ctx.generate() - -func montRed_asm_adx_bmi2_impl*[N: static int]( - r: var array[N, SecretWord], - a: array[N*2, SecretWord], - M: array[N, SecretWord], - m0ninv: BaseType, - hasSpareBit: static bool - ) = - ## Constant-time Montgomery reduction - ## Inline-version - montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit) - -func montRed_asm_adx_bmi2*[N: static int]( - r: var array[N, SecretWord], - a: array[N*2, SecretWord], - M: array[N, SecretWord], - m0ninv: BaseType, - hasSpareBit: static bool - ) = - ## Constant-time Montgomery reduction - montRed_asm_adx_bmi2_impl(r, a, M, m0ninv, hasSpareBit) diff --git a/constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim b/constantine/arithmetic/assembly/limbs_asm_mul_mont_x86.nim similarity index 75% rename from constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim rename to constantine/arithmetic/assembly/limbs_asm_mul_mont_x86.nim index 6c9e3f205..18db83759 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montmul_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_mul_mont_x86.nim @@ -12,7 +12,8 @@ import # Internal ../../config/common, ../../primitives, - ./limbs_asm_montred_x86, + ./limbs_asm_modular_x86, + ./limbs_asm_redc_mont_x86, ./limbs_asm_mul_x86 # ############################################################ @@ -34,7 +35,11 @@ static: doAssert UseASM_X86_64 # Montgomery multiplication # ------------------------------------------------------------ # Fallback when no ADX and BMI2 support (MULX, ADCX, ADOX) -macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = +macro mulMont_CIOS_sparebit_gen[N: static int]( + r_PIR: var Limbs[N], a_PIR, b_PIR, + M_PIR: Limbs[N], m0ninv_REG: BaseType, + skipReduction: static bool + ): untyped = ## Generate an optimized Montgomery Multiplication kernel ## using the CIOS method ## @@ -44,14 +49,17 @@ macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M ## M[^1] < high(SecretWord) shr 2 (i.e. less than 0b00111...1111) ## https://hackmd.io/@zkteam/modular_multiplication + # No register spilling handling + doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs." + result = newStmtList() var ctx = init(Assembler_x86, BaseType) let - scratchSlots = max(N, 6) + scratchSlots = 6 # We could force M as immediate by specializing per moduli - M = init(OperandArray, nimSymbol = M_MM, N, PointerInReg, Input) + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) # If N is too big, we need to spill registers. TODO. t = init(OperandArray, nimSymbol = ident"t", N, ElemsInReg, Output_EarlyClobber) # MultiPurpose Register slots @@ -62,10 +70,10 @@ macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M m0ninv = Operand( desc: OperandDesc( asmId: "[m0ninv]", - nimSymbol: m0ninv_MM, + nimSymbol: m0ninv_REG, rm: MemOffsettable, constraint: Input, - cEmit: "&" & $m0ninv_MM + cEmit: "&" & $m0ninv_REG ) ) @@ -76,7 +84,7 @@ macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M b = scratch[1].asArrayAddr(len = N) # Store the `b` operand A = scratch[2] # High part of extended precision multiplication C = scratch[3] - m = scratch[4] # Stores (t[0] * m0ninv) mod 2^w + m = scratch[4] # Stores (t[0] * m0ninv) mod 2ʷ r = scratch[5] # Stores the `r` operand # Registers used: @@ -94,18 +102,18 @@ macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M result.add quote do: static: doAssert: sizeof(SecretWord) == sizeof(ByteAddress) - var `tsym`: typeof(`r_MM`) # zero init + var `tsym`: typeof(`r_PIR`) # zero init # Assumes 64-bit limbs on 64-bit arch (or you can't store an address) var `scratchSym` {.noInit.}: Limbs[`scratchSlots`] - `scratchSym`[0] = cast[SecretWord](`a_MM`[0].unsafeAddr) - `scratchSym`[1] = cast[SecretWord](`b_MM`[0].unsafeAddr) - `scratchSym`[5] = cast[SecretWord](`r_MM`[0].unsafeAddr) + `scratchSym`[0] = cast[SecretWord](`a_PIR`[0].unsafeAddr) + `scratchSym`[1] = cast[SecretWord](`b_PIR`[0].unsafeAddr) + `scratchSym`[5] = cast[SecretWord](`r_PIR`[0].unsafeAddr) # Algorithm # ----------------------------------------- # for i=0 to N-1 # (A, t[0]) <- a[0] * b[i] + t[0] - # m <- (t[0] * m0ninv) mod 2^w + # m <- (t[0] * m0ninv) mod 2ʷ # (C, _) <- m * M[0] + t[0] # for j=1 to N-1 # (A, t[j]) <- a[j] * b[i] + A + t[j] @@ -127,7 +135,7 @@ macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M ctx.adc rdx, 0 ctx.mov A, rdx - # m <- (t[0] * m0ninv) mod 2^w + # m <- (t[0] * m0ninv) mod 2ʷ ctx.mov m, m0ninv ctx.imul m, t[0] @@ -164,19 +172,27 @@ macro montMul_CIOS_sparebit_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M ctx.add A, C ctx.mov t[N-1], A - ctx.mov rdx, r - let r2 = rdx.asArrayAddr(len = N) - - ctx.finalSubNoCarry( - r2, t, M, - scratch - ) - - result.add ctx.generate - -func montMul_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = - ## Constant-time modular multiplication - montMul_CIOS_sparebit_gen(r, a, b, M, m0ninv) + ctx.mov rax, r # move r away from scratchspace that will be used for final substraction + let r2 = rax.asArrayAddr(len = N) + + if skipReduction: + for i in 0 ..< N: + ctx.mov r2[i], t[i] + else: + ctx.finalSubNoCarryImpl( + r2, t, M, + scratch + ) + result.add ctx.generate() + +func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = + ## Constant-time Montgomery multiplication + ## If "skipReduction" is set + ## the result is in the range [0, 2M) + ## otherwise the result is in the range [0, M) + ## + ## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation + r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipReduction) # Montgomery Squaring # ------------------------------------------------------------ @@ -189,25 +205,12 @@ func square_asm_inline[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen ## but not for stack variables sqr_gen(r, a) -func montRed_asm_inline[N: static int]( - r: var array[N, SecretWord], - a: array[N*2, SecretWord], - M: array[N, SecretWord], - m0ninv: BaseType, - hasSpareBit: static bool - ) {.inline.} = - ## Constant-time Montgomery reduction - ## Extra indirection as the generator assumes that - ## arrays are pointers, which is true for parameters - ## but not for stack variables - montyRedc2x_gen(r, a, M, m0ninv, hasSpareBit) - -func montSquare_CIOS_asm*[N]( +func squareMont_CIOS_asm*[N]( r: var Limbs[N], a, M: Limbs[N], m0ninv: BaseType, - hasSpareBit: static bool) = + hasSpareBit, skipReduction: static bool) = ## Constant-time modular squaring var r2x {.noInit.}: Limbs[2*N] r2x.square_asm_inline(a) - r.montRed_asm_inline(r2x, M, m0ninv, hasSpareBit) + r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipReduction) diff --git a/constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_mul_mont_x86_adx_bmi2.nim similarity index 78% rename from constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim rename to constantine/arithmetic/assembly/limbs_asm_mul_mont_x86_adx_bmi2.nim index 6229cd733..d46c06415 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montmul_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_mul_mont_x86_adx_bmi2.nim @@ -8,12 +8,12 @@ import # Standard library - std/macros, + std/[macros, algorithm], # Internal ../../config/common, ../../primitives, - ./limbs_asm_montred_x86, - ./limbs_asm_montred_x86_adx_bmi2, + ./limbs_asm_modular_x86, + ./limbs_asm_redc_mont_x86_adx_bmi2, ./limbs_asm_mul_x86_adx_bmi2 # ############################################################ @@ -36,6 +36,7 @@ static: doAssert UseASM_X86_64 # Montgomery Multiplication # ------------------------------------------------------------ + proc mulx_by_word( ctx: var Assembler_x86, hi: Operand, @@ -149,7 +150,7 @@ proc partialRedx( ctx.mov rdx, t[0] ctx.imul rdx, m0ninv - # Clear carry flags - TODO: necessary? + # Clear carry flags ctx.`xor` S, S # S,_ := t[0] + m*M[0] @@ -158,6 +159,8 @@ proc partialRedx( ctx.adcx lo, t[0] # set the carry flag for the future ADCX ctx.mov t[0], S + ctx.mov lo, 0 + # for j=1 to N-1 # (S,t[j-1]) := t[j] + m*M[j] + S ctx.comment " for j=1 to N-1" @@ -170,26 +173,31 @@ proc partialRedx( # Last carries # t[N-1] = S + C ctx.comment " Reduction carry " - ctx.mov S, 0 - ctx.adcx t[N-1], S - ctx.adox t[N-1], C + ctx.adcx lo, C # lo contains 0 so C += S + ctx.adox t[N-1], lo -macro montMul_CIOS_sparebit_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM, b_MM, M_MM: Limbs[N], m0ninv_MM: BaseType): untyped = +macro mulMont_CIOS_sparebit_adx_gen[N: static int]( + r_PIR: var Limbs[N], a_PIR, b_PIR, + M_PIR: Limbs[N], m0ninv_REG: BaseType, + skipReduction: static bool): untyped = ## Generate an optimized Montgomery Multiplication kernel ## using the CIOS method ## This requires the most significant word of the Modulus ## M[^1] < high(SecretWord) shr 2 (i.e. less than 0b00111...1111) ## https://hackmd.io/@zkteam/modular_multiplication + # No register spilling handling + doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs." + result = newStmtList() var ctx = init(Assembler_x86, BaseType) let - scratchSlots = max(N, 6) + scratchSlots = 6 - r = init(OperandArray, nimSymbol = r_MM, N, PointerInReg, InputOutput_EnsureClobber) + r = init(OperandArray, nimSymbol = r_PIR, N, PointerInReg, InputOutput_EnsureClobber) # We could force M as immediate by specializing per moduli - M = init(OperandArray, nimSymbol = M_MM, N, PointerInReg, Input) + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) # If N is too big, we need to spill registers. TODO. t = init(OperandArray, nimSymbol = ident"t", N, ElemsInReg, Output_EarlyClobber) # MultiPurpose Register slots @@ -220,12 +228,12 @@ macro montMul_CIOS_sparebit_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM result.add quote do: static: doAssert: sizeof(SecretWord) == sizeof(ByteAddress) - var `tsym`: typeof(`r_MM`) # zero init + var `tsym`{.noInit.}: typeof(`r_PIR`) # zero init # Assumes 64-bit limbs on 64-bit arch (or you can't store an address) var `scratchSym` {.noInit.}: Limbs[`scratchSlots`] - `scratchSym`[0] = cast[SecretWord](`a_MM`[0].unsafeAddr) - `scratchSym`[1] = cast[SecretWord](`b_MM`[0].unsafeAddr) - `scratchSym`[4] = SecretWord `m0ninv_MM` + `scratchSym`[0] = cast[SecretWord](`a_PIR`[0].unsafeAddr) + `scratchSym`[1] = cast[SecretWord](`b_PIR`[0].unsafeAddr) + `scratchSym`[4] = SecretWord `m0ninv_REG` # Algorithm # ----------------------------------------- @@ -238,9 +246,6 @@ macro montMul_CIOS_sparebit_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM # (C,t[j-1]) := t[j] + m*M[j] + C # t[N-1] = C + A - # No register spilling handling - doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs." - for i in 0 ..< N: if i == 0: ctx.mulx_by_word( @@ -263,47 +268,35 @@ macro montMul_CIOS_sparebit_adx_bmi2_gen[N: static int](r_MM: var Limbs[N], a_MM lo, C ) - ctx.finalSubNoCarry( - r, t, M, - scratch - ) + if skipReduction: + for i in 0 ..< N: + ctx.mov r[i], t[i] + else: + ctx.finalSubNoCarryImpl( + r, t, M, + scratch + ) result.add ctx.generate -func montMul_CIOS_sparebit_asm_adx_bmi2*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = - ## Constant-time modular multiplication - ## Requires the prime modulus to have a spare bit in the representation. (Hence if using 64-bit words and 4 words, to be at most 255-bit) - montMul_CIOS_sparebit_adx_bmi2_gen(r, a, b, M, m0ninv) +func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = + ## Constant-time Montgomery multiplication + ## If "skipReduction" is set + ## the result is in the range [0, 2M) + ## otherwise the result is in the range [0, M) + ## + ## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation + r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipReduction) # Montgomery Squaring # ------------------------------------------------------------ -func square_asm_adx_bmi2_inline[rLen, aLen: static int](r: var Limbs[rLen], a: Limbs[aLen]) {.inline.} = - ## Multi-precision Squaring - ## Extra indirection as the generator assumes that - ## arrays are pointers, which is true for parameters - ## but not for stack variables. - sqrx_gen(r, a) - -func montRed_asm_adx_bmi2_inline[N: static int]( - r: var array[N, SecretWord], - a: array[N*2, SecretWord], - M: array[N, SecretWord], - m0ninv: BaseType, - hasSpareBit: static bool - ) {.inline.} = - ## Constant-time Montgomery reduction - ## Extra indirection as the generator assumes that - ## arrays are pointers, which is true for parameters - ## but not for stack variables. - montyRedc2x_adx_gen(r, a, M, m0ninv, hasSpareBit) - -func montSquare_CIOS_asm_adx_bmi2*[N]( +func squareMont_CIOS_asm_adx*[N]( r: var Limbs[N], a, M: Limbs[N], m0ninv: BaseType, - hasSpareBit: static bool) = + hasSpareBit, skipReduction: static bool) = ## Constant-time modular squaring var r2x {.noInit.}: Limbs[2*N] - r2x.square_asm_adx_bmi2_inline(a) - r.montRed_asm_adx_bmi2_inline(r2x, M, m0ninv, hasSpareBit) + r2x.square_asm_adx_inline(a) + r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipReduction) diff --git a/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim index 844a38c33..607768488 100644 --- a/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_mul_x86_adx_bmi2.nim @@ -111,7 +111,7 @@ proc mulaccx_by_word( ctx.adcx hi, rdx ctx.adox hi, rdx -macro mulx_gen[rLen, aLen, bLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen], bx: Limbs[bLen]) = +macro mulx_gen[rLen, aLen, bLen: static int](r_PIR: var Limbs[rLen], a_PIR: Limbs[aLen], b_PIR: Limbs[bLen]) = ## `a`, `b`, `r` can have a different number of limbs ## if `r`.limbs.len < a.limbs.len + b.limbs.len ## The result will be truncated, i.e. it will be @@ -123,9 +123,9 @@ macro mulx_gen[rLen, aLen, bLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen var ctx = init(Assembler_x86, BaseType) let - r = init(OperandArray, nimSymbol = rx, rLen, PointerInReg, InputOutput_EnsureClobber) - a = init(OperandArray, nimSymbol = ax, aLen, PointerInReg, Input) - b = init(OperandArray, nimSymbol = bx, bLen, PointerInReg, Input) + r = init(OperandArray, nimSymbol = r_PIR, rLen, PointerInReg, InputOutput_EnsureClobber) + a = init(OperandArray, nimSymbol = a_PIR, aLen, PointerInReg, Input) + b = init(OperandArray, nimSymbol = b_PIR, bLen, PointerInReg, Input) # MULX requires RDX @@ -168,24 +168,24 @@ macro mulx_gen[rLen, aLen, bLen: static int](rx: var Limbs[rLen], ax: Limbs[aLen # Codegen result.add ctx.generate -func mul_asm_adx_bmi2_impl*[rLen, aLen, bLen: static int]( +func mul_asm_adx_inline*[rLen, aLen, bLen: static int]( r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) {.inline.} = ## Multi-precision Multiplication ## Assumes r doesn't alias a or b ## Inline version mulx_gen(r, a, b) -func mul_asm_adx_bmi2*[rLen, aLen, bLen: static int]( +func mul_asm_adx*[rLen, aLen, bLen: static int]( r: var Limbs[rLen], a: Limbs[aLen], b: Limbs[bLen]) = ## Multi-precision Multiplication ## Assumes r doesn't alias a or b - mul_asm_adx_bmi2_impl(r, a, b) + mul_asm_adx_inline(r, a, b) # Squaring # ----------------------------------------------------------------------------------------------- # # Strategy: -# We want to use the same scheduling as mul_asm_adx_bmi2 +# We want to use the same scheduling as mul_asm_adx # and so process `a[0.. 2, "The Assembly-optimized montgomery reduction requires a minimum of 2 limbs." + doAssert N <= 6, "The Assembly-optimized montgomery reduction requires at most 6 limbs." + result = newStmtList() var ctx = init(Assembler_x86, BaseType) @@ -103,7 +49,7 @@ macro montyRedc2x_gen*[N: static int]( # so we store everything in scratchspaces restoring as needed let # We could force M as immediate by specializing per moduli - M = init(OperandArray, nimSymbol = M_MR, N, PointerInReg, Input) + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) # MUL requires RAX and RDX let uSlots = N+2 @@ -119,9 +65,9 @@ macro montyRedc2x_gen*[N: static int]( result.add quote do: var `usym`{.noinit.}: Limbs[`uSlots`] var `vsym` {.noInit.}: Limbs[`vSlots`] - `vsym`[0] = cast[SecretWord](`r_MR`[0].unsafeAddr) - `vsym`[1] = cast[SecretWord](`a_MR`[0].unsafeAddr) - `vsym`[2] = SecretWord(`m0ninv_MR`) + `vsym`[0] = cast[SecretWord](`r_PIR`[0].unsafeAddr) + `vsym`[1] = cast[SecretWord](`a_PIR`[0].unsafeAddr) + `vsym`[2] = SecretWord(`m0ninv_REG`) let r_temp = v[0].asArrayAddr(len = N) let a = v[1].asArrayAddr(len = 2*N) @@ -131,7 +77,7 @@ macro montyRedc2x_gen*[N: static int]( # --------------------------------------------------------- # for i in 0 .. n-1: # hi <- 0 - # m <- a[i] * m0ninv mod 2^w (i.e. simple multiplication) + # m <- a[i] * m0ninv mod 2ʷ (i.e. simple multiplication) # for j in 0 .. n-1: # (hi, lo) <- a[i+j] + m * M[j] + hi # a[i+j] <- lo @@ -141,15 +87,11 @@ macro montyRedc2x_gen*[N: static int]( # if r >= M: # r -= M - # No register spilling handling - doAssert N > 2, "The Assembly-optimized montgomery reduction requires a minimum of 2 limbs." - doAssert N <= 6, "The Assembly-optimized montgomery reduction requires at most 6 limbs." - for i in 0 ..< N: ctx.mov u[i], a[i] ctx.mov u[N], u[0] - ctx.imul u[0], m0ninv # m <- a[i] * m0ninv mod 2^w + ctx.imul u[0], m0ninv # m <- a[i] * m0ninv mod 2ʷ ctx.mov rax, u[0] # scratch: [a[0] * m0, a[1], a[2], a[3], a[0]] for 4 limbs @@ -208,27 +150,138 @@ macro montyRedc2x_gen*[N: static int]( else: ctx.adc u[i], a[i+N] + # v is invalidated from now on let t = repackRegisters(v, u[N], u[N+1]) - - # v is invalidated - if hasSpareBit: - ctx.finalSubNoCarry(r, u, M, t) + + if hasSpareBit and skipReduction: + for i in 0 ..< N: + ctx.mov r[i], t[i] + elif hasSpareBit: + ctx.finalSubNoCarryImpl(r, u, M, t) else: - ctx.finalSubCanOverflow(r, u, M, t, rax) + ctx.finalSubMayCarryImpl(r, u, M, t, rax) # Code generation result.add ctx.generate() -func montRed_asm*[N: static int]( +func redcMont_asm_inline*[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType, - hasSpareBit: static bool + hasSpareBit: static bool, + skipReduction: static bool = false + ) {.inline.} = + ## Constant-time Montgomery reduction + ## Inline-version + redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipReduction) + +func redcMont_asm*[N: static int]( + r: var array[N, SecretWord], + a: array[N*2, SecretWord], + M: array[N, SecretWord], + m0ninv: BaseType, + hasSpareBit, skipReduction: static bool ) = ## Constant-time Montgomery reduction static: doAssert UseASM_X86_64, "This requires x86-64." - montyRedc2x_gen(r, a, M, m0ninv, hasSpareBit) + redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipReduction) + +# Montgomery conversion +# ---------------------------------------------------------- + +macro mulMont_by_1_gen[N: static int]( + t_EIR: var array[N, SecretWord], + M_PIR: array[N, SecretWord], + m0ninv_REG: BaseType) = + + # No register spilling handling + doAssert N <= 6, "The Assembly-optimized montgomery reduction requires at most 6 limbs." + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + # On x86, compilers only let us use 15 out of 16 registers + # RAX and RDX are defacto used due to the MUL instructions + # so we store everything in scratchspaces restoring as needed + let + scratchSlots = 2 + + t = init(OperandArray, nimSymbol = t_EIR, N, ElemsInReg, InputOutput_EnsureClobber) + # We could force M as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) + # MultiPurpose Register slots + scratch = init(OperandArray, nimSymbol = ident"scratch", scratchSlots, ElemsInReg, InputOutput_EnsureClobber) + + # MUL requires RAX and RDX + + m0ninv = Operand( + desc: OperandDesc( + asmId: "[m0ninv]", + nimSymbol: m0ninv_REG, + rm: MemOffsettable, + constraint: Input, + cEmit: "&" & $m0ninv_REG + ) + ) + + C = scratch[0] # Stores the high-part of muliplication + m = scratch[1] # Stores (t[0] * m0ninv) mod 2ʷ + + let scratchSym = scratch.nimSymbol + + # Copy a in t + result.add quote do: + var `scratchSym` {.noInit.}: Limbs[`scratchSlots`] + + # Algorithm + # --------------------------------------------------------- + # for i in 0 .. n-1: + # m <- t[0] * m0ninv mod 2ʷ (i.e. simple multiplication) + # C, _ = t[0] + m * M[0] + # for j in 1 .. n-1: + # (C, t[j-1]) <- r[j] + m*M[j] + C + # t[n-1] = C + + ctx.comment "for i in 0 ..< N:" + for i in 0 ..< N: + ctx.comment " m <- t[0] * m0ninv mod 2ʷ" + ctx.mov m, m0ninv + ctx.imul m, t[0] + + ctx.comment " C, _ = t[0] + m * M[0]" + ctx.`xor` C, C + ctx.mov rax, M[0] + ctx.mul rdx, rax, m, rax + ctx.add rax, t[0] + ctx.adc C, rdx + + ctx.comment " for j in 1 .. n-1:" + for j in 1 ..< N: + ctx.comment " (C, t[j-1]) <- r[j] + m*M[j] + C" + ctx.mov rax, M[j] + ctx.mul rdx, rax, m, rax + ctx.add C, t[j] + ctx.adc rdx, 0 + ctx.add C, rax + ctx.adc rdx, 0 + ctx.mov t[j-1], C + ctx.mov C, rdx + + ctx.comment " final carry" + ctx.mov t[N-1], C + + result.add ctx.generate() + +func fromMont_asm*(r: var Limbs, a, M: Limbs, m0ninv: BaseType) = + ## Constant-time Montgomery residue form to BigInt conversion + var t{.noInit.} = a + block: + t.mulMont_by_1_gen(M, m0ninv) + + block: # Map from [0, 2p) to [0, p) + var workspace{.noInit.}: typeof(r) + r.finalSub_gen(t, M, workspace, mayCarry = false) # Sanity checks # ---------------------------------------------------------- @@ -242,7 +295,7 @@ when isMainModule: # TODO: Properly handle low number of limbs - func montyRedc2x_Comba[N: static int]( + func redc2xMont_Comba[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], @@ -298,10 +351,10 @@ when isMainModule: var a_sqr{.noInit.}, na_sqr{.noInit.}: Limbs[2] var a_sqr_comba{.noInit.}, na_sqr_comba{.noInit.}: Limbs[2] - a_sqr.montRed_asm(adbl_sqr, M, 1, hasSpareBit = false) - na_sqr.montRed_asm(nadbl_sqr, M, 1, hasSpareBit = false) - a_sqr_comba.montyRedc2x_Comba(adbl_sqr, M, 1) - na_sqr_comba.montyRedc2x_Comba(nadbl_sqr, M, 1) + a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipReduction = false) + na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipReduction = false) + a_sqr_comba.redc2xMont_Comba(adbl_sqr, M, 1) + na_sqr_comba.redc2xMont_Comba(nadbl_sqr, M, 1) debugecho "--------------------------------" debugecho "after:" diff --git a/constantine/arithmetic/assembly/limbs_asm_redc_mont_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_redc_mont_x86_adx_bmi2.nim new file mode 100644 index 000000000..c8367102f --- /dev/null +++ b/constantine/arithmetic/assembly/limbs_asm_redc_mont_x86_adx_bmi2.nim @@ -0,0 +1,269 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + std/macros, + # Internal + ../../config/common, + ../../primitives, + ./limbs_asm_modular_x86 + +# ############################################################ +# +# Assembly implementation of finite fields +# +# ############################################################ + +static: doAssert UseASM_X86_64 + +# MULX/ADCX/ADOX +{.localPassC:"-madx -mbmi2".} +# Necessary for the compiler to find enough registers (enabled at -O1) +{.localPassC:"-fomit-frame-pointer".} + +# No exceptions allowed +{.push raises: [].} + +# Montgomery reduction +# ------------------------------------------------------------ + +macro redc2xMont_adx_gen[N: static int]( + r_PIR: var array[N, SecretWord], + a_PIR: array[N*2, SecretWord], + M_PIR: array[N, SecretWord], + m0ninv_REG: BaseType, + hasSpareBit, skipReduction: static bool + ) = + + # No register spilling handling + doAssert N <= 6, "The Assembly-optimized montgomery multiplication requires at most 6 limbs." + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + let + # We could force M as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) + + let uSlots = N+1 + let vSlots = max(N-1, 5) + + var # Scratchspaces + u = init(OperandArray, nimSymbol = ident"U", uSlots, ElemsInReg, InputOutput_EnsureClobber) + v = init(OperandArray, nimSymbol = ident"V", vSlots, ElemsInReg, InputOutput_EnsureClobber) + + # Prologue + let usym = u.nimSymbol + let vsym = v.nimSymbol + result.add quote do: + static: doAssert: sizeof(SecretWord) == sizeof(ByteAddress) + var `usym`{.noinit.}: Limbs[`uSlots`] + var `vsym` {.noInit.}: Limbs[`vSlots`] + `vsym`[0] = cast[SecretWord](`r_PIR`[0].unsafeAddr) + `vsym`[1] = cast[SecretWord](`a_PIR`[0].unsafeAddr) + `vsym`[2] = SecretWord(`m0ninv_REG`) + + let r_temp = v[0].asArrayAddr(len = N) + let a = v[1].asArrayAddr(len = 2*N) + let m0ninv = v[2] + let lo = v[3] + let hi = v[4] + + # Algorithm + # --------------------------------------------------------- + # for i in 0 .. n-1: + # hi <- 0 + # m <- a[i] * m0ninv mod 2ʷ (i.e. simple multiplication) + # for j in 0 .. n-1: + # (hi, lo) <- a[i+j] + m * M[j] + hi + # a[i+j] <- lo + # a[i+n] += hi + # for i in 0 .. n-1: + # r[i] = a[i+n] + # if r >= M: + # r -= M + + ctx.mov rdx, m0ninv + + for i in 0 ..< N: + ctx.mov u[i], a[i] + + for i in 0 ..< N: + # RDX contains m0ninv at the start of each loop + ctx.comment "" + ctx.imul rdx, u[0] # m <- a[i] * m0ninv mod 2ʷ + ctx.comment "---- Reduction " & $i + ctx.`xor` u[N], u[N] + + for j in 0 ..< N-1: + ctx.comment "" + ctx.mulx hi, lo, M[j], rdx + ctx.adcx u[j], lo + ctx.adox u[j+1], hi + + # Last limb + ctx.comment "" + ctx.mulx hi, lo, M[N-1], rdx + ctx.mov rdx, m0ninv # Reload m0ninv for next iter + ctx.adcx u[N-1], lo + ctx.adox hi, u[N] + ctx.adcx u[N], hi + + u.rotateLeft() + + ctx.mov rdx, r_temp + let r = rdx.asArrayAddr(len = N) + + # This does a[i+n] += hi + # but in a separate carry chain, fused with the + # copy "r[i] = a[i+n]" + for i in 0 ..< N: + if i == 0: + ctx.add u[i], a[i+N] + else: + ctx.adc u[i], a[i+N] + + let t = repackRegisters(v, u[N]) + + if hasSpareBit and skipReduction: + for i in 0 ..< N: + ctx.mov r[i], t[i] + elif hasSpareBit: + ctx.finalSubNoCarryImpl(r, u, M, t) + else: + ctx.finalSubMayCarryImpl(r, u, M, t, hi) + + # Code generation + result.add ctx.generate() + +func redcMont_asm_adx_inline*[N: static int]( + r: var array[N, SecretWord], + a: array[N*2, SecretWord], + M: array[N, SecretWord], + m0ninv: BaseType, + hasSpareBit: static bool, + skipReduction: static bool = false + ) {.inline.} = + ## Constant-time Montgomery reduction + ## Inline-version + redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipReduction) + +func redcMont_asm_adx*[N: static int]( + r: var array[N, SecretWord], + a: array[N*2, SecretWord], + M: array[N, SecretWord], + m0ninv: BaseType, + hasSpareBit: static bool, + skipReduction: static bool = false + ) = + ## Constant-time Montgomery reduction + redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipReduction) + + +# Montgomery conversion +# ---------------------------------------------------------- + +macro mulMont_by_1_adx_gen[N: static int]( + t_EIR: var array[N, SecretWord], + M_PIR: array[N, SecretWord], + m0ninv_REG: BaseType) = + + # No register spilling handling + doAssert N <= 6, "The Assembly-optimized montgomery reduction requires at most 6 limbs." + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + # On x86, compilers only let us use 15 out of 16 registers + # RAX and RDX are defacto used due to the MUL instructions + # so we store everything in scratchspaces restoring as needed + let + scratchSlots = 1 + + t = init(OperandArray, nimSymbol = t_EIR, N, ElemsInReg, InputOutput_EnsureClobber) + # We could force M as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input) + # MultiPurpose Register slots + scratch = init(OperandArray, nimSymbol = ident"scratch", scratchSlots, ElemsInReg, InputOutput_EnsureClobber) + + # MUL requires RAX and RDX + + m0ninv = Operand( + desc: OperandDesc( + asmId: "[m0ninv]", + nimSymbol: m0ninv_REG, + rm: MemOffsettable, + constraint: Input, + cEmit: "&" & $m0ninv_REG + ) + ) + + C = scratch[0] # Stores the high-part of muliplication + + let scratchSym = scratch.nimSymbol + + # Copy a in t + result.add quote do: + var `scratchSym` {.noInit.}: Limbs[`scratchSlots`] + + # Algorithm + # --------------------------------------------------------- + # for i in 0 .. n-1: + # m <- t[0] * m0ninv mod 2ʷ (i.e. simple multiplication) + # C, _ = t[0] + m * M[0] + # for j in 1 .. n-1: + # (C, t[j-1]) <- r[j] + m*M[j] + C + # t[n-1] = C + + # Low-level optimizations + # https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf + # Section 3.5.1.8 xor'ing a reg with itself is free (except for instruction code size) + + ctx.comment "for i in 0 ..< N:" + for i in 0 ..< N: + ctx.comment " m <- t[0] * m0ninv mod 2ʷ" + ctx.mov rdx, m0ninv + ctx.imul rdx, t[0] + + # Create 2 parallel carry-chains for adcx and adox + # We need to clear the carry/overflow flags first for ADCX/ADOX + # with the smallest instruction if possible (xor rax, rax) + # to reduce instruction-cache miss + ctx.comment " C, _ = t[0] + m * M[0]" + ctx.`xor` rax, rax + ctx.mulx C, rax, M[0], rdx + ctx.adcx rax, t[0] # Set C', the carry flag for future adcx, but don't accumulate in C yet + ctx.mov t[0], C + + # for j=1 to N-1 + # (S,t[j-1]) := t[j] + m*M[j] + S + ctx.comment " for j=1 to N-1" + ctx.comment " (C,t[j-1]) := t[j] + m*M[j] + C" + for j in 1 ..< N: + ctx.adcx t[j-1], t[j] + ctx.mulx t[j], C, M[j], rdx + ctx.adox t[j-1], C + + ctx.comment " final carries" + ctx.mov rax, 0 + ctx.adcx t[N-1], rax + ctx.adox t[N-1], rax + + result.add ctx.generate() + +func fromMont_asm_adx*(r: var Limbs, a, M: Limbs, m0ninv: BaseType) = + ## Constant-time Montgomery residue form to BigInt conversion + ## Requires ADX and BMI2 instruction set + var t{.noInit.} = a + block: + t.mulMont_by_1_adx_gen(M, m0ninv) + + block: # Map from [0, 2p) to [0, p) + var workspace{.noInit.}: typeof(r) + r.finalSub_gen(t, M, workspace, mayCarry = false) diff --git a/constantine/arithmetic/bigints.nim b/constantine/arithmetic/bigints.nim index 755eaae55..183e15de8 100644 --- a/constantine/arithmetic/bigints.nim +++ b/constantine/arithmetic/bigints.nim @@ -298,7 +298,7 @@ func prod_high_words*[rBits, aBits, bBits](r: var BigInt[rBits], a: BigInt[aBits # - Barret reduction # - Approximating multiplication by a fractional constant in the form f(a) = K/C * a # with K and C known at compile-time. - # We can instead find a well chosen M = (2^WordBitWidth)^w, with M > C (i.e. M is a power of 2 bigger than C) + # We can instead find a well chosen M = (2^WordBitWidth)ʷ, with M > C (i.e. M is a power of 2 bigger than C) # Precompute P = K*M/C at compile-time # and at runtime do P*a/M <=> P*a >> WordBitWidth*w # i.e. prod_high_words(result, P, a, w) diff --git a/constantine/arithmetic/bigints_montgomery.nim b/constantine/arithmetic/bigints_montgomery.nim index 992f93e6d..550ae8fe2 100644 --- a/constantine/arithmetic/bigints_montgomery.nim +++ b/constantine/arithmetic/bigints_montgomery.nim @@ -24,9 +24,9 @@ import # # ############################################################ -func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) = +func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) = ## Convert a BigInt from its natural representation - ## to the Montgomery n-residue form + ## to the Montgomery residue form ## ## `mres` is overwritten. It's bitlength must be properly set before calling this procedure. ## @@ -39,37 +39,35 @@ func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseTy ## - `r2modM` is R² (mod M) ## with W = M.len ## and R = (2^WordBitWidth)^W - montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits) + getMont(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits) -func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) = - ## Convert a BigInt from its Montgomery n-residue form +func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) = + ## Convert a BigInt from its Montgomery residue form ## to the natural representation ## ## `mres` is modified in-place ## ## Caller must take care of properly switching between ## the natural and montgomery domain. - let one = block: - var one {.noInit.}: BigInt[mBits] - one.setOne() - one - redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, spareBits) + fromMont(r.limbs, a.limbs, M.limbs, m0ninv, spareBits) -func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, spareBits: static int) = +func mulMont*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, + spareBits: static int, skipReduction: static bool = false) = ## Compute r <- a*b (mod M) in the Montgomery domain ## ## This resets r to zero before processing. Use {.noInit.} ## to avoid duplicating with Nim zero-init policy - montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits) + mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipReduction) -func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, spareBits: static int) = +func squareMont*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, + spareBits: static int, skipReduction: static bool = false) = ## Compute r <- a^2 (mod M) in the Montgomery domain ## ## This resets r to zero before processing. Use {.noInit.} ## to avoid duplicating with Nim zero-init policy - montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits) + squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipReduction) -func montyPow*[mBits: static int]( +func powMont*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, spareBits: static int @@ -91,9 +89,9 @@ func montyPow*[mBits: static int]( const scratchLen = if windowSize == 1: 2 else: (1 shl windowSize) + 1 var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] - montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) + powMont(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) -func montyPowUnsafeExponent*[mBits: static int]( +func powMontUnsafeExponent*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, spareBits: static int @@ -115,12 +113,12 @@ func montyPowUnsafeExponent*[mBits: static int]( const scratchLen = if windowSize == 1: 2 else: (1 shl windowSize) + 1 var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] - montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) + powMontUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) from ../io/io_bigints import exportRawUint # Workaround recursive dependencies -func montyPow*[mBits, eBits: static int]( +func powMont*[mBits, eBits: static int]( a: var BigInt[mBits], exponent: BigInt[eBits], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, spareBits: static int @@ -137,9 +135,9 @@ func montyPow*[mBits, eBits: static int]( var expBE {.noInit.}: array[(ebits + 7) div 8, byte] expBE.exportRawUint(exponent, bigEndian) - montyPow(a, expBE, M, one, negInvModWord, windowSize, spareBits) + powMont(a, expBE, M, one, negInvModWord, windowSize, spareBits) -func montyPowUnsafeExponent*[mBits, eBits: static int]( +func powMontUnsafeExponent*[mBits, eBits: static int]( a: var BigInt[mBits], exponent: BigInt[eBits], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, spareBits: static int @@ -160,7 +158,7 @@ func montyPowUnsafeExponent*[mBits, eBits: static int]( var expBE {.noInit.}: array[(ebits + 7) div 8, byte] expBE.exportRawUint(exponent, bigEndian) - montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, spareBits) + powMontUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, spareBits) {.pop.} # inline {.pop.} # raises no exceptions diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index f15a8d427..3ffc5c9fc 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -56,16 +56,20 @@ func fromBig*(dst: var FF, src: BigInt) = when nimvm: dst.mres.montyResidue_precompute(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord()) else: - dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) + dst.mres.getMont(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} = ## Convert a BigInt to its Montgomery form result.fromBig(src) +func fromField*(dst: var BigInt, src: FF) {.noInit, inline.} = + ## Convert a finite-field element to a BigInt in natural representation + dst.fromMont(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) + func toBig*(src: FF): auto {.noInit, inline.} = ## Convert a finite-field element to a BigInt in natural representation var r {.noInit.}: typeof(src.mres) - r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) + r.fromField(src) return r # Copy @@ -102,7 +106,7 @@ func cswap*(a, b: var FF, ctl: CTBool) {.meter.} = # Note: for `+=`, double, sum # not(a.mres < FF.fieldMod()) is unnecessary if the prime has the form -# (2^64)^w - 1 (if using uint64 words). +# (2^64)ʷ - 1 (if using uint64 words). # In practice I'm not aware of such prime being using in elliptic curves. # 2^127 - 1 and 2^521 - 1 are used but 127 and 521 are not multiple of 32/64 @@ -148,7 +152,7 @@ func setMinusOne*(a: var FF) = func `+=`*(a: var FF, b: FF) {.meter.} = ## In-place addition modulo p when UseASM_X86_64 and a.mres.limbs.len <= 6: # TODO: handle spilling - addmod_asm(a.mres.limbs, a.mres.limbs, b.mres.limbs, FF.fieldMod().limbs) + addmod_asm(a.mres.limbs, a.mres.limbs, b.mres.limbs, FF.fieldMod().limbs, FF.getSpareBits() >= 1) else: var overflowed = add(a.mres, b.mres) overflowed = overflowed or not(a.mres < FF.fieldMod()) @@ -165,7 +169,7 @@ func `-=`*(a: var FF, b: FF) {.meter.} = func double*(a: var FF) {.meter.} = ## Double ``a`` modulo p when UseASM_X86_64 and a.mres.limbs.len <= 6: # TODO: handle spilling - addmod_asm(a.mres.limbs, a.mres.limbs, a.mres.limbs, FF.fieldMod().limbs) + addmod_asm(a.mres.limbs, a.mres.limbs, a.mres.limbs, FF.fieldMod().limbs, FF.getSpareBits() >= 1) else: var overflowed = double(a.mres) overflowed = overflowed or not(a.mres < FF.fieldMod()) @@ -175,7 +179,7 @@ func sum*(r: var FF, a, b: FF) {.meter.} = ## Sum ``a`` and ``b`` into ``r`` modulo p ## r is initialized/overwritten when UseASM_X86_64 and a.mres.limbs.len <= 6: # TODO: handle spilling - addmod_asm(r.mres.limbs, a.mres.limbs, b.mres.limbs, FF.fieldMod().limbs) + addmod_asm(r.mres.limbs, a.mres.limbs, b.mres.limbs, FF.fieldMod().limbs, FF.getSpareBits() >= 1) else: var overflowed = r.mres.sum(a.mres, b.mres) overflowed = overflowed or not(r.mres < FF.fieldMod()) @@ -204,20 +208,20 @@ func double*(r: var FF, a: FF) {.meter.} = ## Double ``a`` into ``r`` ## `r` is initialized/overwritten when UseASM_X86_64 and a.mres.limbs.len <= 6: # TODO: handle spilling - addmod_asm(r.mres.limbs, a.mres.limbs, a.mres.limbs, FF.fieldMod().limbs) + addmod_asm(r.mres.limbs, a.mres.limbs, a.mres.limbs, FF.fieldMod().limbs, FF.getSpareBits() >= 1) else: var overflowed = r.mres.double(a.mres) overflowed = overflowed or not(r.mres < FF.fieldMod()) discard csub(r.mres, FF.fieldMod(), overflowed) -func prod*(r: var FF, a, b: FF) {.meter.} = +func prod*(r: var FF, a, b: FF, skipReduction: static bool = false) {.meter.} = ## Store the product of ``a`` by ``b`` modulo p into ``r`` ## ``r`` is initialized / overwritten - r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) + r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction) -func square*(r: var FF, a: FF) {.meter.} = +func square*(r: var FF, a: FF, skipReduction: static bool = false) {.meter.} = ## Squaring modulo p - r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) + r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction) func neg*(r: var FF, a: FF) {.meter.} = ## Negate modulo p @@ -261,20 +265,20 @@ func cneg*(r: var FF, a: FF, ctl: SecretBool) {.meter.} = func cneg*(a: var FF, ctl: SecretBool) {.meter.} = ## Constant-time in-place conditional negation ## The negation is only performed if ctl is "true" - var t = a + var t {.noInit.} = a a.cneg(t, ctl) func cadd*(a: var FF, b: FF, ctl: SecretBool) {.meter.} = ## Constant-time out-place conditional addition ## The addition is only performed if ctl is "true" - var t = a + var t {.noInit.} = a t += b a.ccopy(t, ctl) func csub*(a: var FF, b: FF, ctl: SecretBool) {.meter.} = ## Constant-time out-place conditional substraction ## The substraction is only performed if ctl is "true" - var t = a + var t {.noInit.} = a t -= b a.ccopy(t, ctl) @@ -340,7 +344,7 @@ func pow*(a: var FF, exponent: BigInt) = ## ``a``: a field element to be exponentiated ## ``exponent``: a big integer const windowSize = 5 # TODO: find best window size for each curves - a.mres.montyPow( + a.mres.powMont( exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, @@ -352,7 +356,7 @@ func pow*(a: var FF, exponent: openarray[byte]) = ## ``a``: a field element to be exponentiated ## ``exponent``: a big integer in canonical big endian representation const windowSize = 5 # TODO: find best window size for each curves - a.mres.montyPow( + a.mres.powMont( exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, @@ -371,7 +375,7 @@ func powUnsafeExponent*(a: var FF, exponent: BigInt) = ## - power analysis ## - timing analysis const windowSize = 5 # TODO: find best window size for each curves - a.mres.montyPowUnsafeExponent( + a.mres.powMontUnsafeExponent( exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, @@ -390,7 +394,7 @@ func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) = ## - power analysis ## - timing analysis const windowSize = 5 # TODO: find best window size for each curves - a.mres.montyPowUnsafeExponent( + a.mres.powMontUnsafeExponent( exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, @@ -409,21 +413,39 @@ func `*=`*(a: var FF, b: FF) {.meter.} = ## Multiplication modulo p a.prod(a, b) -func square*(a: var FF) {.meter.} = +func square*(a: var FF, skipReduction: static bool = false) {.meter.} = ## Squaring modulo p - a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) + a.square(a, skipReduction) -func square_repeated*(r: var FF, num: int) {.meter.} = +func square_repeated*(a: var FF, num: int, skipReduction: static bool = false) {.meter.} = ## Repeated squarings + # Except in Tonelli-Shanks, num is always known at compile-time + # and square repeated is inlined, so the compiler should optimize the branches away. + + # TODO: understand the conditions to avoid the final substraction for _ in 0 ..< num: - r.square() + a.square(skipReduction = false) -func square_repeated*(r: var FF, a: FF, num: int) {.meter.} = +func square_repeated*(r: var FF, a: FF, num: int, skipReduction: static bool = false) {.meter.} = ## Repeated squarings + + # TODO: understand the conditions to avoid the final substraction r.square(a) for _ in 1 ..< num: r.square() +func square_repeated_then_mul*(a: var FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} = + ## Square `a`, `num` times and then multiply by b + ## Assumes at least 1 squaring + a.square_repeated(num, skipReduction = false) + a.prod(a, b, skipReduction = skipReduction) + +func square_repeated_then_mul*(r: var FF, a: FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} = + ## Square `a`, `num` times and then multiply by b + ## Assumes at least 1 squaring + r.square_repeated(a, num, skipReduction = false) + r.prod(r, b, skipReduction = skipReduction) + func `*=`*(a: var FF, b: static int) = ## Multiplication by a small integer known at compile-time # Implementation: diff --git a/constantine/arithmetic/finite_fields_double_precision.nim b/constantine/arithmetic/finite_fields_double_precision.nim index ab161c0c3..f5cb81a8e 100644 --- a/constantine/arithmetic/finite_fields_double_precision.nim +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -73,7 +73,7 @@ func redc2x*(r: var Fp, a: FpDbl) = ## Reduce a double-precision field element into r ## from [0, 2ⁿp) range to [0, p) range const N = r.mres.limbs.len - montyRedc2x( + redc2xMont( r.mres.limbs, a.limbs2x, Fp.C.Mod.limbs, diff --git a/constantine/arithmetic/limbs_division.nim b/constantine/arithmetic/limbs_division.nim index c1f6b2f8a..4e0e6eca9 100644 --- a/constantine/arithmetic/limbs_division.nim +++ b/constantine/arithmetic/limbs_division.nim @@ -128,7 +128,7 @@ func numWordsFromBits(bits: int): int {.inline.} = func shlAddMod_estimate(a: LimbsViewMut, aLen: int, c: SecretWord, M: LimbsViewConst, mBits: int ): tuple[neg, tooBig: SecretBool] = - ## Estimate a <- a shl 2^w + c (mod M) + ## Estimate a <- a shl 2ʷ + c (mod M) ## ## with w the base word width, usually 32 on 32-bit platforms and 64 on 64-bit platforms ## diff --git a/constantine/arithmetic/limbs_extmul.nim b/constantine/arithmetic/limbs_extmul.nim index 47a8e2eb1..49ec322ac 100644 --- a/constantine/arithmetic/limbs_extmul.nim +++ b/constantine/arithmetic/limbs_extmul.nim @@ -74,7 +74,7 @@ func prod*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen], b: when UseASM_X86_64 and aLen <= 6: # ADX implies BMI2 if ({.noSideEffect.}: hasAdx()): - mul_asm_adx_bmi2(r, a, b) + mul_asm_adx(r, a, b) else: mul_asm(r, a, b) elif UseASM_X86_64: @@ -98,7 +98,7 @@ func prod_high_words*[rLen, aLen, bLen]( # - Barret reduction # - Approximating multiplication by a fractional constant in the form f(a) = K/C * a # with K and C known at compile-time. - # We can instead find a well chosen M = (2^WordBitWidth)^w, with M > C (i.e. M is a power of 2 bigger than C) + # We can instead find a well chosen M = (2^WordBitWidth)ʷ, with M > C (i.e. M is a power of 2 bigger than C) # Precompute P = K*M/C at compile-time # and at runtime do P*a/M <=> P*a >> (WordBitWidth*w) # i.e. prod_high_words(result, P, a, w) @@ -203,7 +203,7 @@ func square*[rLen, aLen]( when UseASM_X86_64 and aLen in {4, 6} and rLen == 2*aLen: # ADX implies BMI2 if ({.noSideEffect.}: hasAdx()): - square_asm_adx_bmi2(r, a) + square_asm_adx(r, a) else: square_asm(r, a) elif UseASM_X86_64: diff --git a/constantine/arithmetic/limbs_montgomery.nim b/constantine/arithmetic/limbs_montgomery.nim index de77075a0..92d39eef1 100644 --- a/constantine/arithmetic/limbs_montgomery.nim +++ b/constantine/arithmetic/limbs_montgomery.nim @@ -15,12 +15,12 @@ import ./limbs, ./limbs_extmul when UseASM_X86_32: - import ./assembly/limbs_asm_montred_x86 + import ./assembly/limbs_asm_redc_mont_x86 when UseASM_X86_64: import - ./assembly/limbs_asm_montmul_x86, - ./assembly/limbs_asm_montmul_x86_adx_bmi2, - ./assembly/limbs_asm_montred_x86_adx_bmi2 + ./assembly/limbs_asm_mul_mont_x86, + ./assembly/limbs_asm_mul_mont_x86_adx_bmi2, + ./assembly/limbs_asm_redc_mont_x86_adx_bmi2 # ############################################################ # @@ -53,12 +53,19 @@ when UseASM_X86_64: # Montgomery Reduction # ------------------------------------------------------------ -func montyRedc2x_CIOS[N: static int]( +func redc2xMont_CIOS[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], - m0ninv: BaseType) = + m0ninv: BaseType, skipReduction: static bool = false) = ## Montgomery reduce a double-precision bigint modulo M + ## + ## This maps + ## - [0, 4p²) -> [0, 2p) with skipReduction + ## - [0, 4p²) -> [0, p) without + ## + ## SkipReduction skips the final substraction step. + ## For skipReduction, M needs to have a spare bit in it's representation i.e. unused MSB. # - Analyzing and Comparing Montgomery Multiplication Algorithms # Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr. # http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf @@ -81,7 +88,7 @@ func montyRedc2x_CIOS[N: static int]( # # for i in 0 .. n-1: # C <- 0 - # m <- a[i] * m0ninv mod 2^w (i.e. simple multiplication) + # m <- a[i] * m0ninv mod 2ʷ (i.e. simple multiplication) # for j in 0 .. n-1: # (C, S) <- a[i+j] + m * M[j] + C # a[i+j] <- S @@ -95,8 +102,8 @@ func montyRedc2x_CIOS[N: static int]( # to the higher limb if any, thank you "implementation detail" # missing from paper. - var a = a # Copy "t" for mutation and ensure on stack - var res: typeof(r) # Accumulator + var a {.noInit.} = a # Copy "t" for mutation and ensure on stack + var res {.noInit.}: typeof(r) # Accumulator staticFor i, 0, N: var C = Zero let m = a[i] * SecretWord(m0ninv) @@ -112,15 +119,23 @@ func montyRedc2x_CIOS[N: static int]( addC(carry, res[i], a[i+N], res[i], carry) # Final substraction - discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M)) + when not skipReduction: + discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M)) r = res -func montyRedc2x_Comba[N: static int]( +func redc2xMont_Comba[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], - m0ninv: BaseType) = + m0ninv: BaseType, skipReduction: static bool = false) = ## Montgomery reduce a double-precision bigint modulo M + ## + ## This maps + ## - [0, 4p²) -> [0, 2p) with skipReduction + ## - [0, 4p²) -> [0, p) without + ## + ## SkipReduction skips the final substraction step. + ## For skipReduction, M needs to have a spare bit in it's representation i.e. unused MSB. # We use Product Scanning / Comba multiplication var t, u, v = Zero var carry: Carry @@ -156,18 +171,25 @@ func montyRedc2x_Comba[N: static int]( addC(carry, z[N-1], v, a[2*N-1], Carry(0)) # Final substraction - discard z.csub(M, SecretBool(carry) or not(z < M)) + when not skipReduction: + discard z.csub(M, SecretBool(carry) or not(z < M)) r = z # Montgomery Multiplication # ------------------------------------------------------------ -func montyMul_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = +func mulMont_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) ## and no-carry optimization. ## This requires the most significant word of the Modulus ## M[^1] < high(SecretWord) shr 1 (i.e. less than 0b01111...1111) - ## https://hackmd.io/@zkteam/modular_multiplication + ## https://hackmd.io/@gnark/modular_multiplication + ## + ## This maps + ## - [0, 2p) -> [0, 2p) with skipReduction + ## - [0, 2p) -> [0, p) without + ## + ## SkipReduction skips the final substraction step. # We want all the computation to be kept in registers # hence we use a temporary `t`, hoping that the compiler does it. @@ -175,7 +197,7 @@ func montyMul_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = const N = t.len staticFor i, 0, N: # (A, t[0]) <- a[0] * b[i] + t[0] - # m <- (t[0] * m0ninv) mod 2^w + # m <- (t[0] * m0ninv) mod 2ʷ # (C, _) <- m * M[0] + t[0] var A: SecretWord muladd1(A, t[0], a[0], b[i], t[0]) @@ -191,10 +213,11 @@ func montyMul_CIOS_sparebit(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = t[N-1] = C + A - discard t.csub(M, not(t < M)) + when not skipReduction: + discard t.csub(M, not(t < M)) r = t -func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) {.used.} = +func mulMont_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) {.used.} = ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) # - Analyzing and Comparing Montgomery Multiplication Algorithms # Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr. @@ -221,7 +244,7 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) {.used.} = addC(tNp1, tN, tN, A, Carry(0)) # Reduction - # m <- (t[0] * m0ninv) mod 2^w + # m <- (t[0] * m0ninv) mod 2ʷ # (C, _) <- m * M[0] + t[0] var C, lo = Zero let m = t[0] * SecretWord(m0ninv) @@ -239,11 +262,18 @@ func montyMul_CIOS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) {.used.} = # t[N+1] can only be non-zero in the intermediate computation # since it is immediately reduce to t[N] at the end of each "i" iteration # However if t[N] is non-zero we have t > M - discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)^w + discard t.csub(M, tN.isNonZero() or not(t < M)) # TODO: (t >= M) is unnecessary for prime in the form (2^64)ʷ r = t -func montyMul_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = +func mulMont_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) = ## Montgomery Multiplication using Finely Integrated Product Scanning (FIPS) + ## + ## This maps + ## - [0, 2p) -> [0, 2p) with skipReduction + ## - [0, 2p) -> [0, p) without + ## + ## SkipReduction skips the final substraction step. + ## For skipReduction, M needs to have a spare bit in it's representation i.e. unused MSB. # - Architectural Enhancements for Montgomery # Multiplication on Embedded RISC Processors # Johann Großschädl and Guy-Armand Kamendje, 2003 @@ -276,7 +306,8 @@ func montyMul_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = u = t t = Zero - discard z.csub(M, v.isNonZero() or not(z < M)) + when not skipReduction: + discard z.csub(M, v.isNonZero() or not(z < M)) r = z # Montgomery Squaring @@ -313,36 +344,81 @@ func montyMul_FIPS(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType) = # ... # for j in 1 ..< N: # <- Montgomery reduce. +# Montgomery Conversion +# ------------------------------------------------------------ +# +# In Montgomery form, inputs are scaled by a constant R +# so a' = aR (mod p) and b' = bR (mod p) +# +# A classic multiplication would do a'*b' = abR² (mod p) +# we then need to remove the extra R, hence: +# - Montgomery reduction (redc) does 1/R (mod p) to map abR² (mod p) -> abR (mod p) +# - Montgomery multiplication directly compute mulMont(aR, bR) = abR (mod p) +# +# So to convert a to a' = aR (mod p), we can do mulMont(a, R²) = aR (mod p) +# and to convert a' to a = aR / R (mod p) we can do: +# - redc(aR) = a +# - or mulMont(aR, 1) = a + +func fromMont_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) = + ## Convert from Montgomery form to canonical BigInt form + # for i in 0 .. n-1: + # m <- t[0] * m0ninv mod 2ʷ (i.e. simple multiplication) + # C, _ = t[0] + m * M[0] + # for j in 1 ..n-1: + # (C, t[j-1]) <- r[j] + m*M[j] + C + # t[n-1] = C + + const N = a.len + var t {.noInit.} = a # Ensure working in registers + + staticFor i, 0, N: + let m = t[0] * SecretWord(m0ninv) + var C, lo: SecretWord + muladd1(C, lo, m, M[0], t[0]) + staticFor j, 1, N: + muladd2(C, t[j-1], m, M[j], C, t[j]) + t[N-1] = C + + discard t.csub(M, not(t < M)) + r = t + # Exported API # ------------------------------------------------------------ # TODO upstream, using Limbs[N] breaks semcheck -func montyRedc2x*[N: static int]( +func redc2xMont*[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], - m0ninv: BaseType, spareBits: static int) {.inline.} = + m0ninv: BaseType, + spareBits: static int, skipReduction: static bool = false) {.inline.} = ## Montgomery reduce a double-precision bigint modulo M + + const skipReduction = skipReduction and spareBits >= 1 + when UseASM_X86_64 and r.len <= 6: # ADX implies BMI2 if ({.noSideEffect.}: hasAdx()): - montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1) + redcMont_asm_adx(r, a, M, m0ninv, spareBits >= 1, skipReduction) else: when r.len in {3..6}: - montRed_asm(r, a, M, m0ninv, spareBits >= 1) + redcMont_asm(r, a, M, m0ninv, spareBits >= 1, skipReduction) else: - montyRedc2x_CIOS(r, a, M, m0ninv) - # montyRedc2x_Comba(r, a, M, m0ninv) + redc2xMont_CIOS(r, a, M, m0ninv, skipReduction) + # redc2xMont_Comba(r, a, M, m0ninv) elif UseASM_X86_64 and r.len in {3..6}: # TODO: Assembly faster than GCC but slower than Clang - montRed_asm(r, a, M, m0ninv, spareBits >= 1) + redcMont_asm(r, a, M, m0ninv, spareBits >= 1, skipReduction) else: - montyRedc2x_CIOS(r, a, M, m0ninv) - # montyRedc2x_Comba(r, a, M, m0ninv) + redc2xMont_CIOS(r, a, M, m0ninv, skipReduction) + # redc2xMont_Comba(r, a, M, m0ninv, skipReduction) -func montyMul*( +func mulMont*( r: var Limbs, a, b, M: Limbs, - m0ninv: static BaseType, spareBits: static int) {.inline.} = + m0ninv: BaseType, + spareBits: static int, + skipReduction: static bool = false) {.inline.} = ## Compute r <- a*b (mod M) in the Montgomery domain ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64 ## @@ -369,43 +445,50 @@ func montyMul*( # The implementation is visible from here, the compiler can make decision whether to: # - specialize/duplicate code for m0ninv == 1 (especially if only 1 curve is needed) # - keep it generic and optimize code size + + const skipReduction = skipReduction and spareBits >= 1 + when spareBits >= 1: when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling # ADX implies BMI2 if ({.noSideEffect.}: hasAdx()): - montMul_CIOS_sparebit_asm_adx_bmi2(r, a, b, M, m0ninv) + mulMont_CIOS_sparebit_asm_adx(r, a, b, M, m0ninv, skipReduction) else: - montMul_CIOS_sparebit_asm(r, a, b, M, m0ninv) + mulMont_CIOS_sparebit_asm(r, a, b, M, m0ninv, skipReduction) else: - montyMul_CIOS_sparebit(r, a, b, M, m0ninv) + mulMont_CIOS_sparebit(r, a, b, M, m0ninv, skipReduction) else: - montyMul_FIPS(r, a, b, M, m0ninv) + mulMont_FIPS(r, a, b, M, m0ninv, skipReduction) -func montySquare*[N](r: var Limbs[N], a, M: Limbs[N], - m0ninv: static BaseType, spareBits: static int) {.inline.} = +func squareMont*[N](r: var Limbs[N], a, M: Limbs[N], + m0ninv: BaseType, + spareBits: static int, + skipReduction: static bool = false) {.inline.} = ## Compute r <- a^2 (mod M) in the Montgomery domain ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63 + const skipReduction = skipReduction and spareBits >= 1 + when UseASM_X86_64 and a.len in {4, 6}: # ADX implies BMI2 if ({.noSideEffect.}: hasAdx()): - # With ADX and spare bit, montSquare_CIOS_asm_adx_bmi2 + # With ADX and spare bit, squareMont_CIOS_asm_adx # which uses unfused squaring then Montgomery reduction # is slightly slower than fused Montgomery multiplication when spareBits >= 1: - montMul_CIOS_sparebit_asm_adx_bmi2(r, a, a, M, m0ninv) + mulMont_CIOS_sparebit_asm_adx(r, a, a, M, m0ninv, skipReduction) else: - montSquare_CIOS_asm_adx_bmi2(r, a, M, m0ninv, spareBits >= 1) + squareMont_CIOS_asm_adx(r, a, M, m0ninv, spareBits >= 1, skipReduction) else: - montSquare_CIOS_asm(r, a, M, m0ninv, spareBits >= 1) + squareMont_CIOS_asm(r, a, M, m0ninv, spareBits >= 1, skipReduction) elif UseASM_X86_64: var r2x {.noInit.}: Limbs[2*N] r2x.square(a) - r.montyRedc2x(r2x, M, m0ninv, spareBits) + r.redc2xMont(r2x, M, m0ninv, spareBits, skipReduction) else: - montyMul(r, a, a, M, m0ninv, spareBits) - -func redc*(r: var Limbs, a, one, M: Limbs, + mulMont(r, a, a, M, m0ninv, spareBits, skipReduction) + +func fromMont*(r: var Limbs, a, M: Limbs, m0ninv: static BaseType, spareBits: static int) = ## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N) ## to the regular natural representation (mod N) @@ -424,10 +507,16 @@ func redc*(r: var Limbs, a, one, M: Limbs, # - https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_arithmetic_on_multiprecision_(variable-radix)_integers # - http://langevin.univ-tln.fr/cours/MLC/extra/montgomery.pdf # Montgomery original paper - # - montyMul(r, a, one, M, m0ninv, spareBits) + when UseASM_X86_64 and a.len in {2 .. 6}: + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): + fromMont_asm_adx(r, a, M, m0ninv) + else: + fromMont_asm(r, a, M, m0ninv) + else: + fromMont_CIOS(r, a, M, m0ninv) -func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, +func getMont*(r: var Limbs, a, M, r2modM: Limbs, m0ninv: static BaseType, spareBits: static int) = ## Transform a bigint ``a`` from it's natural representation (mod N) ## to a the Montgomery n-residue representation @@ -446,7 +535,7 @@ func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, ## Important: `r` is overwritten ## The result `r` buffer size MUST be at least the size of `M` buffer # Reference: https://eprint.iacr.org/2017/1057.pdf - montyMul(r, a, r2ModM, M, m0ninv, spareBits) + mulMont(r, a, r2ModM, M, m0ninv, spareBits) # Montgomery Modular Exponentiation # ------------------------------------------ @@ -489,7 +578,7 @@ func getWindowLen(bufLen: int): uint = while (1 shl result) + 1 > bufLen: dec result -func montyPowPrologue( +func powMontPrologue( a: var Limbs, M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], @@ -507,12 +596,12 @@ func montyPowPrologue( else: scratchspace[2] = a for k in 2 ..< 1 shl result: - scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, spareBits) + scratchspace[k+1].mulMont(scratchspace[k], a, M, m0ninv, spareBits) # Set a to one a = one -func montyPowSquarings( +func powMontSquarings( a: var Limbs, exponent: openarray[byte], M: Limbs, @@ -557,12 +646,11 @@ func montyPowSquarings( # We have k bits and can do k squaring for i in 0 ..< k: - tmp.montySquare(a, M, m0ninv, spareBits) - a = tmp + a.squareMont(a, M, m0ninv, spareBits) return (k, bits) -func montyPow*( +func powMont*( a: var Limbs, exponent: openarray[byte], M, one: Limbs, @@ -596,7 +684,7 @@ func montyPow*( ## A window of size 5 requires (2^5 + 1)*(381 + 7)/8 = 33 * 48 bytes = 1584 bytes ## of scratchspace (on the stack). - let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits) + let window = powMontPrologue(a, M, one, m0ninv, scratchspace, spareBits) # We process bits with from most to least significant. # At each loop iteration with have acc_len bits in acc. @@ -607,7 +695,7 @@ func montyPow*( acc, acc_len: uint e = 0 while acc_len > 0 or e < exponent.len: - let (k, bits) = montyPowSquarings( + let (k, bits) = powMontSquarings( a, exponent, M, m0ninv, scratchspace[0], window, acc, acc_len, e, @@ -626,10 +714,10 @@ func montyPow*( # Multiply with the looked-up value # we keep the product only if the exponent bits are not all zeroes - scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits) + scratchspace[0].mulMont(a, scratchspace[1], M, m0ninv, spareBits) a.ccopy(scratchspace[0], SecretWord(bits).isNonZero()) -func montyPowUnsafeExponent*( +func powMontUnsafeExponent*( a: var Limbs, exponent: openarray[byte], M, one: Limbs, @@ -649,13 +737,13 @@ func montyPowUnsafeExponent*( # TODO: scratchspace[1] is unused when window > 1 - let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits) + let window = powMontPrologue(a, M, one, m0ninv, scratchspace, spareBits) var acc, acc_len: uint e = 0 while acc_len > 0 or e < exponent.len: - let (_, bits) = montyPowSquarings( + let (_, bits) = powMontSquarings( a, exponent, M, m0ninv, scratchspace[0], window, acc, acc_len, e, @@ -665,10 +753,10 @@ func montyPowUnsafeExponent*( ## Warning ⚠️: Exposes the exponent bits if bits != 0: if window > 1: - scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, spareBits) + scratchspace[0].mulMont(a, scratchspace[1+bits], M, m0ninv, spareBits) else: # scratchspace[1] holds the original `a` - scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits) + scratchspace[0].mulMont(a, scratchspace[1], M, m0ninv, spareBits) a = scratchspace[0] {.pop.} # raises no exceptions diff --git a/constantine/config/precompute.nim b/constantine/config/precompute.nim index e62c673fe..00d10aa6a 100644 --- a/constantine/config/precompute.nim +++ b/constantine/config/precompute.nim @@ -355,7 +355,7 @@ func r3mod*(M: BigInt): BigInt = ## This is used in hash-to-curve to ## reduce a double-sized bigint mod M ## and map it to the Montgomery domain - ## with just redc2x + montyMul + ## with just redc2x + mulMont r_powmod(3, M) func montyOne*(M: BigInt): BigInt = @@ -391,7 +391,7 @@ func primePlus1div2*(P: BigInt): BigInt = # (P+1)/2 = P/2 + 1 if P is odd, # this avoids overflowing if the prime uses all bits - # i.e. in the form (2^64)^w - 1 or (2^32)^w - 1 + # i.e. in the form (2^64)ʷ - 1 or (2^32)ʷ - 1 result = P result.shiftRight(1) @@ -491,7 +491,7 @@ func toCanonicalIntRepr*[bits: static int]( # ############################################################ # This is needed to avoid recursive dependencies -func montyMul_precompute(r: var BigInt, a, b, M: BigInt, m0ninv: BaseType) = +func mulMont_precompute(r: var BigInt, a, b, M: BigInt, m0ninv: BaseType) = ## Montgomery Multiplication using Coarse Grained Operand Scanning (CIOS) var t: typeof(M) # zero-init const N = t.limbs.len @@ -527,4 +527,4 @@ func montyResidue_precompute*(r: var BigInt, a, M, r2modM: BigInt, ## Transform a bigint ``a`` from it's natural representation (mod N) ## to a the Montgomery n-residue representation ## This is intended for compile-time precomputations-only - montyMul_precompute(r, a, r2ModM, M, m0ninv) + mulMont_precompute(r, a, r2ModM, M, m0ninv) diff --git a/constantine/hash_to_curve/h2c_hash_to_field.nim b/constantine/hash_to_curve/h2c_hash_to_field.nim index aa53539ea..8dd46e077 100644 --- a/constantine/hash_to_curve/h2c_hash_to_field.nim +++ b/constantine/hash_to_curve/h2c_hash_to_field.nim @@ -154,15 +154,15 @@ func expandMessageXMD*[B1, B2, B3: byte|char, len_in_bytes: static int]( return func redc2x[FF](r: var FF, big2x: BigInt) {.inline.} = - r.mres.limbs.montyRedc2x( + r.mres.limbs.redc2xMont( big2x.limbs, FF.fieldMod().limbs, FF.getNegInvModWord(), FF.getSpareBits() ) -func montyMul(r: var BigInt, a, b: BigInt, FF: type) {.inline.} = - r.limbs.montyMul( +func mulMont(r: var BigInt, a, b: BigInt, FF: type) {.inline.} = + r.limbs.mulMont( a.limbs, b.limbs, FF.fieldMod().limbs, FF.getNegInvModWord(), @@ -228,14 +228,14 @@ func hashToField*[Field; B1, B2, B3: byte|char, count: static int]( # Reduces modulo p and output in Montgomery domain when m == 1: output[i].redc2x(big2x) - output[i].mres.montyMul( + output[i].mres.mulMont( output[i].mres, Fp[Field.C].getR3ModP(), Fp[Field.C]) else: output[i].coords[j].redc2x(big2x) - output[i].coords[j].mres.montyMul( + output[i].coords[j].mres.mulMont( output[i].coords[j].mres, Fp[Field.C].getR3ModP(), Fp[Field.C]) diff --git a/constantine/primitives/cpuinfo_x86.nim b/constantine/primitives/cpuinfo_x86.nim index 35a1a6d77..47d298438 100644 --- a/constantine/primitives/cpuinfo_x86.nim +++ b/constantine/primitives/cpuinfo_x86.nim @@ -249,7 +249,7 @@ let hasBmi2Impl = testX86Feature(Bmi2) hasTsxHleImpl = testX86Feature(TsxHle) hasTsxRtmImpl = testX86Feature(TsxRtm) - hasAdxImpl = testX86Feature(TsxHle) + hasAdxImpl = testX86Feature(Adx) hasSgxImpl = testX86Feature(Sgx) hasGfniImpl = testX86Feature(Gfni) hasAesImpl = testX86Feature(Aes) diff --git a/constantine/primitives/macro_assembler_x86.nim b/constantine/primitives/macro_assembler_x86.nim index d2aefbd6c..6917becf3 100644 --- a/constantine/primitives/macro_assembler_x86.nim +++ b/constantine/primitives/macro_assembler_x86.nim @@ -337,6 +337,10 @@ func generate*(a: Assembler_x86): NimNode = ) ) ) + result = nnkBlockStmt.newTree( + newEmptyNode(), + result + ) func getStrOffset(a: Assembler_x86, op: Operand): string = if op.kind != kFromArray: diff --git a/constantine/protocols/ethereum_evm_precompiles.nim b/constantine/protocols/ethereum_evm_precompiles.nim index ace843fe0..27402b759 100644 --- a/constantine/protocols/ethereum_evm_precompiles.nim +++ b/constantine/protocols/ethereum_evm_precompiles.nim @@ -195,8 +195,8 @@ func eth_evm_ecmul*( var sprime{.noInit.}: typeof(smod.mres) # Due to mismatch between the BigInt[256] input and the rest being BigInt[254] - # we use the low-level montyResidue instead of 'fromBig' - montyResidue(smod.mres.limbs, s.limbs, + # we use the low-level getMont instead of 'fromBig' + getMont(smod.mres.limbs, s.limbs, Fr[BN254_Snarks].fieldMod().limbs, Fr[BN254_Snarks].getR2modP().limbs, Fr[BN254_Snarks].getNegInvModWord(), diff --git a/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim b/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim index 541783970..53cf0653c 100644 --- a/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim +++ b/constantine/tower_field_extensions/assembly/fp2_asm_x86_adx_bmi2.nim @@ -13,8 +13,8 @@ import ../../arithmetic, ../../arithmetic/assembly/[ limbs_asm_mul_x86_adx_bmi2, - limbs_asm_montmul_x86_adx_bmi2, - limbs_asm_montred_x86_adx_bmi2 + limbs_asm_mul_mont_x86_adx_bmi2, + limbs_asm_redc_mont_x86_adx_bmi2 ] @@ -50,7 +50,7 @@ func has2extraBits(F: type Fp): bool = # 𝔽p2 squaring # ------------------------------------------------------------ -func sqrx2x_complex_asm_adx_bmi2*( +func sqrx2x_complex_asm_adx*( r: var array[2, FpDbl], a: array[2, Fp] ) = @@ -67,11 +67,11 @@ func sqrx2x_complex_asm_adx_bmi2*( t0.double(a.c1) t1.sum(a.c0, a.c1) - r.c1.mul_asm_adx_bmi2_impl(t0, a.c0) + r.c1.mul_asm_adx_inline(t0, a.c0) t0.diff(a.c0, a.c1) - r.c0.mul_asm_adx_bmi2_impl(t0, t1) + r.c0.mul_asm_adx_inline(t0, t1) -func sqrx_complex_sparebit_asm_adx_bmi2*( +func sqrx_complex_sparebit_asm_adx*( r: var array[2, Fp], a: array[2, Fp] ) = @@ -85,15 +85,15 @@ func sqrx_complex_sparebit_asm_adx_bmi2*( var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) v0.diff(a.c0, a.c1) v1.sum(a.c0, a.c1) - r.c1.mres.limbs.montMul_CIOS_sparebit_asm_adx_bmi2(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) + r.c1.mres.limbs.mulMont_CIOS_sparebit_asm_adx(a.c0.mres.limbs, a.c1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) # aliasing: a unneeded now r.c1.double() - r.c0.mres.limbs.montMul_CIOS_sparebit_asm_adx_bmi2(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) + r.c0.mres.limbs.mulMont_CIOS_sparebit_asm_adx(v0.mres.limbs, v1.mres.limbs, Fp.fieldMod().limbs, Fp.getNegInvModWord()) # 𝔽p2 multiplication # ------------------------------------------------------------ -func mulx2x_complex_asm_adx_bmi2*( +func mul2x_fp2_complex_asm_adx*( r: var array[2, FpDbl], a, b: array[2, Fp] ) = @@ -101,15 +101,15 @@ func mulx2x_complex_asm_adx_bmi2*( var D {.noInit.}: typeof(r.c0) var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) - r.c0.limbs2x.mul_asm_adx_bmi2_impl(a.c0.mres.limbs, b.c0.mres.limbs) - D.limbs2x.mul_asm_adx_bmi2_impl(a.c1.mres.limbs, b.c1.mres.limbs) + r.c0.limbs2x.mul_asm_adx_inline(a.c0.mres.limbs, b.c0.mres.limbs) + D.limbs2x.mul_asm_adx_inline(a.c1.mres.limbs, b.c1.mres.limbs) when Fp.has1extraBit(): t0.sumUnr(a.c0, a.c1) t1.sumUnr(b.c0, b.c1) else: t0.sum(a.c0, a.c1) t1.sum(b.c0, b.c1) - r.c1.limbs2x.mul_asm_adx_bmi2_impl(t0.mres.limbs, t1.mres.limbs) + r.c1.limbs2x.mul_asm_adx_inline(t0.mres.limbs, t1.mres.limbs) when Fp.has1extraBit(): r.c1.diff2xUnr(r.c1, r.c0) r.c1.diff2xUnr(r.c1, D) @@ -118,20 +118,20 @@ func mulx2x_complex_asm_adx_bmi2*( r.c1.diff2xMod(r.c1, D) r.c0.diff2xMod(r.c0, D) -func mulx_complex_asm_adx_bmi2*( +func mul_fp2_complex_asm_adx*( r: var array[2, Fp], a, b: array[2, Fp] ) = ## Complex multiplication on 𝔽p2 var d {.noInit.}: array[2,doublePrec(Fp)] - d.mulx2x_complex_asm_adx_bmi2(a, b) - r.c0.mres.limbs.montRed_asm_adx_bmi2_impl( + d.mul2x_fp2_complex_asm_adx(a, b) + r.c0.mres.limbs.redcMont_asm_adx_inline( d.c0.limbs2x, Fp.fieldMod().limbs, Fp.getNegInvModWord(), Fp.has1extraBit() ) - r.c1.mres.limbs.montRed_asm_adx_bmi2_impl( + r.c1.mres.limbs.redcMont_asm_adx_inline( d.c1.limbs2x, Fp.fieldMod().limbs, Fp.getNegInvModWord(), diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index 8b1c7ebea..6635f848a 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -454,7 +454,7 @@ func prod*[C: static Curve]( # (c0 + c1 x) (u + v x) => u c0 + (u c0 + u c1)x + v c1 x² # => u c0 + β v c1 + (v c0 + u c1) x when a.fromComplexExtension() and u == 1 and v == 1: - let t = a.c0 + let t {.noInit.} = a.c0 r.c0.diff(t, a.c1) r.c1.sum(t, a.c1) else: @@ -504,7 +504,7 @@ func prod2x*[C: static Curve]( const Beta {.used.} = C.getNonResidueFp() when complex and U == 1 and V == 1: - let a1 = a.c1 + let a1 {.noInit.} = a.c1 r.c1.sum2xMod(a.c0, a1) r.c0.diff2xMod(a.c0, a1) else: @@ -567,13 +567,13 @@ func `/=`*[C: static Curve](a: var Fp2[C], _: type NonResidue) = # 1/2 * (c0 + c1, c1 - c0) when a.fromComplexExtension() and u == 1 and v == 1: - let t = a.c0 + let t {.noInit.} = a.c0 a.c0 += a.c1 a.c1 -= t a.div2() else: - var a0 = a.c0 - let a1 = a.c1 + var a0 {.noInit.} = a.c0 + let a1 {.noInit.} = a.c1 const u2v2 = u*u - Beta*v*v # (u² - βv²) # TODO can be precomputed to avoid costly inversion. var u2v2inv {.noInit.}: a.c0.typeof @@ -605,7 +605,7 @@ func prod*(r: var QuadraticExt, a: QuadraticExt, _: type NonResidue) = ## - if sextic non-residue: 𝔽p8, 𝔽p12 or 𝔽p24 ## ## Assumes that the non-residue is sqrt(lower extension non-residue) - let t = a.c0 + let t {.noInit.} = a.c0 r.c0.prod(a.c1, NonResidue) r.c1 = t @@ -628,7 +628,7 @@ func prod2x*( _: type NonResidue) = ## Multiplication by non-residue static: doAssert not(r.c0 is FpDbl), "Wrong dispatch, there is a specific non-residue multiplication for the base extension." - let t = a.c0 + let t {.noInit.} = a.c0 r.c0.prod2x(a.c1, NonResidue) `=`(r.c1, t) # "r.c1 = t", is refused by the compiler. @@ -650,7 +650,7 @@ func prod*(r: var CubicExt, a: CubicExt, _: type NonResidue) = ## For all curves γ = v with v the factor for the cubic extension coordinate ## and v³ = ξ ## (c0 + c1 v + c2 v²) v => ξ c2 + c0 v + c1 v² - let t = a.c2 + let t {.noInit.} = a.c2 r.c1 = a.c0 r.c2 = a.c1 r.c0.prod(t, NonResidue) @@ -677,7 +677,7 @@ func prod2x*( ## For all curves γ = v with v the factor for cubic extension coordinate ## and v³ = ξ ## (c0 + c1 v + c2 v²) v => ξ c2 + c0 v + c1 v² - let t = a.c2 + let t {.noInit.} = a.c2 r.c1 = a.c0 r.c2 = a.c1 r.c0.prod2x(t, NonResidue) @@ -1243,7 +1243,7 @@ func square*(r: var QuadraticExt, a: QuadraticExt) = when true: when UseASM_X86_64 and a.c0.mres.limbs.len <= 6 and r.typeof.has1extraBit(): if ({.noSideEffect.}: hasAdx()): - r.coords.sqrx_complex_sparebit_asm_adx_bmi2(a.coords) + r.coords.sqrx_complex_sparebit_asm_adx(a.coords) else: r.square_complex(a) else: @@ -1281,7 +1281,7 @@ func prod*(r: var QuadraticExt, a, b: QuadraticExt) = else: # faster when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: if ({.noSideEffect.}: hasAdx()): - r.coords.mulx_complex_asm_adx_bmi2(a.coords, b.coords) + r.coords.mul_fp2_complex_asm_adx(a.coords, b.coords) else: var d {.noInit.}: doublePrec(typeof(r)) d.prod2x_complex(a, b) @@ -1318,7 +1318,7 @@ func prod2x*(r: var QuadraticExt2x, a, b: QuadraticExt) = when a.fromComplexExtension(): when UseASM_X86_64 and a.c0.mres.limbs.len <= 6: if ({.noSideEffect.}: hasAdx()): - r.coords.mulx2x_complex_asm_adx_bmi2(a.coords, b.coords) + r.coords.mul2x_fp2_complex_asm_adx(a.coords, b.coords) else: r.prod2x_complex(a, b) else: @@ -1591,7 +1591,7 @@ func prod2xImpl(r: var CubicExt2x, a, b: CubicExt) = V2.prod2x(a.c2, b.c2) # r₀ = β ((a₁ + a₂)(b₁ + b₂) - v₁ - v₂) + v₀ - when false: # CubicExt.has1extraBit(): + when a.c0 is Fp and CubicExt.has1extraBit(): t0.sumUnr(a.c1, a.c2) t1.sumUnr(b.c1, b.c2) else: @@ -1604,7 +1604,7 @@ func prod2xImpl(r: var CubicExt2x, a, b: CubicExt) = r.c0.sum2xMod(r.c0, V0) # r₁ = (a₀ + a₁) * (b₀ + b₁) - v₀ - v₁ + β v₂ - when false: # CubicExt.has1extraBit(): + when a.c0 is Fp and CubicExt.has1extraBit(): t0.sumUnr(a.c0, a.c1) t1.sumUnr(b.c0, b.c1) else: @@ -1617,7 +1617,7 @@ func prod2xImpl(r: var CubicExt2x, a, b: CubicExt) = r.c1.sum2xMod(r.c1, r.c2) # r₂ = (a₀ + a₂) * (b₀ + b₂) - v₀ - v₂ + v₁ - when false: # CubicExt.has1extraBit(): + when a.c0 is Fp and CubicExt.has1extraBit(): t0.sumUnr(a.c0, a.c2) t1.sumUnr(b.c0, b.c2) else: diff --git a/docs/optimizations.md b/docs/optimizations.md index a7906ddec..5741a7f04 100644 --- a/docs/optimizations.md +++ b/docs/optimizations.md @@ -84,7 +84,11 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n - [ ] NAF recoding - [ ] windowed-NAF recoding - [ ] SIMD vectorized select in window algorithm - - [ ] Almost Montgomery Multiplication, https://eprint.iacr.org/2011/239.pdf + - [ ] Montgomery Multiplication with no final substraction, + - Bos and Montgomery, https://eprint.iacr.org/2017/1057.pdf + - Colin D Walter, https://colinandmargaret.co.uk/Research/CDW_ELL_99.pdf + - Hachez and Quisquater, https://link.springer.com/content/pdf/10.1007%2F3-540-44499-8_23.pdf + - Gueron, https://eprint.iacr.org/2011/239.pdf - [ ] Pippenger multi-exponentiation (variable-time) - [ ] parallelized Pippenger diff --git a/helpers/prng_unsafe.nim b/helpers/prng_unsafe.nim index ccac0b0c9..721fdee36 100644 --- a/helpers/prng_unsafe.nim +++ b/helpers/prng_unsafe.nim @@ -148,7 +148,7 @@ func random_unsafe(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) + a.mres.getMont(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_unsafe(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element @@ -179,7 +179,7 @@ func random_highHammingWeight(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) + a.mres.getMont(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_highHammingWeight(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element @@ -224,7 +224,7 @@ func random_long01Seq(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) + a.mres.getMont(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_long01Seq(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element diff --git a/tests/t_finite_fields.nim b/tests/t_finite_fields.nim index 29d9bc0cd..ed91e12ae 100644 --- a/tests/t_finite_fields.nim +++ b/tests/t_finite_fields.nim @@ -8,7 +8,8 @@ import std/unittest, ../constantine/arithmetic, - ../constantine/io/io_fields, + ../constantine/arithmetic/limbs_montgomery, + ../constantine/io/[io_bigints, io_fields], ../constantine/config/curves static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" @@ -317,4 +318,54 @@ proc largeField() = check: bool r.isZero() + test "fromMont doesn't need a final substraction with 256-bit prime (full word used)": + block: + var a: Fp[Secp256k1] + a.mres = Fp[Secp256k1].getMontyPrimeMinus1() + let expected = BigInt[256].fromHex"0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2E" + + var r: BigInt[256] + r.fromField(a) + + check: bool(r == expected) + block: + var a: Fp[Secp256k1] + var d: FpDbl[Secp256k1] + + # Set Montgomery repr to the largest field element in Montgomery Residue form + a.mres = BigInt[256].fromHex"0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2E" + d.limbs2x = (BigInt[512].fromHex"0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2E").limbs + + var r, expected: BigInt[256] + + r.fromField(a) + expected.limbs.redc2xMont(d.limbs2x, Secp256k1.Mod().limbs, Fp[Secp256k1].getNegInvModWord(), Fp[Secp256k1].getSpareBits()) + + check: bool(r == expected) + + test "fromMont doesn't need a final substraction with 255-bit prime (1 spare bit)": + block: + var a: Fp[Curve25519] + a.mres = Fp[Curve25519].getMontyPrimeMinus1() + let expected = BigInt[255].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec" + + var r: BigInt[255] + r.fromField(a) + + check: bool(r == expected) + block: + var a: Fp[Curve25519] + var d: FpDbl[Curve25519] + + # Set Montgomery repr to the largest field element in Montgomery Residue form + a.mres = BigInt[255].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec" + d.limbs2x = (BigInt[512].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec").limbs + + var r, expected: BigInt[255] + + r.fromField(a) + expected.limbs.redc2xMont(d.limbs2x, Curve25519.Mod().limbs, Fp[Curve25519].getNegInvModWord(), Fp[Curve25519].getSpareBits()) + + check: bool(r == expected) + largeField()