Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low-level refactor part 2 #176

Merged
merged 1 commit into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions constantine/arithmetic/assembly/limbs_asm_mul_mont_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static: doAssert UseASM_X86_64
macro mulMont_CIOS_sparebit_gen[N: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR,
M_PIR: Limbs[N], m0ninv_REG: BaseType,
skipReduction: static bool
skipFinalSub: static bool
): untyped =
## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method
Expand Down Expand Up @@ -175,7 +175,7 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
ctx.mov rax, r # move r away from scratchspace that will be used for final substraction
let r2 = rax.asArrayAddr(len = N)

if skipReduction:
if skipFinalSub:
for i in 0 ..< N:
ctx.mov r2[i], t[i]
else:
Expand All @@ -185,14 +185,14 @@ macro mulMont_CIOS_sparebit_gen[N: static int](
)
result.add ctx.generate()

func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) =
func mulMont_CIOS_sparebit_asm*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Constant-time Montgomery multiplication
## If "skipReduction" is set
## If "skipFinalSub" is set
## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M)
##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipReduction)
r.mulMont_CIOS_sparebit_gen(a, b, M, m0ninv, skipFinalSub)

# Montgomery Squaring
# ------------------------------------------------------------
Expand All @@ -209,8 +209,8 @@ func squareMont_CIOS_asm*[N](
r: var Limbs[N],
a, M: Limbs[N],
m0ninv: BaseType,
hasSpareBit, skipReduction: static bool) =
hasSpareBit, skipFinalSub: static bool) =
## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_inline(a)
r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipReduction)
r.redcMont_asm_inline(r2x, M, m0ninv, hasSpareBit, skipFinalSub)
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ proc partialRedx(
macro mulMont_CIOS_sparebit_adx_gen[N: static int](
r_PIR: var Limbs[N], a_PIR, b_PIR,
M_PIR: Limbs[N], m0ninv_REG: BaseType,
skipReduction: static bool): untyped =
skipFinalSub: static bool): untyped =
## Generate an optimized Montgomery Multiplication kernel
## using the CIOS method
## This requires the most significant word of the Modulus
Expand Down Expand Up @@ -268,7 +268,7 @@ macro mulMont_CIOS_sparebit_adx_gen[N: static int](
lo, C
)

if skipReduction:
if skipFinalSub:
for i in 0 ..< N:
ctx.mov r[i], t[i]
else:
Expand All @@ -279,14 +279,14 @@ macro mulMont_CIOS_sparebit_adx_gen[N: static int](

result.add ctx.generate

func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipReduction: static bool = false) =
func mulMont_CIOS_sparebit_asm_adx*(r: var Limbs, a, b, M: Limbs, m0ninv: BaseType, skipFinalSub: static bool = false) =
## Constant-time Montgomery multiplication
## If "skipReduction" is set
## If "skipFinalSub" is set
## the result is in the range [0, 2M)
## otherwise the result is in the range [0, M)
##
## This procedure can only be called if the modulus doesn't use the full bitwidth of its underlying representation
r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipReduction)
r.mulMont_CIOS_sparebit_adx_gen(a, b, M, m0ninv, skipFinalSub)

# Montgomery Squaring
# ------------------------------------------------------------
Expand All @@ -295,8 +295,8 @@ func squareMont_CIOS_asm_adx*[N](
r: var Limbs[N],
a, M: Limbs[N],
m0ninv: BaseType,
hasSpareBit, skipReduction: static bool) =
hasSpareBit, skipFinalSub: static bool) =
## Constant-time modular squaring
var r2x {.noInit.}: Limbs[2*N]
r2x.square_asm_adx_inline(a)
r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipReduction)
r.redcMont_asm_adx(r2x, M, m0ninv, hasSpareBit, skipFinalSub)
16 changes: 8 additions & 8 deletions constantine/arithmetic/assembly/limbs_asm_redc_mont_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ macro redc2xMont_gen*[N: static int](
a_PIR: array[N*2, SecretWord],
M_PIR: array[N, SecretWord],
m0ninv_REG: BaseType,
hasSpareBit, skipReduction: static bool
hasSpareBit, skipFinalSub: static bool
) =

# No register spilling handling
Expand Down Expand Up @@ -153,7 +153,7 @@ macro redc2xMont_gen*[N: static int](
# v is invalidated from now on
let t = repackRegisters(v, u[N], u[N+1])

if hasSpareBit and skipReduction:
if hasSpareBit and skipFinalSub:
for i in 0 ..< N:
ctx.mov r[i], t[i]
elif hasSpareBit:
Expand All @@ -170,22 +170,22 @@ func redcMont_asm_inline*[N: static int](
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool,
skipReduction: static bool = false
skipFinalSub: static bool = false
) {.inline.} =
## Constant-time Montgomery reduction
## Inline-version
redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipReduction)
redc2xMont_gen(r, a, M, m0ninv, hasSpareBit, skipFinalSub)

