Skip to content

Commit

Permalink
Fix WriteBatchIterator: doesn’t correctly detect invalid sequences (#132
Browse files Browse the repository at this point in the history
)
  • Loading branch information
linxGnu committed Nov 30, 2023
1 parent 254f68b commit 23acbb6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 13 deletions.
19 changes: 6 additions & 13 deletions write_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package grocksdb
import "C"

import (
"encoding/binary"
"errors"
"io"
)
Expand Down Expand Up @@ -342,21 +343,13 @@ func (iter *WriteBatchIterator) decodeRecType() WriteBatchRecordType {
}

func (iter *WriteBatchIterator) decodeVarint() uint64 {
var n int
var x uint64
for shift := uint(0); shift < 64 && n < len(iter.data); shift += 7 {
b := uint64(iter.data[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
iter.data = iter.data[n:]
return x
}
}
if n == len(iter.data) {
v, n := binary.Uvarint(iter.data)
if n > 0 {
iter.data = iter.data[n:]
} else if n == 0 {
iter.err = io.ErrShortBuffer
} else {
iter.err = errors.New("malformed varint")
}
return 0
return v
}
54 changes: 54 additions & 0 deletions write_batch_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package grocksdb

import (
"math"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -89,3 +90,56 @@ func TestWriteBatchIterator(t *testing.T) {
// there shouldn't be any left
require.False(t, iter.Next())
}

func TestDecodeVarint_ISSUE131(t *testing.T) {
t.Parallel()

tests := []struct {
name string
in []byte
wantValue uint64
expectErr bool
}{
{
name: "invalid: 10th byte",
in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
wantValue: 0,
expectErr: true,
},
{
name: "valid: math.MaxUint64-40",
in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01},
wantValue: math.MaxUint64 - 40,
expectErr: false,
},
{
name: "invalid: with more than MaxVarintLen64 bytes",
in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01},
wantValue: 0,
expectErr: true,
},
{
name: "invalid: 1000 bytes",
in: func() []byte {
b := make([]byte, 1000)
for i := range b {
b[i] = 0xff
}
b[999] = 0
return b
}(),
wantValue: 0,
expectErr: true,
},
}

for _, test := range tests {
wbi := &WriteBatchIterator{data: test.in}
require.EqualValues(t, test.wantValue, wbi.decodeVarint(), test.name)
if test.expectErr {
require.Error(t, wbi.err, test.name)
} else {
require.NoError(t, wbi.err, test.name)
}
}
}

0 comments on commit 23acbb6

Please sign in to comment.