Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Includes test vectors for hashing to curve #239

Merged
merged 1 commit into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions group/expander.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package group

import (
"crypto"
"encoding/binary"
"errors"
"io"

"github.com/cloudflare/circl/xof"
)

type Expander interface {
// Expand generates a pseudo-random byte string of a determined length by
// expanding an input string.
Expand(in []byte, length uint) (pseudo []byte)
}

type expanderMD struct {
h crypto.Hash
dst []byte
}

// NewExpanderMD returns a hash function based on a Merkle-Damgård hash function.
func NewExpanderMD(h crypto.Hash, dst []byte) *expanderMD {
return &expanderMD{h, dst}
}

func (e *expanderMD) calcDSTPrime() []byte {
var dstPrime []byte
if l := len(e.dst); l > maxDSTLength {
H := e.h.New()
mustWrite(H, longDSTPrefix[:])
mustWrite(H, e.dst)
dstPrime = H.Sum(nil)
} else {
dstPrime = make([]byte, l, l+1)
copy(dstPrime, e.dst)
}
return append(dstPrime, byte(len(dstPrime)))
}

func (e *expanderMD) Expand(in []byte, n uint) []byte {
H := e.h.New()
bLen := uint(H.Size())
ell := (n + (bLen - 1)) / bLen
if ell > 255 {
panic(errorLongOutput)
}

zPad := make([]byte, H.BlockSize())
libStr := []byte{0, 0}
libStr[0] = byte((n >> 8) & 0xFF)
libStr[1] = byte(n & 0xFF)
dstPrime := e.calcDSTPrime()

H.Reset()
mustWrite(H, zPad)
mustWrite(H, in)
mustWrite(H, libStr)
mustWrite(H, []byte{0})
mustWrite(H, dstPrime)
b0 := H.Sum(nil)

H.Reset()
mustWrite(H, b0)
mustWrite(H, []byte{1})
mustWrite(H, dstPrime)
bi := H.Sum(nil)
pseudo := append([]byte{}, bi...)
for i := uint(2); i <= ell; i++ {
H.Reset()
for i := range b0 {
bi[i] ^= b0[i]
}
mustWrite(H, bi)
mustWrite(H, []byte{byte(i)})
mustWrite(H, dstPrime)
bi = H.Sum(nil)
pseudo = append(pseudo, bi...)
}
return pseudo[0:n]
}

// expanderXOF is based on an extendable output function.
type expanderXOF struct {
id xof.ID
kSecLevel uint
dst []byte
}

// NewExpanderXOF returns an Expander based on an extendable output function.
// The kSecLevel parameter is the target security level in bits, and dst is
// a domain separation string.
func NewExpanderXOF(id xof.ID, kSecLevel uint, dst []byte) *expanderXOF {
return &expanderXOF{id, kSecLevel, dst}
}

// Expand panics if output's length is longer than 2^16 bytes.
func (e *expanderXOF) Expand(in []byte, n uint) []byte {
bLen := []byte{0, 0}
binary.BigEndian.PutUint16(bLen, uint16(n))
pseudo := make([]byte, n)
dstPrime := e.calcDSTPrime()

H := e.id.New()
mustWrite(H, in)
mustWrite(H, bLen)
mustWrite(H, dstPrime)
mustReadFull(H, pseudo)
return pseudo
}

func (e *expanderXOF) calcDSTPrime() []byte {
var dstPrime []byte
if l := len(e.dst); l > maxDSTLength {
H := e.id.New()
mustWrite(H, longDSTPrefix[:])
mustWrite(H, e.dst)
max := ((2 * e.kSecLevel) + 7) / 8
dstPrime = make([]byte, max, max+1)
mustReadFull(H, dstPrime)
} else {
dstPrime = make([]byte, l, l+1)
copy(dstPrime, e.dst)
}
return append(dstPrime, byte(len(dstPrime)))
}

func mustWrite(w io.Writer, b []byte) {
if n, err := w.Write(b); err != nil || n != len(b) {
panic(err)
}
}

func mustReadFull(r io.Reader, b []byte) {
if n, err := io.ReadFull(r, b); err != nil || n != len(b) {
panic(err)
}
}

const maxDSTLength = 255

var (
longDSTPrefix = [17]byte{'H', '2', 'C', '-', 'O', 'V', 'E', 'R', 'S', 'I', 'Z', 'E', '-', 'D', 'S', 'T', '-'}

errorLongOutput = errors.New("requested too many bytes")
)
112 changes: 112 additions & 0 deletions group/expander_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package group_test

import (
"bytes"
"crypto"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"testing"

"github.com/cloudflare/circl/group"
"github.com/cloudflare/circl/internal/test"
"github.com/cloudflare/circl/xof"
)

func TestExpander(t *testing.T) {
fileNames, err := filepath.Glob("./testdata/expand*.json")
if err != nil {
t.Fatal(err)
}

for _, fileName := range fileNames {
f, err := os.Open(fileName)
if err != nil {
t.Fatal(err)
}
dec := json.NewDecoder(f)
var v vectorExpanderSuite
err = dec.Decode(&v)
if err != nil {
t.Fatal(err)
}
f.Close()

t.Run(v.Name+"/"+v.Hash, func(t *testing.T) { testExpander(t, &v) })
}
}

