Skip to content
This repository has been archived by the owner on Sep 8, 2022. It is now read-only.

Commit

Permalink
feat: add batching for bulletproofs (#58)
Browse files Browse the repository at this point in the history
* add method for getting blinding factors

* add getter for blinding factors in bbs+ sigs

* add batch verification & update bulletproofs

* get rid of blinding helper

* update comments
  • Loading branch information
cbdnyu committed Jun 2, 2022
1 parent 269410e commit 60eddc5
Show file tree
Hide file tree
Showing 14 changed files with 864 additions and 131 deletions.
2 changes: 1 addition & 1 deletion pkg/bulletproof/generators.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/coinbase/kryptology/pkg/core/curves"
)

// generators contains a list of points to be used as generators for bulletproofs
// generators contains a list of points to be used as generators for bulletproofs.
type generators []curves.Point

// ippGenerators holds generators necessary for an Inner Product Proof
Expand Down
12 changes: 6 additions & 6 deletions pkg/bulletproof/generators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/crypto/sha3"

"github.com/coinbase/kryptology/pkg/core/curves"
)
Expand All @@ -30,14 +31,13 @@ func TestGeneratorsUniquePerDomain(t *testing.T) {
}

func noDuplicates(gs generators) bool {
var seen []curves.Point
seen := map[[32]byte]bool{}
for _, G := range gs {
for _, seenG := range seen {
if seenG.Equal(G) {
return false
}
value := sha3.Sum256(G.ToAffineCompressed())
if seen[value] {
return false
}
seen = append(seen, G)
seen[value] = true
}
return true
}
Expand Down
41 changes: 20 additions & 21 deletions pkg/bulletproof/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/coinbase/kryptology/pkg/core/curves"
)

