Skip to content

Commit

Permalink
Turn this package into a wrapper for protobuf/encoding/protodelim
Browse files Browse the repository at this point in the history
Since Go Protobuf v1.30.0, the protodelim package is available upstream.

The only notable API difference is that protodelim does not return the number of
bytes read, which is why I added the countingReader type to pbutil/decode.go.
  • Loading branch information
stapelberg committed Jan 12, 2024
1 parent 5a0f916 commit 7c0096c
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 84 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ go 1.19

require (
github.com/google/go-cmp v0.5.9
google.golang.org/protobuf v1.28.1
google.golang.org/protobuf v1.31.0
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
78 changes: 34 additions & 44 deletions pbutil/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,40 @@
package pbutil

import (
"encoding/binary"
"errors"
"io"

"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/proto"
)

// TODO: Give error package name prefix in next minor release.
var errInvalidVarint = errors.New("invalid varint32 encountered")
type countingReader struct {
r io.Reader
n int
}

// implements protodelim.Reader
func (r *countingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if n > 0 {
r.n += n
}
return n, err
}

// implements protodelim.Reader
func (c *countingReader) ReadByte() (byte, error) {
var buf [1]byte
for {
n, err := c.Read(buf[:])
if n == 0 && err == nil {
// io.Reader states: Callers should treat a return of 0 and nil as
// indicating that nothing happened; in particular it does not
// indicate EOF.
continue
}
return buf[0], err
}
}

// ReadDelimited decodes a message from the provided length-delimited stream,
// where the length is encoded as 32-bit varint prefix to the message body.
Expand All @@ -37,45 +62,10 @@ var errInvalidVarint = errors.New("invalid varint32 encountered")
// of the stream has been reached in doing so. In that case, any subsequent
// calls return (0, io.EOF).
func ReadDelimited(r io.Reader, m proto.Message) (n int, err error) {
// TODO: Consider allowing the caller to specify a decode buffer in the
// next major version.

// TODO: Consider using error wrapping to annotate error state in pass-
// through cases in the next minor version.

// Per AbstractParser#parsePartialDelimitedFrom with
// CodedInputStream#readRawVarint32.
var headerBuf [binary.MaxVarintLen32]byte
var bytesRead, varIntBytes int
var messageLength uint64
for varIntBytes == 0 { // i.e. no varint has been decoded yet.
if bytesRead >= len(headerBuf) {
return bytesRead, errInvalidVarint
}
// We have to read byte by byte here to avoid reading more bytes
// than required. Each read byte is appended to what we have
// read before.
newBytesRead, err := r.Read(headerBuf[bytesRead : bytesRead+1])
if newBytesRead == 0 {
if err != nil {
return bytesRead, err
}
// A Reader should not return (0, nil); but if it does, it should
// be treated as no-op according to the Reader contract.
continue
}
bytesRead += newBytesRead
// Now present everything read so far to the varint decoder and
// see if a varint can be decoded already.
messageLength, varIntBytes = binary.Uvarint(headerBuf[:bytesRead])
}

messageBuf := make([]byte, messageLength)
newBytesRead, err := io.ReadFull(r, messageBuf)
bytesRead += newBytesRead
if err != nil {
return bytesRead, err
cr := &countingReader{r: r}
opts := protodelim.UnmarshalOptions{
MaxSize: -1,
}

return bytesRead, proto.Unmarshal(messageBuf, m)
err = opts.UnmarshalFrom(cr, m)
return cr.n, err
}
42 changes: 24 additions & 18 deletions pbutil/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package pbutil

import (
"bytes"
"encoding/binary"
"errors"
"io"
"testing"
Expand All @@ -29,29 +30,34 @@ import (

func TestReadDelimitedIllegalVarint(t *testing.T) {
var tests = []struct {
in []byte
n int
err error
name string
in []byte
n int
}{
{
in: []byte{255, 255, 255, 255, 255},
n: 5,
err: errInvalidVarint,
name: "all 0xFF",
in: []byte{255, 255, 255, 255, 255},
n: 5,
},

// Ensure ReadDelimited eventually stops parsing a varint instead of
// looping as long as the input bytes have the continuation bit set.
{
in: []byte{255, 255, 255, 255, 255, 255},
n: 5,
err: errInvalidVarint,
name: "infinite continuation bits",
in: bytes.Repeat([]byte{255}, 2*binary.MaxVarintLen64),
n: binary.MaxVarintLen64,
},
}
for _, test := range tests {
n, err := ReadDelimited(bytes.NewReader(test.in), nil)
if got, want := n, test.n; !cmp.Equal(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", test.in, got, want)
}
if got, want := err, test.err; !errors.Is(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = ?, %#v; want = ?, %#v", test.in, got, want)
}
t.Run(test.name, func(t *testing.T) {
n, err := ReadDelimited(bytes.NewReader(test.in), nil)
if got, want := n, test.n; !cmp.Equal(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", test.in, got, want)
}
if err == nil {
t.Errorf("ReadDelimited(%#v) unexpectedly did not result in an error", test.in)
}
})
}
}

Expand All @@ -61,7 +67,7 @@ func TestReadDelimitedPrematureHeader(t *testing.T) {
if got, want := n, 1; !cmp.Equal(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", data[0:1], got, want)
}
if got, want := err, io.EOF; !errors.Is(got, want) {
if got, want := err, io.ErrUnexpectedEOF; !errors.Is(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = ?, %#v; want = ?, %#v", data[0:1], got, want)
}
}
Expand All @@ -83,7 +89,7 @@ func TestReadDelimitedPrematureHeaderIncremental(t *testing.T) {
if got, want := n, 1; !cmp.Equal(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = %#v, ?; want = %#v, ?", data[0:1], got, want)
}
if got, want := err, io.EOF; !errors.Is(got, want) {
if got, want := err, io.ErrUnexpectedEOF; !errors.Is(got, want) {
t.Errorf("ReadDelimited(%#v, nil) = ?, %#v; want = ?, %#v", data[0:1], got, want)
}
}
Expand Down
21 changes: 2 additions & 19 deletions pbutil/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
package pbutil

import (
"encoding/binary"
"io"

"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/proto"
)

Expand All @@ -28,22 +28,5 @@ import (
// number of bytes written and any applicable error. This is roughly
// equivalent to the companion Java API's MessageLite#writeDelimitedTo.
func WriteDelimited(w io.Writer, m proto.Message) (n int, err error) {
// TODO: Consider allowing the caller to specify an encode buffer in the
// next major version.

buffer, err := proto.Marshal(m)
if err != nil {
return 0, err
}

var buf [binary.MaxVarintLen32]byte
encodedLength := binary.PutUvarint(buf[:], uint64(len(buffer)))

sync, err := w.Write(buf[:encodedLength])
if err != nil {
return sync, err
}

n, err = w.Write(buffer)
return n + sync, err
return protodelim.MarshalTo(w, m)
}

0 comments on commit 7c0096c

Please sign in to comment.