diff --git a/blake3.go b/blake3.go index 5ede51f..14bc907 100644 --- a/blake3.go +++ b/blake3.go @@ -154,71 +154,6 @@ func (n node) chainingValue() (cv [8]uint32) { return } -// An OutputReader produces an seekable stream of 2^64 - 1 pseudorandom output -// bytes. -type OutputReader struct { - n node - block [blockSize]byte - off uint64 -} - -// Read implements io.Reader. Callers may assume that Read returns len(p), nil -// unless the read would extend beyond the end of the stream. -func (or *OutputReader) Read(p []byte) (int, error) { - if or.off == math.MaxUint64 { - return 0, io.EOF - } else if rem := math.MaxUint64 - or.off; uint64(len(p)) > rem { - p = p[:rem] - } - lenp := len(p) - for len(p) > 0 { - if or.off%blockSize == 0 { - or.n.counter = or.off / blockSize - words := or.n.compress() - wordsToBytes(words[:], or.block[:]) - } - - n := copy(p, or.block[or.off%blockSize:]) - p = p[n:] - or.off += uint64(n) - } - return lenp, nil -} - -// Seek implements io.Seeker. -func (or *OutputReader) Seek(offset int64, whence int) (int64, error) { - off := or.off - switch whence { - case io.SeekStart: - if offset < 0 { - return 0, errors.New("seek position cannot be negative") - } - off = uint64(offset) - case io.SeekCurrent: - if offset < 0 { - if uint64(-offset) > off { - return 0, errors.New("seek position cannot be negative") - } - off -= uint64(-offset) - } else { - off += uint64(offset) - } - case io.SeekEnd: - off = uint64(offset) - 1 - default: - panic("invalid whence") - } - or.off = off - or.n.counter = uint64(off) / blockSize - if or.off%blockSize != 0 { - words := or.n.compress() - wordsToBytes(words[:], or.block[:]) - } - // NOTE: or.off >= 2^63 will result in a negative return value. - // Nothing we can do about this. - return int64(or.off), nil -} - // chunkState manages the state involved in hashing a single chunk of input. type chunkState struct { n node @@ -233,6 +168,12 @@ func (cs *chunkState) chunkCounter() uint64 { return cs.n.counter } +// complete is a helper method that reports whether a full chunk has been +// processed. +func (cs *chunkState) complete() bool { + return cs.bytesConsumed == chunkSize +} + // update incorporates input into the chunkState. func (cs *chunkState) update(input []byte) { for len(input) > 0 { @@ -311,26 +252,6 @@ type Hasher struct { size int // output size, for Sum } -func newHasher(key [8]uint32, flags uint32, size int) *Hasher { - return &Hasher{ - cs: newChunkState(key, 0, flags), - key: key, - flags: flags, - size: size, - } -} - -// New returns a Hasher for the specified size and key. If key is nil, the hash -// is unkeyed. -func New(size int, key []byte) *Hasher { - if key == nil { - return newHasher(iv, 0, size) - } - var keyWords [8]uint32 - bytesToWords(key[:], keyWords[:]) - return newHasher(keyWords, flagKeyedHash, size) -} - // addChunkChainingValue appends a chunk to the right edge of the Merkle tree. func (h *Hasher) addChunkChainingValue(cv [8]uint32, totalChunks uint64) { // This chunk might complete some subtrees. For each completed subtree, its @@ -383,7 +304,7 @@ func (h *Hasher) Write(p []byte) (int, error) { // If the current chunk is complete, finalize it and add it to the tree, // then reset the chunk state (but keep incrementing the counter across // chunks). - if h.cs.bytesConsumed == chunkSize { + if h.cs.complete() { cv := h.cs.node().chainingValue() totalChunks := h.cs.chunkCounter() + 1 h.addChunkChainingValue(cv, totalChunks) @@ -423,6 +344,26 @@ func (h *Hasher) XOF() *OutputReader { } } +func newHasher(key [8]uint32, flags uint32, size int) *Hasher { + return &Hasher{ + cs: newChunkState(key, 0, flags), + key: key, + flags: flags, + size: size, + } +} + +// New returns a Hasher for the specified size and key. If key is nil, the hash +// is unkeyed. +func New(size int, key []byte) *Hasher { + if key == nil { + return newHasher(iv, 0, size) + } + var keyWords [8]uint32 + bytesToWords(key[:], keyWords[:]) + return newHasher(keyWords, flagKeyedHash, size) +} + // Sum256 returns the unkeyed BLAKE3 hash of b, truncated to 256 bits. func Sum256(b []byte) (out [32]byte) { h := newHasher(iv, 0, 0) @@ -463,5 +404,70 @@ func DeriveKey(subKey []byte, ctx string, srcKey []byte) { h.XOF().Read(subKey) } +// An OutputReader produces an seekable stream of 2^64 - 1 pseudorandom output +// bytes. +type OutputReader struct { + n node + block [blockSize]byte + off uint64 +} + +// Read implements io.Reader. Callers may assume that Read returns len(p), nil +// unless the read would extend beyond the end of the stream. +func (or *OutputReader) Read(p []byte) (int, error) { + if or.off == math.MaxUint64 { + return 0, io.EOF + } else if rem := math.MaxUint64 - or.off; uint64(len(p)) > rem { + p = p[:rem] + } + lenp := len(p) + for len(p) > 0 { + if or.off%blockSize == 0 { + or.n.counter = or.off / blockSize + words := or.n.compress() + wordsToBytes(words[:], or.block[:]) + } + + n := copy(p, or.block[or.off%blockSize:]) + p = p[n:] + or.off += uint64(n) + } + return lenp, nil +} + +// Seek implements io.Seeker. +func (or *OutputReader) Seek(offset int64, whence int) (int64, error) { + off := or.off + switch whence { + case io.SeekStart: + if offset < 0 { + return 0, errors.New("seek position cannot be negative") + } + off = uint64(offset) + case io.SeekCurrent: + if offset < 0 { + if uint64(-offset) > off { + return 0, errors.New("seek position cannot be negative") + } + off -= uint64(-offset) + } else { + off += uint64(offset) + } + case io.SeekEnd: + off = uint64(offset) - 1 + default: + panic("invalid whence") + } + or.off = off + or.n.counter = uint64(off) / blockSize + if or.off%blockSize != 0 { + words := or.n.compress() + wordsToBytes(words[:], or.block[:]) + } + // NOTE: or.off >= 2^63 will result in a negative return value. + // Nothing we can do about this. + return int64(or.off), nil +} + // ensure that Hasher implements hash.Hash var _ hash.Hash = (*Hasher)(nil) diff --git a/blake3_test.go b/blake3_test.go index 1f885cc..7db3605 100644 --- a/blake3_test.go +++ b/blake3_test.go @@ -118,6 +118,27 @@ func TestXOF(t *testing.T) { if n != 0 || err != io.EOF { t.Errorf("expected (0, EOF) when reading past end of stream, got (%v, %v)", n, err) } + + // test invalid seek offsets + _, err = xof.Seek(-1, io.SeekStart) + if err == nil { + t.Error("expected invalid offset error, got nil") + } + xof.Seek(0, io.SeekStart) + _, err = xof.Seek(-1, io.SeekCurrent) + if err == nil { + t.Error("expected invalid offset error, got nil") + } + + // test invalid seek whence + didPanic := func() (p bool) { + defer func() { p = recover() != nil }() + xof.Seek(0, 17) + return + }() + if !didPanic { + t.Error("expected panic when seeking with invalid whence") + } } func TestSum(t *testing.T) { @@ -142,6 +163,27 @@ func TestSum(t *testing.T) { } } +func TestReset(t *testing.T) { + for _, vec := range testVectors.Cases { + in := testInput[:vec.InputLen] + + h := blake3.New(32, nil) + h.Write(in) + out1 := h.Sum(nil) + h.Reset() + h.Write(in) + out2 := h.Sum(nil) + if !bytes.Equal(out1, out2) { + t.Error("Reset did not reset Hasher state properly") + } + } + + // gotta have 100% test coverage... + if blake3.New(0, nil).BlockSize() != 64 { + t.Error("incorrect block size") + } +} + type nopReader struct{} func (nopReader) Read(p []byte) (int, error) { return len(p), nil }