Skip to content

Commit

Permalink
Add test + doc
Browse files Browse the repository at this point in the history
  • Loading branch information
klauspost committed Oct 7, 2022
1 parent a9b6adf commit b63f9c2
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 13 deletions.
18 changes: 12 additions & 6 deletions estream/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

This package provides a flexible way to merge multiple streams with controlled encryption.

The stream is stateful and allows to send individually encrypted streams.

## Features

* Allows encrypted and unencrypted streams.
* Any number of keys can be used on a streams.
* Any number of keys can be used on streams.
* Each key can be encrypted by a (different) public key.
* Each stream is identified by a string "name".
* A stream has optional (unencrypted) metadata slice.
Expand Down Expand Up @@ -62,6 +64,9 @@ The functions above return an `io.WriteCloser`.
Data for this stream should be written to this interface
and `Close()` should be called before another stream can be added.

Note that enuncrypted streams are unbuffered, so it may be a benefit to insert a `bufio.Writer`
to avoid very small packets. Encrypted streams are buffered since

# Reading Streams

To read back data `r, err := estream.NewReader(input)` can be used for create a Reader.
Expand Down Expand Up @@ -137,7 +142,8 @@ but may contain data that will be ignored by older versions.

Each block is preceded by a messagepack encoded int8 indicating the block type.

Positive types must be parsed by the decoder. Negative types are *skippable* blocks.
Positive types must be parsed by the decoder. Negative types are *skippable* blocks,
so unknown skippable blocks can be ignored.

Blocks have their length encoded as a messagepack unsigned integer following the block ID.
This indicates the number of bytes to skip after the length to reach the next block ID.
Expand Down Expand Up @@ -236,10 +242,10 @@ It is expected that the parser returns the message and stops processing.

## Checksum types

| ID | Type | Bytes |
|-----|-----------------------|-----------|
| 0 | No checksum | (ignored) |
| 1 | 64 bit xxhash (XXH64) | 8 |
| ID | Type | Bytes |
|-----|------------------------------------|-----------|
| 0 | No checksum | (ignored) |
| 1 | 64 bit xxhash (XXH64) (Big Endian) | 8 |

# Version History

Expand Down
18 changes: 18 additions & 0 deletions estream/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func (r *Reader) NextStream() (*Stream, error) {
return nil, errors.New("previous stream not read until EOF")
}

// Temp storage for blocks.
block := make([]byte, 1024)
for {
// Read block ID.
Expand All @@ -108,10 +109,14 @@ func (r *Reader) NextStream() (*Stream, error) {
return nil, r.setErr(err)
}
id := blockID(n)

// Read block size
sz, err := r.mr.ReadUint32()
if err != nil {
return nil, r.setErr(err)
}

// Read block data
if cap(block) < int(sz) {
block = make([]byte, sz)
}
Expand All @@ -121,15 +126,19 @@ func (r *Reader) NextStream() (*Stream, error) {
return nil, r.setErr(err)
}

// Parse block
switch id {
case blockPlainKey:
// Read plaintext key.
key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32))
if err != nil {
return nil, r.setErr(err)
}
if len(key) != 32 {
return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
}

// Set key for following streams.
r.key = (*[32]byte)(key)
case blockEncryptedKey:
// Read public key
Expand Down Expand Up @@ -177,7 +186,9 @@ func (r *Reader) NextStream() (*Stream, error) {
return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
}
r.key = (*[32]byte)(key)

