From 5ba0642482d166d502ff50a94f59f35922c94cb4 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 3 Jul 2024 18:01:34 +0800 Subject: [PATCH] feat: add PrependError for thriftgo (#1420) --- pkg/generic/thrift/base.go | 105 ++++++++++--------- pkg/protocol/bthrift/apache/exception.go | 28 ----- pkg/protocol/bthrift/exception.go | 54 +++++++++- pkg/protocol/bthrift/exception_test.go | 39 +++++++ pkg/remote/trans/netpollmux/control_frame.go | 19 ++-- 5 files changed, 155 insertions(+), 90 deletions(-) delete mode 100644 pkg/protocol/bthrift/apache/exception.go diff --git a/pkg/generic/thrift/base.go b/pkg/generic/thrift/base.go index 8a3f05febb..0139ca24c3 100644 --- a/pkg/generic/thrift/base.go +++ b/pkg/generic/thrift/base.go @@ -19,6 +19,7 @@ package thrift import ( "fmt" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) @@ -110,18 +111,18 @@ func (p *TrafficEnv) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_TrafficEnv[fieldId]), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_TrafficEnv[fieldId]), err) SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *TrafficEnv) ReadField1(iprot thrift.TProtocol) error { @@ -166,13 +167,13 @@ func (p *TrafficEnv) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *TrafficEnv) writeField1(oprot thrift.TProtocol) (err error) { @@ -187,9 +188,9 @@ func (p *TrafficEnv) writeField1(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } func (p *TrafficEnv) writeField2(oprot thrift.TProtocol) (err error) { @@ -204,9 +205,9 @@ func (p *TrafficEnv) writeField2(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } func (p *TrafficEnv) String() string { @@ -408,18 +409,18 @@ func (p *Base) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Base[fieldId]), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_Base[fieldId]), err) SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *Base) ReadField1(iprot thrift.TProtocol) error { @@ -535,13 +536,13 @@ func (p *Base) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *Base) writeField1(oprot thrift.TProtocol) (err error) { @@ -556,9 +557,9 @@ func (p *Base) writeField1(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } func (p *Base) writeField2(oprot thrift.TProtocol) (err error) { @@ -573,9 +574,9 @@ func (p *Base) writeField2(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } func (p *Base) writeField3(oprot thrift.TProtocol) (err error) { @@ -590,9 +591,9 @@ func (p *Base) writeField3(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } func (p *Base) writeField4(oprot thrift.TProtocol) (err error) { @@ -607,9 +608,9 @@ func (p *Base) writeField4(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 4 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 4 end error: ", p), err) } func (p *Base) writeField5(oprot thrift.TProtocol) (err error) { @@ -626,9 +627,9 @@ func (p *Base) writeField5(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 5 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 5 end error: ", p), err) } func (p *Base) writeField6(oprot thrift.TProtocol) (err error) { @@ -658,9 +659,9 @@ func (p *Base) writeField6(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 6 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 6 end error: ", p), err) } func (p *Base) String() string { @@ -788,18 +789,18 @@ func (p *BaseResp) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) ReadFieldError: - return thrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_BaseResp[fieldId]), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d '%s' error: ", p, fieldId, fieldIDToName_BaseResp[fieldId]), err) SkipFieldError: - return thrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T field %d skip type %d error: ", p, fieldId, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *BaseResp) ReadField1(iprot thrift.TProtocol) error { @@ -877,13 +878,13 @@ func (p *BaseResp) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldError: - return thrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T write field %d error: ", p, fieldId), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *BaseResp) writeField1(oprot thrift.TProtocol) (err error) { @@ -898,9 +899,9 @@ func (p *BaseResp) writeField1(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 1 end error: ", p), err) } func (p *BaseResp) writeField2(oprot thrift.TProtocol) (err error) { @@ -915,9 +916,9 @@ func (p *BaseResp) writeField2(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 2 end error: ", p), err) } func (p *BaseResp) writeField3(oprot thrift.TProtocol) (err error) { @@ -947,9 +948,9 @@ func (p *BaseResp) writeField3(oprot thrift.TProtocol) (err error) { } return nil WriteFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 begin error: ", p), err) WriteFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field 3 end error: ", p), err) } func (p *BaseResp) String() string { diff --git a/pkg/protocol/bthrift/apache/exception.go b/pkg/protocol/bthrift/apache/exception.go deleted file mode 100644 index 2a0a1f67ff..0000000000 --- a/pkg/protocol/bthrift/apache/exception.go +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package apache - -import "github.com/apache/thrift/lib/go/thrift" - -// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go - -// Generic Thrift exception -type TException interface { - error -} - -var PrependError = thrift.PrependError diff --git a/pkg/protocol/bthrift/exception.go b/pkg/protocol/bthrift/exception.go index 9854520c87..e3ed256840 100644 --- a/pkg/protocol/bthrift/exception.go +++ b/pkg/protocol/bthrift/exception.go @@ -17,12 +17,13 @@ package bthrift import ( + "errors" "fmt" thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) -// ApplicationException represents the application exception decoder for replacing apache.TApplicationException +// ApplicationException is for replacing apache.TApplicationException // it implements ThriftMsgFastCodec interface. type ApplicationException struct { t int32 @@ -177,3 +178,54 @@ func (e *ApplicationException) Error() string { } return fmt.Sprintf("unknown exception type [%d]", e.t) } + +// TransportException is for replacing apache.TransportException +// it implements ThriftMsgFastCodec interface. +type TransportException struct { + ApplicationException // same implementation ... +} + +// NewTransportException ... +func NewTransportException(t int32, m string) *TransportException { + ret := TransportException{} + ret.t = t + ret.m = m + return &ret +} + +// ProtocolException is for replacing apache.ProtocolException +// it implements ThriftMsgFastCodec interface. +type ProtocolException struct { + ApplicationException // same implementation ... +} + +// NewTransportException ... +func NewProtocolException(t int32, m string) *ProtocolException { + ret := ProtocolException{} + ret.t = t + ret.m = m + return &ret +} + +// Generic Thrift exception with TypeId method +type tException interface { + Error() string + TypeId() int32 +} + +// Prepends additional information to an error without losing the Thrift exception interface +func PrependError(prepend string, err error) error { + if t, ok := err.(*TransportException); ok { + return NewTransportException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(*ProtocolException); ok { + return NewProtocolException(t.TypeID(), prepend+err.Error()) + } + if t, ok := err.(*ApplicationException); ok { + return NewApplicationException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(tException); ok { // apache thrift exception? + return NewApplicationException(t.TypeId(), prepend+t.Error()) + } + return errors.New(prepend + err.Error()) +} diff --git a/pkg/protocol/bthrift/exception_test.go b/pkg/protocol/bthrift/exception_test.go index 9a625c9266..574653d2cf 100644 --- a/pkg/protocol/bthrift/exception_test.go +++ b/pkg/protocol/bthrift/exception_test.go @@ -18,6 +18,7 @@ package bthrift import ( "bytes" + "errors" "testing" "github.com/cloudwego/kitex/internal/test" @@ -61,3 +62,41 @@ func TestApplicationException(t *testing.T) { test.Assert(t, ex4.TypeID() == 1) test.Assert(t, ex4.Msg() == "t1") } + +func TestPrependError(t *testing.T) { + var ok bool + ex0 := NewTransportException(1, "world") + err0 := PrependError("hello ", ex0) + ex0, ok = err0.(*TransportException) + test.Assert(t, ok) + test.Assert(t, ex0.TypeID() == 1) + test.Assert(t, ex0.Error() == "hello world") + + ex1 := NewProtocolException(2, "world") + err1 := PrependError("hello ", ex1) + ex1, ok = err1.(*ProtocolException) + test.Assert(t, ok) + test.Assert(t, ex1.TypeID() == 2) + test.Assert(t, ex1.Error() == "hello world") + + ex2 := NewApplicationException(3, "world") + err2 := PrependError("hello ", ex2) + ex2, ok = err2.(*ApplicationException) + test.Assert(t, ok) + test.Assert(t, ex2.TypeID() == 3) + test.Assert(t, ex2.Error() == "hello world") + + err3 := PrependError("hello ", errors.New("world")) + _, ok = err3.(tException) + test.Assert(t, !ok) + test.Assert(t, err3.Error() == "hello world") + + // the code below, it's for compatibility test only. + // it can be removed in the future along with Read/Write method + ex9 := thrift.NewTApplicationException(9, "world") + err9 := PrependError("hello ", ex9) + ex, ok := err9.(tException) + test.Assert(t, ok) + test.Assert(t, ex.TypeId() == 9) + test.Assert(t, ex.Error() == "hello world") +} diff --git a/pkg/remote/trans/netpollmux/control_frame.go b/pkg/remote/trans/netpollmux/control_frame.go index 9c50813fbe..4c060a24d7 100644 --- a/pkg/remote/trans/netpollmux/control_frame.go +++ b/pkg/remote/trans/netpollmux/control_frame.go @@ -25,7 +25,8 @@ package netpollmux import ( "fmt" - thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" ) type ControlFrame struct{} @@ -66,16 +67,16 @@ func (p *ControlFrame) Read(iprot thrift.TProtocol) (err error) { return nil ReadStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct begin error: ", p), err) ReadFieldBeginError: - return thrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) + return bthrift.PrependError(fmt.Sprintf("%T read field %d begin error: ", p, fieldId), err) SkipFieldTypeError: - return thrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) + return bthrift.PrependError(fmt.Sprintf("%T skip field type %d error", p, fieldTypeId), err) ReadFieldEndError: - return thrift.PrependError(fmt.Sprintf("%T read field end error", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read field end error", p), err) ReadStructEndError: - return thrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T read struct end error: ", p), err) } func (p *ControlFrame) Write(oprot thrift.TProtocol) (err error) { @@ -92,11 +93,11 @@ func (p *ControlFrame) Write(oprot thrift.TProtocol) (err error) { } return nil WriteStructBeginError: - return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) WriteFieldStopError: - return thrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write field stop error: ", p), err) WriteStructEndError: - return thrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) + return bthrift.PrependError(fmt.Sprintf("%T write struct end error: ", p), err) } func (p *ControlFrame) String() string {