Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low-level refactoring #175

Merged
merged 15 commits into from
Feb 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmarks/bench_fields_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ proc sqrBench*(T: typedesc, iters: int) =
bench("Squaring", T, iters):
r.square(x)

proc toBigBench*(T: typedesc, iters: int) =
var r: matchingBigInt(T.C)
let x = rng.random_unsafe(T)
preventOptimAway(r)
bench("BigInt <- field conversion", T, iters):
r.fromField(x)

proc toFieldBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(matchingBigInt(T.C))
preventOptimAway(r)
bench("BigInt -> field conversion", T, iters):
r.fromBig(x)

proc invBench*(T: typedesc, iters: int) =
var r: T
let x = rng.random_unsafe(T)
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/bench_fp.nim
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ proc main() =
mulBench(Fp[curve], Iters)
sqrBench(Fp[curve], Iters)
smallSeparator()
toBigBench(Fp[curve], Iters)
toFieldBench(Fp[curve], Iters)
smallSeparator()
invBench(Fp[curve], ExponentIters)
sqrtBench(Fp[curve], ExponentIters)
sqrtRatioBench(Fp[curve], ExponentIters)
Expand Down
146 changes: 114 additions & 32 deletions constantine/arithmetic/assembly/limbs_asm_modular_x86.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,111 @@ import
# They are nice to let the compiler deals with mov
# but too constraining so we move things ourselves.

static: doAssert UseASM_X86_64
static: doAssert UseASM_X86_32

{.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers

proc finalSubNoCarryImpl*(
ctx: var Assembler_x86,
r: Operand or OperandArray,
a, M, scratch: OperandArray
) =
## Reduce `a` into `r` modulo `M`
##
## r, a, scratch, scratchReg are mutated
## M is read-only
let N = M.len
ctx.comment "Final substraction (no carry)"
for i in 0 ..< N:
ctx.mov scratch[i], a[i]
if i == 0:
ctx.sub scratch[i], M[i]
else:
ctx.sbb scratch[i], M[i]

# If we borrowed it means that we were smaller than
# the modulus and we don't need "scratch"
for i in 0 ..< N:
ctx.cmovnc a[i], scratch[i]
ctx.mov r[i], a[i]

proc finalSubMayCarryImpl*(
ctx: var Assembler_x86,
r: Operand or OperandArray,
a, M, scratch: OperandArray,
scratchReg: Operand or Register or OperandReuse
) =
## Reduce `a` into `r` modulo `M`
## To be used when the final substraction can
## also depend on the carry flag
##
## r, a, scratch, scratchReg are mutated
## M is read-only

ctx.comment "Final substraction (may carry)"

# Mask: scratchReg contains 0xFFFF or 0x0000
ctx.sbb scratchReg, scratchReg

# Now substract the modulus to test a < p
let N = M.len
for i in 0 ..< N:
ctx.mov scratch[i], a[i]
if i == 0:
ctx.sub scratch[i], M[i]
else:
ctx.sbb scratch[i], M[i]

# If it overflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
ctx.sbb scratchReg, 0

# If we borrowed it means that we were smaller than
# the modulus and we don't need "scratch"
for i in 0 ..< N:
ctx.cmovnc a[i], scratch[i]
ctx.mov r[i], a[i]

macro finalSub_gen*[N: static int](
r_PIR: var array[N, SecretWord],
a_EIR, M_PIR: array[N, SecretWord],
scratch_EIR: var array[N, SecretWord],
mayCarry: static bool): untyped =
## Returns:
## a-M if a > M
## a otherwise
##
## - r_PIR is a pointer to the result array, mutated,
## - a_EIR is an array of registers, mutated,
## - M_PIR is a pointer to an array, read-only,
## - scratch_EIR is an array of registers, mutated
## - mayCarry is set to true when the carry flag also needs to be read
result = newStmtList()

var ctx = init(Assembler_x86, BaseType)
let
r = init(OperandArray, nimSymbol = r_PIR, N, PointerInReg, InputOutput)
# We reuse the reg used for b for overflow detection
a = init(OperandArray, nimSymbol = a_EIR, N, ElemsInReg, InputOutput)
# We could force m as immediate by specializing per moduli
M = init(OperandArray, nimSymbol = M_PIR, N, PointerInReg, Input)
t = init(OperandArray, nimSymbol = scratch_EIR, N, ElemsInReg, Output_EarlyClobber)

if mayCarry:
ctx.finalSubMayCarryImpl(
r, a, M, t, rax
)
else:
ctx.finalSubNoCarryImpl(
r, a, M, t
)

result.add ctx.generate()

# Field addition
# ------------------------------------------------------------

macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped =
macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N], hasSpareBit: static bool): untyped =
## Generate an optimized modular addition kernel
# Register pressure note:
# We could generate a kernel per modulus m by hardcoding it as immediate
Expand Down Expand Up @@ -68,33 +165,18 @@ macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped =
# Interleaved copy in a second buffer as well
ctx.mov v[i], u[i]