func redcMont_asm*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit, skipReduction: static bool
hasSpareBit, skipFinalSub: static bool
) =
## Constant-time Montgomery reduction
static: doAssert UseASM_X86_64, "This requires x86-64."
redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipReduction)
redcMont_asm_inline(r, a, M, m0ninv, hasSpareBit, skipFinalSub)

# Montgomery conversion
# ----------------------------------------------------------
Expand Down Expand Up @@ -351,8 +351,8 @@ when isMainModule:
var a_sqr{.noInit.}, na_sqr{.noInit.}: Limbs[2]
var a_sqr_comba{.noInit.}, na_sqr_comba{.noInit.}: Limbs[2]

a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipReduction = false)
na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipReduction = false)
a_sqr.redcMont_asm(adbl_sqr, M, 1, hasSpareBit = false, skipFinalSub = false)
na_sqr.redcMont_asm(nadbl_sqr, M, 1, hasSpareBit = false, skipFinalSub = false)
a_sqr_comba.redc2xMont_Comba(adbl_sqr, M, 1)
na_sqr_comba.redc2xMont_Comba(nadbl_sqr, M, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ macro redc2xMont_adx_gen[N: static int](
a_PIR: array[N*2, SecretWord],
M_PIR: array[N, SecretWord],
m0ninv_REG: BaseType,
hasSpareBit, skipReduction: static bool
hasSpareBit, skipFinalSub: static bool
) =

# No register spilling handling
Expand Down Expand Up @@ -131,7 +131,7 @@ macro redc2xMont_adx_gen[N: static int](

let t = repackRegisters(v, u[N])

if hasSpareBit and skipReduction:
if hasSpareBit and skipFinalSub:
for i in 0 ..< N:
ctx.mov r[i], t[i]
elif hasSpareBit:
Expand All @@ -148,22 +148,22 @@ func redcMont_asm_adx_inline*[N: static int](
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool,
skipReduction: static bool = false
skipFinalSub: static bool = false
) {.inline.} =
## Constant-time Montgomery reduction
## Inline-version
redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipReduction)
redc2xMont_adx_gen(r, a, M, m0ninv, hasSpareBit, skipFinalSub)

func redcMont_asm_adx*[N: static int](
r: var array[N, SecretWord],
a: array[N*2, SecretWord],
M: array[N, SecretWord],
m0ninv: BaseType,
hasSpareBit: static bool,
skipReduction: static bool = false
skipFinalSub: static bool = false
) =
## Constant-time Montgomery reduction
redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipReduction)
redcMont_asm_adx_inline(r, a, M, m0ninv, hasSpareBit, skipFinalSub)


# Montgomery conversion
Expand Down
12 changes: 6 additions & 6 deletions constantine/arithmetic/bigints_montgomery.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import
#
# ############################################################

func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) =
func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: BaseType, spareBits: static int) =
## Convert a BigInt from its natural representation
## to the Montgomery residue form
##
Expand All @@ -41,7 +41,7 @@ func getMont*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, s
## and R = (2^WordBitWidth)^W
getMont(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits)

func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) =
func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: BaseType, spareBits: static int) =
## Convert a BigInt from its Montgomery residue form
## to the natural representation
##
Expand All @@ -52,20 +52,20 @@ func fromMont*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static
fromMont(r.limbs, a.limbs, M.limbs, m0ninv, spareBits)

func mulMont*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType,
spareBits: static int, skipReduction: static bool = false) =
spareBits: static int, skipFinalSub: static bool = false) =
## Compute r <- a*b (mod M) in the Montgomery domain
##
## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy
mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipReduction)
mulMont(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits, skipFinalSub)

func squareMont*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType,
spareBits: static int, skipReduction: static bool = false) =
spareBits: static int, skipFinalSub: static bool = false) =
## Compute r <- a^2 (mod M) in the Montgomery domain
##
## This resets r to zero before processing. Use {.noInit.}
## to avoid duplicating with Nim zero-init policy
squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipReduction)
squareMont(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits, skipFinalSub)

