From d076e14b4849f4262f6f50042a9370ec5ce0116d Mon Sep 17 00:00:00 2001 From: Jaewan Park Date: Sat, 24 Feb 2024 05:49:17 +0900 Subject: [PATCH] rpc_util: Fix RecvBufferPool deactivation issues (#6766) --- experimental/shared_buffer_pool_test.go | 196 ++++++++++++++++++------ rpc_util.go | 54 ++++--- server.go | 5 +- 3 files changed, 188 insertions(+), 67 deletions(-) diff --git a/experimental/shared_buffer_pool_test.go b/experimental/shared_buffer_pool_test.go index c13b2dc0221..7c4074e18bd 100644 --- a/experimental/shared_buffer_pool_test.go +++ b/experimental/shared_buffer_pool_test.go @@ -26,12 +26,12 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/experimental" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" testgrpc "google.golang.org/grpc/interop/grpc_testing" - testpb "google.golang.org/grpc/interop/grpc_testing" ) type s struct { @@ -44,59 +44,161 @@ func Test(t *testing.T) { const defaultTestTimeout = 10 * time.Second -func (s) TestRecvBufferPool(t *testing.T) { - ss := &stubserver.StubServer{ - FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { - for i := 0; i < 10; i++ { - preparedMsg := &grpc.PreparedMsg{} - err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{ - Payload: &testpb.Payload{ - Body: []byte{'0' + uint8(i)}, - }, - }) +func (s) TestRecvBufferPoolStream(t *testing.T) { + tcs := []struct { + name string + callOpts []grpc.CallOption + }{ + { + name: "default", + }, + { + name: "useCompressor", + callOpts: []grpc.CallOption{ + grpc.UseCompressor(gzip.Name), + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + const reqCount = 10 + + ss := &stubserver.StubServer{ + FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { + for i := 0; i < reqCount; i++ { + preparedMsg := &grpc.PreparedMsg{} + if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{ + Payload: &testgrpc.Payload{ + Body: []byte{'0' + uint8(i)}, + }, + }); err != nil { + return err + } + stream.SendMsg(preparedMsg) + } + return nil + }, + } + + pool := &checkBufferPool{} + sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)} + dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)} + if err := ss.Start(sopts, dopts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...) + if err != nil { + t.Fatalf("ss.Client.FullDuplexCall failed: %v", err) + } + + var ngot int + var buf bytes.Buffer + for { + reply, err := stream.Recv() + if err == io.EOF { + break + } if err != nil { - return err + t.Fatal(err) } - stream.SendMsg(preparedMsg) + ngot++ + if buf.Len() > 0 { + buf.WriteByte(',') + } + buf.Write(reply.GetPayload().GetBody()) } - return nil - }, + if want := 10; ngot != want { + t.Fatalf("Got %d replies, want %d", ngot, want) + } + if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { + t.Fatalf("Got replies %q; want %q", got, want) + } + + if len(pool.puts) != reqCount { + t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts)) + } + }) } - sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())} - dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())} - if err := ss.Start(sopts, dopts...); err != nil { - t.Fatalf("Error starting endpoint server: %v", err) +} + +func (s) TestRecvBufferPoolUnary(t *testing.T) { + tcs := []struct { + name string + callOpts []grpc.CallOption + }{ + { + name: "default", + }, + { + name: "useCompressor", + callOpts: []grpc.CallOption{ + grpc.UseCompressor(gzip.Name), + }, + }, } - defer ss.Stop() - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + const largeSize = 1024 - stream, err := ss.Client.FullDuplexCall(ctx) - if err != nil { - t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) - } + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) { + return &testgrpc.SimpleResponse{ + Payload: &testgrpc.Payload{ + Body: make([]byte, largeSize), + }, + }, nil + }, + } - var ngot int - var buf bytes.Buffer - for { - reply, err := stream.Recv() - if err == io.EOF { - break - } - if err != nil { - t.Fatal(err) - } - ngot++ - if buf.Len() > 0 { - buf.WriteByte(',') - } - buf.Write(reply.GetPayload().GetBody()) - } - if want := 10; ngot != want { - t.Errorf("Got %d replies, want %d", ngot, want) - } - if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want { - t.Errorf("Got replies %q; want %q", got, want) + pool := &checkBufferPool{} + sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)} + dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)} + if err := ss.Start(sopts, dopts...); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + const reqCount = 10 + for i := 0; i < reqCount; i++ { + if _, err := ss.Client.UnaryCall( + ctx, + &testgrpc.SimpleRequest{ + Payload: &testgrpc.Payload{ + Body: make([]byte, largeSize), + }, + }, + tc.callOpts..., + ); err != nil { + t.Fatalf("ss.Client.UnaryCall failed: %v", err) + } + } + + const bufferCount = reqCount * 2 // req + resp + if len(pool.puts) != bufferCount { + t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts)) + } + }) } } + +type checkBufferPool struct { + puts [][]byte +} + +func (p *checkBufferPool) Get(size int) []byte { + return make([]byte, size) +} + +func (p *checkBufferPool) Put(bs *[]byte) { + p.puts = append(p.puts, *bs) +} diff --git a/rpc_util.go b/rpc_util.go index d17ede0fa43..82493d237bc 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -744,17 +744,19 @@ type payloadInfo struct { uncompressedBytes []byte } -func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) { - pf, buf, err := p.recvMsg(maxReceiveMessageSize) +// recvAndDecompress reads a message from the stream, decompressing it if necessary. +// +// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as +// the buffer is no longer needed. +func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, +) (uncompressedBuf []byte, cancel func(), err error) { + pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize) if err != nil { - return nil, err - } - if payInfo != nil { - payInfo.compressedLength = len(buf) + return nil, nil, err } if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { - return nil, st.Err() + return nil, nil, st.Err() } var size int @@ -762,21 +764,35 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // use this decompressor as the default. if dc != nil { - buf, err = dc.Do(bytes.NewReader(buf)) - size = len(buf) + uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf)) + size = len(uncompressedBuf) } else { - buf, size, err = decompress(compressor, buf, maxReceiveMessageSize) + uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize) } if err != nil { - return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) + return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err) } if size > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java // implementation. - return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) + return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize) } + } else { + uncompressedBuf = compressedBuf } - return buf, nil + + if payInfo != nil { + payInfo.compressedLength = len(compressedBuf) + payInfo.uncompressedBytes = uncompressedBuf + + cancel = func() {} + } else { + cancel = func() { + p.recvBufferPool.Put(&compressedBuf) + } + } + + return uncompressedBuf, cancel, nil } // Using compressor, decompress d, returning data and size. @@ -796,6 +812,9 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize // size is used as an estimate to size the buffer, but we // will read more data if available. // +MinRead so ReadFrom will not reallocate if size is correct. + // + // TODO: If we ensure that the buffer size is the same as the DecompressedSize, + // we can also utilize the recv buffer pool here. buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead)) bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1)) return buf.Bytes(), int(bytesRead), err @@ -811,18 +830,15 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error { - buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) + buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor) if err != nil { return err } + defer cancel() + if err := c.Unmarshal(buf, m); err != nil { return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err) } - if payInfo != nil { - payInfo.uncompressedBytes = buf - } else { - p.recvBufferPool.Put(&buf) - } return nil } diff --git a/server.go b/server.go index 0bf5c78b0dd..155a512bc3e 100644 --- a/server.go +++ b/server.go @@ -1342,7 +1342,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor if len(shs) != 0 || len(binlogs) != 0 { payInfo = &payloadInfo{} } - d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) + + d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) if err != nil { if e := t.WriteStatus(stream, status.Convert(err)); e != nil { channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) @@ -1353,6 +1354,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor t.IncrMsgRecv() } df := func(v any) error { + defer cancel() + if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) }