Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(storage): remove protobuf's copy of data on unmarshalling #9526

Merged
merged 13 commits into from Mar 19, 2024
2 changes: 1 addition & 1 deletion storage/go.mod
Expand Up @@ -8,6 +8,7 @@ require (
cloud.google.com/go v0.112.1
cloud.google.com/go/compute/metadata v0.2.3
cloud.google.com/go/iam v1.1.6
github.com/golang/protobuf v1.5.3
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/googleapis/gax-go/v2 v2.12.2
Expand All @@ -26,7 +27,6 @@ require (
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/martian/v3 v3.3.2 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
Expand Down
262 changes: 252 additions & 10 deletions storage/grpc_client.go
Expand Up @@ -27,15 +27,18 @@ import (
"cloud.google.com/go/internal/trace"
gapic "cloud.google.com/go/storage/internal/apiv2"
"cloud.google.com/go/storage/internal/apiv2/storagepb"
"github.com/golang/protobuf/proto"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/googleapi"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protowire"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
)

Expand Down Expand Up @@ -902,12 +905,50 @@ func (c *grpcStorageClient) RewriteObject(ctx context.Context, req *rewriteObjec
return r, nil
}

// bytesCodec is a grpc codec which permits receiving messages as either
// protobuf messages, or as raw []bytes.
type bytesCodec struct {
encoding.Codec
}

func (bytesCodec) Marshal(v any) ([]byte, error) {
vv, ok := v.(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v)
}
return proto.Marshal(vv)
}

func (bytesCodec) Unmarshal(data []byte, v any) error {
switch v := v.(type) {
case *[]byte:
// If gRPC could recycle the data []byte after unmarshaling (through
// buffer pools), we would need to make a copy here.
*v = data
return nil
case proto.Message:
return proto.Unmarshal(data, v)
default:
return fmt.Errorf("can not unmarshal type %T", v)
}
}

func (bytesCodec) Name() string {
// If this isn't "", then gRPC sets the content-subtype of the call to this
// value and we get errors.
return ""
}

func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRangeReaderParams, opts ...storageOption) (r *Reader, err error) {
ctx = trace.StartSpan(ctx, "cloud.google.com/go/storage.grpcStorageClient.NewRangeReader")
defer func() { trace.EndSpan(ctx, err) }()

s := callSettings(c.settings, opts...)

s.gax = append(s.gax, gax.WithGRPCOptions(
grpc.ForceCodec(bytesCodec{}),
))

if s.userProject != "" {
ctx = setUserProjectMetadata(ctx, s.userProject)
}
Expand All @@ -923,6 +964,8 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
req.Generation = params.gen
}

var databuf []byte

// Define a function that initiates a Read with offset and length, assuming
// we have already read seen bytes.
reopen := func(seen int64) (*readStreamResponse, context.CancelFunc, error) {
Expand Down Expand Up @@ -957,12 +1000,23 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
return err
}

msg, err = stream.Recv()
// Receive the message into databuf as a wire-encoded message so we can
// use a custom decoder to avoid an extra copy at the protobuf layer.
err := stream.RecvMsg(&databuf)
// These types of errors show up on the Recv call, rather than the
// initialization of the stream via ReadObject above.
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
return ErrObjectNotExist
}
if err != nil {
return err
}
// Use a custom decoder that uses protobuf unmarshalling for all
// fields except the checksummed data.
// Subsequent receives in Read calls will skip all protobuf
// unmarshalling and directly read the content from the gRPC []byte
// response, since only the first call will contain other fields.
msg, err = readFullObjectResponse(databuf)

return err
}, s.retry, s.idempotent)
Expand Down Expand Up @@ -1008,6 +1062,7 @@ func (c *grpcStorageClient) NewRangeReader(ctx context.Context, params *newRange
leftovers: msg.GetChecksummedData().GetContent(),
settings: s,
zeroRange: params.length == 0,
databuf: databuf,
},
}

Expand Down Expand Up @@ -1406,6 +1461,7 @@ type gRPCReader struct {
stream storagepb.Storage_ReadObjectClient
reopen func(seen int64) (*readStreamResponse, context.CancelFunc, error)
leftovers []byte
databuf []byte
cancel context.CancelFunc
settings *settings
}
Expand Down Expand Up @@ -1436,7 +1492,7 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
}

