Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

precomp: use normalized extended points #59

Merged
merged 6 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions bandersnatch/bandersnatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@ import (
"io"

gnarkbandersnatch "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch"
gnarkfr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: note that although this is gnarkfr (bls12-381), its fp for bandersnatch, so maybe gnarkfp or something along those lines would be less confusing since we use fp already in this file

"github.com/crate-crypto/go-ipa/bandersnatch/fp"
)

var CurveParams = gnarkbandersnatch.GetEdwardsCurve()

type PointAffine = gnarkbandersnatch.PointAffine
type PointProj = gnarkbandersnatch.PointProj
type PointExtended = gnarkbandersnatch.PointExtended

var Identity = PointProj{
X: fp.Zero(),
Y: fp.One(),
Z: fp.One(),
}

var IdentityExt = PointExtendedFromProj(&Identity)

// Reads an uncompressed affine point
// Point is not guaranteed to be in the prime subgroup
func ReadUncompressedPoint(r io.Reader) (PointAffine, error) {
Expand Down Expand Up @@ -92,3 +96,62 @@ func computeY(x *fp.Element, choose_largest bool) *fp.Element {
return sqrtY.Neg(sqrtY)
}
}

// PointExtendedFromProj converts a point in projective coordinates to extended coordinates.
func PointExtendedFromProj(p *PointProj) PointExtended {
var pzinv fp.Element
pzinv.Inverse(&p.Z)
var z fp.Element
z.Mul(&p.X, &p.Y).Mul(&z, &pzinv)
return PointExtended{
X: p.X,
Y: p.Y,
Z: p.Z,
T: z,
}
}
kevaundray marked this conversation as resolved.
Show resolved Hide resolved

// PointExtendedNormalized is an extended point which is normalized.
// i.e: Z=1. We store it this way to save 32 bytes per point in memory.
type PointExtendedNormalized struct {
X, Y, T gnarkfr.Element
}
kevaundray marked this conversation as resolved.
Show resolved Hide resolved

// Neg computes p = -p1
func (p *PointExtendedNormalized) Neg(p1 *PointExtendedNormalized) *PointExtendedNormalized {
p.X.Neg(&p1.X)
p.Y = p1.Y
p.T.Neg(&p1.T)
return p
}
kevaundray marked this conversation as resolved.
Show resolved Hide resolved

// ExtendedAddNormalized computes p = p1 + p2.
// https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html#addition-madd-2008-hwcd
func ExtendedAddNormalized(p, p1 *PointExtended, p2 *PointExtendedNormalized) *gnarkbandersnatch.PointExtended {
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
var A, B, C, D, E, F, G, H, tmp gnarkfr.Element
A.Mul(&p1.X, &p2.X)
B.Mul(&p1.Y, &p2.Y)
C.Mul(&p1.T, &p2.T).Mul(&C, &CurveParams.D)
D.Set(&p1.Z)
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
tmp.Add(&p1.X, &p1.Y)
E.Add(&p2.X, &p2.Y).
Mul(&E, &tmp).
Sub(&E, &A).
Sub(&E, &B)
F.Sub(&D, &C)
G.Add(&D, &C)
H.Set(&A)

// mulBy5(&H)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm guessing this comment can be removed or put one line down

H.Neg(&H)
gnarkfr.MulBy5(&H)
kevaundray marked this conversation as resolved.
Show resolved Hide resolved

H.Sub(&B, &H)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this match the referenced link including the letters used :)


p.X.Mul(&E, &F)
p.Y.Mul(&G, &H)
p.T.Mul(&E, &H)
p.Z.Mul(&F, &G)

