diff --git a/stream/stream_test.go b/stream/stream_test.go new file mode 100644 index 0000000..42bd7e9 --- /dev/null +++ b/stream/stream_test.go @@ -0,0 +1,99 @@ +// Copyright 2019 Google LLC +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +package stream_test + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "github.com/muesli/sasquatch/stream" + "golang.org/x/crypto/chacha20poly1305" +) + +const cs = stream.ChunkSize + +func TestRoundTrip(t *testing.T) { + for _, stepSize := range []int{512, 600, 1000, cs} { + for _, length := range []int{0, 1000, cs, cs + 100} { + t.Run(fmt.Sprintf("len=%d,step=%d", length, stepSize), + func(t *testing.T) { testRoundTrip(t, stepSize, length) }) + } + } +} + +func testRoundTrip(t *testing.T, stepSize, length int) { + src := make([]byte, length) + if _, err := rand.Read(src); err != nil { + t.Fatal(err) + } + buf := &bytes.Buffer{} + key := make([]byte, chacha20poly1305.KeySize) + if _, err := rand.Read(key); err != nil { + t.Fatal(err) + } + + w, err := stream.NewWriter(key, buf) + if err != nil { + t.Fatal(err) + } + + var n int + for n < length { + b := length - n + if b > stepSize { + b = stepSize + } + nn, err := w.Write(src[n : n+b]) + if err != nil { + t.Fatal(err) + } + if nn != b { + t.Errorf("Write returned %d, expected %d", nn, b) + } + n += nn + + nn, err = w.Write(src[n:n]) + if err != nil { + t.Fatal(err) + } + if nn != 0 { + t.Errorf("Write returned %d, expected 0", nn) + } + } + + if err := w.Close(); err != nil { + t.Error("Close returned an error:", err) + } + + t.Logf("buffer size: %d", buf.Len()) + + r, err := stream.NewReader(key, buf) + if err != nil { + t.Fatal(err) + } + + n = 0 + readBuf := make([]byte, stepSize) + for n < length { + b := length - n + if b > stepSize { + b = stepSize + } + nn, err := r.Read(readBuf) + if err != nil { + t.Fatalf("Read error at index %d: %v", n, err) + } + + if !bytes.Equal(readBuf[:nn], src[n:n+nn]) { + t.Errorf("wrong data at indexes %d - %d", n, n+nn) + } + + n += nn + } +}