func testExpander(t *testing.T, vs *vectorExpanderSuite) {
var exp group.Expander
switch vs.Hash {
case "SHA256":
exp = group.NewExpanderMD(crypto.SHA256, []byte(vs.DST))
case "SHA512":
exp = group.NewExpanderMD(crypto.SHA512, []byte(vs.DST))
case "SHAKE128":
exp = group.NewExpanderXOF(xof.SHAKE128, 0, []byte(vs.DST))
case "SHAKE256":
exp = group.NewExpanderXOF(xof.SHAKE256, 0, []byte(vs.DST))
default:
t.Skip("hash not supported: " + vs.Hash)
}

for i, v := range vs.Tests {
lenBytes, err := strconv.ParseUint(v.Len, 0, 64)
if err != nil {
t.Fatal(err)
}

got := exp.Expand([]byte(v.Msg), uint(lenBytes))
want, err := hex.DecodeString(v.UniformBytes)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(got, want) {
test.ReportError(t, got, want, i)
}
}
}

type vectorExpanderSuite struct {
DST string `json:"DST"`
Hash string `json:"hash"`
Name string `json:"name"`
Tests []struct {
DstPrime string `json:"DST_prime"`
Len string `json:"len_in_bytes"`
Msg string `json:"msg"`
MsgPrime string `json:"msg_prime"`
UniformBytes string `json:"uniform_bytes"`
} `json:"tests"`
}

func BenchmarkExpander(b *testing.B) {
in := []byte("input")
dst := []byte("dst")

for _, v := range []struct {
Name string
Exp group.Expander
}{
{"XMD", group.NewExpanderMD(crypto.SHA256, dst)},
{"XOF", group.NewExpanderXOF(xof.SHAKE128, 0, dst)},
} {
exp := v.Exp
for l := 8; l <= 10; l++ {
max := int64(1) << uint(l)

b.Run(fmt.Sprintf("%v/%v", v.Name, max), func(b *testing.B) {
b.SetBytes(max)
b.ResetTimer()
for i := 0; i < b.N; i++ {
exp.Expand(in, uint(max))
}
})
}
}
}
13 changes: 1 addition & 12 deletions group/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,9 @@
package group

import (
"crypto/elliptic"
"encoding"
"errors"
"io"

"github.com/cloudflare/circl/ecc/p384"
)

var (
// P256 is the group generated by P-256 elliptic curve.
P256 Group = wG{elliptic.P256()}
// P384 is the group generated by P-384 elliptic curve.
P384 Group = wG{p384.P384()}
// P521 is the group generated by P-521 elliptic curve.
P521 Group = wG{elliptic.P521()}
)

type Params struct {
Expand All @@ -36,6 +24,7 @@ type Group interface {
RandomElement(io.Reader) Element
RandomScalar(io.Reader) Scalar
HashToElement(data, dst []byte) Element
HashToElementNonUniform(b, dst []byte) Element
HashToScalar(data, dst []byte) Scalar
}

Expand Down
55 changes: 15 additions & 40 deletions group/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ import (
"github.com/cloudflare/circl/internal/test"
)

var allGroups = []group.Group{
group.P256,
group.P384,
group.P521,
group.Ristretto255,
}

func TestGroup(t *testing.T) {
const testTimes = 1 << 7
for _, g := range []group.Group{
group.P256,
group.P384,
group.P521,
group.Ristretto255,
} {
for _, g := range allGroups {
g := g
n := g.(fmt.Stringer).String()
t.Run(n+"/Add", func(tt *testing.T) { testAdd(tt, testTimes, g) })
Expand Down Expand Up @@ -126,22 +128,24 @@ func isZero(b []byte) bool {
func testMarshal(t *testing.T, testTimes int, g group.Group) {
params := g.Params()
I := g.Identity()
got, _ := I.MarshalBinary()
got, err := I.MarshalBinary()
test.CheckNoErr(t, err, "error on MarshalBinary")
if !isZero(got) {
test.ReportError(t, got, "Non-zero identity")
}
if l := uint(len(got)); !(l == 1 || l == params.ElementLength) {
test.ReportError(t, l, params.ElementLength)
}
got, _ = I.MarshalBinaryCompress()
got, err = I.MarshalBinaryCompress()
test.CheckNoErr(t, err, "error on MarshalBinaryCompress")
if !isZero(got) {
test.ReportError(t, got, "Non-zero identity")
}
if l := uint(len(got)); !(l == 1 || l == params.CompressedElementLength) {
test.ReportError(t, l, params.CompressedElementLength)
}
II := g.NewElement()
err := II.UnmarshalBinary(got)
err = II.UnmarshalBinary(got)
if err != nil || !I.IsEqual(II) {
test.ReportError(t, I, II)
}
Expand Down Expand Up @@ -203,11 +207,7 @@ func testScalar(t *testing.T, testTimes int, g group.Group) {
}

func BenchmarkElement(b *testing.B) {
for _, g := range []group.Group{
group.P256,
group.P384,
group.P521,
} {
for _, g := range allGroups {
x := g.RandomElement(rand.Reader)
y := g.RandomElement(rand.Reader)
n := g.RandomScalar(rand.Reader)
Expand Down Expand Up @@ -236,11 +236,7 @@ func BenchmarkElement(b *testing.B) {
}

func BenchmarkScalar(b *testing.B) {
for _, g := range []group.Group{
group.P256,
group.P384,
group.P521,
} {
for _, g := range allGroups {
x := g.RandomScalar(rand.Reader)
y := g.RandomScalar(rand.Reader)
name := g.(fmt.Stringer).String()
Expand All @@ -261,24 +257,3 @@ func BenchmarkScalar(b *testing.B) {
})
}
}

func BenchmarkHash(b *testing.B) {
for _, g := range []group.Group{
group.P256,
group.P384,
group.P521,
} {
g := g
name := g.(fmt.Stringer).String()
b.Run(name+"/HashToElement", func(b *testing.B) {
for i := 0; i < b.N; i++ {
g.HashToElement(nil, nil)
}
})
b.Run(name+"/HashToScalar", func(b *testing.B) {
for i := 0; i < b.N; i++ {
g.HashToScalar(nil, nil)
}
})
}
}
Loading