diff --git a/crypto/bn256/gnark/g1.go b/crypto/bn256/gnark/g1.go index 2f933dd5360..59e04cb247f 100644 --- a/crypto/bn256/gnark/g1.go +++ b/crypto/bn256/gnark/g1.go @@ -1,6 +1,7 @@ package bn256 import ( + "errors" "math/big" "github.com/consensys/gnark-crypto/ecc/bn254" @@ -31,21 +32,62 @@ func (g *G1) ScalarMult(a *G1, scalar *big.Int) { // Unmarshal deserializes `buf` into `g` // -// Note: whether the deserialization is of a compressed -// or an uncompressed point, is encoded in the bytes. -// -// For our purpose, the point will always be serialized -// as uncompressed, ie 64 bytes. +// The input is expected to be in the EVM format: +// 64 bytes: [32-byte x coordinate][32-byte y coordinate] +// where each coordinate is in big-endian format. // // This method also checks whether the point is on the // curve and in the prime order subgroup. func (g *G1) Unmarshal(buf []byte) (int, error) { - return g.inner.SetBytes(buf) + if len(buf) < 64 { + return 0, errors.New("invalid G1 point size") + } + + if allZeroes(buf[:64]) { + // point at infinity + g.inner.X.SetZero() + g.inner.Y.SetZero() + return 64, nil + } + + if err := g.inner.X.SetBytesCanonical(buf[:32]); err != nil { + return 0, err + } + if err := g.inner.Y.SetBytesCanonical(buf[32:64]); err != nil { + return 0, err + } + + if !g.inner.IsOnCurve() { + return 0, errors.New("point is not on curve") + } + if !g.inner.IsInSubGroup() { + return 0, errors.New("point is not in correct subgroup") + } + return 64, nil } // Marshal serializes the point into a byte slice. // -// Note: The point is serialized as uncompressed. +// The output is in EVM format: 64 bytes total. +// [32-byte x coordinate][32-byte y coordinate] +// where each coordinate is a big-endian integer padded to 32 bytes. func (p *G1) Marshal() []byte { - return p.inner.Marshal() + output := make([]byte, 64) + + xBytes := p.inner.X.Bytes() + copy(output[:32], xBytes[:]) + + yBytes := p.inner.Y.Bytes() + copy(output[32:64], yBytes[:]) + + return output +} + +func allZeroes(buf []byte) bool { + for i := range buf { + if buf[i] != 0 { + return false + } + } + return true } diff --git a/crypto/bn256/gnark/g2.go b/crypto/bn256/gnark/g2.go index 205373a5919..07452cc2d87 100644 --- a/crypto/bn256/gnark/g2.go +++ b/crypto/bn256/gnark/g2.go @@ -1,6 +1,8 @@ package bn256 import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bn254" ) @@ -18,21 +20,66 @@ type G2 struct { // Unmarshal deserializes `buf` into `g` // -// Note: whether the deserialization is of a compressed -// or an uncompressed point, is encoded in the bytes. -// -// For our purpose, the point will always be serialized -// as uncompressed, ie 128 bytes. +// The input is expected to be in the EVM format: +// 128 bytes: [32-byte x.0][32-byte x.1][32-byte y.0][32-byte y.1] +// where each value is a big-endian integer. // // This method also checks whether the point is on the // curve and in the prime order subgroup. func (g *G2) Unmarshal(buf []byte) (int, error) { - return g.inner.SetBytes(buf) + if len(buf) < 128 { + return 0, errors.New("invalid G2 point size") + } + + if allZeroes(buf[:128]) { + // point at infinity + g.inner.X.A0.SetZero() + g.inner.X.A1.SetZero() + g.inner.Y.A0.SetZero() + g.inner.Y.A1.SetZero() + return 128, nil + } + if err := g.inner.X.A0.SetBytesCanonical(buf[0:32]); err != nil { + return 0, err + } + if err := g.inner.X.A1.SetBytesCanonical(buf[32:64]); err != nil { + return 0, err + } + if err := g.inner.Y.A0.SetBytesCanonical(buf[64:96]); err != nil { + return 0, err + } + if err := g.inner.Y.A1.SetBytesCanonical(buf[96:128]); err != nil { + return 0, err + } + + if !g.inner.IsOnCurve() { + return 0, errors.New("point is not on curve") + } + if !g.inner.IsInSubGroup() { + return 0, errors.New("point is not in correct subgroup") + } + return 128, nil } // Marshal serializes the point into a byte slice. // -// Note: The point is serialized as uncompressed. +// The output is in EVM format: 128 bytes total. +// [32-byte x.0][32-byte x.1][32-byte y.0][32-byte y.1] +// where each value is a big-endian integer. func (g *G2) Marshal() []byte { - return g.inner.Marshal() + output := make([]byte, 128) + + xA0Bytes := g.inner.X.A0.Bytes() + copy(output[:32], xA0Bytes[:]) + + xA1Bytes := g.inner.X.A1.Bytes() + copy(output[32:64], xA1Bytes[:]) + + yA0Bytes := g.inner.Y.A0.Bytes() + copy(output[64:96], yA0Bytes[:]) + + yA1Bytes := g.inner.Y.A1.Bytes() + copy(output[96:128], yA1Bytes[:]) + + return output } diff --git a/crypto/bn256/gnark/native_format_test.go b/crypto/bn256/gnark/native_format_test.go new file mode 100644 index 00000000000..e2b67449321 --- /dev/null +++ b/crypto/bn256/gnark/native_format_test.go @@ -0,0 +1,42 @@ +package bn256 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254" +) + +func TestNativeGnarkFormatIncompatibility(t *testing.T) { + // Use official gnark serialization + _, _, g1Gen, _ := bn254.Generators() + wrongSer := g1Gen.Bytes() + + var evmG1 G1 + _, err := evmG1.Unmarshal(wrongSer[:]) + if err == nil { + t.Fatalf("points serialized using the official bn254 serialization algorithm, should not work with the evm format") + } +} + +func TestSerRoundTrip(t *testing.T) { + _, _, g1Gen, g2Gen := bn254.Generators() + + expectedG1 := G1{inner: g1Gen} + bytesG1 := expectedG1.Marshal() + + expectedG2 := G2{inner: g2Gen} + bytesG2 := expectedG2.Marshal() + + var gotG1 G1 + gotG1.Unmarshal(bytesG1) + + var gotG2 G2 + gotG2.Unmarshal(bytesG2) + + if !expectedG1.inner.Equal(&gotG1.inner) { + t.Errorf("serialization roundtrip failed for G1") + } + if !expectedG2.inner.Equal(&gotG2.inner) { + t.Errorf("serialization roundtrip failed for G2") + } +}