# Mask: overflowed contains 0xFFFF or 0x0000
# TODO: unnecessary if MSB never set, i.e. "Field.getSpareBits >= 1"
let overflowed = b.reuseRegister()
ctx.sbb overflowed, overflowed
if hasSparebit:
ctx.finalSubNoCarryImpl(r, u, M, v)
else:
ctx.finalSubMayCarryImpl(
r, u, M, v, b.reuseRegister()
)

# Now substract the modulus to test a < p
for i in 0 ..< N:
if i == 0:
ctx.sub v[0], M[0]
else:
ctx.sbb v[i], M[i]

# If it overflows here, it means that it was
# smaller than the modulus and we don't need V
ctx.sbb overflowed, 0

# Conditional Mov and
# and store result
for i in 0 ..< N:
ctx.cmovnc u[i], v[i]
ctx.mov r[i], u[i]

result.add ctx.generate
result.add ctx.generate()

func addmod_asm*(r: var Limbs, a, b, m: Limbs) =
func addmod_asm*(r: var Limbs, a, b, m: Limbs, hasSpareBit: static bool) =
## Constant-time modular addition
addmod_gen(r, a, b, m)
addmod_gen(r, a, b, m, hasSpareBit)

# Field substraction
# ------------------------------------------------------------
Expand Down Expand Up @@ -225,7 +307,7 @@ when isMainModule:
debugecho " a: ", a.toHex()
debugecho " b: ", b.toHex()
debugecho " m: ", m.toHex()
addmod_asm(a, a, b, m)
addmod_asm(a, a, b, m, hasSpareBit = false)
debugecho "after:"
debugecho " a: ", a.toHex().tolower
debugecho " s: ", s
Expand All @@ -245,7 +327,7 @@ when isMainModule:
debugecho " a: ", a.toHex()
debugecho " b: ", b.toHex()
debugecho " m: ", m.toHex()
addmod_asm(a, a, b, m)
addmod_asm(a, a, b, m, hasSpareBit = false)
debugecho "after:"
debugecho " a: ", a.toHex().tolower
debugecho " s: ", s
Expand All @@ -265,7 +347,7 @@ when isMainModule:
debugecho " a: ", a.toHex()
debugecho " b: ", b.toHex()
debugecho " m: ", m.toHex()
addmod_asm(a, a, b, m)
addmod_asm(a, a, b, m, hasSpareBit = false)
debugecho "after:"
debugecho " a: ", a.toHex().tolower
debugecho " s: ", s
Expand All @@ -285,7 +367,7 @@ when isMainModule:
debugecho " a: ", a.toHex()
debugecho " b: ", b.toHex()
debugecho " m: ", m.toHex()
addmod_asm(a, a, b, m)
addmod_asm(a, a, b, m, hasSpareBit = false)
debugecho "after:"
debugecho " a: ", a.toHex().tolower
debugecho " s: ", s
Expand All @@ -308,7 +390,7 @@ when isMainModule:
debugecho " a: ", a.toHex()
debugecho " b: ", b.toHex()
debugecho " m: ", m.toHex()
submod_asm(a, a, b, m)
submod_asm(a, a, b, m, hasSpareBit = false)
debugecho "after:"
debugecho " a: ", a.toHex().tolower
debugecho " s: ", s
Expand All @@ -333,7 +415,7 @@ when isMainModule:
debugecho " a: ", a.toHex()
debugecho " b: ", b.toHex()
debugecho " m: ", m.toHex()
submod_asm(r, a, b, m)
submod_asm(r, a, b, m, hasSpareBit = false)
debugecho "after:"
debugecho " r: ", r.toHex().tolower
debugecho " s: ", s
Expand Down