Skip to content

Commit

Permalink
Enable assembly for more cases (#559)
Browse files Browse the repository at this point in the history
* Enable assembly for more cases
* Fix excessive allocation in decoder in tests.
  • Loading branch information
klauspost committed Apr 22, 2022
1 parent 4a51c29 commit aa5b572
Show file tree
Hide file tree
Showing 8 changed files with 1,164 additions and 121 deletions.
13 changes: 12 additions & 1 deletion zstd/_generate/gen.go
Expand Up @@ -70,6 +70,7 @@ func main() {

exec := executeSimple{
useSeqs: true,
safeMem: false,
}
exec.generateProcedure("sequenceDecs_executeSimple_amd64")

Expand All @@ -79,6 +80,11 @@ func main() {
decodeSync.setBMI2(true)
decodeSync.generateProcedure("sequenceDecs_decodeSync_bmi2")

decodeSync.execute.safeMem = true
decodeSync.setBMI2(false)
decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_amd64")
decodeSync.setBMI2(true)
decodeSync.generateProcedure("sequenceDecs_decodeSync_safe_bmi2")
Generate()
b, err := ioutil.ReadFile(out.Value.String())
if err != nil {
Expand Down Expand Up @@ -860,6 +866,7 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV

type executeSimple struct {
useSeqs bool // Generate code that uses the `seqs` auxiliary table
safeMem bool
}

// copySize returns register size used to fast copy.
Expand Down Expand Up @@ -1130,7 +1137,11 @@ func (e executeSimple) executeSingleTriple(c *executeSingleTripleContext, handle

Comment("Copy non-overlapping match")
{
e.copyMemoryPrecise("2", src, c.outBase, ml)
if e.safeMem {
e.copyMemoryPrecise("2", src, c.outBase, ml)
} else {
e.copyMemory("2", src, c.outBase, ml)
}
ADDQ(ml, c.outBase)
ADDQ(ml, c.outPosition)
JMP(LabelRef("handle_loop"))
Expand Down
14 changes: 8 additions & 6 deletions zstd/decoder.go
Expand Up @@ -347,18 +347,20 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
frame.history.setDict(&dict)
}

if frame.FrameContentSize != fcsUnknown && frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
return dst, ErrDecoderSizeExceeded
if frame.WindowSize > d.o.maxWindowSize {
return dst, ErrWindowSizeExceeded
}
if frame.FrameContentSize < 1<<30 {
// Never preallocate more than 1 GB up front.
if frame.FrameContentSize != fcsUnknown {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
return dst, ErrDecoderSizeExceeded
}
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize))
dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize)+compressedBlockOverAlloc)
copy(dst2, dst)
dst = dst2
}
}

