Skip to content

Commit

Permalink
zstd: Fix ReadFrom with small blocks (#278)
Browse files Browse the repository at this point in the history
Two 'last' blocks was added on small payloads when using ReadFrom.

Fixes #277
  • Loading branch information
klauspost committed Aug 13, 2020
1 parent fb1e79e commit 898127a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
8 changes: 4 additions & 4 deletions zstd/blockenc.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func (b *blockEnc) encodeRaw(a []byte) {
b.output = bh.appendTo(b.output[:0])
b.output = append(b.output, a...)
if debug {
println("Adding RAW block, length", len(a))
println("Adding RAW block, length", len(a), "last:", b.last)
}
}

Expand All @@ -308,7 +308,7 @@ func (b *blockEnc) encodeRawTo(dst, src []byte) []byte {
dst = bh.appendTo(dst)
dst = append(dst, src...)
if debug {
println("Adding RAW block, length", len(src))
println("Adding RAW block, length", len(src), "last:", b.last)
}
return dst
}
Expand All @@ -322,7 +322,7 @@ func (b *blockEnc) encodeLits(raw bool) error {
// Don't compress extremely small blocks
if len(b.literals) < 32 || raw {
if debug {
println("Adding RAW block, length", len(b.literals))
println("Adding RAW block, length", len(b.literals), "last:", b.last)
}
bh.setType(blockTypeRaw)
b.output = bh.appendTo(b.output)
Expand All @@ -349,7 +349,7 @@ func (b *blockEnc) encodeLits(raw bool) error {
switch err {
case huff0.ErrIncompressible:
if debug {
println("Adding RAW block, length", len(b.literals))
println("Adding RAW block, length", len(b.literals), "last:", b.last)
}
bh.setType(blockTypeRaw)
b.output = bh.appendTo(b.output)
Expand Down
1 change: 1 addition & 0 deletions zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func (e *Encoder) nextBlock(final bool) error {
s.filling = s.filling[:0]
s.headerWritten = true
s.fullFrameWritten = true
s.eofWritten = true
return nil
}

Expand Down
28 changes: 28 additions & 0 deletions zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,34 @@ func TestEncoder_EncodeAllSilesia(t *testing.T) {
t.Log("Encoded content matched")
}

func TestEncoderReadFrom(t *testing.T) {
buffer := bytes.NewBuffer(nil)
encoder, err := NewWriter(buffer)
if err != nil {
t.Fatal(err)
}
if _, err := encoder.ReadFrom(strings.NewReader("0")); err != nil {
t.Fatal(err)
}
if err := encoder.Close(); err != nil {
t.Fatal(err)
}

dec, _ := NewReader(nil)
toDec := buffer.Bytes()
toDec = append(toDec, toDec...)
decoded, err := dec.DecodeAll(toDec, nil)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal([]byte("00"), decoded) {
t.Logf("encoded: % x\n", buffer.Bytes())
t.Fatalf("output mismatch, got %s", string(decoded))
}
dec.Close()
}

func TestEncoder_EncodeAllEmpty(t *testing.T) {
if testing.Short() {
t.SkipNow()
Expand Down

0 comments on commit 898127a

Please sign in to comment.