return p
}
78 changes: 62 additions & 16 deletions banderwagon/precomp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,24 @@ func NewPrecompMSM(points []Element) (MSMPrecomp, error) {
// MSM calculates the 256-MSM of the given scalars on the fixed basis.
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
// It automatically detects how many non-zero scalars there are and parallelizes the computation.
func (msm *MSMPrecomp) MSM(scalars []fr.Element) Element {
result := Identity.inner
result := bandersnatch.IdentityExt

for i := range scalars {
if !scalars[i].IsZero() {
msm.precompPoints[i].ScalarMul(scalars[i], &result)
}
}
return Element{inner: result}
return Element{inner: bandersnatch.PointProj{
X: result.X,
Y: result.Y,
Z: result.Z,
}}
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}

// PrecompPoint is a precomputed table for a single point.
type PrecompPoint struct {
windowSize int
windows [][]bandersnatch.PointAffine
windows [][]bandersnatch.PointExtendedNormalized
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}

// NewPrecompPoint creates a new PrecompPoint for the given point and window size.
Expand All @@ -88,27 +92,23 @@ func NewPrecompPoint(point Element, windowSize int) (PrecompPoint, error) {

res := PrecompPoint{
windowSize: windowSize,
windows: make([][]bandersnatch.PointAffine, 256/windowSize),
windows: make([][]bandersnatch.PointExtendedNormalized, 256/windowSize),
}

windows := make([][]bandersnatch.PointProj, 256/windowSize)
windows := make([][]bandersnatch.PointExtended, 256/windowSize)
group, _ := errgroup.WithContext(context.Background())
group.SetLimit(runtime.NumCPU())
for i := 0; i < len(res.windows); i++ {
i := i
base := point.inner
base := bandersnatch.PointExtendedFromProj(&point.inner)
group.Go(func() error {
windows[i] = make([]bandersnatch.PointProj, 1<<(windowSize-1))
windows[i] = make([]bandersnatch.PointExtended, 1<<(windowSize-1))
curr := base
for j := 0; j < len(windows[i]); j++ {
windows[i][j] = curr
curr.Add(&curr, &base)
}
batchProjToAffine(windows[i])
res.windows[i] = make([]bandersnatch.PointAffine, 1<<(windowSize-1))
for j := range windows[i] {
res.windows[i][j].FromProj(&windows[i][j])
}
res.windows[i] = batchToExtendedPointNormalized(windows[i])
return nil
})
point.ScalarMul(&point, &specialWindow)
Expand All @@ -121,12 +121,12 @@ func NewPrecompPoint(point Element, windowSize int) (PrecompPoint, error) {
// ScalarMul multiplies the point by the given scalar using the precomputed points.
// It applies a trick to push a carry between windows since our precomputed tables
// avoid storing point inverses.
func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointProj) {
func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointExtended) {
numWindowsInLimb := 64 / pp.windowSize

scalar.FromMont()
var carry uint64
var pNeg bandersnatch.PointAffine
var pNeg bandersnatch.PointExtendedNormalized
for l := 0; l < fr.Limbs; l++ {
for w := 0; w < numWindowsInLimb; w++ {
windowValue := (scalar[l]>>(pp.windowSize*w))&((1<<pp.windowSize)-1) + carry
Expand All @@ -139,11 +139,11 @@ func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointProj
windowValue = (1 << pp.windowSize) - windowValue
if windowValue != 0 {
pNeg.Neg(&pp.windows[l*numWindowsInLimb+w][windowValue-1])
res.MixedAdd(res, &pNeg)
bandersnatch.ExtendedAddNormalized(res, res, &pNeg)
}
carry = 1
} else {
res.MixedAdd(res, &pp.windows[l*numWindowsInLimb+w][windowValue-1])
bandersnatch.ExtendedAddNormalized(res, res, &pp.windows[l*numWindowsInLimb+w][windowValue-1])
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -195,3 +195,49 @@ func batchProjToAffine(points []bandersnatch.PointProj) []bandersnatch.PointAffi

return result
}

func batchToExtendedPointNormalized(points []bandersnatch.PointExtended) []bandersnatch.PointExtendedNormalized {
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
result := make([]bandersnatch.PointExtendedNormalized, len(points))
zeroes := make([]bool, len(points))
accumulator := fp.One()

// batch invert all points[].Z coordinates with Montgomery batch inversion trick
// (stores points[].Z^-1 in result[i].X to avoid allocating a slice of fr.Elements)
for i := 0; i < len(points); i++ {
if points[i].Z.IsZero() {
zeroes[i] = true
continue
}
result[i].X = accumulator
accumulator.Mul(&accumulator, &points[i].Z)
}

var accInverse fp.Element
accInverse.Inverse(&accumulator)

for i := len(points) - 1; i >= 0; i-- {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}
result[i].X.Mul(&result[i].X, &accInverse)
accInverse.Mul(&accInverse, &points[i].Z)
}

// batch convert to affine.
parallel.Execute(len(points), func(start, end int) {
for i := start; i < end; i++ {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}

a := result[i].X
result[i].X.Mul(&points[i].X, &a)
result[i].Y.Mul(&points[i].Y, &a)
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
result[i].T.Mul(&result[i].X, &result[i].Y)
kevaundray marked this conversation as resolved.
Show resolved Hide resolved
}
})

return result
}