Skip to content

Commit

Permalink
Low-level refactor part 2 (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Feb 14, 2022
1 parent 14af7e8 commit 5db30ef
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 129 deletions.
14 changes: 7 additions & 7 deletions constantine/arithmetic/assembly/limbs_asm_mul_mont_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static: doAssert UseASM_X86_64
macro mulMont_CIOS_sparebit_gen[N: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR,
M_PIR: Limbs[N], m0ninv_REG: BaseType,
skipReduction: static bool
skipFinalSub: static bool
): untyped =
## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method
Expand Down Expand Up @@ -175,7 +175,7 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
ctx.mov rax, r # move r away from scratchspace that will be used for final substraction
let r2 = rax.asArrayAddr(len = N)

if skipReduction:
if skipFinalSub:
for i in 0 ..< N:
ctx.mov r2[i], t[i]
else:
Expand All @@ -185,14 +185,14 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
)
result.add ctx.generate()

func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) =
func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Constant-time Montgomery multiplication
## If "skipReduction" is set
## If "skipFinalSub" is set
## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M)
##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipReduction)
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipFinalSub)

# Montgomery Squaring
# ------------------------------------------------------------
Expand All @@ -209,8 +209,8 @@ func squareMont_CIOS_asm*[N](
r: var Limbs[N],
a, M: Limbs[N],
m0ninv: BaseType,
hasSpareBit, skipReduction: static bool) =
hasSpareBit, skipFinalSub: static bool) =
## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_inline(a)
r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipReduction)
r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipFinalSub)
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ proc partialRedx(
macro mulMont_CIOS_sparebit_adx_gen[N: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR,
M_PIR: Limbs[N], m0ninv_REG: BaseType,
skipReduction: static bool): untyped =
skipFinalSub: static bool): untyped =
## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method
## This requires the most significant word of the Modulus
Expand Down Expand Up @@ -268,7 +268,7 @@ macro mulMont_CIOS_sparebit_adx_gen[N: static int](
lo, C
)

if skipReduction:
if skipFinalSub:
for i in 0 ..< N:
ctx.mov r[i], t[i]
else:
Expand All @@ -279,14 +279,14 @@ macro mulMont_CIOS_sparebit_adx_gen[N: static int](

result.add ctx.generate

func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) =
func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Constant-time Montgomery multiplication
## If "skipReduction" is set
## If "skipFinalSub" is set
## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M)
##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipReduction)
r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipFinalSub)

# Montgomery Squaring
# ------------------------------------------------------------
Expand All @@ -295,8 +295,8 @@ func squareMont_CIOS_asm_adx*[N](
r: var Limbs[N],
a, M: Limbs[N],
m0ninv: BaseType,
hasSpareBit, skipReduction: static bool) =
hasSpareBit, skipFinalSub: static bool) =
## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_adx_inline(a)
r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipReduction)
r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipFinalSub)
16 changes: 8 additions & 8 deletions constantine/arithmetic/assembly/limbs_asm_redc_mont_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ macro redc2xMont_gen*[N: static int](
a_PIR: array[N*2, SecretWord],
M_PIR: array[N, SecretWord],
m0ninv_REG: BaseType,
hasSpareBit, skipReduction: static bool
hasSpareBit, skipFinalSub: static bool
) =

# No register spilling handling
Expand Down Expand Up @@ -153,7 +153,7 @@ macro redc2xMont_gen*[N: static int](
# v is invalidated from now on
let t = repackRegisters(v, u[N], u[N+1])

if hasSpareBit and skipReduction:
if hasSpareBit and skipFinalSub:
for i in 0 ..< N:
ctx.mov r[i], t[i]
elif hasSpareBit:
Expand All @@ -170,22 +170,22 @@ func redcMont_asm_inline*[N: static int](
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool,
skipReduction: static bool = false
skipFinalSub: static bool = false
) {.inline.} =
## Constant-time Montgomery reduction
## Inline-version
redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipReduction)
redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipFinalSub)

