Skip to content

Commit

Permalink
zstd: Fix default level first dictionary encode (#829)
Browse files Browse the repository at this point in the history
Fix `allDirty` not being set correctly on first encode, meaning lookup table was not copied.

This would cause first encode only to use partial dictionary lookups.

Fixes #828
  • Loading branch information
klauspost committed Jun 20, 2023
1 parent ae6c851 commit 4edb2e8
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 9 deletions.
148 changes: 148 additions & 0 deletions zstd/dict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,154 @@ func TestEncoder_SmallDict(t *testing.T) {
}
}

func TestEncoder_SmallDictFresh(t *testing.T) {
// All files have CRC
zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
var dicts [][]byte
var encs []func() *Encoder
var noDictEncs []*Encoder
var encNames []string

for _, tt := range zr.File {
if !strings.HasSuffix(tt.Name, ".dict") {
continue
}
func() {
r, err := tt.Open()
if err != nil {
t.Fatal(err)
}
defer r.Close()
in, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
dicts = append(dicts, in)
for level := SpeedFastest; level < speedLast; level++ {
if isRaceTest && level >= SpeedBestCompression {
break
}
level := level
encs = append(encs, func() *Encoder {
enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderDict(in), WithEncoderLevel(level), WithWindowSize(1<<17))
if err != nil {
t.Fatal(err)
}
return enc
})
encNames = append(encNames, fmt.Sprint("level-", level.String(), "-dict-", len(dicts)))

enc, err := NewWriter(nil, WithEncoderConcurrency(1), WithEncoderLevel(level), WithWindowSize(1<<17))
if err != nil {
t.Fatal(err)
}
noDictEncs = append(noDictEncs, enc)
}
}()
}
dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
if err != nil {
t.Fatal(err)
return
}
defer dec.Close()
for i, tt := range zr.File {
if testing.Short() && i > 100 {
break
}
if !strings.HasSuffix(tt.Name, ".zst") {
continue
}
r, err := tt.Open()
if err != nil {
t.Fatal(err)
}
defer r.Close()
in, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
decoded, err := dec.DecodeAll(in, nil)
if err != nil {
t.Fatal(err)
}
if testing.Short() && len(decoded) > 1000 {
continue
}

t.Run("encodeall-"+tt.Name, func(t *testing.T) {
// Attempt to compress with all dicts
var b []byte
var tmp []byte
for i := range encs {
i := i
t.Run(encNames[i], func(t *testing.T) {
enc := encs[i]()
defer enc.Close()
b = enc.EncodeAll(decoded, b[:0])
tmp, err = dec.DecodeAll(in, tmp[:0])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(tmp, decoded) {
t.Fatal("output mismatch")
}

tmp = noDictEncs[i].EncodeAll(decoded, tmp[:0])

if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
t.Log("reference:", len(in), "no dict:", len(tmp), "with dict:", len(b), "SAVED:", len(tmp)-len(b))
// Check that we reduced this significantly
if len(b) > 250 {
t.Error("output was bigger than expected")
}
}
})
}
})
t.Run("stream-"+tt.Name, func(t *testing.T) {
// Attempt to compress with all dicts
var tmp []byte
for i := range encs {
i := i
t.Run(encNames[i], func(t *testing.T) {
enc := encs[i]()
defer enc.Close()
var buf bytes.Buffer
enc.ResetContentSize(&buf, int64(len(decoded)))
_, err := enc.Write(decoded)
if err != nil {
t.Fatal(err)
}
err = enc.Close()
if err != nil {
t.Fatal(err)
}
tmp, err = dec.DecodeAll(buf.Bytes(), tmp[:0])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(tmp, decoded) {
t.Fatal("output mismatch")
}
var buf2 bytes.Buffer
noDictEncs[i].Reset(&buf2)
noDictEncs[i].Write(decoded)
noDictEncs[i].Close()

if strings.Contains(t.Name(), "dictplain") && strings.Contains(t.Name(), "dict-1") {
t.Log("reference:", len(in), "no dict:", buf2.Len(), "with dict:", buf.Len(), "SAVED:", buf2.Len()-buf.Len())
// Check that we reduced this significantly
if buf.Len() > 250 {
t.Error("output was bigger than expected")
}
}
})
}
})
}
}

func benchmarkEncodeAllLimitedBySize(b *testing.B, lowerLimit int, upperLimit int) {
zr := testCreateZipReader("testdata/dict-tests-small.zip", b)
t := testing.TB(b)
Expand Down
1 change: 1 addition & 0 deletions zstd/enc_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func (e *fastBase) resetBase(d *dict, singleBlock bool) {
} else {
e.crc.Reset()
}
e.blk.dictLitEnc = nil
if d != nil {
low := e.lowMem
if singleBlock {
Expand Down
2 changes: 1 addition & 1 deletion zstd/enc_dfast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
}
}
e.lastDictID = d.id
e.allDirty = true
allDirty = true
}
// Reset table to initial state
e.cur = e.maxMatchOff
Expand Down
11 changes: 3 additions & 8 deletions zstd/enc_fast.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,13 +829,12 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
}
if true {
end := e.maxMatchOff + int32(len(d.content)) - 8
for i := e.maxMatchOff; i < end; i += 3 {
for i := e.maxMatchOff; i < end; i += 2 {
const hashLog = tableBits

cv := load6432(d.content, i-e.maxMatchOff)
nextHash := hashLen(cv, hashLog, tableFastHashLen) // 0 -> 5
nextHash1 := hashLen(cv>>8, hashLog, tableFastHashLen) // 1 -> 6
nextHash2 := hashLen(cv>>16, hashLog, tableFastHashLen) // 2 -> 7
nextHash := hashLen(cv, hashLog, tableFastHashLen) // 0 -> 6
nextHash1 := hashLen(cv>>8, hashLog, tableFastHashLen) // 1 -> 7
e.dictTable[nextHash] = tableEntry{
val: uint32(cv),
offset: i,
Expand All @@ -844,10 +843,6 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
val: uint32(cv >> 8),
offset: i + 1,
}
e.dictTable[nextHash2] = tableEntry{
val: uint32(cv >> 16),
offset: i + 2,
}
}
}
e.lastDictID = d.id
Expand Down

0 comments on commit 4edb2e8

Please sign in to comment.