Skip to content

Commit

Permalink
Merge branch 'develop' into feat/retry
Browse files Browse the repository at this point in the history
  • Loading branch information
whalecold committed Jun 3, 2024
2 parents 44766d0 + f348aa4 commit e59cb66
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 27 deletions.
8 changes: 5 additions & 3 deletions pkg/remote/codec/grpc/grpc_compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
"errors"
"io"

"github.com/cloudwego/kitex/pkg/rpcinfo"

"github.com/bytedance/gopkg/lang/mcache"

"github.com/cloudwego/kitex/pkg/rpcinfo"

"github.com/cloudwego/kitex/pkg/remote/codec/protobuf/encoding"

"github.com/cloudwego/kitex/pkg/remote"
Expand Down Expand Up @@ -63,7 +63,9 @@ func decodeGRPCFrame(ctx context.Context, in remote.ByteBuffer) ([]byte, error)
}

func compress(compressor encoding.Compressor, data []byte) ([]byte, error) {
defer mcache.Free(data)
if len(data) != 0 {
defer mcache.Free(data)
}
cbuf := &bytes.Buffer{}
z, err := compressor.Compress(cbuf)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion pkg/remote/codec/protobuf/protobuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/serviceinfo"
)

/**
Expand All @@ -50,6 +51,12 @@ func NewProtobufCodec() remote.PayloadCodec {
return &protobufCodec{}
}

// IsProtobufCodec checks if the codec is protobufCodec
func IsProtobufCodec(c remote.PayloadCodec) bool {
_, ok := c.(*protobufCodec)
return ok
}

// protobufCodec implements PayloadMarshaler
type protobufCodec struct{}

Expand Down Expand Up @@ -210,7 +217,7 @@ func (c protobufCodec) Unmarshal(ctx context.Context, message remote.Message, in
}

func (c protobufCodec) Name() string {
return "protobuf"
return serviceinfo.Protobuf.String()
}

// MessageWriterWithContext writes to output bytebuffer
Expand Down
3 changes: 2 additions & 1 deletion pkg/remote/codec/thrift/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/cloudwego/kitex/pkg/remote/codec"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/stats"
)

Expand Down Expand Up @@ -231,7 +232,7 @@ func validateMessageBeforeDecode(message remote.Message, seqID int32, methodName

// Name implements the remote.PayloadCodec interface.
func (c thriftCodec) Name() string {
return "thrift"
return serviceinfo.Thrift.String()
}

// MessageWriterWithContext write to thrift.TProtocol
Expand Down
19 changes: 16 additions & 3 deletions server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ import (
"github.com/cloudwego/kitex/pkg/limiter"
"github.com/cloudwego/kitex/pkg/registry"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
"github.com/cloudwego/kitex/pkg/remote/codec/thrift"
"github.com/cloudwego/kitex/pkg/remote/trans/netpollmux"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/stats"
"github.com/cloudwego/kitex/pkg/streaming"
"github.com/cloudwego/kitex/pkg/utils"
Expand Down Expand Up @@ -207,9 +210,19 @@ func WithCodec(c remote.Codec) Option {
// WithPayloadCodec to set a payloadCodec that handle other payload which not support by kitex
func WithPayloadCodec(c remote.PayloadCodec) Option {
return Option{F: func(o *internal_server.Options, di *utils.Slice) {
di.Push(fmt.Sprintf("WithPayloadCodec(%+v)", c))

o.RemoteOpt.PayloadCodec = c
if thrift.IsThriftCodec(c) {
// default thriftCodec has been registered,
// if using NewThriftCodecWithConfig to set codec mode, just replace the registered one
di.Push(fmt.Sprintf("ResetThriftPayloadCodec(%+v)", c))
remote.PutPayloadCode(serviceinfo.Thrift, c)
} else if protobuf.IsProtobufCodec(c) {
di.Push(fmt.Sprintf("ResetProtobufPayloadCodec(%+v)", c))
remote.PutPayloadCode(serviceinfo.Protobuf, c)
} else {
di.Push(fmt.Sprintf("WithPayloadCodec(%+v)", c))
// if specify RemoteOpt.PayloadCodec, then the priority is highest, all payload decode will use this one
o.RemoteOpt.PayloadCodec = c
}
}}
}

Expand Down
133 changes: 114 additions & 19 deletions server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@ import (
"context"
"math/rand"
"net"
"reflect"
"testing"
"time"

"github.com/cloudwego/kitex/internal/mocks"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/diagnosis"
"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/generic"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
"github.com/cloudwego/kitex/pkg/remote/codec/thrift"
"github.com/cloudwego/kitex/pkg/remote/trans/detection"
"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/remote/trans/netpollmux"
Expand All @@ -38,6 +41,7 @@ import (
"github.com/cloudwego/kitex/pkg/serviceinfo"
"github.com/cloudwego/kitex/pkg/stats"
"github.com/cloudwego/kitex/pkg/utils"
"github.com/cloudwego/kitex/transport"
)

// TestOptionDebugInfo tests the creation of a server with DebugService option
Expand Down Expand Up @@ -307,30 +311,115 @@ func TestMuxTransportOption(t *testing.T) {

// TestPayloadCodecOption tests the creation of a server with RemoteOpt.PayloadCodec option
func TestPayloadCodecOption(t *testing.T) {
svr1 := NewServer()
time.AfterFunc(100*time.Millisecond, func() {
err := svr1.Stop()
t.Run("NotSetPayloadCodec", func(t *testing.T) {
svr := NewServer()
time.AfterFunc(100*time.Millisecond, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
err = svr.Run()
test.Assert(t, err == nil, err)
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec == nil)

tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift)
tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err := remote.GetPayloadCodec(tRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, thrift.IsThriftCodec(pc))

pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf)
pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err = remote.GetPayloadCodec(pRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, protobuf.IsProtobufCodec(pc))
})
t.Run("SetPreRegisteredProtobufCodec", func(t *testing.T) {
svr := NewServer(WithPayloadCodec(protobuf.NewProtobufCodec()))
time.AfterFunc(100*time.Millisecond, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
err = svr.Run()
test.Assert(t, err == nil, err)
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec == nil)

tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift)
tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err := remote.GetPayloadCodec(tRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, thrift.IsThriftCodec(pc))

pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf)
pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err = remote.GetPayloadCodec(pRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, protobuf.IsProtobufCodec(pc))
})
err := svr1.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
err = svr1.Run()
test.Assert(t, err == nil, err)
iSvr1 := svr1.(*server)
test.Assert(t, iSvr1.opt.RemoteOpt.PayloadCodec == nil)

svr2 := NewServer(WithPayloadCodec(protobuf.NewProtobufCodec()))
time.AfterFunc(100*time.Millisecond, func() {
err := svr2.Stop()
t.Run("SetPreRegisteredThriftCodec", func(t *testing.T) {
thriftCodec := thrift.NewThriftCodecDisableFastMode(false, true)
svr := NewServer(WithPayloadCodec(thriftCodec))
time.AfterFunc(100*time.Millisecond, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
err = svr.Run()
test.Assert(t, err == nil, err)
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec == nil)

tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift)
tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err := remote.GetPayloadCodec(tRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, thrift.IsThriftCodec(pc))
test.Assert(t, reflect.DeepEqual(pc, thriftCodec))

pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf)
pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err = remote.GetPayloadCodec(pRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, protobuf.IsProtobufCodec(pc))
})

t.Run("SetNonPreRegisteredCodec", func(t *testing.T) {
// generic.BinaryThriftGeneric().PayloadCodec() is not the pre registered codec, RemoteOpt.PayloadCodec won't be nil
binaryThriftCodec := generic.BinaryThriftGeneric().PayloadCodec()
svr := NewServer(WithPayloadCodec(binaryThriftCodec))
time.AfterFunc(100*time.Millisecond, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
err = svr.Run()
test.Assert(t, err == nil, err)
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.RemoteOpt.PayloadCodec != nil)
test.DeepEqual(t, iSvr.opt.RemoteOpt.PayloadCodec, binaryThriftCodec)

tRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Thrift)
tRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err := remote.GetPayloadCodec(tRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, !thrift.IsThriftCodec(pc))
test.Assert(t, reflect.DeepEqual(pc, binaryThriftCodec))

pRecvMsg := NewRemoteMsgWithPayloadType(serviceinfo.Protobuf)
pRecvMsg.SetPayloadCodec(iSvr.opt.RemoteOpt.PayloadCodec)
pc, err = remote.GetPayloadCodec(pRecvMsg)
test.Assert(t, err == nil)
test.Assert(t, !protobuf.IsProtobufCodec(pc))
test.Assert(t, reflect.DeepEqual(pc, binaryThriftCodec))
})
err = svr2.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
err = svr2.Run()
test.Assert(t, err == nil, err)
iSvr2 := svr2.(*server)
test.Assert(t, iSvr2.opt.RemoteOpt.PayloadCodec != nil)
test.DeepEqual(t, iSvr2.opt.RemoteOpt.PayloadCodec, protobuf.NewProtobufCodec())
}

// TestRemoteOptGRPCCfgUintValueOption tests the creation of a server with RemoteOpt.GRPCCfg option
Expand Down Expand Up @@ -474,3 +563,9 @@ func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) {
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.RefuseTrafficWithoutServiceName)
}

func NewRemoteMsgWithPayloadType(ct serviceinfo.PayloadCodec) remote.Message {
remoteMsg := remote.NewMessage(nil, nil, nil, remote.Call, remote.Server)
remoteMsg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, ct))
return remoteMsg
}

0 comments on commit e59cb66

Please sign in to comment.