Skip to content

Commit

Permalink
fix: EnableSkipDecoder switch not working in Buffer Protocol scenario (
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Jun 7, 2024
1 parent bb385e7 commit c028c86
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 21 deletions.
4 changes: 0 additions & 4 deletions pkg/remote/codec/thrift/skip_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ import (
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
)

const (
EnableSkipDecoder CodecType = 0b10000
)

// skipDecoder is used to parse the input byte-by-byte and skip the thrift payload
// for making use of Frugal and FastCodec in standard Thrift Binary Protocol scenario.
type skipDecoder struct {
Expand Down
7 changes: 6 additions & 1 deletion pkg/remote/codec/thrift/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ const (
FastWrite CodecType = 0b0001
FastRead CodecType = 0b0010

FastReadWrite = FastRead | FastWrite
FastReadWrite = FastRead | FastWrite
EnableSkipDecoder CodecType = 0b10000
)

var (
Expand Down Expand Up @@ -196,6 +197,10 @@ func (c thriftCodec) Unmarshal(ctx context.Context, message remote.Message, in r
data := message.Data()
msgBeginLen := bthrift.Binary.MessageBeginLength(methodName, msgType, seqID)
dataLen := message.PayloadLen() - msgBeginLen - bthrift.Binary.MessageEndLength()
// For Buffer Protocol, dataLen would be negative. Set it to zero so as not to confuse
if dataLen < 0 {
dataLen = 0
}

ri := message.RPCInfo()
rpcinfo.Record(ctx, ri, stats.WaitReadStart, nil)
Expand Down
18 changes: 13 additions & 5 deletions pkg/remote/codec/thrift/thrift_frugal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ func TestFrugalCodec(t *testing.T) {
ctx := context.Background()
codec := &thriftCodec{FrugalRead | FrugalWrite}

testFrugalDataConversion(t, ctx, codec)
testFrugalDataConversion(t, ctx, codec, transport.TTHeader)
})
t.Run("fallback to frugal and data has tag", func(t *testing.T) {
ctx := context.Background()
codec := NewThriftCodec()

testFrugalDataConversion(t, ctx, codec)
testFrugalDataConversion(t, ctx, codec, transport.TTHeader)
})
t.Run("configure BasicCodec to disable frugal fallback", func(t *testing.T) {
ctx := context.Background()
Expand All @@ -142,23 +142,31 @@ func TestFrugalCodec(t *testing.T) {
out.Flush()
test.Assert(t, err != nil)
})
t.Run("configure frugal and SkipDecoder for Buffer Protocol", func(t *testing.T) {
ctx := context.Background()
codec := NewThriftCodecWithConfig(FrugalRead | FrugalWrite | EnableSkipDecoder)

testFrugalDataConversion(t, ctx, codec, transport.PurePayload)
})
})
}
}

func testFrugalDataConversion(t *testing.T, ctx context.Context, codec remote.PayloadCodec) {
func testFrugalDataConversion(t *testing.T, ctx context.Context, codec remote.PayloadCodec, protocol transport.Protocol) {
for _, tb := range transportBuffers {
t.Run(tb.Name, func(t *testing.T) {
// encode client side
sendMsg := initFrugalTagSendMsg(transport.TTHeader)
sendMsg := initFrugalTagSendMsg(protocol)
buf := tb.NewBuffer()
err := codec.Marshal(ctx, sendMsg, buf)
test.Assert(t, err == nil, err)
buf.Flush()

// decode server side
recvMsg := initFrugalTagRecvMsg()
recvMsg.SetPayloadLen(buf.ReadableLen())
if protocol != transport.PurePayload {
recvMsg.SetPayloadLen(buf.ReadableLen())
}
test.Assert(t, err == nil, err)
err = codec.Unmarshal(ctx, recvMsg, buf)
test.Assert(t, err == nil, err)
Expand Down
79 changes: 68 additions & 11 deletions pkg/remote/codec/thrift/thrift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,7 @@ func TestNormal(t *testing.T) {
test.Assert(t, err == nil, err)

// compare Req Arg
sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req
recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req
test.Assert(t, sendReq.Msg == recvReq.Msg)
test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList))
test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap))
for i, item := range sendReq.StrList {
test.Assert(t, item == recvReq.StrList[i])
}
for k := range sendReq.StrMap {
test.Assert(t, sendReq.StrMap[k] == recvReq.StrMap[k])
}
compare(t, sendMsg, recvMsg)
})
}
}
Expand Down Expand Up @@ -229,6 +219,59 @@ func TestTransErrorUnwrap(t *testing.T) {
test.Assert(t, uwErr2.Error() == errMsg)
}

func TestSkipDecoder(t *testing.T) {
testcases := []struct {
desc string
codec remote.PayloadCodec
protocol transport.Protocol
}{
{
desc: "Disable SkipDecoder, fallback to Apache Thrift Codec for Buffer Protocol",
codec: NewThriftCodec(),
protocol: transport.PurePayload,
},
{
desc: "Disable SkipDecoder, using FastCodec for TTHeader Protocol",
codec: NewThriftCodec(),
protocol: transport.TTHeader,
},
{
desc: "Enable SkipDecoder, using FastCodec for Buffer Protocol",
codec: NewThriftCodecWithConfig(FastRead | FastWrite | EnableSkipDecoder),
protocol: transport.PurePayload,
},
{
desc: "Enable SkipDecoder, using FastCodec for TTHeader Protocol",
codec: NewThriftCodecWithConfig(FastRead | FastWrite | EnableSkipDecoder),
protocol: transport.TTHeader,
},
}

for _, tc := range testcases {
for _, tb := range transportBuffers {
t.Run(tc.desc+"#"+tb.Name, func(t *testing.T) {
// encode client side
sendMsg := initSendMsg(tc.protocol)
buf := tb.NewBuffer()
err := tc.codec.Marshal(context.Background(), sendMsg, buf)
test.Assert(t, err == nil, err)
buf.Flush()

// decode server side
recvMsg := initRecvMsg()
if tc.protocol != transport.PurePayload {
recvMsg.SetPayloadLen(buf.ReadableLen())
}
err = tc.codec.Unmarshal(context.Background(), recvMsg, buf)
test.Assert(t, err == nil, err)

// compare Req Arg
compare(t, sendMsg, recvMsg)
})
}
}
}

func initSendMsg(tp transport.Protocol) remote.Message {
var _args mt.MockTestArgs
_args.Req = prepareReq()
Expand All @@ -247,6 +290,20 @@ func initRecvMsg() remote.Message {
return msg
}

func compare(t *testing.T, sendMsg, recvMsg remote.Message) {
sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req
recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req
test.Assert(t, sendReq.Msg == recvReq.Msg)
test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList))
test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap))
for i, item := range sendReq.StrList {
test.Assert(t, item == recvReq.StrList[i])
}
for k := range sendReq.StrMap {
test.Assert(t, sendReq.StrMap[k] == recvReq.StrMap[k])
}
}

func initServerErrorMsg(tp transport.Protocol, ri rpcinfo.RPCInfo, transErr *remote.TransError) remote.Message {
errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server)
errMsg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
Expand Down

0 comments on commit c028c86

Please sign in to comment.