case blockPlainStream, blockEncStream:
// Read metadata
name, block, err := msgp.ReadStringBytes(block)
if err != nil {
return nil, r.setErr(err)
Expand All @@ -195,13 +206,16 @@ func (r *Reader) NextStream() (*Stream, error) {
return nil, r.setErr(fmt.Errorf("unknown checksum type %d", checksum))
}

// Return plaintext stream
if id == blockPlainStream {
return &Stream{
Reader: r.newStreamReader(checksum),
Name: name,
Extra: extra,
}, nil
}

// Handle encrypted streams.
if r.key == nil {
if r.skipEncrypted {
if err := r.skipDataBlocks(); err != nil {
Expand Down Expand Up @@ -252,6 +266,7 @@ func (r *Reader) NextStream() (*Stream, error) {
}
}

// skipDataBlocks reads data blocks until end.
func (r *Reader) skipDataBlocks() error {
for {
// Read block ID.
Expand Down Expand Up @@ -291,6 +306,7 @@ func (r *Reader) skipDataBlocks() error {
}
}

// setErr sets a stateful error.
func (r *Reader) setErr(err error) error {
if r.err != nil {
return r.err
Expand Down Expand Up @@ -320,13 +336,15 @@ type streamReader struct {
check checksumType
}

// newStreamReader creates a stream reader that can be read to get all data blocks.
func (r *Reader) newStreamReader(ct checksumType) *streamReader {
sr := &streamReader{up: r, check: ct}
sr.h.Reset()
r.inStream = true
return sr
}

// Read will return data blocks as on stream.
func (r *streamReader) Read(b []byte) (int, error) {
if r.isEOF {
return 0, io.EOF
Expand Down
32 changes: 28 additions & 4 deletions estream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,21 @@ import (
// No private key should be returned for this.
type ReplaceFn func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey)

// ReplaceKeysOptions allows passing additional options to ReplaceKeys.
type ReplaceKeysOptions struct {
// If EncryptAll set all unencrypted keys will be encrypted.
EncryptAll bool

// PassErrors will pass through error an error packet,
// and not return an error.
PassErrors bool
}

// ReplaceKeys will replace the keys in a stream.
//
// A replace function must be provided. See ReplaceFn for functionality.
// If encryptAll is set unencrypted keys will be re-encrypted.
func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, encryptAll bool) error {
// If encryptAll is set.
func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, o ReplaceKeysOptions) error {
var ver [2]byte
if _, err := io.ReadFull(r, ver[:]); err != nil {
return err
Expand All @@ -59,9 +69,14 @@ func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, encryptAll bool) e
if _, err := w.Write(ver[:]); err != nil {
return err
}
block := make([]byte, 1024)
// Input
mr := msgp.NewReader(r)
mw := msgp.NewWriter(w)

// Temporary block storage.
block := make([]byte, 1024)

// Write a block.
writeBlock := func(id blockID, sz uint32, content []byte) error {
if err := mw.WriteInt8(int8(id)); err != nil {
return err
Expand All @@ -80,6 +95,8 @@ func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, encryptAll bool) e
return err
}
id := blockID(n)

// Read size
sz, err := mr.ReadUint32()
if err != nil {
return err
Expand Down Expand Up @@ -146,7 +163,7 @@ func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, encryptAll bool) e
return err
}
case blockPlainKey:
if !encryptAll {
if !o.EncryptAll {
if err := writeBlock(id, sz, block); err != nil {
return err
}
Expand Down Expand Up @@ -184,6 +201,13 @@ func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, encryptAll bool) e
}
return mw.Flush()
case blockError:
if o.PassErrors {
if err := writeBlock(id, sz, block); err != nil {
return err
}
return mw.Flush()
}
// Return error
msg, _, err := msgp.ReadStringBytes(block)
if err != nil {
return err
Expand Down
31 changes: 28 additions & 3 deletions estream/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestStreamRoundtrip(t *testing.T) {
}
st.Close()
wantStreams += 2
wantDecStreams += 1
wantDecStreams++
}
err = w.Close()
if err != nil {
Expand Down Expand Up @@ -169,7 +169,6 @@ func TestStreamRoundtrip(t *testing.T) {
if gotStreams != wantDecStreams {
t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
}

}

func TestReplaceKeys(t *testing.T) {
Expand Down Expand Up @@ -250,7 +249,7 @@ func TestReplaceKeys(t *testing.T) {
}
t.Fatal("unknown key\n", *key, "\nwant\n", priv.PublicKey)
return nil, nil
}, true)
}, ReplaceKeysOptions{EncryptAll: true})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -302,3 +301,29 @@ func TestReplaceKeys(t *testing.T) {
t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
}
}

func TestError(t *testing.T) {
var buf bytes.Buffer
w := NewWriter(&buf)
if err := w.AddKeyPlain(); err != nil {
t.Fatal(err)
}
want := "an error message!"
if err := w.AddError(want); err != nil {
t.Fatal(err)
}
w.Close()

// Read back...
r, err := NewReader(&buf)
if err != nil {
t.Fatal(err)
}
st, err := r.NextStream()
if err == nil {
t.Fatalf("did not receive error, got %v, err: %v", st, err)
}
if err.Error() != want {
t.Errorf("Expected %q, got %q", want, err.Error())
}
}

0 comments on commit b63f9c2

Please sign in to comment.