// innerProduct takes two lists of scalars (a, b) and performs the dot product returning a single scalar
// innerProduct takes two lists of scalars (a, b) and performs the dot product returning a single scalar.
func innerProduct(a, b []curves.Scalar) (curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of scalar vectors must be the same")
Expand All @@ -31,7 +31,7 @@ func innerProduct(a, b []curves.Scalar) (curves.Scalar, error) {
return innerProduct, nil
}

// splitPointVector takes a vector of points, splits it in half returning each half
// splitPointVector takes a vector of points, splits it in half returning each half.
func splitPointVector(points []curves.Point) ([]curves.Point, []curves.Point, error) {
if len(points) < 1 {
return nil, nil, errors.New("length of points must be at least one")
Expand All @@ -45,7 +45,7 @@ func splitPointVector(points []curves.Point) ([]curves.Point, []curves.Point, er
return firstHalf, secondHalf, nil
}

// splitScalarVector takes a vector of scalars, splits it in half returning each half
// splitScalarVector takes a vector of scalars, splits it in half returning each half.
func splitScalarVector(scalars []curves.Scalar) ([]curves.Scalar, []curves.Scalar, error) {
if len(scalars) < 1 {
return nil, nil, errors.New("length of scalars must be at least one")
Expand All @@ -59,7 +59,7 @@ func splitScalarVector(scalars []curves.Scalar) ([]curves.Scalar, []curves.Scala
return firstHalf, secondHalf, nil
}

// multiplyScalarToPointVector takes a single scalar and a list of points, multiplies each point by scalar
// multiplyScalarToPointVector takes a single scalar and a list of points, multiplies each point by scalar.
func multiplyScalarToPointVector(x curves.Scalar, g []curves.Point) []curves.Point {
products := make([]curves.Point, len(g))
for i, gElem := range g {
Expand All @@ -70,7 +70,7 @@ func multiplyScalarToPointVector(x curves.Scalar, g []curves.Point) []curves.Poi
return products
}

// multiplyScalarToScalarVector takes a single scalar (x) and a list of scalars (a), multiplies each scalar in the vector by the scalar
// multiplyScalarToScalarVector takes a single scalar (x) and a list of scalars (a), multiplies each scalar in the vector by the scalar.
func multiplyScalarToScalarVector(x curves.Scalar, a []curves.Scalar) []curves.Scalar {
products := make([]curves.Scalar, len(a))
for i, aElem := range a {
Expand All @@ -81,7 +81,7 @@ func multiplyScalarToScalarVector(x curves.Scalar, a []curves.Scalar) []curves.S
return products
}

// multiplyPairwisePointVectors takes two lists of points (g, h) and performs a pairwise multiplication returning a list of points
// multiplyPairwisePointVectors takes two lists of points (g, h) and performs a pairwise multiplication returning a list of points.
func multiplyPairwisePointVectors(g, h []curves.Point) ([]curves.Point, error) {
if len(g) != len(h) {
return nil, errors.New("length of point vectors must be the same")
Expand All @@ -94,7 +94,7 @@ func multiplyPairwisePointVectors(g, h []curves.Point) ([]curves.Point, error) {
return product, nil
}

// multiplyPairwiseScalarVectors takes two lists of points (a, b) and performs a pairwise multiplication returning a list of scalars
// multiplyPairwiseScalarVectors takes two lists of points (a, b) and performs a pairwise multiplication returning a list of scalars.
func multiplyPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of point vectors must be the same")
Expand All @@ -107,7 +107,7 @@ func multiplyPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error
return product, nil
}

// addPairwiseScalarVectors takes two lists of scalars (a, b) and performs a pairwise addition returning a list of scalars
// addPairwiseScalarVectors takes two lists of scalars (a, b) and performs a pairwise addition returning a list of scalars.
func addPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of scalar vectors must be the same")
Expand All @@ -120,7 +120,7 @@ func addPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
return sum, nil
}

// subtractPairwiseScalarVectors takes two lists of scalars (a, b) and performs a pairwise subtraction returning a list of scalars
// subtractPairwiseScalarVectors takes two lists of scalars (a, b) and performs a pairwise subtraction returning a list of scalars.
func subtractPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error) {
if len(a) != len(b) {
return nil, errors.New("length of scalar vectors must be the same")
Expand All @@ -132,7 +132,7 @@ func subtractPairwiseScalarVectors(a, b []curves.Scalar) ([]curves.Scalar, error
return diff, nil
}

// invertScalars takes a list of scalars then returns a list with each element inverted
// invertScalars takes a list of scalars then returns a list with each element inverted.
func invertScalars(xs []curves.Scalar) ([]curves.Scalar, error) {
xinvs := make([]curves.Scalar, len(xs))
for i, x := range xs {
Expand All @@ -146,36 +146,35 @@ func invertScalars(xs []curves.Scalar) ([]curves.Scalar, error) {
return xinvs, nil
}

// isPowerOfTwo returns whether a number i is a power of two or not
// isPowerOfTwo returns whether a number i is a power of two or not.
func isPowerOfTwo(i int) bool {
return i&(i-1) == 0
}

// get2nVector returns a scalar vector 2^n such that [1, 2, 4, ... 2^(n-1)]
// See k^n and 2^n definitions on pg 12 of https://eprint.iacr.org/2017/1066.pdf
func get2nVector(len int, curve curves.Curve) []curves.Scalar {
vector2n := make([]curves.Scalar, len)
func get2nVector(length int, curve curves.Curve) []curves.Scalar {
vector2n := make([]curves.Scalar, length)
vector2n[0] = curve.Scalar.One()
vector2n[1] = vector2n[0].Double()
for i := 2; i < len; i++ {
for i := 1; i < length; i++ {
vector2n[i] = vector2n[i-1].Double()
}
return vector2n
}

func get1nVector(len int, curve curves.Curve) []curves.Scalar {
vector1n := make([]curves.Scalar, len)
for i := 0; i < len; i++ {
func get1nVector(length int, curve curves.Curve) []curves.Scalar {
vector1n := make([]curves.Scalar, length)
for i := 0; i < length; i++ {
vector1n[i] = curve.Scalar.One()
}
return vector1n
}

func getknVector(k curves.Scalar, len int, curve curves.Curve) []curves.Scalar {
vectorkn := make([]curves.Scalar, len)
func getknVector(k curves.Scalar, length int, curve curves.Curve) []curves.Scalar {
vectorkn := make([]curves.Scalar, length)
vectorkn[0] = curve.Scalar.One()
vectorkn[1] = k
for i := 2; i < len; i++ {
for i := 2; i < length; i++ {
vectorkn[i] = vectorkn[i-1].Mul(k)
}
return vectorkn
Expand Down
16 changes: 8 additions & 8 deletions pkg/bulletproof/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,18 @@ func TestSplitListofPointsZeroLength(t *testing.T) {
require.Error(t, err)
}

func randScalarVec(len int, curve curves.Curve) []curves.Scalar {
var out []curves.Scalar
for i := 0; i < len; i++ {
out = append(out, curve.Scalar.Random(crand.Reader))
func randScalarVec(length int, curve curves.Curve) []curves.Scalar {
out := make([]curves.Scalar, length)
for i := 0; i < length; i++ {
out[i] = curve.Scalar.Random(crand.Reader)
}
return out
}

func randPointVec(len int, curve curves.Curve) []curves.Point {
var out []curves.Point
for i := 0; i < len; i++ {
out = append(out, curve.Point.Random(crand.Reader))
func randPointVec(length int, curve curves.Curve) []curves.Point {
out := make([]curves.Point, length)
for i := 0; i < length; i++ {
out[i] = curve.Point.Random(crand.Reader)
}
return out
}
38 changes: 20 additions & 18 deletions pkg/bulletproof/ipp_prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

// InnerProductProver is the struct used to create InnerProductProofs
// It specifies which curve to use and holds precomputed generators
// See NewInnerProductProver() for prover initialization
// See NewInnerProductProver() for prover initialization.
type InnerProductProver struct {
curve curves.Curve
generators ippGenerators
Expand Down Expand Up @@ -47,7 +47,7 @@ type ippRecursion struct {
// NewInnerProductProver initializes a new prover
// It uses the specified domain to generate generators for vectors of at most maxVectorLength
// A prover can be used to construct inner product proofs for vectors of length less than or equal to maxVectorLength
// A prover is defined by an explicit curve
// A prover is defined by an explicit curve.
func NewInnerProductProver(maxVectorLength int, domain []byte, curve curves.Curve) (*InnerProductProver, error) {
generators, err := getGeneratorPoints(maxVectorLength, domain, curve)
if err != nil {
Expand All @@ -57,7 +57,7 @@ func NewInnerProductProver(maxVectorLength int, domain []byte, curve curves.Curv
}

// NewInnerProductProof initializes a new InnerProductProof for a specified curve
// This should be used in tandem with UnmarshalBinary() to convert a marshaled proof into the struct
// This should be used in tandem with UnmarshalBinary() to convert a marshaled proof into the struct.
func NewInnerProductProof(curve *curves.Curve) *InnerProductProof {
var capLs, capRs []curves.Point
newProof := InnerProductProof{
Expand All @@ -73,7 +73,7 @@ func NewInnerProductProof(curve *curves.Curve) *InnerProductProof {
// rangeToIPP takes the output of a range proof and converts it into an inner product proof
// See section 4.2 on pg 20
// The conversion specifies generators to use (g and hPrime), as well as the two vectors l, r of which the inner product is tHat
// Additionally, note that the P used for the IPP is in fact P*h^-mu from the range proof
// Additionally, note that the P used for the IPP is in fact P*h^-mu from the range proof.
func (prover *InnerProductProver) rangeToIPP(proofG, proofH []curves.Point, l, r []curves.Scalar, tHat curves.Scalar, capPhmuinv, u curves.Point, transcript *merlin.Transcript) (*InnerProductProof, error) {
// Note that P as a witness is only g^l * h^r
// P needs to be in the form of g^l * h^r * u^<l,r>
Expand Down Expand Up @@ -131,7 +131,7 @@ func (prover *InnerProductProver) getP(a, b []curves.Scalar, u curves.Point) (cu

// Prove executes the prover protocol on pg 16 of https://eprint.iacr.org/2017/1066.pdf
// It generates an inner product proof for vectors a and b, using u to blind the inner product in P
// A transcript is used for the Fiat Shamir heuristic
// A transcript is used for the Fiat Shamir heuristic.
func (prover *InnerProductProver) Prove(a, b []curves.Scalar, u curves.Point, transcript *merlin.Transcript) (*InnerProductProof, error) {
// Vectors must have length power of two
if !isPowerOfTwo(len(a)) {
Expand Down Expand Up @@ -241,8 +241,10 @@ func (prover *InnerProductProver) proveRecursive(recursionParams *ippRecursion)
capR := rga.Add(rhb).Add(ucR)

// Add L,R for verifier to use to calculate final g, h
newL := append(recursionParams.capLs, capL)
newR := append(recursionParams.capRs, capR)
newL := recursionParams.capLs
newL = append(newL, capL)
newR := recursionParams.capRs
newR = append(newR, capR)

// Get x from L, R for non-interactive (See section 4.4 pg22 of https://eprint.iacr.org/2017/1066.pdf)
// Note this replaces the interactive model, i.e. L36-28 of pg16 of https://eprint.iacr.org/2017/1066.pdf
Expand Down Expand Up @@ -325,10 +327,10 @@ func (prover *InnerProductProver) proveRecursive(recursionParams *ippRecursion)
// For each recursion, it takes the current state of the transcript and appends the newly calculated L and R values
// A new scalar is then read from the transcript
// See section 4.4 pg22 of https://eprint.iacr.org/2017/1066.pdf
func (prover *InnerProductProver) calcx(L, R curves.Point, transcript *merlin.Transcript) (curves.Scalar, error) {
// Add the newest L and R values to transcript
transcript.AppendMessage([]byte("addRecursiveL"), L.ToAffineUncompressed())
transcript.AppendMessage([]byte("addRecursiveR"), R.ToAffineUncompressed())
func (prover *InnerProductProver) calcx(capL, capR curves.Point, transcript *merlin.Transcript) (curves.Scalar, error) {
// Add the newest capL and capR values to transcript
transcript.AppendMessage([]byte("addRecursiveL"), capL.ToAffineUncompressed())
transcript.AppendMessage([]byte("addRecursiveR"), capR.ToAffineUncompressed())
// Read 64 bytes from, set to scalar
outBytes := transcript.ExtractBytes([]byte("getx"), 64)
x, err := prover.curve.NewScalar().SetBytesWide(outBytes)
Expand All @@ -339,7 +341,7 @@ func (prover *InnerProductProver) calcx(L, R curves.Point, transcript *merlin.Tr
return x, nil
}

// MarshalBinary takes an inner product proof and marshals into bytes
// MarshalBinary takes an inner product proof and marshals into bytes.
func (proof *InnerProductProof) MarshalBinary() []byte {
var out []byte
out = append(out, proof.a.Bytes()...)
Expand All @@ -353,36 +355,36 @@ func (proof *InnerProductProof) MarshalBinary() []byte {
}

// UnmarshalBinary takes bytes of a marshaled proof and writes them into an inner product proof
// The inner product proof used should be from the output of NewInnerProductProof()
// The inner product proof used should be from the output of NewInnerProductProof().
func (proof *InnerProductProof) UnmarshalBinary(data []byte) error {
scalarLen := len(proof.curve.NewScalar().Bytes())
pointLen := len(proof.curve.NewGeneratorPoint().ToAffineCompressed())
ptr := 0
// Get scalars
a, err := proof.curve.NewScalar().SetBytes(data[ptr : ptr+scalarLen])
if err != nil {
return errors.New("InnerProductProof UnmarshalBinary SetBytes")
return errors.New("innerProductProof UnmarshalBinary SetBytes")
}
proof.a = a
ptr += scalarLen
b, err := proof.curve.NewScalar().SetBytes(data[ptr : ptr+scalarLen])
if err != nil {
return errors.New("InnerProductProof UnmarshalBinary SetBytes")
return errors.New("innerProductProof UnmarshalBinary SetBytes")
}
proof.b = b
ptr += scalarLen
// Get points
var capLs, capRs []curves.Point
var capLs, capRs []curves.Point //nolint:prealloc // pointer arithmetic makes it too unreadable.
for ptr < len(data) {
capLElem, err := proof.curve.Point.FromAffineCompressed(data[ptr : ptr+pointLen])
if err != nil {
return errors.New("InnerProductProof UnmarshalBinary FromAffineCompressed")
return errors.New("innerProductProof UnmarshalBinary FromAffineCompressed")
}
capLs = append(capLs, capLElem)
ptr += pointLen
capRElem, err := proof.curve.Point.FromAffineCompressed(data[ptr : ptr+pointLen])
if err != nil {
return errors.New("InnerProductProof UnmarshalBinary FromAffineCompressed")
return errors.New("innerProductProof UnmarshalBinary FromAffineCompressed")
}
capRs = append(capRs, capRElem)
ptr += pointLen
Expand Down
28 changes: 14 additions & 14 deletions pkg/bulletproof/ipp_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// InnerProductVerifier is the struct used to verify inner product proofs
// It specifies which curve to use and holds precomputed generators
// See NewInnerProductProver() for prover initialization
// See NewInnerProductProver() for prover initialization.
type InnerProductVerifier struct {
curve curves.Curve
generators ippGenerators
Expand All @@ -18,7 +18,7 @@ type InnerProductVerifier struct {
// NewInnerProductVerifier initializes a new verifier
// It uses the specified domain to generate generators for vectors of at most maxVectorLength
// A verifier can be used to verify inner product proofs for vectors of length less than or equal to maxVectorLength
// A verifier is defined by an explicit curve
// A verifier is defined by an explicit curve.
func NewInnerProductVerifier(maxVectorLength int, domain []byte, curve curves.Curve) (*InnerProductVerifier, error) {
generators, err := getGeneratorPoints(maxVectorLength, domain, curve)
if err != nil {
Expand Down Expand Up @@ -97,9 +97,9 @@ func (verifier *InnerProductVerifier) VerifyFromRangeProof(proofG, proofH []curv
return lhs.Equal(rhs), nil
}

// getRHS gets the right hand side of the final comparison of section 3.1 on pg17
func (verifier *InnerProductVerifier) getRHS(P curves.Point, proof *InnerProductProof, xs []curves.Scalar) (curves.Point, error) {
product := P
// getRHS gets the right hand side of the final comparison of section 3.1 on pg17.
func (*InnerProductVerifier) getRHS(capP curves.Point, proof *InnerProductProof, xs []curves.Scalar) (curves.Point, error) {
product := capP
for j, Lj := range proof.capLs {
Rj := proof.capRs[j]
xj := xs[j]
Expand All @@ -115,7 +115,7 @@ func (verifier *InnerProductVerifier) getRHS(P curves.Point, proof *InnerProduct
return product, nil
}

// getLHS gets the left hand side of the final comparison of section 3.1 on pg17
// getLHS gets the left hand side of the final comparison of section 3.1 on pg17.
func (verifier *InnerProductVerifier) getLHS(u curves.Point, proof *InnerProductProof, g, h []curves.Point, s []curves.Scalar) (curves.Point, error) {
sInv, err := invertScalars(s)
if err != nil {
Expand All @@ -138,14 +138,14 @@ func (verifier *InnerProductVerifier) getLHS(u curves.Point, proof *InnerProduct

// getxs calculates the x values from Ls and Rs
// Note that each x is read from the transcript, then the L and R at a certain index are written to the transcript
// This mirrors the reading of xs and writing of Ls and Rs in the prover
func getxs(transcript *merlin.Transcript, Ls, Rs []curves.Point, curve curves.Curve) ([]curves.Scalar, error) {
xs := make([]curves.Scalar, len(Ls))
for i, Li := range Ls {
Ri := Rs[i]
// This mirrors the reading of xs and writing of Ls and Rs in the prover.
func getxs(transcript *merlin.Transcript, capLs, capRs []curves.Point, curve curves.Curve) ([]curves.Scalar, error) {
xs := make([]curves.Scalar, len(capLs))
for i, capLi := range capLs {
capRi := capRs[i]
// Add the newest L and R values to transcript
transcript.AppendMessage([]byte("addRecursiveL"), Li.ToAffineUncompressed())
transcript.AppendMessage([]byte("addRecursiveR"), Ri.ToAffineUncompressed())
transcript.AppendMessage([]byte("addRecursiveL"), capLi.ToAffineUncompressed())
transcript.AppendMessage([]byte("addRecursiveR"), capRi.ToAffineUncompressed())
// Read 64 bytes from, set to scalar
outBytes := transcript.ExtractBytes([]byte("getx"), 64)
x, err := curve.NewScalar().SetBytesWide(outBytes)
Expand Down Expand Up @@ -185,7 +185,7 @@ func (verifier *InnerProductVerifier) gets(xs []curves.Scalar, n int) ([]curves.
// getsNew calculates the vector s of values used for verification
// It provides analogous functionality as gets(), but uses a O(n) algorithm vs O(nlogn)
// The algorithm inverts all xs, then begins multiplying the inversion by the square of x elements to
// calculate all s values thus minimizing necessary inversions/ computation
// calculate all s values thus minimizing necessary inversions/ computation.
func (verifier *InnerProductVerifier) getsNew(xs []curves.Scalar, n int) ([]curves.Scalar, error) {
var err error
ss := make([]curves.Scalar, n)
Expand Down

0 comments on commit 60eddc5

Please sign in to comment.