diff --git a/rpc_util.go b/rpc_util.go index b7723aa09cb..a4b6bc6873c 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -640,14 +640,18 @@ func encode(c baseCodec, msg any) ([]byte, error) { return b, nil } -// compress returns the input bytes compressed by compressor or cp. If both -// compressors are nil, returns nil. +// compress returns the input bytes compressed by compressor or cp. +// If both compressors are nil, or if the message has zero length, returns nil, +// indicating no compression was done. // // TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor. func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) { if compressor == nil && cp == nil { return nil, nil } + if len(in) == 0 { + return nil, nil + } wrapErr := func(err error) error { return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) } diff --git a/test/compressor_test.go b/test/compressor_test.go index 91e21e5266e..a18d14f4ac7 100644 --- a/test/compressor_test.go +++ b/test/compressor_test.go @@ -290,6 +290,7 @@ func (s) TestSetSendCompressorSuccess(t *testing.T) { for _, tt := range []struct { name string desc string + payload *testpb.Payload dialOpts []grpc.DialOption resCompressor string wantCompressInvokes int32 @@ -297,12 +298,21 @@ func (s) TestSetSendCompressorSuccess(t *testing.T) { { name: "identity_request_and_gzip_response", desc: "request is uncompressed and response is gzip compressed", + payload: &testpb.Payload{Body: []byte("payload")}, resCompressor: "gzip", wantCompressInvokes: 1, }, + { + name: "identity_request_and_empty_response", + desc: "request is uncompressed and response is gzip compressed", + payload: nil, + resCompressor: "gzip", + wantCompressInvokes: 0, + }, { name: "gzip_request_and_identity_response", desc: "request is gzip compressed and response is uncompressed with identity", + payload: &testpb.Payload{Body: []byte("payload")}, resCompressor: "identity", dialOpts: []grpc.DialOption{ // Use WithCompressor instead of UseCompressor to avoid counting @@ -314,24 +324,26 @@ func (s) TestSetSendCompressorSuccess(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { t.Run("unary", func(t *testing.T) { - testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) + testUnarySetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) }) t.Run("stream", func(t *testing.T) { - testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) + testStreamSetSendCompressorSuccess(t, tt.payload, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts) }) }) } } -func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { +func testUnarySetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { wc := setupGzipWrapCompressor(t) ss := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil { return nil, err } - return &testpb.Empty{}, nil + return &testpb.SimpleResponse{ + Payload: payload, + }, nil }, } if err := ss.Start(nil, dialOpts...); err != nil { @@ -342,7 +354,7 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantC ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { t.Fatalf("Unexpected unary call error, got: %v, want: nil", err) } @@ -352,7 +364,7 @@ func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantC } } -func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { +func testStreamSetSendCompressorSuccess(t *testing.T, payload *testpb.Payload, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) { wc := setupGzipWrapCompressor(t) ss := &stubserver.StubServer{ FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { @@ -364,7 +376,9 @@ func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, want return err } - return stream.Send(&testpb.StreamingOutputCallResponse{}) + return stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: payload, + }) }, } if err := ss.Start(nil, dialOpts...); err != nil {