Skip to content

Commit

Permalink
Fix ReadFrom over-read
Browse files Browse the repository at this point in the history
`bufio.NewReader` usually reads ahead. Limit the input to the expected size.

Return `io.ErrUnexpectedEOF` when stream is unexpectedly truncated.

Add regression tests (and apparently some Go 1.19 formatting)

Fixes #108
  • Loading branch information
klauspost committed Sep 12, 2022
1 parent 36eb6b6 commit 2a8cb8e
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 23 deletions.
45 changes: 23 additions & 22 deletions bitset.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ Example use:
As an alternative to BitSets, one should check out the 'big' package,
which provides a (less set-theoretical) view of bitsets.
*/
package bitset

Expand Down Expand Up @@ -434,21 +433,20 @@ func (b *BitSet) NextSet(i uint) (uint, bool) {
// including possibly the current index and up to cap(buffer).
// If the returned slice has len zero, then no more set bits were found
//
// buffer := make([]uint, 256) // this should be reused
// j := uint(0)
// j, buffer = bitmap.NextSetMany(j, buffer)
// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) {
// for k := range buffer {
// do something with buffer[k]
// }
// j += 1
// }
//
// buffer := make([]uint, 256) // this should be reused
// j := uint(0)
// j, buffer = bitmap.NextSetMany(j, buffer)
// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) {
// for k := range buffer {
// do something with buffer[k]
// }
// j += 1
// }
//
// It is possible to retrieve all set bits as follow:
//
// indices := make([]uint, bitmap.Count())
// bitmap.NextSetMany(0, indices)
// indices := make([]uint, bitmap.Count())
// bitmap.NextSetMany(0, indices)
//
// However if bitmap.Count() is large, it might be preferable to
// use several calls to NextSetMany, for performance reasons.
Expand Down Expand Up @@ -932,6 +930,9 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) {
// Read length first
err := binary.Read(stream, binaryOrder, &length)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, err
}
newset := New(uint(length))
Expand All @@ -940,17 +941,17 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) {
return 0, errors.New("unmarshalling error: type mismatch")
}

// Read remaining bytes as set
// current implementation bufio.Reader is more memory efficient than
// binary.Read for large set
reader := bufio.NewReader(stream)
var item = make([]byte, binary.Size(uint64(0))) // one uint64
nWords := uint64(wordsNeeded(uint(length)))
for i := uint64(0); i < nWords; i++ {
if _, err := reader.Read(item); err != nil {
var item [8]byte
nWords := wordsNeeded(uint(length))
reader := bufio.NewReader(io.LimitReader(stream, 8*int64(nWords)))
for i := 0; i < nWords; i++ {
if _, err := reader.Read(item[:]); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, err
}
newset.set[i] = binaryOrder.Uint64(item)
newset.set[i] = binaryOrder.Uint64(item[:])
}

*b = *newset
Expand Down
149 changes: 148 additions & 1 deletion bitset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
package bitset

import (
"bytes"
"compress/gzip"
"encoding"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"strconv"
"testing"
Expand Down Expand Up @@ -1279,7 +1284,7 @@ func TestMarshalUnmarshalJSONWithTrailingData(t *testing.T) {
}

// appending some noise
data = data[:len(data) - 3] // remove "
data = data[:len(data)-3] // remove "
data = append(data, []byte(`AAAAAAAAAA"`)...)

b := new(BitSet)
Expand Down Expand Up @@ -1644,3 +1649,145 @@ func TestDeleteWithBitSetInstance(t *testing.T) {

}
}

func TestWriteTo(t *testing.T) {
const length = 9585
const oneEvery = 97
addBuf := []byte(`12345678`)
bs := New(length)
// Add some bits
for i := uint(0); i < length; i += oneEvery {
bs = bs.Set(i)
}

var buf bytes.Buffer
n, err := bs.WriteTo(&buf)
if err != nil {
t.Fatal(err)
}
wantSz := buf.Len() // Size of the serialized data in bytes.
if n != int64(wantSz) {
t.Errorf("want write size to be %d, got %d", wantSz, n)
}
buf.Write(addBuf) // Add additional data on stream.

// Generate test input for regression tests:
if false {
gzout := bytes.NewBuffer(nil)
gz, err := gzip.NewWriterLevel(gzout, 9)
if err != nil {
t.Fatal(err)
}
gz.Write(buf.Bytes())
gz.Close()
t.Log("Encoded:", base64.StdEncoding.EncodeToString(gzout.Bytes()))
}

// Read back.
bs = New(length)
n, err = bs.ReadFrom(&buf)
if err != nil {
t.Fatal(err)
}
if n != int64(wantSz) {
t.Errorf("want read size to be %d, got %d", wantSz, n)
}
// Check bits
for i := uint(0); i < length; i += oneEvery {
if !bs.Test(i) {
t.Errorf("bit %d was not set", i)
}
}

more, err := io.ReadAll(&buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(more, addBuf) {
t.Fatalf("extra mismatch. got %v, want %v", more, addBuf)
}
}

func TestReadFrom(t *testing.T) {
addBuf := []byte(`12345678`) // Bytes after stream
tests := []struct {
length uint
oneEvery uint
input string // base64+gzipped
wantErr error
}{
{
length: 9585,
oneEvery: 97,
input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwMSACVgYGBg4sIgLMDAwKGARd2BgYGjAFB41noDx6IAJajw64IAajw4UoMajg4ZR4/EaP5pQh1g+MDQyNjE1M7cABAAA//9W5OoOwAQAAA==",
},
{
length: 1337,
oneEvery: 42,
input: "H4sIAAAAAAAC/2IAA1ZLBgYWEIPRAUQKgJkMcCZYisEBzkSSYkSTYqCxAYZGxiamZuYWgAAAAP//D0wyWbgAAAA=",
},
{
length: 1337, // Truncated input.
oneEvery: 42,
input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwARmAQIAAP//vR3xdRkAAAA=",
wantErr: io.ErrUnexpectedEOF,
},
{
length: 1337, // Empty input.
oneEvery: 42,
input: "H4sIAAAAAAAC/wEAAP//AAAAAAAAAAA=",
wantErr: io.ErrUnexpectedEOF,
},
}

for i, test := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
fatalErr := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}

var buf bytes.Buffer
b, err := base64.StdEncoding.DecodeString(test.input)
fatalErr(err)
gz, err := gzip.NewReader(bytes.NewBuffer(b))
fatalErr(err)
_, err = io.Copy(&buf, gz)
fatalErr(err)
fatalErr(gz.Close())

bs := New(test.length)
_, err = bs.ReadFrom(&buf)
if err != nil {
if errors.Is(err, test.wantErr) {
// Correct, nothing more we can test.
return
}
t.Fatalf("did not get expected error %v, got %v", test.wantErr, err)
} else {
if test.wantErr != nil {
t.Fatalf("did not get expected error %v", test.wantErr)
}
}
fatalErr(err)

// Test if correct bits are set.
for i := uint(0); i < test.length; i++ {
want := i%test.oneEvery == 0
got := bs.Test(i)
if want != got {
t.Errorf("bit %d was %v, should be %v", i, got, want)
}
}

more, err := io.ReadAll(&buf)
fatalErr(err)

if !bytes.Equal(more, addBuf) {
t.Errorf("extra mismatch. got %v, want %v", more, addBuf)
}
})
}
}

0 comments on commit 2a8cb8e

Please sign in to comment.