if cap(dst) == 0 {
// Allocate len(input) * 2 by default if nothing is provided
// and we didn't get frame content size.
Expand Down
4 changes: 2 additions & 2 deletions zstd/decoder_options.go
Expand Up @@ -31,7 +31,7 @@ func (o *decoderOptions) setDefault() {
if o.concurrent > 4 {
o.concurrent = 4
}
o.maxDecodedSize = 1 << 63
o.maxDecodedSize = 64 << 30
}

// WithDecoderLowmem will set whether to use a lower amount of memory,
Expand Down Expand Up @@ -66,7 +66,7 @@ func WithDecoderConcurrency(n int) DOption {
// WithDecoderMaxMemory allows to set a maximum decoded size for in-memory
// non-streaming operations or maximum window size for streaming operations.
// This can be used to control memory usage of potentially hostile content.
// Maximum and default is 1 << 63 bytes.
// Maximum is 1 << 63 bytes. Default is 64GiB.
func WithDecoderMaxMemory(n uint64) DOption {
return func(o *decoderOptions) error {
if n == 0 {
Expand Down
33 changes: 30 additions & 3 deletions zstd/decoder_test.go
Expand Up @@ -410,21 +410,42 @@ func TestNewDecoderBad(t *testing.T) {
if true {
t.Run("Reader-4", func(t *testing.T) {
newFn := func() (*Decoder, error) {
return NewReader(nil, WithDecoderConcurrency(4))
return NewReader(nil, WithDecoderConcurrency(4), WithDecoderMaxMemory(1<<30))
}
testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)

})
t.Run("Reader-1", func(t *testing.T) {
newFn := func() (*Decoder, error) {
return NewReader(nil, WithDecoderConcurrency(1))
return NewReader(nil, WithDecoderConcurrency(1), WithDecoderMaxMemory(1<<30))
}
testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)
})
t.Run("Reader-4-bigmem", func(t *testing.T) {
newFn := func() (*Decoder, error) {
return NewReader(nil, WithDecoderConcurrency(4), WithDecoderMaxMemory(1<<30), WithDecoderLowmem(false))
}
testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)

})
t.Run("Reader-1-bigmem", func(t *testing.T) {
newFn := func() (*Decoder, error) {
return NewReader(nil, WithDecoderConcurrency(1), WithDecoderMaxMemory(1<<30), WithDecoderLowmem(false))
}
testDecoderFileBad(t, "testdata/bad.zip", newFn, errMap)
})
}
t.Run("DecodeAll", func(t *testing.T) {
defer timeout(10 * time.Second)()
dec, err := NewReader(nil)
dec, err := NewReader(nil, WithDecoderMaxMemory(1<<30))
if err != nil {
t.Fatal(err)
}
testDecoderDecodeAllError(t, "testdata/bad.zip", dec, errMap)
})
t.Run("DecodeAll-bigmem", func(t *testing.T) {
defer timeout(10 * time.Second)()
dec, err := NewReader(nil, WithDecoderMaxMemory(1<<30), WithDecoderLowmem(false))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1591,6 +1612,12 @@ func testDecoderDecodeAllError(t *testing.T, fn string, dec *Decoder, errMap map
} else {
want := errMap[tt.Name]
if want != err.Error() {
if want == ErrFrameSizeMismatch.Error() && err == ErrDecoderSizeExceeded {
return
}
if want == ErrWindowSizeExceeded.Error() && err == ErrDecoderSizeExceeded {
return
}
t.Errorf("error mismatch, prev run got %s, now got %s", want, err.Error())
}
return
Expand Down
13 changes: 13 additions & 0 deletions zstd/framedec.go
Expand Up @@ -326,6 +326,19 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
d.history.ignoreBuffer = len(dst)
// Store input length, so we only check new data.
crcStart := len(dst)
d.history.decoders.maxSyncLen = 0
if d.FrameContentSize != fcsUnknown {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
return dst, ErrDecoderSizeExceeded
}
if uint64(cap(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst)
dst = dst2
}
}
var err error
for {
err = dec.reset(d.rawInput, d.WindowSize)
Expand Down
1 change: 1 addition & 0 deletions zstd/seqdec.go
Expand Up @@ -73,6 +73,7 @@ type sequenceDecs struct {
seqSize int
windowSize int
maxBits uint8
maxSyncLen uint64
}

// initialize all 3 decoders from the stream input.
Expand Down
34 changes: 30 additions & 4 deletions zstd/seqdec_amd64.go
Expand Up @@ -39,11 +39,29 @@ func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSy
//go:noescape
func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int

// sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer.
//go:noescape
func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int

// sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer.
//go:noescape
func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int

// decode sequences from the stream with the provided history but without a dictionary.
func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
if len(s.dict) > 0 || cap(s.out)-len(s.out) < maxCompressedBlockSizeAlloc {
if len(s.dict) > 0 {
return false, nil
}
if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSize {
return false, nil
}
useSafe := false
if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSizeAlloc {
useSafe = true
}
if s.maxSyncLen > 0 && uint64(cap(s.out))-compressedBlockOverAlloc < s.maxSyncLen {
useSafe = true
}

br := s.br

Expand Down Expand Up @@ -73,9 +91,17 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {

var errCode int
if cpuinfo.HasBMI2() {
errCode = sequenceDecs_decodeSync_bmi2(s, br, &ctx)
if useSafe {
errCode = sequenceDecs_decodeSync_safe_bmi2(s, br, &ctx)
} else {
errCode = sequenceDecs_decodeSync_bmi2(s, br, &ctx)
}
} else {
errCode = sequenceDecs_decodeSync_amd64(s, br, &ctx)
if useSafe {
errCode = sequenceDecs_decodeSync_safe_amd64(s, br, &ctx)
} else {
errCode = sequenceDecs_decodeSync_amd64(s, br, &ctx)
}
}
switch errCode {
case noError:
Expand Down Expand Up @@ -211,7 +237,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
}
}
if errCode != 0 {
i := len(seqs) - ctx.iteration
i := len(seqs) - ctx.iteration - 1
switch errCode {
case errorMatchLenOfsMismatch:
ml := ctx.seqs[i].ml
Expand Down

0 comments on commit aa5b572

Please sign in to comment.