func redcMont_asm*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit, skipReduction: static bool
hasSpareBit, skipFinalSub: static bool
) =
## Constant-time Montgomery reduction
static: doAssert UseASM_X86_64, "This requires x86-64."
redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipReduction)
redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipFinalSub)

# Montgomery conversion
# ----------------------------------------------------------
Expand Down Expand Up @@ -351,8 +351,8 @@ when isMainModule:
var a_sqr{.noInit.}, na_sqr{.noInit.}: Limbs[2]
var a_sqr_comba{.noInit.}, na_sqr_comba{.noInit.}: Limbs[2]

a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipReduction = false)
na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipReduction = false)
a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipFinalSub = false)
na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipFinalSub = false)
a_sqr_comba.redc2xMont_Comba(adbl_sqr, M, 1)
na_sqr_comba.redc2xMont_Comba(nadbl_sqr, M, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ macro redc2xMont_adx_gen[N: static int](
a_PIR: array[N*2, SecretWord],
M_PIR: array[N, SecretWord],
m0ninv_REG: BaseType,
hasSpareBit, skipReduction: static bool
hasSpareBit, skipFinalSub: static bool
) =

# No register spilling handling
Expand Down Expand Up @@ -131,7 +131,7 @@ macro redc2xMont_adx_gen[N: static int](

let t = repackRegisters(v, u[N])

if hasSpareBit and skipReduction:
if hasSpareBit and skipFinalSub:
for i in 0 ..< N:
ctx.mov r[i], t[i]
elif hasSpareBit:
Expand All @@ -148,22 +148,22 @@ func redcMont_asm_adx_inline*[N: static int](
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool,
skipReduction: static bool = false
skipFinalSub: static bool = false
) {.inline.} =
## Constant-time Montgomery reduction
## Inline-version
redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipReduction)
redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipFinalSub)

func redcMont_asm_adx*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool,
skipReduction: static bool = false
skipFinalSub: static bool = false
) =
## Constant-time Montgomery reduction
redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipReduction)
redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipFinalSub)


# Montgomery conversion
Expand Down
12 changes: 6 additions & 6 deletions constantine/arithmetic/bigints_montgomery.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import
#
# ############################################################

func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) =
func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: BaseType, spareBits: static int) =
## Convert a BigInt from its natural representation
## to the Montgomery residue form
##
Expand All @@ -41,7 +41,7 @@ func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, s
## and R = (2^WordBitWidth)^W
getMont(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits)

func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) =
func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: BaseType, spareBits: static int) =
## Convert a BigInt from its Montgomery residue form
## to the natural representation
##
Expand All @@ -52,20 +52,20 @@ func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static
fromMont(r.limbs, a.limbs, M.limbs, m0ninv, spareBits)

func mulMont*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType,
spareBits: static int, skipReduction: static bool = false) =
spareBits: static int, skipFinalSub: static bool = false) =
## Compute r <- a*b (mod M) in the Montgomery domain
##
## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy
mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipReduction)
mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipFinalSub)

func squareMont*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType,
spareBits: static int, skipReduction: static bool = false) =
spareBits: static int, skipFinalSub: static bool = false) =
## Compute r <- a^2 (mod M) in the Montgomery domain
##
## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy
squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipReduction)
squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipFinalSub)

func powMont*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte],
Expand Down
90 changes: 59 additions & 31 deletions constantine/arithmetic/finite_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,14 @@ func double*(r: var FF, a: FF) {.meter.} =
overflowed = overflowed or not(r.mres < FF.fieldMod())
discard csub(r.mres, FF.fieldMod(), overflowed)

func prod*(r: var FF, a, b: FF, skipReduction: static bool = false) {.meter.} =
func prod*(r: var FF, a, b: FF, skipFinalSub: static bool = false) {.meter.} =
## Store the product of ``a`` by ``b`` modulo p into ``r``
## ``r`` is initialized / overwritten
r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction)
r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

func square*(r: var FF, a: FF, skipReduction: static bool = false) {.meter.} =
func square*(r: var FF, a: FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p
r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction)
r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