// Attempt to Recv the next message on the stream.
msg, err := r.recv()
content, err := r.recv()
if err != nil {
return 0, err
}
Expand All @@ -1448,7 +1504,6 @@ func (r *gRPCReader) Read(p []byte) (int, error) {
// present in the response here.
// TODO: Figure out if we need to support decompressive transcoding
// https://cloud.google.com/storage/docs/transcoding.
content := msg.GetChecksummedData().GetContent()
n = copy(p[n:], content)
leftover := len(content) - n
if leftover > 0 {
Expand All @@ -1471,18 +1526,20 @@ func (r *gRPCReader) Close() error {
return nil
}

// recv attempts to Recv the next message on the stream. In the event
// that a retryable error is encountered, the stream will be closed, reopened,
// and Recv again. This will attempt to Recv until one of the following is true:
// recv attempts to Recv the next message on the stream and extract the object
// data that it contains. In the event that a retryable error is encountered,
// the stream will be closed, reopened, and RecvMsg again.
// This will attempt to Recv until one of the following is true:
//
// * Recv is successful
// * A non-retryable error is encountered
// * The Reader's context is canceled
//
// The last error received is the one that is returned, which could be from
// an attempt to reopen the stream.
func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
msg, err := r.stream.Recv()
func (r *gRPCReader) recv() ([]byte, error) {
err := r.stream.RecvMsg(&r.databuf)

var shouldRetry = ShouldRetry
if r.settings.retry != nil && r.settings.retry.shouldRetry != nil {
shouldRetry = r.settings.retry.shouldRetry
Expand All @@ -1492,10 +1549,195 @@ func (r *gRPCReader) recv() (*storagepb.ReadObjectResponse, error) {
// reopen the stream, but will backoff if further attempts are necessary.
// Reopening the stream Recvs the first message, so if retrying is
// successful, the next logical chunk will be returned.
msg, err = r.reopenStream()
msg, err := r.reopenStream()
tritone marked this conversation as resolved.
Show resolved Hide resolved
return msg.GetChecksummedData().GetContent(), err
}

if err != nil {
return nil, err
}

return readObjectResponseContent(r.databuf)
}

// ReadObjectResponse field and subfield numbers.
const (
checksummedDataField = protowire.Number(1)
checksummedDataContentField = protowire.Number(1)
checksummedDataCRC32CField = protowire.Number(2)
objectChecksumsField = protowire.Number(2)
contentRangeField = protowire.Number(3)
metadataField = protowire.Number(4)
)

// readObjectResponseContent returns the checksummed_data.content field of a
// ReadObjectResponse message, or an error if the message is invalid.
// This can be used on recvs of objects after the first recv, since only the
// first message will contain non-data fields.
func readObjectResponseContent(b []byte) ([]byte, error) {
checksummedData, err := readProtoBytes(b, checksummedDataField)
if err != nil {
return b, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", err)
}
content, err := readProtoBytes(checksummedData, checksummedDataContentField)
if err != nil {
return content, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", err)
}

return msg, err
return content, nil
}

// readFullObjectResponse returns the ReadObjectResponse that is encoded in the
// wire-encoded message buffer b, or an error if the message is invalid.
// This must be used on the first recv of an object as it may contain all fields
// of ReadObjectResponse, and we use or pass on those fields to the user.
// This function is essentially identical to proto.Unmarshal, except it aliases
// the data in the input []byte. If the proto library adds a feature to
// Unmarshal that does that, this function can be dropped.
func readFullObjectResponse(b []byte) (*storagepb.ReadObjectResponse, error) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to note that this function is essentially identical to proto.Unmarshal, except it aliases the data in the input []byte. If we ever add a feature to Unmarshal that does that, this function can be dropped.

msg := &storagepb.ReadObjectResponse{}

// Loop over the entire message, extracting fields as we go. This does not
// handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags.
off := 0
for off < len(b) {
// Consume the next tag. This will tell us which field is next in the
// buffer, its type, and how much space it takes up.
fieldNum, fieldType, fieldLength := protowire.ConsumeTag(b[off:])
if fieldLength < 0 {
return nil, protowire.ParseError(fieldLength)
}
off += fieldLength

// Unmarshal the field according to its type. Only fields that are not
// nil will be present.
switch {
case fieldNum == checksummedDataField && fieldType == protowire.BytesType:
// The ChecksummedData field was found. Initialize the struct.
msg.ChecksummedData = &storagepb.ChecksummedData{}

// Get the bytes corresponding to the checksummed data.
fieldContent, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData: %v", protowire.ParseError(n))
}
off += n

// Get the nested fields. We need to do this manually as it contains
// the object content bytes.
contentOff := 0
for contentOff < len(fieldContent) {
gotNum, gotTyp, n := protowire.ConsumeTag(fieldContent[contentOff:])
if n < 0 {
return nil, protowire.ParseError(n)
}
contentOff += n

switch {
case gotNum == checksummedDataContentField && gotTyp == protowire.BytesType:
// Get the content bytes.
bytes, n := protowire.ConsumeBytes(fieldContent[contentOff:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Content: %v", protowire.ParseError(n))
}
msg.ChecksummedData.Content = bytes
contentOff += n
case gotNum == checksummedDataCRC32CField && gotTyp == protowire.Fixed32Type:
v, n := protowire.ConsumeFixed32(fieldContent[contentOff:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ChecksummedData.Crc32C: %v", protowire.ParseError(n))
}
msg.ChecksummedData.Crc32C = &v
contentOff += n
default:
n = protowire.ConsumeFieldValue(gotNum, gotTyp, fieldContent[contentOff:])
if n < 0 {
return nil, protowire.ParseError(n)
}
contentOff += n
}
}
case fieldNum == objectChecksumsField && fieldType == protowire.BytesType:
// The field was found. Initialize the struct.
msg.ObjectChecksums = &storagepb.ObjectChecksums{}

// Get the bytes corresponding to the checksums.
bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ObjectChecksums: %v", protowire.ParseError(n))
}
off += n

// Unmarshal.
if err := proto.Unmarshal(bytes, msg.ObjectChecksums); err != nil {
return nil, err
}
case fieldNum == contentRangeField && fieldType == protowire.BytesType:
msg.ContentRange = &storagepb.ContentRange{}

bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.ContentRange: %v", protowire.ParseError(n))
}
off += n

if err := proto.Unmarshal(bytes, msg.ContentRange); err != nil {
return nil, err
}
case fieldNum == metadataField && fieldType == protowire.BytesType:
msg.Metadata = &storagepb.Object{}

bytes, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, fmt.Errorf("invalid ReadObjectResponse.Metadata: %v", protowire.ParseError(n))
}
off += n

if err := proto.Unmarshal(bytes, msg.Metadata); err != nil {
return nil, err
}
default:
fieldLength = protowire.ConsumeFieldValue(fieldNum, fieldType, b[off:])
if fieldLength < 0 {
return nil, fmt.Errorf("default: %v", protowire.ParseError(fieldLength))
}
off += fieldLength
}
}

return msg, nil
}

// readProtoBytes returns the contents of the protobuf field with number num
// and type bytes from a wire-encoded message. If the field cannot be found,
// the returned slice will be nil and no error will be returned.
//
// It does not handle field concatenation, in which the contents of a single field
// are split across multiple protobuf tags. Encoded data containing split fields
// of this form is technically permissable, but uncommon.
func readProtoBytes(b []byte, num protowire.Number) ([]byte, error) {
off := 0
for off < len(b) {
gotNum, gotTyp, n := protowire.ConsumeTag(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
if gotNum == num && gotTyp == protowire.BytesType {
b, n := protowire.ConsumeBytes(b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
return b, nil
}
n = protowire.ConsumeFieldValue(gotNum, gotTyp, b[off:])
if n < 0 {
return nil, protowire.ParseError(n)
}
off += n
}
return nil, nil
}

// reopenStream "closes" the existing stream and attempts to reopen a stream and
Expand Down