Skip to content

Commit

Permalink
Merge 45da0dd into b0cc41c
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantxie committed Sep 27, 2023
2 parents b0cc41c + 45da0dd commit 1ed7fd4
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 23 deletions.
5 changes: 5 additions & 0 deletions pkg/generic/json_test/generic_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ var reqRegression = `{"Msg":"hello","InnerBase":{"Base":{"LogID":"log_id_inner"}

var respMsgWithExtra = `{"Msg":"world","required_field":"required_field","extra_field":"extra_field"}`

var reqExtendMsg = `{"Msg":123}`

var errResp = "Test Error"

type Simple struct {
Expand Down Expand Up @@ -158,6 +160,9 @@ type GenericServiceImpl struct{}

// GenericCall ...
func (g *GenericServiceImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) {
if method == "ExtendMethod" {
return request, nil
}
buf := request.(string)
rpcinfo := rpcinfo.GetRPCInfo(ctx)
fmt.Printf("Method from Ctx: %s\n", rpcinfo.Invocation().MethodName())
Expand Down
7 changes: 7 additions & 0 deletions pkg/generic/json_test/generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ func testThrift(t *testing.T) {
test.Assert(t, ok)
test.Assert(t, reflect.DeepEqual(gjson.Get(respStr, "Msg").String(), "world"), "world")

// extend method
resp, err = cli.GenericCall(context.Background(), "ExtendMethod", reqExtendMsg, callopt.WithRPCTimeout(100*time.Second))
test.Assert(t, err == nil, err)
respStr, ok = resp.(string)
test.Assert(t, ok)
test.Assert(t, respStr == reqExtendMsg)

svr.Stop()
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/generic/json_test/idl/example.thrift
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include "base.thrift"
include "self_ref.thrift"
include "extend.thrift"
namespace go kitex.test.server

enum FOO {
Expand Down Expand Up @@ -42,7 +43,7 @@ struct A {
2: self_ref.A a
}

service ExampleService {
service ExampleService extends extend.ExtendService {
ExampleResp ExampleMethod(1: ExampleReq req)throws(1: Exception err),
A Foo(1: A req)
string Ping(1: string msg)
Expand Down
12 changes: 12 additions & 0 deletions pkg/generic/json_test/idl/extend.thrift
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace go extend

struct ExampleReq {
1: i64 Msg
}
struct ExampleResp {
1: i64 Msg
}

service ExtendService {
ExampleResp ExtendMethod(1: ExampleReq req)
}
2 changes: 1 addition & 1 deletion pkg/generic/json_test/idl/self_ref.thrift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace go kitex.test.server
namespace go self_ref

struct A {
1: A self
Expand Down
30 changes: 9 additions & 21 deletions pkg/generic/thrift/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ func Parse(tree *parser.Thrift, mode ParseMode) (*descriptor.ServiceDescriptor,
Router: descriptor.NewRouter(),
}

structsCache := map[string]*descriptor.TypeDescriptor{}

// support one service
svcs := tree.Services
switch mode {
Expand All @@ -92,10 +90,13 @@ func Parse(tree *parser.Thrift, mode ParseMode) (*descriptor.ServiceDescriptor,

visitedSvcs := make(map[*parser.Service]bool, len(tree.Services))
for _, svc := range svcs {
for p := range getAllFunctions(svc, tree, visitedSvcs) {
fn := p.data.(*parser.Function)
if err := addFunction(fn, p.tree, sDsc, structsCache); err != nil {
return nil, err
for p := range getAllSvcs(svc, tree, visitedSvcs) {
svc := p.data.(*parser.Service)
structsCache := map[string]*descriptor.TypeDescriptor{}
for _, fn := range svc.Functions {
if err := addFunction(fn, p.tree, sDsc, structsCache); err != nil {
return nil, err
}
}
}
}
Expand All @@ -107,8 +108,7 @@ type pair struct {
data interface{}
}

func getAllFunctions(svc *parser.Service, tree *parser.Thrift, visitedSvcs map[*parser.Service]bool) chan *pair {
ch := make(chan *pair)
func getAllSvcs(svc *parser.Service, tree *parser.Thrift, visitedSvcs map[*parser.Service]bool) chan *pair {
svcs := make(chan *pair)
addSvc := func(tree *parser.Thrift, svc *parser.Service) {
if exist := visitedSvcs[svc]; !exist {
Expand All @@ -130,19 +130,7 @@ func getAllFunctions(svc *parser.Service, tree *parser.Thrift, visitedSvcs map[*
}
close(svcs)
})
gofunc.GoFunc(context.Background(), func() {
for p := range svcs {
svc := p.data.(*parser.Service)
for _, fn := range svc.Functions {
ch <- &pair{
tree: p.tree,
data: fn,
}
}
}
close(ch)
})
return ch
return svcs
}

func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.ServiceDescriptor, structsCache map[string]*descriptor.TypeDescriptor) (err error) {
Expand Down

0 comments on commit 1ed7fd4

Please sign in to comment.