func neg*(r: var FF, a: FF) {.meter.} =
## Negate modulo p
Expand Down Expand Up @@ -413,38 +413,23 @@ func `*=`*(a: var FF, b: FF) {.meter.} =
## Multiplication modulo p
a.prod(a, b)

func square*(a: var FF, skipReduction: static bool = false) {.meter.} =
func square*(a: var FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p
a.square(a, skipReduction)
a.square(a, skipFinalSub)

func square_repeated*(a: var FF, num: int, skipReduction: static bool = false) {.meter.} =
func square_repeated*(a: var FF, num: int, skipFinalSub: static bool = false) {.meter.} =
## Repeated squarings
# Except in Tonelli-Shanks, num is always known at compile-time
# and square repeated is inlined, so the compiler should optimize the branches away.

# TODO: understand the conditions to avoid the final substraction
for _ in 0 ..< num:
a.square(skipReduction = false)

func square_repeated*(r: var FF, a: FF, num: int, skipReduction: static bool = false) {.meter.} =
## Repeated squarings

# TODO: understand the conditions to avoid the final substraction
r.square(a)
for _ in 1 ..< num:
r.square()

func square_repeated_then_mul*(a: var FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} =
## Square `a`, `num` times and then multiply by b
## Assumes at least 1 squaring
a.square_repeated(num, skipReduction = false)
a.prod(a, b, skipReduction = skipReduction)
for _ in 0 ..< num-1:
a.square(skipFinalSub = true)
a.square(skipFinalSub)

func square_repeated_then_mul*(r: var FF, a: FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} =
## Square `a`, `num` times and then multiply by b
## Assumes at least 1 squaring
r.square_repeated(a, num, skipReduction = false)
r.prod(r, b, skipReduction = skipReduction)
func square_repeated*(r: var FF, a: FF, num: int, skipFinalSub: static bool = false) {.meter.} =
## Repeated squarings
r.square(a, skipFinalSub = true)
for _ in 1 ..< num-1:
r.square(skipFinalSub = true)
r.square(skipFinalSub)

func `*=`*(a: var FF, b: static int) =
## Multiplication by a small integer known at compile-time
Expand Down Expand Up @@ -550,3 +535,46 @@ template mulCheckSparse*(a: var Fp, b: Fp) =

{.pop.} # inline
{.pop.} # raises no exceptions

# ############################################################
#
# Field arithmetic ergonomic macros
#
# ############################################################

import std/macros

macro addchain*(fn: untyped): untyped =
## Modify all prod, `*=`, square, square_repeated calls
## to skipFinalSub except the very last call.
## This assumes straight-line code.
fn.expectKind(nnkFuncDef)

result = fn
var body = newStmtList()

for i, statement in fn[^1]:
statement.expectKind({nnkCommentStmt, nnkVarSection, nnkCall, nnkInfix})

var s = statement.copyNimTree()
if i + 1 != result[^1].len:
# Modify all but the last
if s.kind == nnkCall:
doAssert s[0].kind == nnkDotExpr, "Only method call syntax or infix syntax is supported in addition chains"
doAssert s[0][1].eqIdent"prod" or s[0][1].eqIdent"square" or s[0][1].eqIdent"square_repeated"
s.add newLit(true)
elif s.kind == nnkInfix:
doAssert s[0].eqIdent"*="
# a *= b -> prod(a, a, b, true)
s = newCall(
bindSym"prod",
s[1],
s[1],
s[2],
newLit(true)
)

body.add s

result[^1] = body
# echo result.toStrLit()
7 changes: 4 additions & 3 deletions constantine/arithmetic/finite_fields_square_root.nim
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,14 @@ func invsqrt_tonelli_shanks_pre(
t.square(z)
t *= a
r = z
var b = t
var root = Fp.C.tonelliShanks(root_of_unity)
var b {.noInit.} = t
var root {.noInit.} = Fp.C.tonelliShanks(root_of_unity)

var buf {.noInit.}: Fp

for i in countdown(e, 2, 1):
b.square_repeated(i-2)
if i-2 >= 1:
b.square_repeated(i-2)

let bNotOne = not b.isOne()
buf.prod(r, root)
Expand Down
Loading

0 comments on commit 5db30ef

Please sign in to comment.