Skip to content

Commit

Permalink
refactor: new generic interface without thrift apache (#1434)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost authored Jul 12, 2024
1 parent 7a557f2 commit 8671eca
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 72 deletions.
27 changes: 6 additions & 21 deletions internal/mocks/thrift/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package thrift

import (
"errors"
"io"

"github.com/cloudwego/kitex/pkg/protocol/bthrift"
Expand All @@ -32,37 +33,21 @@ type ApacheCodecAdapter struct {
func (p ApacheCodecAdapter) Write(tp thrift.TProtocol) error {
b := make([]byte, p.p.BLength())
b = b[:p.p.FastWriteNocopy(b, nil)]
trans := tp.Transport()
if t, ok := trans.(remoteByteBuffer); ok {
// remote.ByteBuffer not always implement io.Writer ...
// can only use WriteBinary
_, err := t.WriteBinary(b)
return err
}
_, err := tp.Transport().Write(b)
return err
}

type remoteByteBuffer interface {
ReadableLen() (n int)
Next(n int) (p []byte, err error)
WriteBinary(b []byte) (n int, err error)
}

// Read implements thrift.TStruct
func (p ApacheCodecAdapter) Read(tp thrift.TProtocol) error {
var err error
var b []byte
trans := tp.Transport()
if t, ok := trans.(remoteByteBuffer); ok {
// remote.ByteBuffer not always implement io.Reader ...
// can only use Next()
b, err = t.Next(t.ReadableLen())
} else {
n := trans.RemainingBytes()
b = make([]byte, n)
_, err = io.ReadFull(trans, b)
n := trans.RemainingBytes()
if int64(n) < 0 {
return errors.New("unknown buffer len")
}
b = make([]byte, n)
_, err = io.ReadFull(trans, b)
if err == nil {
_, err = p.p.FastRead(b)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/remote/codec/thrift/binary_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ func (ttransportByteBuffer) Close() error { panic("not
func (ttransportByteBuffer) Flush(ctx context.Context) (err error) { panic("not implemented") }
func (ttransportByteBuffer) IsOpen() bool { panic("not implemented") }
func (ttransportByteBuffer) Open() error { panic("not implemented") }
func (ttransportByteBuffer) RemainingBytes() uint64 { panic("not implemented") }
func (p ttransportByteBuffer) RemainingBytes() uint64 { return uint64(p.ReadableLen()) }

// Transport ...
func (p *BinaryProtocol) Transport() thrift.TTransport {
Expand Down
23 changes: 17 additions & 6 deletions pkg/remote/codec/thrift/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
"io"

"github.com/cloudwego/kitex/pkg/protocol/bthrift"
thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache"
Expand Down Expand Up @@ -239,11 +240,6 @@ func (c thriftCodec) Name() string {
return serviceinfo.Thrift.String()
}

// MessageWriterWithMethodWithContext write to thrift.TProtocol
type MessageWriterWithMethodWithContext interface {
Write(ctx context.Context, method string, oprot thrift.TProtocol) error
}

// MessageWriter write to thrift.TProtocol
type MessageWriter interface {
Write(oprot thrift.TProtocol) error
Expand All @@ -254,9 +250,24 @@ type MessageReader interface {
Read(oprot thrift.TProtocol) error
}

type genericWriter interface { // used by pkg/generic
Write(ctx context.Context, method string, w io.Writer) error
}

type genericReader interface { // used by pkg/generic
Read(ctx context.Context, method string, dataLen int, r io.Reader) error
}

// MessageWriterWithMethodWithContext write to thrift.TProtocol
// TODO(marina.sakai): remove it after we use the new genericWriter interface
type MessageWriterWithMethodWithContext interface {
Write(ctx context.Context, method string, oprot thrift.TProtocol) error
}

// MessageReaderWithMethodWithContext read from thrift.TProtocol with method
// TODO(marina.sakai): remove it after we use the new genericReader interface
type MessageReaderWithMethodWithContext interface {
Read(ctx context.Context, method string, dataLen int, oprot thrift.TProtocol) error
Read(ctx context.Context, method string, dataLen int, iprot thrift.TProtocol) error
}

// ThriftMsgFastCodec ...
Expand Down
30 changes: 17 additions & 13 deletions pkg/remote/codec/thrift/thrift_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func verifyMarshalBasicThriftDataType(data interface{}) error {
switch data.(type) {
case MessageWriter:
case MessageWriterWithMethodWithContext:
case genericWriter:
default:
return errEncodeMismatchMsgType
}
Expand All @@ -92,18 +93,20 @@ func verifyMarshalBasicThriftDataType(data interface{}) error {
// marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId)
// It uses the old thrift way which is much slower than FastCodec and Frugal
func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}, method string, rpcRole remote.RPCRole) error {
var err error
switch msg := data.(type) {
case MessageWriter:
if err := msg.Write(tProt); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
}
err = msg.Write(tProt)
case MessageWriterWithMethodWithContext:
if err := msg.Write(ctx, method, tProt); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
}
err = msg.Write(ctx, method, tProt)
case genericWriter:
err = msg.Write(ctx, method, tProt.Transport())
default:
return errEncodeMismatchMsgType
}
if err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
}
return nil
}

Expand Down Expand Up @@ -227,6 +230,7 @@ func verifyUnmarshalBasicThriftDataType(data interface{}) error {
switch data.(type) {
case MessageReader:
case MessageReaderWithMethodWithContext:
case genericReader:
default:
return errDecodeMismatchMsgType
}
Expand All @@ -238,17 +242,17 @@ func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method s
var err error
switch t := data.(type) {
case MessageReader:
if err = t.Read(tProt); err != nil {
return remote.NewTransError(remote.ProtocolError, err)
}
err = t.Read(tProt)
case MessageReaderWithMethodWithContext:
// methodName is necessary for generic calls to methodInfo from serviceInfo
if err = t.Read(ctx, method, dataLen, tProt); err != nil {
return remote.NewTransError(remote.ProtocolError, err)
}
err = t.Read(ctx, method, dataLen, tProt)
case genericReader:
err = t.Read(ctx, method, dataLen, tProt.Transport())
default:
return errDecodeMismatchMsgType
}
if err != nil {
return remote.NewTransError(remote.ProtocolError, err)
}
return nil
}

Expand Down
34 changes: 5 additions & 29 deletions pkg/remote/trans/netpoll/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package netpoll

import (
"errors"
"io"
"sync"

"github.com/cloudwego/netpoll"
Expand All @@ -36,11 +35,6 @@ func init() {
func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer {
bytebuf := bytebufPool.Get().(*netpollByteBuffer)
bytebuf.reader = r
// TODO(wangtieju): fix me when netpoll support netpoll.Reader
// and LinkBuffer not support io.Reader, type assertion would fail when r is from NewBuffer
if ir, ok := r.(io.Reader); ok {
bytebuf.ioReader = ir
}
bytebuf.status = remote.BitReadable
bytebuf.readSize = 0
return bytebuf
Expand All @@ -50,11 +44,6 @@ func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer {
func NewWriterByteBuffer(w netpoll.Writer) remote.ByteBuffer {
bytebuf := bytebufPool.Get().(*netpollByteBuffer)
bytebuf.writer = w
// TODO(wangtieju): fix me when netpoll support netpoll.Writer
// and LinkBuffer not support io.Reader, type assertion would fail when w is from NewBuffer
if iw, ok := w.(io.Writer); ok {
bytebuf.ioWriter = iw
}
bytebuf.status = remote.BitWritable
return bytebuf
}
Expand All @@ -64,12 +53,6 @@ func NewReaderWriterByteBuffer(rw netpoll.ReadWriter) remote.ByteBuffer {
bytebuf := bytebufPool.Get().(*netpollByteBuffer)
bytebuf.writer = rw
bytebuf.reader = rw
// TODO(wangtieju): fix me when netpoll support netpoll.ReadWriter
// and LinkBuffer not support io.ReadWriter, type assertion would fail when rw is from NewBuffer
if irw, ok := rw.(io.ReadWriter); ok {
bytebuf.ioReader = irw
bytebuf.ioWriter = irw
}
bytebuf.status = remote.BitWritable | remote.BitReadable
return bytebuf
}
Expand All @@ -81,8 +64,6 @@ func newNetpollByteBuffer() interface{} {
type netpollByteBuffer struct {
writer netpoll.Writer
reader netpoll.Reader
ioReader io.Reader
ioWriter io.Writer
status int
readSize int
}
Expand Down Expand Up @@ -130,10 +111,9 @@ func (b *netpollByteBuffer) Read(p []byte) (n int, err error) {
if b.status&remote.BitReadable == 0 {
return -1, errors.New("unreadable buffer, cannot support Read")
}
if b.ioReader != nil {
return b.ioReader.Read(p)
}
return -1, errors.New("ioReader is nil")
rb, err := b.reader.Next(len(p))
b.readSize += len(rb)
return copy(p, rb), err
}

// ReadString is a more efficient way to read string than Next.
Expand Down Expand Up @@ -188,10 +168,8 @@ func (b *netpollByteBuffer) Write(p []byte) (n int, err error) {
if b.status&remote.BitWritable == 0 {
return -1, errors.New("unwritable buffer, cannot support Write")
}
if b.ioWriter != nil {
return b.ioWriter.Write(p)
}
return -1, errors.New("ioWriter is nil")
wb, err := b.writer.Malloc(len(p))
return copy(wb, p), err
}

// WriteString is a more efficient way to write string, using the unsafe method to convert the string to []byte.
Expand Down Expand Up @@ -268,8 +246,6 @@ func (b *netpollByteBuffer) Release(e error) (err error) {
func (b *netpollByteBuffer) zero() {
b.writer = nil
b.reader = nil
b.ioReader = nil
b.ioWriter = nil
b.status = 0
b.readSize = 0
}
9 changes: 7 additions & 2 deletions pkg/remote/trans/netpoll/http_client_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package netpoll

import (
"bytes"
"context"
"net/http"
"strings"
Expand Down Expand Up @@ -56,7 +57,11 @@ func init() {
// TestHTTPWrite test http_client_handler Write return err
func TestHTTPWrite(t *testing.T) {
// 1. prepare mock data
conn := &MockNetpollConn{}
conn := &MockNetpollConn{
WriterFunc: func() netpoll.Writer {
return netpoll.NewWriter(&bytes.Buffer{})
},
}
rwTimeout := time.Second
cfg := rpcinfo.NewRPCConfig()
rpcinfo.AsMutableRPCConfig(cfg).SetReadWriteTimeout(rwTimeout)
Expand All @@ -70,8 +75,8 @@ func TestHTTPWrite(t *testing.T) {
// 2. test
ctx, err := httpCilTransHdlr.Write(ctx, conn, msg)
// check ctx/err not nil
test.Assert(t, err == nil, err)
test.Assert(t, ctx != nil)
test.Assert(t, err != nil)
}

// TestHTTPRead test http_client_handler Read return err
Expand Down

0 comments on commit 8671eca

Please sign in to comment.