diff --git a/zstd/bytebuf.go b/zstd/bytebuf.go index 512ffe5b95..55a388553d 100644 --- a/zstd/bytebuf.go +++ b/zstd/bytebuf.go @@ -109,7 +109,7 @@ func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) { } func (r *readerWrapper) readByte() (byte, error) { - n2, err := r.r.Read(r.tmp[:1]) + n2, err := io.ReadFull(r.r, r.tmp[:1]) if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index bab352075d..002857a6c0 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -563,6 +563,48 @@ func TestNewDecoderSmallFile(t *testing.T) { t.Logf("Decoded %d bytes with %f.2 MB/s", n, mbpersec) } +// cursedReader wraps a reader and returns zero bytes every other read. +// This is used to test the ability of the consumer to handle empty reads without EOF, +// which can happen when reading from a network connection. +type cursedReader struct { + io.Reader + numReads int +} + +func (r *cursedReader) Read(p []byte) (n int, err error) { + r.numReads++ + if r.numReads%2 == 0 { + return 0, nil + } + + return r.Reader.Read(p) +} + +func TestNewDecoderZeroLengthReads(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + file := "testdata/z000028.zst" + const wantSize = 39807 + f, err := os.Open(file) + if err != nil { + t.Fatal(err) + } + defer f.Close() + dec, err := NewReader(&cursedReader{Reader: f}) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + n, err := io.Copy(io.Discard, dec) + if err != nil { + t.Fatal(err) + } + if n != wantSize { + t.Errorf("want size %d, got size %d", wantSize, n) + } +} + type readAndBlock struct { buf []byte unblock chan struct{}