diff --git a/pkg/remote/codec/grpc/grpc_compress.go b/pkg/remote/codec/grpc/grpc_compress.go index 3d257bb9ef..a8fe32a0a9 100644 --- a/pkg/remote/codec/grpc/grpc_compress.go +++ b/pkg/remote/codec/grpc/grpc_compress.go @@ -23,10 +23,10 @@ import ( "errors" "io" - "github.com/cloudwego/kitex/pkg/rpcinfo" - "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding" "github.com/cloudwego/kitex/pkg/remote" @@ -63,7 +63,9 @@ func decodeGRPCFrame(ctx context.Context, in remote.ByteBuffer) ([]byte, error) } func compress(compressor encoding.Compressor, data []byte) ([]byte, error) { - defer mcache.Free(data) + if len(data) != 0 { + defer mcache.Free(data) + } cbuf := &bytes.Buffer{} z, err := compressor.Compress(cbuf) if err != nil { diff --git a/pkg/remote/codec/protobuf/protobuf.go b/pkg/remote/codec/protobuf/protobuf.go index 3745184d02..a5a11e0aff 100644 --- a/pkg/remote/codec/protobuf/protobuf.go +++ b/pkg/remote/codec/protobuf/protobuf.go @@ -26,6 +26,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" + "github.com/cloudwego/kitex/pkg/serviceinfo" ) /** @@ -50,6 +51,12 @@ func NewProtobufCodec() remote.PayloadCodec { return &protobufCodec{} } +// IsProtobufCodec checks if the codec is protobufCodec +func IsProtobufCodec(c remote.PayloadCodec) bool { + _, ok := c.(*protobufCodec) + return ok +} + // protobufCodec implements PayloadMarshaler type protobufCodec struct{} @@ -210,7 +217,7 @@ func (c protobufCodec) Unmarshal(ctx context.Context, message remote.Message, in } func (c protobufCodec) Name() string { - return "protobuf" + return serviceinfo.Protobuf.String() } // MessageWriterWithContext writes to output bytebuffer diff --git a/pkg/remote/codec/thrift/thrift.go b/pkg/remote/codec/thrift/thrift.go index 3ec0ed0301..eb9771e965 100644 --- a/pkg/remote/codec/thrift/thrift.go +++ b/pkg/remote/codec/thrift/thrift.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/kitex/pkg/remote/codec" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" ) @@ -231,7 +232,7 @@ func validateMessageBeforeDecode(message remote.Message, seqID int32, methodName // Name implements the remote.PayloadCodec interface. func (c thriftCodec) Name() string { - return "thrift" + return serviceinfo.Thrift.String() } // MessageWriterWithContext write to thrift.TProtocol diff --git a/server/option.go b/server/option.go index 59e587d7cd..a9c67fc37d 100644 --- a/server/option.go +++ b/server/option.go @@ -31,9 +31,12 @@ import ( "github.com/cloudwego/kitex/pkg/limiter" "github.com/cloudwego/kitex/pkg/registry" "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/remote/trans/netpollmux" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/streaming" "github.com/cloudwego/kitex/pkg/utils" @@ -207,9 +210,19 @@ func WithCodec(c remote.Codec) Option { // WithPayloadCodec to set a payloadCodec that handle other payload which not support by kitex func WithPayloadCodec(c remote.PayloadCodec) Option { return Option{F: func(o *internal_server.Options, di *utils.Slice) { - di.Push(fmt.Sprintf("WithPayloadCodec(%+v)", c)) - - o.RemoteOpt.PayloadCodec = c + if thrift.IsThriftCodec(c) { + // default thriftCodec has been registered, + // if using NewThriftCodecWithConfig to set codec mode, just replace the registered one + di.Push(fmt.Sprintf("ResetThriftPayloadCodec(%+v)", c)) + remote.PutPayloadCode(serviceinfo.Thrift, c) + } else if protobuf.IsProtobufCodec(c) { + di.Push(fmt.Sprintf("ResetProtobufPayloadCodec(%+v)", c)) + remote.PutPayloadCode(serviceinfo.Protobuf, c) + } else { + di.Push(fmt.Sprintf("WithPayloadCodec(%+v)", c)) + // if specify RemoteOpt.PayloadCodec, then the priority is highest, all payload decode will use this one + o.RemoteOpt.PayloadCodec = c + } }} } diff --git a/server/option_test.go b/server/option_test.go index 5ccb28b16c..1a9609b32e 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -20,6 +20,7 @@ import ( "context" "math/rand" "net" + "reflect" "testing" "time" @@ -27,8 +28,10 @@ import ( "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/diagnosis" "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/generic" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/remote/trans/detection" "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" "github.com/cloudwego/kitex/pkg/remote/trans/netpollmux" @@ -38,6 +41,7 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/kitex/transport" ) // TestOptionDebugInfo tests the creation of a server with DebugService option @@ -307,30 +311,115 @@ func TestMuxTransportOption(t *testing.T) { // TestPayloadCodecOption tests the creation of a server with RemoteOpt.PayloadCodec option func TestPayloadCodecOption(t *testing.T) { - svr1 := NewServer() - time.AfterFunc(100*time.Millisecond, func() { - err := svr1.Stop() + t.Run("NotSetPayloadCodec", func(t *testing.T) { + svr := NewServer() + time.AfterFunc(100*time.Millisecond, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + test.Assert(t, err == nil, err) + err = svr.Run() test.Assert(t, err == nil, err) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec == nil) + + tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift) + tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err := remote.GetPayloadCodec(tRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, thrift.IsThriftCodec(pc)) + + pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf) + pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err = remote.GetPayloadCodec(pRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, protobuf.IsProtobufCodec(pc)) + }) + t.Run("SetPreRegisteredProtobufCodec", func(t *testing.T) { + svr := NewServer(WithPayloadCodec(protobuf.NewProtobufCodec())) + time.AfterFunc(100*time.Millisecond, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + test.Assert(t, err == nil, err) + err = svr.Run() + test.Assert(t, err == nil, err) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec == nil) + + tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift) + tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err := remote.GetPayloadCodec(tRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, thrift.IsThriftCodec(pc)) + + pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf) + pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err = remote.GetPayloadCodec(pRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, protobuf.IsProtobufCodec(pc)) }) - err := svr1.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) - test.Assert(t, err == nil, err) - err = svr1.Run() - test.Assert(t, err == nil, err) - iSvr1 := svr1.(*server) - test.Assert(t, iSvr1.opt.RemoteOpt.PayloadCodec == nil) - svr2 := NewServer(WithPayloadCodec(protobuf.NewProtobufCodec())) - time.AfterFunc(100*time.Millisecond, func() { - err := svr2.Stop() + t.Run("SetPreRegisteredThriftCodec", func(t *testing.T) { + thriftCodec := thrift.NewThriftCodecDisableFastMode(false, true) + svr := NewServer(WithPayloadCodec(thriftCodec)) + time.AfterFunc(100*time.Millisecond, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + test.Assert(t, err == nil, err) + err = svr.Run() test.Assert(t, err == nil, err) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec == nil) + + tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift) + tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err := remote.GetPayloadCodec(tRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, thrift.IsThriftCodec(pc)) + test.Assert(t, reflect.DeepEqual(pc, thriftCodec)) + + pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf) + pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err = remote.GetPayloadCodec(pRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, protobuf.IsProtobufCodec(pc)) + }) + + t.Run("SetNonPreRegisteredCodec", func(t *testing.T) { + // generic.BinaryThriftGeneric().PayloadCodec() is not the pre registered codec, RemoteOpt.PayloadCodec won't be nil + binaryThriftCodec := generic.BinaryThriftGeneric().PayloadCodec() + svr := NewServer(WithPayloadCodec(binaryThriftCodec)) + time.AfterFunc(100*time.Millisecond, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + test.Assert(t, err == nil, err) + err = svr.Run() + test.Assert(t, err == nil, err) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec != nil) + test.DeepEqual(t, iSvr.opt.RemoteOpt.PayloadCodec, binaryThriftCodec) + + tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift) + tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err := remote.GetPayloadCodec(tRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, !thrift.IsThriftCodec(pc)) + test.Assert(t, reflect.DeepEqual(pc, binaryThriftCodec)) + + pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf) + pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec) + pc, err = remote.GetPayloadCodec(pRecvMsg) + test.Assert(t, err == nil) + test.Assert(t, !protobuf.IsProtobufCodec(pc)) + test.Assert(t, reflect.DeepEqual(pc, binaryThriftCodec)) }) - err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) - test.Assert(t, err == nil, err) - err = svr2.Run() - test.Assert(t, err == nil, err) - iSvr2 := svr2.(*server) - test.Assert(t, iSvr2.opt.RemoteOpt.PayloadCodec != nil) - test.DeepEqual(t, iSvr2.opt.RemoteOpt.PayloadCodec, protobuf.NewProtobufCodec()) } // TestRemoteOptGRPCCfgUintValueOption tests the creation of a server with RemoteOpt.GRPCCfg option @@ -474,3 +563,9 @@ func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) { iSvr := svr.(*server) test.Assert(t, iSvr.opt.RefuseTrafficWithoutServiceName) } + +func NewRemoteMsgWithPayloadType(ct serviceinfo.PayloadCodec) remote.Message { + remoteMsg := remote.NewMessage(nil, nil, nil, remote.Call, remote.Server) + remoteMsg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, ct)) + return remoteMsg +}