func powMont*[mBits: static int](
a: var BigInt[mBits], exponent: openarray[byte],
Expand Down
90 changes: 59 additions & 31 deletions constantine/arithmetic/finite_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,14 @@ func double*(r: var FF, a: FF) {.meter.} =
overflowed = overflowed or not(r.mres < FF.fieldMod())
discard csub(r.mres, FF.fieldMod(), overflowed)

func prod*(r: var FF, a, b: FF, skipReduction: static bool = false) {.meter.} =
func prod*(r: var FF, a, b: FF, skipFinalSub: static bool = false) {.meter.} =
## Store the product of ``a`` by ``b`` modulo p into ``r``
## ``r`` is initialized / overwritten
r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction)
r.mres.mulMont(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

func square*(r: var FF, a: FF, skipReduction: static bool = false) {.meter.} =
func square*(r: var FF, a: FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p
r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipReduction)
r.mres.squareMont(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub)

func neg*(r: var FF, a: FF) {.meter.} =
## Negate modulo p
Expand Down Expand Up @@ -413,38 +413,23 @@ func `*=`*(a: var FF, b: FF) {.meter.} =
## Multiplication modulo p
a.prod(a, b)

func square*(a: var FF, skipReduction: static bool = false) {.meter.} =
func square*(a: var FF, skipFinalSub: static bool = false) {.meter.} =
## Squaring modulo p
a.square(a, skipReduction)
a.square(a, skipFinalSub)

func square_repeated*(a: var FF, num: int, skipReduction: static bool = false) {.meter.} =
func square_repeated*(a: var FF, num: int, skipFinalSub: static bool = false) {.meter.} =
## Repeated squarings
# Except in Tonelli-Shanks, num is always known at compile-time
# and square repeated is inlined, so the compiler should optimize the branches away.

# TODO: understand the conditions to avoid the final substraction
for _ in 0 ..< num:
a.square(skipReduction = false)

func square_repeated*(r: var FF, a: FF, num: int, skipReduction: static bool = false) {.meter.} =
## Repeated squarings

# TODO: understand the conditions to avoid the final substraction
r.square(a)
for _ in 1 ..< num:
r.square()

func square_repeated_then_mul*(a: var FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} =
## Square `a`, `num` times and then multiply by b
## Assumes at least 1 squaring
a.square_repeated(num, skipReduction = false)
a.prod(a, b, skipReduction = skipReduction)
for _ in 0 ..< num-1:
a.square(skipFinalSub = true)
a.square(skipFinalSub)

func square_repeated_then_mul*(r: var FF, a: FF, num: int, b: FF, skipReduction: static bool = false) {.meter.} =
## Square `a`, `num` times and then multiply by b
## Assumes at least 1 squaring
r.square_repeated(a, num, skipReduction = false)
r.prod(r, b, skipReduction = skipReduction)
func square_repeated*(r: var FF, a: FF, num: int, skipFinalSub: static bool = false) {.meter.} =
## Repeated squarings
r.square(a, skipFinalSub = true)
for _ in 1 ..< num-1:
r.square(skipFinalSub = true)
r.square(skipFinalSub)

func `*=`*(a: var FF, b: static int) =
## Multiplication by a small integer known at compile-time
Expand Down Expand Up @@ -550,3 +535,46 @@ template mulCheckSparse*(a: var Fp, b: Fp) =

{.pop.} # inline
{.pop.} # raises no exceptions

# ############################################################
#
# Field arithmetic ergonomic macros
#
# ############################################################

import std/macros

macro addchain*(fn: untyped): untyped =
## Modify all prod, `*=`, square, square_repeated calls
## to skipFinalSub except the very last call.
## This assumes straight-line code.
fn.expectKind(nnkFuncDef)

result = fn
var body = newStmtList()

for i, statement in fn[^1]:
statement.expectKind({nnkCommentStmt, nnkVarSection, nnkCall, nnkInfix})

var s = statement.copyNimTree()
if i + 1 != result[^1].len:
# Modify all but the last
if s.kind == nnkCall:
doAssert s[0].kind == nnkDotExpr, "Only method call syntax or infix syntax is supported in addition chains"
doAssert s[0][1].eqIdent"prod" or s[0][1].eqIdent"square" or s[0][1].eqIdent"square_repeated"
s.add newLit(true)
elif s.kind == nnkInfix:
doAssert s[0].eqIdent"*="
# a *= b -> prod(a, a, b, true)
s = newCall(
bindSym"prod",
s[1],
s[1],
s[2],
newLit(true)
)

body.add s

result[^1] = body
# echo result.toStrLit()
7 changes: 4 additions & 3 deletions constantine/arithmetic/finite_fields_square_root.nim
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,14 @@ func invsqrt_tonelli_shanks_pre(
t.square(z)
t *= a
r = z
var b = t
var root = Fp.C.tonelliShanks(root_of_unity)
var b {.noInit.} = t
var root {.noInit.} = Fp.C.tonelliShanks(root_of_unity)

var buf {.noInit.}: Fp

for i in countdown(e, 2, 1):
b.square_repeated(i-2)
if i-2 >= 1:
b.square_repeated(i-2)

let bNotOne = not b.isOne()
buf.prod(r, root)
Expand Down
Loading