From 117e47afbdced20fb06df6678685fdd9caccf586 Mon Sep 17 00:00:00 2001 From: ktr Date: Sat, 4 Mar 2023 17:44:42 +0900 Subject: [PATCH] use protocompile (#644) * pass unit test * unit test and CLI E2E tests passed * REPL E2E tests are passed * fix tests * delete unnecessary files * add a test for fully-qualified service name * add a test for fully-qualified message name * tweak code/comments --- e2e/repl_test.go | 8 + .../teste2e_oldrepl-desc_a_map.golden | 11 +- ...e2e_oldrepl-desc_a_repeated_message.golden | 10 +- ...teste2e_oldrepl-desc_simple_message.golden | 10 +- .../teste2e_oldrepl-show_service.golden | 28 +- ...ecting_only_fully-qualified_service.golden | 5 + ...repl-desc_a_fully-qualified_message.golden | 6 + .../fixtures/teste2e_repl-desc_a_map.golden | 11 +- ...este2e_repl-desc_a_repeated_message.golden | 10 +- .../teste2e_repl-desc_simple_message.golden | 10 +- ...esc_simple_message_in_empty_package.golden | 10 +- .../fixtures/teste2e_repl-show_service.golden | 28 +- fill/filler.go | 10 +- fill/proto/convert.go | 107 ------- fill/proto/convert_test.go | 141 --------- fill/proto/interactive_filler.go | 284 +++++++++++------- fill/proto/interactive_filler_test.go | 121 ++++++-- fill/proto/testdata/test.proto | 29 ++ fill/silent_filler.go | 34 ++- fill/silent_filler_test.go | 25 +- go.mod | 4 +- go.sum | 8 +- grpc/grpcreflection/reflection.go | 64 ++-- idl/idl_test.go | 182 ----------- idl/proto/proto.go | 188 ------------ idl/proto/proto_test.go | 64 ---- mode/cli.go | 9 +- mode/common.go | 59 ++-- mode/repl.go | 6 +- proto/descsource.go | 79 +++++ proto/helper.go | 34 +++ proto/mock.go | 112 +++++++ proto/registry.go | 59 ++++ repl/commands.go | 4 +- repl/completer.go | 20 +- repl/completer_test.go | 9 +- repl/proto_commands.go | 32 +- repl/repl_test.go | 14 +- usecase/call_rpc.go | 82 ++--- usecase/call_rpc_test.go | 41 +-- usecase/format_descriptor.go | 29 +- usecase/format_messages.go | 6 +- usecase/format_method.go | 4 +- usecase/format_methods.go | 2 +- usecase/format_packages.go | 6 +- usecase/format_service_descriptors.go | 6 +- usecase/format_services.go | 6 +- usecase/format_services_old.go | 47 --- usecase/get_type_descriptor.go | 18 +- usecase/list_packages.go | 29 +- usecase/list_rpcs.go | 42 ++- usecase/list_services.go | 8 +- usecase/list_services_old.go | 26 -- usecase/parse_method.go | 15 +- usecase/use_package.go | 13 +- usecase/use_service.go | 21 +- usecase/usecase.go | 23 +- 57 files changed, 1075 insertions(+), 1194 deletions(-) create mode 100644 e2e/testdata/fixtures/teste2e_repl-call_unary_by_selecting_only_fully-qualified_service.golden create mode 100644 e2e/testdata/fixtures/teste2e_repl-desc_a_fully-qualified_message.golden delete mode 100644 fill/proto/convert.go delete mode 100644 fill/proto/convert_test.go create mode 100644 fill/proto/testdata/test.proto delete mode 100644 idl/idl_test.go delete mode 100644 idl/proto/proto_test.go create mode 100644 proto/descsource.go create mode 100644 proto/helper.go create mode 100644 proto/mock.go create mode 100644 proto/registry.go delete mode 100644 usecase/format_services_old.go delete mode 100644 usecase/list_services_old.go diff --git a/e2e/repl_test.go b/e2e/repl_test.go index 83c0dac5..1ac4929f 100644 --- a/e2e/repl_test.go +++ b/e2e/repl_test.go @@ -75,6 +75,10 @@ func TestE2E_REPL(t *testing.T) { commonFlags: "--proto testdata/test.proto", input: []interface{}{"service Example", "call Unary", "kaguya"}, }, + "call Unary by selecting only fully-qualified service": { + commonFlags: "--proto testdata/test.proto", + input: []interface{}{"service api.Example", "call Unary", "kaguya"}, + }, "call Unary by selecting only service (empty package)": { registerEmptyPackageService: true, commonFlags: "--proto testdata/empty_package.proto", @@ -308,6 +312,10 @@ func TestE2E_REPL(t *testing.T) { commonFlags: "--proto testdata/test.proto", input: []interface{}{"desc SimpleRequest"}, }, + "desc a fully-qualified message": { + commonFlags: "--proto testdata/test.proto", + input: []interface{}{"desc api.SimpleRequest"}, + }, "desc simple message in empty package": { commonFlags: "--proto testdata/empty_package.proto", input: []interface{}{"desc SimpleRequest"}, diff --git a/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_map.golden b/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_map.golden index 2c218562..080ab002 100644 --- a/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_map.golden +++ b/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_map.golden @@ -1,7 +1,6 @@ -+-------+--------------------------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+--------------------------------+----------+ -| kvs | map | | -+-------+--------------------------------+----------+ ++-------+-----------------------------+----------+ +| FIELD | TYPE | REPEATED | ++-------+-----------------------------+----------+ +| kvs | map | false | ++-------+-----------------------------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_repeated_message.golden b/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_repeated_message.golden index 60731141..799bda91 100644 --- a/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_repeated_message.golden +++ b/e2e/testdata/fixtures/teste2e_oldrepl-desc_a_repeated_message.golden @@ -1,6 +1,6 @@ -+-------+---------------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+---------------------+----------+ -| name | TYPE_MESSAGE (Name) | true | -+-------+---------------------+----------+ ++-------+----------------+----------+ +| FIELD | TYPE | REPEATED | ++-------+----------------+----------+ +| name | message (Name) | true | ++-------+----------------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_oldrepl-desc_simple_message.golden b/e2e/testdata/fixtures/teste2e_oldrepl-desc_simple_message.golden index 1b3f6537..2d364568 100644 --- a/e2e/testdata/fixtures/teste2e_oldrepl-desc_simple_message.golden +++ b/e2e/testdata/fixtures/teste2e_oldrepl-desc_simple_message.golden @@ -1,6 +1,6 @@ -+-------+-------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+-------------+----------+ -| name | TYPE_STRING | false | -+-------+-------------+----------+ ++-------+--------+----------+ +| FIELD | TYPE | REPEATED | ++-------+--------+----------+ +| name | string | false | ++-------+--------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_oldrepl-show_service.golden b/e2e/testdata/fixtures/teste2e_oldrepl-show_service.golden index 6ea2270b..585b53d6 100644 --- a/e2e/testdata/fixtures/teste2e_oldrepl-show_service.golden +++ b/e2e/testdata/fixtures/teste2e_oldrepl-show_service.golden @@ -1,24 +1,6 @@ -+---------+---------------------------+-----------------------------+----------------+ -| SERVICE | RPC | REQUEST TYPE | RESPONSE TYPE | -+---------+---------------------------+-----------------------------+----------------+ -| Example | Unary | SimpleRequest | SimpleResponse | -| Example | UnaryMessage | UnaryMessageRequest | SimpleResponse | -| Example | UnaryRepeated | UnaryRepeatedRequest | SimpleResponse | -| Example | UnaryRepeatedMessage | UnaryRepeatedMessageRequest | SimpleResponse | -| Example | UnaryRepeatedEnum | UnaryRepeatedEnumRequest | SimpleResponse | -| Example | UnarySelf | UnarySelfRequest | SimpleResponse | -| Example | UnaryMap | UnaryMapRequest | SimpleResponse | -| Example | UnaryMapMessage | UnaryMapMessageRequest | SimpleResponse | -| Example | UnaryOneof | UnaryOneofRequest | SimpleResponse | -| Example | UnaryEnum | UnaryEnumRequest | SimpleResponse | -| Example | UnaryBytes | UnaryBytesRequest | SimpleResponse | -| Example | UnaryHeader | UnaryHeaderRequest | SimpleResponse | -| Example | UnaryHeaderTrailer | SimpleRequest | SimpleResponse | -| Example | UnaryHeaderTrailerFailure | SimpleRequest | SimpleResponse | -| Example | UnaryWithMapResponse | SimpleRequest | MapResponse | -| Example | UnaryEcho | UnaryMessageRequest | SimpleResponse | -| Example | ClientStreaming | SimpleRequest | SimpleResponse | -| Example | ServerStreaming | SimpleRequest | SimpleResponse | -| Example | BidiStreaming | SimpleRequest | SimpleResponse | -+---------+---------------------------+-----------------------------+----------------+ ++-------------+ +| NAME | ++-------------+ +| api.Example | ++-------------+ diff --git a/e2e/testdata/fixtures/teste2e_repl-call_unary_by_selecting_only_fully-qualified_service.golden b/e2e/testdata/fixtures/teste2e_repl-call_unary_by_selecting_only_fully-qualified_service.golden new file mode 100644 index 00000000..92088c5a --- /dev/null +++ b/e2e/testdata/fixtures/teste2e_repl-call_unary_by_selecting_only_fully-qualified_service.golden @@ -0,0 +1,5 @@ + +{ + "message": "kaguya" +} + diff --git a/e2e/testdata/fixtures/teste2e_repl-desc_a_fully-qualified_message.golden b/e2e/testdata/fixtures/teste2e_repl-desc_a_fully-qualified_message.golden new file mode 100644 index 00000000..2d364568 --- /dev/null +++ b/e2e/testdata/fixtures/teste2e_repl-desc_a_fully-qualified_message.golden @@ -0,0 +1,6 @@ ++-------+--------+----------+ +| FIELD | TYPE | REPEATED | ++-------+--------+----------+ +| name | string | false | ++-------+--------+----------+ + diff --git a/e2e/testdata/fixtures/teste2e_repl-desc_a_map.golden b/e2e/testdata/fixtures/teste2e_repl-desc_a_map.golden index 2c218562..080ab002 100644 --- a/e2e/testdata/fixtures/teste2e_repl-desc_a_map.golden +++ b/e2e/testdata/fixtures/teste2e_repl-desc_a_map.golden @@ -1,7 +1,6 @@ -+-------+--------------------------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+--------------------------------+----------+ -| kvs | map | | -+-------+--------------------------------+----------+ ++-------+-----------------------------+----------+ +| FIELD | TYPE | REPEATED | ++-------+-----------------------------+----------+ +| kvs | map | false | ++-------+-----------------------------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_repl-desc_a_repeated_message.golden b/e2e/testdata/fixtures/teste2e_repl-desc_a_repeated_message.golden index 60731141..799bda91 100644 --- a/e2e/testdata/fixtures/teste2e_repl-desc_a_repeated_message.golden +++ b/e2e/testdata/fixtures/teste2e_repl-desc_a_repeated_message.golden @@ -1,6 +1,6 @@ -+-------+---------------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+---------------------+----------+ -| name | TYPE_MESSAGE (Name) | true | -+-------+---------------------+----------+ ++-------+----------------+----------+ +| FIELD | TYPE | REPEATED | ++-------+----------------+----------+ +| name | message (Name) | true | ++-------+----------------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_repl-desc_simple_message.golden b/e2e/testdata/fixtures/teste2e_repl-desc_simple_message.golden index 1b3f6537..2d364568 100644 --- a/e2e/testdata/fixtures/teste2e_repl-desc_simple_message.golden +++ b/e2e/testdata/fixtures/teste2e_repl-desc_simple_message.golden @@ -1,6 +1,6 @@ -+-------+-------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+-------------+----------+ -| name | TYPE_STRING | false | -+-------+-------------+----------+ ++-------+--------+----------+ +| FIELD | TYPE | REPEATED | ++-------+--------+----------+ +| name | string | false | ++-------+--------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_repl-desc_simple_message_in_empty_package.golden b/e2e/testdata/fixtures/teste2e_repl-desc_simple_message_in_empty_package.golden index 1b3f6537..2d364568 100644 --- a/e2e/testdata/fixtures/teste2e_repl-desc_simple_message_in_empty_package.golden +++ b/e2e/testdata/fixtures/teste2e_repl-desc_simple_message_in_empty_package.golden @@ -1,6 +1,6 @@ -+-------+-------------+----------+ -| FIELD | TYPE | REPEATED | -+-------+-------------+----------+ -| name | TYPE_STRING | false | -+-------+-------------+----------+ ++-------+--------+----------+ +| FIELD | TYPE | REPEATED | ++-------+--------+----------+ +| name | string | false | ++-------+--------+----------+ diff --git a/e2e/testdata/fixtures/teste2e_repl-show_service.golden b/e2e/testdata/fixtures/teste2e_repl-show_service.golden index 6ea2270b..585b53d6 100644 --- a/e2e/testdata/fixtures/teste2e_repl-show_service.golden +++ b/e2e/testdata/fixtures/teste2e_repl-show_service.golden @@ -1,24 +1,6 @@ -+---------+---------------------------+-----------------------------+----------------+ -| SERVICE | RPC | REQUEST TYPE | RESPONSE TYPE | -+---------+---------------------------+-----------------------------+----------------+ -| Example | Unary | SimpleRequest | SimpleResponse | -| Example | UnaryMessage | UnaryMessageRequest | SimpleResponse | -| Example | UnaryRepeated | UnaryRepeatedRequest | SimpleResponse | -| Example | UnaryRepeatedMessage | UnaryRepeatedMessageRequest | SimpleResponse | -| Example | UnaryRepeatedEnum | UnaryRepeatedEnumRequest | SimpleResponse | -| Example | UnarySelf | UnarySelfRequest | SimpleResponse | -| Example | UnaryMap | UnaryMapRequest | SimpleResponse | -| Example | UnaryMapMessage | UnaryMapMessageRequest | SimpleResponse | -| Example | UnaryOneof | UnaryOneofRequest | SimpleResponse | -| Example | UnaryEnum | UnaryEnumRequest | SimpleResponse | -| Example | UnaryBytes | UnaryBytesRequest | SimpleResponse | -| Example | UnaryHeader | UnaryHeaderRequest | SimpleResponse | -| Example | UnaryHeaderTrailer | SimpleRequest | SimpleResponse | -| Example | UnaryHeaderTrailerFailure | SimpleRequest | SimpleResponse | -| Example | UnaryWithMapResponse | SimpleRequest | MapResponse | -| Example | UnaryEcho | UnaryMessageRequest | SimpleResponse | -| Example | ClientStreaming | SimpleRequest | SimpleResponse | -| Example | ServerStreaming | SimpleRequest | SimpleResponse | -| Example | BidiStreaming | SimpleRequest | SimpleResponse | -+---------+---------------------------+-----------------------------+----------------+ ++-------------+ +| NAME | ++-------------+ +| api.Example | ++-------------+ diff --git a/fill/filler.go b/fill/filler.go index 9b5c9be6..74925914 100644 --- a/fill/filler.go +++ b/fill/filler.go @@ -1,7 +1,11 @@ // Package fill provides fillers that fills each field with a value. package fill -import "errors" +import ( + "errors" + + "google.golang.org/protobuf/types/dynamicpb" +) var ( ErrCodecMismatch = errors.New("unsupported codec (could be invalid JSON format)") @@ -15,7 +19,7 @@ type Filler interface { // - io.EOF: At the end of input. // - ErrCodecMismatch: If v isn't a supported type. // - Fill(v interface{}) error + Fill(v *dynamicpb.Message) error } // InteractiveFillerOpts represents options for InteractiveFiller. @@ -41,5 +45,5 @@ type InteractiveFiller interface { // - io.EOF: At the end of input. // - ErrCodecMismatch: If v isn't a supported type. // - Fill(v interface{}, opts InteractiveFillerOpts) error + Fill(v *dynamicpb.Message, opts InteractiveFillerOpts) error } diff --git a/fill/proto/convert.go b/fill/proto/convert.go deleted file mode 100644 index e0b0c535..00000000 --- a/fill/proto/convert.go +++ /dev/null @@ -1,107 +0,0 @@ -package proto - -import ( - "fmt" - "strconv" - - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/pkg/errors" -) - -var protoDefaults = map[descriptor.FieldDescriptorProto_Type]interface{}{ - descriptor.FieldDescriptorProto_TYPE_DOUBLE: float64(0), - descriptor.FieldDescriptorProto_TYPE_FLOAT: float32(0), - descriptor.FieldDescriptorProto_TYPE_INT64: int64(0), - descriptor.FieldDescriptorProto_TYPE_UINT64: uint64(0), - descriptor.FieldDescriptorProto_TYPE_INT32: int32(0), - descriptor.FieldDescriptorProto_TYPE_UINT32: uint32(0), - descriptor.FieldDescriptorProto_TYPE_FIXED64: uint64(0), - descriptor.FieldDescriptorProto_TYPE_FIXED32: uint32(0), - descriptor.FieldDescriptorProto_TYPE_BOOL: false, - descriptor.FieldDescriptorProto_TYPE_STRING: "", - descriptor.FieldDescriptorProto_TYPE_BYTES: []byte{}, - descriptor.FieldDescriptorProto_TYPE_SFIXED64: int64(0), - descriptor.FieldDescriptorProto_TYPE_SFIXED32: int32(0), - descriptor.FieldDescriptorProto_TYPE_SINT64: int64(0), - descriptor.FieldDescriptorProto_TYPE_SINT32: int32(0), -} - -// convertValue converts a string input pv to fieldType. -func convertValue(pv string, fieldType descriptor.FieldDescriptorProto_Type) (interface{}, error) { - if pv == "" { - d, found := protoDefaults[fieldType] - if found { - return d, nil - } - // if not found, we'll let the normal code execute - } - - var v interface{} - var err error - - switch fieldType { - case descriptor.FieldDescriptorProto_TYPE_DOUBLE: - v, err = strconv.ParseFloat(pv, 64) - - case descriptor.FieldDescriptorProto_TYPE_FLOAT: - v, err = strconv.ParseFloat(pv, 32) - v = float32(v.(float64)) - - case descriptor.FieldDescriptorProto_TYPE_INT64: - v, err = strconv.ParseInt(pv, 10, 64) - - case descriptor.FieldDescriptorProto_TYPE_UINT64: - v, err = strconv.ParseUint(pv, 10, 64) - - case descriptor.FieldDescriptorProto_TYPE_INT32: - v, err = strconv.ParseInt(pv, 10, 32) - v = int32(v.(int64)) - - case descriptor.FieldDescriptorProto_TYPE_UINT32: - v, err = strconv.ParseUint(pv, 10, 32) - v = uint32(v.(uint64)) - - case descriptor.FieldDescriptorProto_TYPE_FIXED64: - v, err = strconv.ParseUint(pv, 10, 64) - - case descriptor.FieldDescriptorProto_TYPE_FIXED32: - v, err = strconv.ParseUint(pv, 10, 32) - v = uint32(v.(uint64)) - - case descriptor.FieldDescriptorProto_TYPE_BOOL: - v, err = strconv.ParseBool(pv) - - case descriptor.FieldDescriptorProto_TYPE_STRING: - // already string - v = pv - - // Use strconv.Unquote to interpret byte literals and Unicode literals. - // For example, a user inputs `\x6f\x67\x69\x73\x6f`, - // His expects "ogiso" in string, but backslashes in the input are not interpreted as an escape sequence. - // So, we need to call strconv.Unquote to interpret backslashes as an escape sequence. - case descriptor.FieldDescriptorProto_TYPE_BYTES: - pv, err = strconv.Unquote(`"` + pv + `"`) - v = []byte(pv) - - case descriptor.FieldDescriptorProto_TYPE_SFIXED64: - v, err = strconv.ParseInt(pv, 10, 64) - - case descriptor.FieldDescriptorProto_TYPE_SFIXED32: - v, err = strconv.ParseInt(pv, 10, 32) - v = int32(v.(int64)) - - case descriptor.FieldDescriptorProto_TYPE_SINT64: - v, err = strconv.ParseInt(pv, 10, 64) - - case descriptor.FieldDescriptorProto_TYPE_SINT32: - v, err = strconv.ParseInt(pv, 10, 32) - v = int32(v.(int64)) - - default: - return nil, fmt.Errorf("invalid type: %s", fieldType) - } - if err != nil { - return nil, errors.Wrapf(err, "failed to convert an inputted value '%s' to type %s", pv, fieldType) - } - return v, nil -} diff --git a/fill/proto/convert_test.go b/fill/proto/convert_test.go deleted file mode 100644 index ab8c6303..00000000 --- a/fill/proto/convert_test.go +++ /dev/null @@ -1,141 +0,0 @@ -package proto - -import ( - "reflect" - "testing" - - "github.com/golang/protobuf/protoc-gen-go/descriptor" -) - -func Test_convertValue(t *testing.T) { - cases := map[string]struct { - v string - fieldType descriptor.FieldDescriptorProto_Type - - expected interface{} - hasErr bool - }{ - "default of string": { - v: "", - fieldType: descriptor.FieldDescriptorProto_TYPE_STRING, - expected: "", - }, - "double": { - v: "100.2", - fieldType: descriptor.FieldDescriptorProto_TYPE_DOUBLE, - expected: float64(100.2), - }, - "float": { - v: "100.2", - fieldType: descriptor.FieldDescriptorProto_TYPE_FLOAT, - expected: float32(100.2), - }, - "int64": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_INT64, - expected: int64(100), - }, - "uint64": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_UINT64, - expected: uint64(100), - }, - "int32": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_INT32, - expected: int32(100), - }, - "uint32": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_UINT32, - expected: uint32(100), - }, - "fixed64": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_FIXED64, - expected: uint64(100), - }, - "fixed32": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_FIXED32, - expected: uint32(100), - }, - "bool": { - v: "true", - fieldType: descriptor.FieldDescriptorProto_TYPE_BOOL, - expected: true, - }, - "string": { - v: "violet evergarden", - fieldType: descriptor.FieldDescriptorProto_TYPE_STRING, - expected: "violet evergarden", - }, - "bytes": { - v: "ogiso", - fieldType: descriptor.FieldDescriptorProto_TYPE_BYTES, - expected: []byte("ogiso"), - }, - "bytes (non-ascii string)": { - v: "小木曽", - fieldType: descriptor.FieldDescriptorProto_TYPE_BYTES, - expected: []byte("小木曽"), - }, - "bytes (Unicode literals)": { - v: "\u5c0f\u6728\u66fd", - fieldType: descriptor.FieldDescriptorProto_TYPE_BYTES, - expected: []byte("小木曽"), - }, - "sfixed64": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_SFIXED64, - expected: int64(100), - }, - "sfixed32": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_SFIXED32, - expected: int32(100), - }, - "sint64": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_SINT64, - expected: int64(100), - }, - "sint32": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_SINT32, - expected: int32(100), - }, - "invalid type": { - v: "100", - fieldType: descriptor.FieldDescriptorProto_TYPE_SINT64 + 1, // Invalid type. - hasErr: true, - }, - "invalid value": { - v: "100.10", - fieldType: descriptor.FieldDescriptorProto_TYPE_INT32, - hasErr: true, - }, - } - - for name, c := range cases { - c := c - t.Run(name, func(t *testing.T) { - actual, err := convertValue(c.v, c.fieldType) - if c.hasErr { - if err == nil { - t.Errorf("convertValue must return an error, but got nil") - } - return - } - - if err != nil { - t.Fatalf("convertValue must not return errors, but got an error: '%s'", err) - } - - if !reflect.DeepEqual(c.expected, actual) { - t.Errorf("expected '%v' (type = %T), but got '%v' (type = %T)", - c.expected, c.expected, actual, actual) - } - }) - } -} diff --git a/fill/proto/interactive_filler.go b/fill/proto/interactive_filler.go index ea176bf7..ae3113d3 100644 --- a/fill/proto/interactive_filler.go +++ b/fill/proto/interactive_filler.go @@ -8,14 +8,12 @@ import ( "strconv" "strings" - "github.com/jhump/protoreflect/desc" - "github.com/jhump/protoreflect/desc/builder" - "github.com/jhump/protoreflect/dynamic" "github.com/ktr0731/evans/fill" "github.com/ktr0731/evans/logger" "github.com/ktr0731/evans/prompt" "github.com/pkg/errors" - "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) // InteractiveFiller is an implementation of fill.InteractiveFiller. @@ -37,13 +35,8 @@ func NewInteractiveFiller(prompt prompt.Prompt, prefixFormat string) *Interactiv // Fill let you input each field interactively by using a prompt. v will be set field values inputted by a prompt. // // Note that Fill resets the previous state when it is called again. -func (f *InteractiveFiller) Fill(v interface{}, opts fill.InteractiveFillerOpts) error { - msg, ok := v.(*dynamic.Message) - if !ok { - return fill.ErrCodecMismatch - } - - resolver := newResolver(f.prompt, f.prefixFormat, prompt.ColorInitial, msg, nil, false, opts) +func (f *InteractiveFiller) Fill(v *dynamicpb.Message, opts fill.InteractiveFillerOpts) error { + resolver := newResolver(f.prompt, f.prefixFormat, prompt.ColorInitial, v, nil, false, opts) _, err := resolver.resolve() if err != nil { return err @@ -57,9 +50,9 @@ type resolver struct { prefixFormat string color prompt.Color - msg *dynamic.Message + msg *dynamicpb.Message - m *desc.MessageDescriptor + m protoreflect.MessageDescriptor ancestors []string // repeated represents that the message is repeated field or not. // If the message is not a field or not a repeated field, it is false. @@ -72,7 +65,7 @@ func newResolver( prompt prompt.Prompt, prefixFormat string, color prompt.Color, - msg *dynamic.Message, + msg *dynamicpb.Message, ancestors []string, repeated bool, opts fill.InteractiveFillerOpts, @@ -82,26 +75,29 @@ func newResolver( prefixFormat: prefixFormat, color: color, msg: msg, - m: msg.GetMessageDescriptor(), + m: msg.Descriptor(), ancestors: ancestors, repeated: repeated, opts: opts, } } -func (r *resolver) resolve() (*dynamic.Message, error) { +func (r *resolver) resolve() (*dynamicpb.Message, error) { selectedOneof := make(map[string]interface{}) - for _, f := range r.m.GetFields() { - if isOneOfField := f.GetOneOf() != nil; isOneOfField { - fqn := f.GetOneOf().GetFullyQualifiedName() + // for _, f := range r.m.Fields(). { + for i := 0; i < r.m.Fields().Len(); i++ { + f := r.m.Fields().Get(i) + + if isOneOfField := f.ContainingOneof() != nil; isOneOfField { + fqn := string(f.ContainingOneof().FullName()) if _, selected := selectedOneof[fqn]; selected { // Skip if one of choices is already selected. continue } selectedOneof[fqn] = nil - if err := r.resolveOneof(f.GetOneOf()); err != nil { + if err := r.resolveOneof(f.ContainingOneof()); err != nil { return nil, err } continue @@ -122,80 +118,125 @@ func (r *resolver) resolve() (*dynamic.Message, error) { return r.msg, nil } -func (r *resolver) resolveOneof(o *desc.OneOfDescriptor) error { - choices := make([]string, 0, len(o.GetChoices())) - for _, c := range o.GetChoices() { - choices = append(choices, c.GetName()) +func (r *resolver) resolveOneof(o protoreflect.OneofDescriptor) error { + choices := make([]string, 0, o.Fields().Len()) + for i := 0; i < o.Fields().Len(); i++ { + c := o.Fields().Get(i) + choices = append(choices, string(c.Name())) } - choice, err := r.selectChoices(o.GetFullyQualifiedName(), choices) + choice, err := r.selectChoices(string(o.FullName()), choices) if err != nil { return err } - return r.resolveField(o.GetChoices()[choice]) + return r.resolveField(o.Fields().Get(choice)) } -func (r *resolver) resolveField(f *desc.FieldDescriptor) error { - resolve := func(f *desc.FieldDescriptor) (interface{}, error) { - var converter func(string) (interface{}, error) +func (r *resolver) resolveField(f protoreflect.FieldDescriptor) error { + resolve := func(f protoreflect.FieldDescriptor) (protoreflect.Value, error) { + var converter func(string) (protoreflect.Value, error) - switch t := f.GetType(); t { - case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: + switch t := f.Kind(); t { + case protoreflect.MessageKind: if r.skipMessage(f) { - return nil, prompt.ErrSkip + return protoreflect.Value{}, prompt.ErrSkip } msgr := newResolver( r.prompt, r.prefixFormat, r.color.NextVal(), - dynamic.NewMessage(f.GetMessageType()), - append(r.ancestors, f.GetName()), - r.repeated || f.IsRepeated(), + dynamicpb.NewMessage(f.Message()), + append(r.ancestors, string(f.Name())), + r.repeated || f.IsList(), r.opts, ) - return msgr.resolve() - case descriptorpb.FieldDescriptorProto_TYPE_ENUM: - return r.resolveEnum(r.makePrefix(f), f.GetEnumType()) - case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: - converter = func(v string) (interface{}, error) { return strconv.ParseFloat(v, 64) } - - case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: - converter = func(v string) (interface{}, error) { + msg, err := msgr.resolve() + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(msg), nil + case protoreflect.EnumKind: + v, err := r.resolveEnum(r.makePrefix(f), f.Enum()) + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(protoreflect.EnumNumber(v)), nil + case protoreflect.DoubleKind: + converter = func(v string) (protoreflect.Value, error) { + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(f), nil + } + + case protoreflect.FloatKind: + converter = func(v string) (protoreflect.Value, error) { f, err := strconv.ParseFloat(v, 32) - return float32(f), err + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(float32(f)), nil + } + + case protoreflect.Int64Kind, protoreflect.Sfixed64Kind, protoreflect.Sint64Kind: + converter = func(v string) (protoreflect.Value, error) { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(n), nil } - case descriptorpb.FieldDescriptorProto_TYPE_INT64, - descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, - descriptorpb.FieldDescriptorProto_TYPE_SINT64: - converter = func(v string) (interface{}, error) { return strconv.ParseInt(v, 10, 64) } + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + converter = func(v string) (protoreflect.Value, error) { + n, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return protoreflect.Value{}, err + } - case descriptorpb.FieldDescriptorProto_TYPE_UINT64, - descriptorpb.FieldDescriptorProto_TYPE_FIXED64: - converter = func(v string) (interface{}, error) { return strconv.ParseUint(v, 10, 64) } + return protoreflect.ValueOf(n), nil + } - case descriptorpb.FieldDescriptorProto_TYPE_INT32, - descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, - descriptorpb.FieldDescriptorProto_TYPE_SINT32: - converter = func(v string) (interface{}, error) { + case protoreflect.Int32Kind, protoreflect.Sfixed32Kind, protoreflect.Sint32Kind: + converter = func(v string) (protoreflect.Value, error) { i, err := strconv.ParseInt(v, 10, 32) - return int32(i), err + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(int32(i)), err } - case descriptorpb.FieldDescriptorProto_TYPE_UINT32, - descriptorpb.FieldDescriptorProto_TYPE_FIXED32: - converter = func(v string) (interface{}, error) { + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + converter = func(v string) (protoreflect.Value, error) { u, err := strconv.ParseUint(v, 10, 32) - return uint32(u), err + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(uint32(u)), err } - case descriptorpb.FieldDescriptorProto_TYPE_BOOL: - converter = func(v string) (interface{}, error) { return strconv.ParseBool(v) } + case protoreflect.BoolKind: + converter = func(v string) (protoreflect.Value, error) { + b, err := strconv.ParseBool(v) + if err != nil { + return protoreflect.Value{}, err + } + + return protoreflect.ValueOf(b), nil + } - case descriptorpb.FieldDescriptorProto_TYPE_STRING: - converter = func(v string) (interface{}, error) { return v, nil } + case protoreflect.StringKind: + converter = func(v string) (protoreflect.Value, error) { return protoreflect.ValueOf(v), nil } // For bytes, if neither BytesAsBase64 nor BytesAsQuotedLiterals is explicitly set, // try to decode as base64 first, and if that fails, fall back trying to parse @@ -211,28 +252,27 @@ func (r *resolver) resolveField(f *desc.FieldDescriptor) error { // For example, a user inputs `\x6f\x67\x69\x73\x6f`, // His expects "ogiso" in string, but backslashes in the input are not interpreted as an escape sequence. // So, we need to call strconv.Unquote to interpret backslashes as an escape sequence. - case descriptorpb.FieldDescriptorProto_TYPE_BYTES: - converter = func(v string) (interface{}, error) { + case protoreflect.BytesKind: + converter = func(v string) (protoreflect.Value, error) { if r.opts.BytesFromFile { b, err := os.ReadFile(v) if err != nil { - return nil, err + return protoreflect.Value{}, err } - return b, nil - + return protoreflect.ValueOf(b), nil } else if r.opts.BytesAsBase64 { b, err := base64.StdEncoding.DecodeString(v) if err != nil { - return nil, err + return protoreflect.Value{}, err } - return b, nil + return protoreflect.ValueOf(b), nil } else if r.opts.BytesAsQuotedLiterals { v, err := strconv.Unquote(`"` + v + `"`) if err != nil { - return nil, err + return protoreflect.Value{}, err } - return []byte(v), nil + return protoreflect.ValueOf([]byte(v)), nil } // try to decode as base64 @@ -243,18 +283,18 @@ func (r *resolver) resolveField(f *desc.FieldDescriptor) error { if err2 != nil { // failed to parse as this too, assume user intended to input base64, propagate // that error - return nil, err + return protoreflect.Value{}, err } // log a warning and return the decoded literal string logger.Println(`warning: entering bytes as quoted literal is deprecated. Use --bytes-as-quoted-literals or base64 encoding"`) - return []byte(v), nil + return protoreflect.ValueOf([]byte(v)), nil } // succeeded decoding as base64, return - return b, nil + return protoreflect.ValueOf(b), nil } default: - return nil, fmt.Errorf("invalid type: %s", t) + return protoreflect.Value{}, fmt.Errorf("invalid type: %s", t) } prefix := r.makePrefix(f) @@ -262,13 +302,15 @@ func (r *resolver) resolveField(f *desc.FieldDescriptor) error { return r.input(prefix, f, converter) } - if !f.IsRepeated() { + if f.Cardinality() != protoreflect.Repeated { // TODO: or cardinality v, err := resolve(f) if err != nil { return err } - return r.msg.TrySetField(f, v) + // TODO: is it okay? + r.msg.Set(f, v) + return nil } color := r.color @@ -292,16 +334,23 @@ func (r *resolver) resolveField(f *desc.FieldDescriptor) error { return err } - if err := r.msg.TryAddRepeatedField(f, v); err != nil { - return err + switch { + case f.IsList(): + r.msg.Mutable(f).List().Append(v) + case f.IsMap(): + key := v.Message().Get(v.Message().Descriptor().Fields().Get(0)).MapKey() + val := v.Message().Get(v.Message().Descriptor().Fields().Get(1)) + r.msg.Mutable(f).Map().Set(key, val) } } } -func (r *resolver) resolveEnum(prefix string, e *desc.EnumDescriptor) (int32, error) { - choices := make([]string, 0, len(e.GetValues())) - for _, v := range e.GetValues() { - choices = append(choices, v.GetName()) +func (r *resolver) resolveEnum(prefix string, e protoreflect.EnumDescriptor) (int32, error) { + choices := make([]string, 0, e.Values().Len()) + // for _, v := range e.GetValues() { + for i := 0; i < e.Values().Len(); i++ { + v := e.Values().Get(i) + choices = append(choices, string(v.Name())) } choice, err := r.selectChoices(prefix, choices) @@ -309,34 +358,24 @@ func (r *resolver) resolveEnum(prefix string, e *desc.EnumDescriptor) (int32, er return 0, err } - value := e.GetValues()[choice].AsEnumValueDescriptorProto() + num := int32(e.Values().Get(choice).Number()) - return *value.Number, nil + return num, nil } -func (r *resolver) input(prefix string, f *desc.FieldDescriptor, converter func(string) (interface{}, error)) (interface{}, error) { +func (r *resolver) input(prefix string, f protoreflect.FieldDescriptor, converter func(string) (protoreflect.Value, error)) (protoreflect.Value, error) { r.prompt.SetPrefix(prefix) r.prompt.SetPrefixColor(r.color) in, err := r.prompt.Input() if err != nil { - return nil, err + return protoreflect.Value{}, err } if in == "" { - if f.IsRepeated() { - builder, err := builder.FromField(f) - if err != nil { - return nil, err - } - - // Clear "repeated". - builder.Label = descriptorpb.FieldDescriptorProto_Label(0) - f, err = builder.Build() - if err != nil { - return nil, err - } + if f.IsList() { + return defaultValueFromKind(f.Kind()), nil } - return f.GetDefaultValue(), nil + return protoreflect.ValueOf(f.Default().Interface()), nil } return converter(in) @@ -358,9 +397,9 @@ func (r *resolver) selectChoices(msg string, choices []string) (int, error) { return n, nil } -func (r *resolver) addRepeatedField(f *desc.FieldDescriptor) bool { +func (r *resolver) addRepeatedField(f protoreflect.FieldDescriptor) bool { if !r.opts.AddRepeatedManually { - if f.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE || len(f.GetMessageType().GetFields()) != 0 { + if f.Kind() != protoreflect.MessageKind || f.Message().Fields().Len() != 0 { return true } @@ -368,7 +407,7 @@ func (r *resolver) addRepeatedField(f *desc.FieldDescriptor) bool { // For user's experience, always display prompt in this case. } - msg := fmt.Sprintf("add a repeated field value? field=%s", f.GetFullyQualifiedName()) + msg := fmt.Sprintf("add a repeated field value? field=%s", f.FullName()) choices := []string{"yes", "no"} n, _, err := r.prompt.Select(msg, choices) if err != nil || n == 1 { @@ -378,17 +417,17 @@ func (r *resolver) addRepeatedField(f *desc.FieldDescriptor) bool { return true } -func (r *resolver) skipMessage(f *desc.FieldDescriptor) bool { +func (r *resolver) skipMessage(f protoreflect.FieldDescriptor) bool { if !r.opts.DigManually { return false } - msg := fmt.Sprintf("dig down? field=%s", f.GetFullyQualifiedName()) + msg := fmt.Sprintf("dig down? field=%s", f.FullName()) n, _, _ := r.prompt.Select(msg, []string{"dig down", "skip"}) return n == 1 } -func (r *resolver) makePrefix(field *desc.FieldDescriptor) string { +func (r *resolver) makePrefix(field protoreflect.FieldDescriptor) string { const delimiter = "::" joinedAncestor := strings.Join(r.ancestors, delimiter) @@ -399,12 +438,35 @@ func (r *resolver) makePrefix(field *desc.FieldDescriptor) string { s := r.prefixFormat s = strings.ReplaceAll(s, "{ancestor}", joinedAncestor) - s = strings.ReplaceAll(s, "{name}", field.GetName()) - s = strings.ReplaceAll(s, "{type}", field.GetType().String()) + s = strings.ReplaceAll(s, "{name}", string(field.Name())) + s = strings.ReplaceAll(s, "{type}", field.Kind().String()) - if r.repeated || field.IsRepeated() { + if r.repeated || field.IsList() { return " " + s } return s } + +var protoDefaults = map[protoreflect.Kind]interface{}{ + protoreflect.DoubleKind: float64(0), + protoreflect.FloatKind: float32(0), + protoreflect.Int64Kind: int64(0), + protoreflect.Uint64Kind: uint64(0), + protoreflect.Int32Kind: int32(0), + protoreflect.Uint32Kind: uint32(0), + protoreflect.Fixed64Kind: uint64(0), + protoreflect.Fixed32Kind: uint32(0), + protoreflect.BoolKind: false, + protoreflect.StringKind: "", + protoreflect.BytesKind: []byte{}, + protoreflect.Sfixed64Kind: int64(0), + protoreflect.Sfixed32Kind: int32(0), + protoreflect.Sint64Kind: int64(0), + protoreflect.Sint32Kind: int32(0), +} + +// convertValue converts a string input pv to protoreflect.Value. +func defaultValueFromKind(kind protoreflect.Kind) protoreflect.Value { + return protoreflect.ValueOf(protoDefaults[kind]) +} diff --git a/fill/proto/interactive_filler_test.go b/fill/proto/interactive_filler_test.go index f3367d82..6e06c8e5 100644 --- a/fill/proto/interactive_filler_test.go +++ b/fill/proto/interactive_filler_test.go @@ -1,15 +1,18 @@ package proto import ( + "context" "fmt" + "reflect" "runtime" "testing" + "github.com/bufbuild/protocompile" "github.com/golang/protobuf/jsonpb" //nolint:staticcheck - "github.com/jhump/protoreflect/desc/builder" - "github.com/jhump/protoreflect/dynamic" "github.com/ktr0731/evans/fill" "github.com/ktr0731/evans/prompt" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) type stubPrompt struct { @@ -54,34 +57,18 @@ func (p *stubPrompt) SetPrefix(string) {} func (p *stubPrompt) SetPrefixColor(prompt.Color) {} func TestInteractiveFiller(t *testing.T) { - b := builder.NewMessage("Message") - b.AddField(builder.NewField("a", builder.FieldTypeMessage(builder.NewMessage("SubMessage"))).SetRepeated()) - b.AddField(builder.NewField("b", builder.FieldTypeEnum( - builder.NewEnum("Enum"). - AddValue(builder.NewEnumValue("enum1").SetNumber(5)). - AddValue(builder.NewEnumValue("enum2").SetNumber(7)), - ))) - b.AddField(builder.NewField("c", builder.FieldTypeDouble())) - b.AddField(builder.NewField("d", builder.FieldTypeFloat())) - b.AddField(builder.NewField("e", builder.FieldTypeInt64())) - b.AddField(builder.NewField("f", builder.FieldTypeSFixed64())) - b.AddField(builder.NewField("g", builder.FieldTypeSInt64())) - b.AddField(builder.NewField("h", builder.FieldTypeUInt64())) - b.AddField(builder.NewField("i", builder.FieldTypeFixed64())) - b.AddField(builder.NewField("j", builder.FieldTypeInt32())) - b.AddField(builder.NewField("k", builder.FieldTypeSFixed32())) - b.AddField(builder.NewField("l", builder.FieldTypeSInt32())) - b.AddField(builder.NewField("m", builder.FieldTypeUInt32())) - b.AddField(builder.NewField("n", builder.FieldTypeFixed32())) - b.AddField(builder.NewField("o", builder.FieldTypeBool())) - b.AddField(builder.NewField("p", builder.FieldTypeString())) - b.AddField(builder.NewField("q", builder.FieldTypeBytes())) - m, err := b.Build() + c := &protocompile.Compiler{ + Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ + ImportPaths: []string{"testdata"}, + }), + } + compiled, err := c.Compile(context.TODO(), "test.proto") if err != nil { - t.Fatalf("Build should not return an error, but got '%s'", err) + t.Fatal(err) } - msg := dynamic.NewMessage(m) + m := compiled[0].Messages().ByName(protoreflect.Name("Message")) + msg := dynamicpb.NewMessage(m) p := &stubPrompt{ t: t, input: []string{ @@ -129,3 +116,83 @@ func TestInteractiveFiller(t *testing.T) { t.Errorf("want: %s\ngot: %s", want, got) } } + +func Test_defaultValueFromKind(t *testing.T) { + cases := map[string]struct { + kind protoreflect.Kind + + expected protoreflect.Value + }{ + "string": { + kind: protoreflect.StringKind, + expected: protoreflect.ValueOf(""), + }, + "double": { + kind: protoreflect.DoubleKind, + expected: protoreflect.ValueOf(float64(0)), + }, + "float": { + kind: protoreflect.FloatKind, + expected: protoreflect.ValueOf(float32(0)), + }, + "int64": { + kind: protoreflect.Int64Kind, + expected: protoreflect.ValueOf(int64(0)), + }, + "uint64": { + kind: protoreflect.Uint64Kind, + expected: protoreflect.ValueOf(uint64(0)), + }, + "int32": { + kind: protoreflect.Int32Kind, + expected: protoreflect.ValueOf(int32(0)), + }, + "uint32": { + kind: protoreflect.Uint32Kind, + expected: protoreflect.ValueOf(uint32(0)), + }, + "fixed64": { + kind: protoreflect.Fixed64Kind, + expected: protoreflect.ValueOf(uint64(0)), + }, + "fixed32": { + kind: protoreflect.Fixed32Kind, + expected: protoreflect.ValueOf(uint32(0)), + }, + "bool": { + kind: protoreflect.BoolKind, + expected: protoreflect.ValueOf(false), + }, + "bytes": { + kind: protoreflect.BytesKind, + expected: protoreflect.ValueOf([]byte{}), + }, + "sfixed64": { + kind: protoreflect.Sfixed64Kind, + expected: protoreflect.ValueOf(int64(0)), + }, + "sfixed32": { + kind: protoreflect.Sfixed32Kind, + expected: protoreflect.ValueOf(int32(0)), + }, + "sint64": { + kind: protoreflect.Sint64Kind, + expected: protoreflect.ValueOf(int64(0)), + }, + "sint32": { + kind: protoreflect.Sint32Kind, + expected: protoreflect.ValueOf(int32(0)), + }, + } + + for name, c := range cases { + name, c := name, c + t.Run(name, func(t *testing.T) { + actual := defaultValueFromKind(c.kind) + if !reflect.DeepEqual(c.expected, actual) { + t.Errorf("expected '%v' (type = %T), but got '%v' (type = %T)", + c.expected, c.expected, actual, actual) + } + }) + } +} diff --git a/fill/proto/testdata/test.proto b/fill/proto/testdata/test.proto new file mode 100644 index 00000000..3c6e8104 --- /dev/null +++ b/fill/proto/testdata/test.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package api; + +message Message { + message SubMessage {} + enum Enum { + enum1 = 0; + enum2 = 7; + } + + repeated SubMessage a = 1; + Enum b = 2; + double c = 3; + float d = 4; + int64 e = 5; + sfixed64 f = 6; + sint64 g = 7; + uint64 h = 8; + fixed64 i = 9; + int32 j = 10; + sfixed32 k = 11; + sint32 l = 12; + uint32 m = 13; + fixed32 n = 14; + bool o = 15; + string p = 16; + bytes q = 17; +} diff --git a/fill/silent_filler.go b/fill/silent_filler.go index 784ffb93..1d61df50 100644 --- a/fill/silent_filler.go +++ b/fill/silent_filler.go @@ -4,36 +4,38 @@ import ( "encoding/json" "io" - "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/dynamicpb" ) // SilentFilter is a Filler implementation that doesn't behave interactive actions. type SilentFiller struct { - dec *json.Decoder + dec *protojson.UnmarshalOptions + in *json.Decoder } // NewSilentFiller receives input as io.Reader and returns an instance of SilentFiller. func NewSilentFiller(in io.Reader) *SilentFiller { return &SilentFiller{ - dec: json.NewDecoder(in), + dec: &protojson.UnmarshalOptions{ + Resolver: nil, // TODO + }, + in: json.NewDecoder(in), } } // Fill fills values of each field from a JSON string. If the JSON string is invalid JSON format or v is a nil pointer, // Fill returns ErrCodecMismatch. -func (f *SilentFiller) Fill(v interface{}) error { - err := f.dec.Decode(v) - if err != nil { - if errors.Is(err, io.EOF) { - return io.EOF - } +func (f *SilentFiller) Fill(v *dynamicpb.Message) error { + var in interface{} + if err := f.in.Decode(&in); err != nil { + return err + } - switch err.(type) { - case *json.InvalidUnmarshalError, *json.SyntaxError: - return ErrCodecMismatch - default: - return errors.Wrap(err, "failed to read input as JSON") - } + b, err := json.Marshal(in) + if err != nil { + return err } - return nil + + return f.dec.Unmarshal(b, v) } diff --git a/fill/silent_filler_test.go b/fill/silent_filler_test.go index 705d2826..9ce79df3 100644 --- a/fill/silent_filler_test.go +++ b/fill/silent_filler_test.go @@ -1,10 +1,15 @@ package fill_test import ( + "context" + "path/filepath" "strings" "testing" + "github.com/bufbuild/protocompile" "github.com/ktr0731/evans/fill" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) func TestSilentFiller(t *testing.T) { @@ -12,15 +17,29 @@ func TestSilentFiller(t *testing.T) { in string hasErr bool }{ - "normal": {in: `{"foo": "bar"}`}, + "normal": {in: `{"p": "bar"}`}, "invalid JSON": {in: `foo`, hasErr: true}, } + + c := &protocompile.Compiler{ + Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ + ImportPaths: []string{filepath.Join("proto", "testdata")}, + }), + } + compiled, err := c.Compile(context.TODO(), "test.proto") + if err != nil { + t.Fatal(err) + } + + md := compiled[0].Messages().ByName(protoreflect.Name("Message")) + for name, c := range cases { c := c t.Run(name, func(t *testing.T) { + f := fill.NewSilentFiller(strings.NewReader(c.in)) - var i interface{} - err := f.Fill(&i) + i := dynamicpb.NewMessage(md) + err := f.Fill(i) if c.hasErr { if err == nil { t.Errorf("Fill must return an error, but got nil") diff --git a/go.mod b/go.mod index 4c43b6f8..3e9c08d0 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/Songmu/gocredits v0.3.0 + github.com/bufbuild/protocompile v0.1.0 github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e github.com/fatih/color v1.13.0 github.com/golang/protobuf v1.5.2 @@ -37,11 +38,12 @@ require ( github.com/tj/go-spin v1.1.0 github.com/zchee/go-xdgbasedir v1.0.3 go.uber.org/goleak v1.2.0 + golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8 golang.org/x/sync v0.1.0 golang.org/x/tools v0.5.0 google.golang.org/genproto v0.0.0-20221024183307-1bc688fe9f3e google.golang.org/grpc v1.51.0 - google.golang.org/protobuf v1.28.1 + google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8 ) require ( diff --git a/go.sum b/go.sum index 0793ea95..92a02aaa 100644 --- a/go.sum +++ b/go.sum @@ -244,6 +244,8 @@ github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6r github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb h1:m935MPodAbYS46DG4pJSv7WO+VECIWUQ7OJYSoTrMh4= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb/go.mod h1:PkYb9DJNAwrSvRx5DYA+gUcOIgTGVMNkfSCbZM8cWpI= +github.com/bufbuild/protocompile v0.1.0 h1:HjgJBI85hY/qmW5tw/66sNDZ7z0UDdVSi/5r40WHw4s= +github.com/bufbuild/protocompile v0.1.0/go.mod h1:ix/MMMdsT3fzxfw91dvbfzKW3fRRnuPCP47kpAm5m/4= github.com/caarlos0/ctrlc v1.2.0 h1:AtbThhmbeYx1WW3WXdWrd94EHKi+0NPRGS4/4pzrjwk= github.com/caarlos0/ctrlc v1.2.0/go.mod h1:n3gDlSjsXZ7rbD9/RprIR040b7oaLfNStikPd4gFago= github.com/caarlos0/env/v6 v6.10.0 h1:lA7sxiGArZ2KkiqpOQNf8ERBRWI+v8MWIH+eGjSN22I= @@ -1093,6 +1095,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200331195152-e8c3332aa8e5/go.mod h1:4M0jN8W1tt0AVLNr8HDosyJCDCDuyL9N9+3m7wDWgKw= +golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8 h1:Xt4/LzbTwfocTk9ZLEu4onjeFucl88iW+v4j4PWbQuE= +golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1639,8 +1643,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8 h1:KR8+MyP7/qOlV+8Af01LtjL04bu7on42eVsxT4EyBQk= +google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= diff --git a/grpc/grpcreflection/reflection.go b/grpc/grpcreflection/reflection.go index 1b3926d2..1e0f7702 100644 --- a/grpc/grpcreflection/reflection.go +++ b/grpc/grpcreflection/reflection.go @@ -6,7 +6,6 @@ import ( "context" "strings" - "github.com/jhump/protoreflect/desc" gr "github.com/jhump/protoreflect/grpcreflect" "github.com/ktr0731/grpc-web-go-client/grpcweb" "github.com/ktr0731/grpc-web-go-client/grpcweb/grpcweb_reflection_v1alpha" @@ -15,6 +14,9 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" ) // ServiceName represents the gRPC reflection service name. @@ -24,16 +26,19 @@ var ErrTLSHandshakeFailed = errors.New("TLS handshake failed") // Client defines gRPC reflection client. type Client interface { - // ListPackages lists file descriptors from the gRPC reflection server. - // ListPackages returns these errors: + // ListServices lists registered service names. + // ListServices returns these errors: // - ErrTLSHandshakeFailed: TLS misconfig. - ListPackages() ([]*desc.FileDescriptor, error) + ListServices() ([]string, error) + // FindSymbol returns the symbol associated with the given name. + FindSymbol(name string) (protoreflect.Descriptor, error) // Reset clears internal states of Client. Reset() } type client struct { - client *gr.Client + resolver *protoregistry.Files + client *gr.Client } func getCtx(headers map[string][]string) context.Context { @@ -47,19 +52,21 @@ func getCtx(headers map[string][]string) context.Context { // NewClient returns an instance of gRPC reflection client for gRPC protocol. func NewClient(conn grpc.ClientConnInterface, headers map[string][]string) Client { return &client{ - client: gr.NewClientV1Alpha(getCtx(headers), grpc_reflection_v1alpha.NewServerReflectionClient(conn)), + client: gr.NewClientV1Alpha(getCtx(headers), grpc_reflection_v1alpha.NewServerReflectionClient(conn)), + resolver: protoregistry.GlobalFiles, } } // NewWebClient returns an instance of gRPC reflection client for gRPC-Web protocol. func NewWebClient(conn *grpcweb.ClientConn, headers map[string][]string) Client { return &client{ - client: gr.NewClientV1Alpha(getCtx(headers), grpcweb_reflection_v1alpha.NewServerReflectionClient(conn)), + client: gr.NewClientV1Alpha(getCtx(headers), grpcweb_reflection_v1alpha.NewServerReflectionClient(conn)), + resolver: protoregistry.GlobalFiles, } } -func (c *client) ListPackages() ([]*desc.FileDescriptor, error) { - ssvcs, err := c.client.ListServices() +func (c *client) ListServices() ([]string, error) { + svcs, err := c.client.ListServices() if err != nil { msg := status.Convert(err).Message() // Check whether the error message contains TLS related error. @@ -73,21 +80,36 @@ func (c *client) ListPackages() ([]*desc.FileDescriptor, error) { return nil, errors.Wrap(err, "failed to list services from reflection enabled gRPC server") } - fds := make([]*desc.FileDescriptor, 0, len(ssvcs)) - for _, s := range ssvcs { - svc, err := c.client.ResolveService(s) - if err != nil { - if gr.IsElementNotFoundError(err) { - // Service doesn't expose the ServiceDescriptor, skip. - continue - } - return nil, errors.Wrapf(err, "failed to resolve service '%s'", s) - } + return svcs, nil +} + +func (c *client) FindSymbol(name string) (protoreflect.Descriptor, error) { + fullName := protoreflect.FullName(name) + + d, err := c.resolver.FindDescriptorByName(fullName) + if err != nil && !errors.Is(err, protoregistry.NotFound) { + return nil, err + } + if err == nil { + return d, nil + } + + jfd, err := c.client.FileContainingSymbol(name) + if err != nil { + return nil, errors.Wrap(err, "failed to find file containing symbol") + } + + // TODO: consider dependencies + fd, err := protodesc.NewFile(jfd.AsFileDescriptorProto(), c.resolver) + if err != nil { + return nil, err + } - fds = append(fds, svc.GetFile()) + if err := c.resolver.RegisterFile(fd); err != nil { + return nil, err } - return fds, nil + return c.resolver.FindDescriptorByName(fullName) } func (c *client) Reset() { diff --git a/idl/idl_test.go b/idl/idl_test.go deleted file mode 100644 index c00cd4c3..00000000 --- a/idl/idl_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package idl_test - -import ( - "errors" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/ktr0731/evans/idl" - "github.com/ktr0731/evans/idl/proto" -) - -func TestSpec(t *testing.T) { - cases := map[string]struct { - newNormalSpec func(*testing.T) idl.Spec - newEmptyPackageSpec func(*testing.T) idl.Spec - }{ - "proto": { - newNormalSpec: func(t *testing.T) idl.Spec { - fnames := []string{"message.proto", "api.proto", "other_package.proto"} - spec, err := proto.LoadFiles([]string{filepath.Join("proto", "testdata")}, fnames) - if err != nil { - t.Fatalf("LoadFiles must not return an error, but got '%s'", err) - } - return spec - }, - newEmptyPackageSpec: func(t *testing.T) idl.Spec { - fnames := []string{"empty_package.proto"} - spec, err := proto.LoadFiles([]string{filepath.Join("proto", "testdata")}, fnames) - if err != nil { - t.Fatalf("LoadFiles must not return an error, but got '%s'", err) - } - return spec - }, - }, - } - - for name, c := range cases { - c := c - t.Run(name, func(t *testing.T) { - t.Run("normal", func(t *testing.T) { - spec := c.newNormalSpec(t) - - t.Run("ServiceNames", func(t *testing.T) { - actualServiceNames := spec.ServiceNames() - expectedServiceNames := []string{"api.Example"} - if diff := cmp.Diff(expectedServiceNames, actualServiceNames); diff != "" { - t.Errorf("ServiceNames returned unexpected service names:\n%s", diff) - } - }) - - t.Run("RPCs", func(t *testing.T) { - _, err := spec.RPCs("") - if err != idl.ErrServiceUnselected { - t.Errorf("RPCs must return ErrServiceUnselected if svcName is empty, but got '%s'", err) - } - _, err = spec.RPCs("Foo") - if err != idl.ErrUnknownServiceName { - t.Errorf("RPCs must return ErrUnknownServiceName, but got '%s'", err) - } - - rpcs, err := spec.RPCs("api.Example") - if err != nil { - t.Fatalf("api.Example service must have an RPC, but couldn't get it: '%s'", err) - } - actualRPCNames := make([]string, len(rpcs)) - for i, rpc := range rpcs { - actualRPCNames[i] = rpc.Name - } - expectedRPCNames := []string{"RPC"} - if diff := cmp.Diff(expectedRPCNames, actualRPCNames); diff != "" { - t.Errorf("RPCs returned unexpected RPC names:\n%s", diff) - } - }) - - t.Run("RPC", func(t *testing.T) { - _, err := spec.RPC("", "") - if err != idl.ErrServiceUnselected { - t.Errorf("RPC must return ErrServiceUnselected if svcName is empty, but got '%s'", err) - } - _, err = spec.RPC("Foo", "") - if err != idl.ErrUnknownServiceName { - t.Errorf("RPC must return ErrUnknownServiceName, but got '%s'", err) - } - _, err = spec.RPC("api.Example", "") - if err != idl.ErrUnknownRPCName { - t.Errorf("RPC must return ErrUnknownRPCName if rpcName is empty, but got '%s'", err) - } - - actualRPC, err := spec.RPC("api.Example", "RPC") - if err != nil { - t.Fatalf("Example service must have an RPC named 'RPC', but couldn't get it: '%s'", err) - } - - const expectedFQRN = "api.Example.RPC" - if actualFQRN := actualRPC.FullyQualifiedName; actualFQRN != expectedFQRN { - t.Errorf("expected FullyQualifiedName is '%s', but got '%s'", expectedFQRN, actualFQRN) - } - }) - - t.Run("ResolveSymbol", func(t *testing.T) { - _, err := spec.ResolveSymbol("Foo") - if !errors.Is(err, idl.ErrUnknownSymbol) { - t.Fatalf("ResolveSymbol must return ErrUnknownSymbol because api.Foo is an undefined type, but got '%s'", err) - } - actual, err := spec.ResolveSymbol("api.Request") - if err != nil { - t.Fatalf("ResolveSymbol must return the descriptor of api.Request, but got an error: '%s'", err) - } - - if actual == nil { - t.Errorf("actual must not be nil") - } - }) - }) - - t.Run("empty package", func(t *testing.T) { - spec := c.newEmptyPackageSpec(t) - - t.Run("ServiceNames", func(t *testing.T) { - actualServiceNames := spec.ServiceNames() - expectedServiceNames := []string{"Example"} - if diff := cmp.Diff(expectedServiceNames, actualServiceNames); diff != "" { - t.Errorf("ServiceNames returned unexpected service names:\n%s", diff) - } - }) - - t.Run("RPCs", func(t *testing.T) { - _, err := spec.RPCs("") - if err != idl.ErrServiceUnselected { - t.Errorf("RPCs must return ErrServiceUnselected if svcName is empty, but got '%s'", err) - } - _, err = spec.RPCs("Foo") - if err != idl.ErrUnknownServiceName { - t.Errorf("RPCs must return ErrUnknownServiceName, but got '%s'", err) - } - rpcs, err := spec.RPCs("Example") - if err != nil { - t.Fatalf("RPCs must not return an error if pkgName is empty, but got '%s'", err) - } - - actualRPCNames := make([]string, len(rpcs)) - for i, rpc := range rpcs { - actualRPCNames[i] = rpc.Name - } - expectedRPCNames := []string{"RPC"} - if diff := cmp.Diff(expectedRPCNames, actualRPCNames); diff != "" { - t.Errorf("RPCs returned unexpected RPC names:\n%s", diff) - } - }) - - t.Run("RPC", func(t *testing.T) { - _, err := spec.RPC("Example", "") - if err != idl.ErrUnknownRPCName { - t.Errorf("RPC must return ErrUnknownRPCName if rpcName is empty, but got '%s'", err) - } - - actualRPC, err := spec.RPC("Example", "RPC") - if err != nil { - t.Fatalf("Example service must have an RPC named 'RPC', but couldn't get it: '%s'", err) - } - - const expectedFQRN = "Example.RPC" - if actualFQRN := actualRPC.FullyQualifiedName; actualFQRN != expectedFQRN { - t.Errorf("expected FullyQualifiedName is '%s', but got '%s'", expectedFQRN, actualFQRN) - } - }) - - t.Run("ResolveSymbol", func(t *testing.T) { - actual, err := spec.ResolveSymbol("Request") - if err != nil { - t.Fatalf("ResolveSymbol must return the descriptor of api.Request, but got an error: '%s'", err) - } - - if actual == nil { - t.Errorf("actual must not be nil") - } - }) - }) - }) - } -} diff --git a/idl/proto/proto.go b/idl/proto/proto.go index 4e5e98f2..dfc16e01 100644 --- a/idl/proto/proto.go +++ b/idl/proto/proto.go @@ -2,197 +2,9 @@ package proto import ( - "fmt" - "sort" "strings" - - "github.com/jhump/protoreflect/desc" - "github.com/jhump/protoreflect/desc/protoparse" - "github.com/jhump/protoreflect/desc/protoprint" - "github.com/jhump/protoreflect/dynamic" - "github.com/ktr0731/evans/grpc" - "github.com/ktr0731/evans/grpc/grpcreflection" - "github.com/ktr0731/evans/idl" - "github.com/pkg/errors" ) -type spec struct { - fileDescs []*desc.FileDescriptor - pkgNames []string - // Loaded service descriptors. - svcDescs []*desc.ServiceDescriptor - // key: fully qualified service name, val: method descriptors belong to the service. - rpcDescs map[string][]*desc.MethodDescriptor - // key: fully qualified message name, val: the message descriptor. - msgDescs map[string]*desc.MessageDescriptor -} - -func (s *spec) ServiceNames() []string { - svcNames := make([]string, len(s.svcDescs)) - for i, d := range s.svcDescs { - svcNames[i] = d.GetFullyQualifiedName() - } - return svcNames -} - -func (s *spec) RPCs(svcName string) ([]*grpc.RPC, error) { - if svcName == "" { - return nil, idl.ErrServiceUnselected - } - - rpcDescs, ok := s.rpcDescs[svcName] - if !ok { - return nil, idl.ErrUnknownServiceName - } - - rpcs := make([]*grpc.RPC, len(rpcDescs)) - for i, d := range rpcDescs { - rpc, err := s.RPC(svcName, d.GetName()) - if err != nil { - panic(fmt.Sprintf("RPC must not return an error, but got '%s'", err)) - } - rpcs[i] = rpc - } - return rpcs, nil -} - -func (s *spec) RPC(svcName, rpcName string) (*grpc.RPC, error) { - if svcName == "" { - return nil, idl.ErrServiceUnselected - } - - rpcDescs, ok := s.rpcDescs[svcName] - if !ok { - return nil, idl.ErrUnknownServiceName - } - - for _, d := range rpcDescs { - if d.GetName() == rpcName { - return &grpc.RPC{ - Name: d.GetName(), - FullyQualifiedName: d.GetFullyQualifiedName(), - RequestType: &grpc.Type{ - Name: d.GetInputType().GetName(), - FullyQualifiedName: d.GetInputType().GetFullyQualifiedName(), - New: func() interface{} { - return dynamic.NewMessage(d.GetInputType()) - }, - }, - ResponseType: &grpc.Type{ - Name: d.GetOutputType().GetName(), - FullyQualifiedName: d.GetOutputType().GetFullyQualifiedName(), - New: func() interface{} { - return dynamic.NewMessage(d.GetOutputType()) - }, - }, - IsServerStreaming: d.IsServerStreaming(), - IsClientStreaming: d.IsClientStreaming(), - }, nil - } - } - return nil, idl.ErrUnknownRPCName -} - -// ResolveSymbol returns the descriptor of the passed fully-qualified descriptor name. -// The actual type of the returned interface{} implements desc.Descriptor. -func (s *spec) ResolveSymbol(symbol string) (interface{}, error) { - for _, f := range s.fileDescs { - d := f.FindSymbol(symbol) - if d != nil { - return d, nil - } - } - return nil, idl.ErrUnknownSymbol -} - -// FormatDescriptor formats v as a Protocol Buffers descriptor type. -// If v doesn't implement desc.Descriptor, it returns an error. -func (s *spec) FormatDescriptor(v interface{}) (string, error) { - desc, ok := v.(desc.Descriptor) - if !ok { - return "", errors.New("v should be a desc.Descriptor") - } - p := &protoprint.Printer{ - Compact: true, - ForceFullyQualifiedNames: true, - SortElements: true, - } - str, err := p.PrintProtoToString(desc) - if err != nil { - return "", errors.Wrap(err, "failed to convert the descriptor to string") - } - return strings.TrimSpace(str), nil -} - -// LoadFiles receives proto file names and import paths like protoc's options. -// Then, LoadFiles parses these files and instantiates a new idl.Spec. -func LoadFiles(importPaths []string, fnames []string) (idl.Spec, error) { - p := &protoparse.Parser{ - ImportPaths: importPaths, - } - fileDescs, err := p.ParseFiles(fnames...) - if err != nil { - return nil, errors.Wrap(err, "proto: failed to parse passed proto files") - } - - // Collect dependency file descriptors - for _, d := range fileDescs { - fileDescs = append(fileDescs, d.GetDependencies()...) - } - - return newSpec(fileDescs), nil -} - -// LoadByReflection receives a gRPC reflection client, then tries to instantiate a new idl.Spec by using gRPC reflection. -func LoadByReflection(client grpcreflection.Client) (idl.Spec, error) { - fileDescs, err := client.ListPackages() - if err != nil { - return nil, errors.Wrap(err, "failed to list packages by gRPC reflection") - } - return newSpec(fileDescs), nil -} - -func newSpec(fds []*desc.FileDescriptor) idl.Spec { - var ( - encounteredPkgs = make(map[string]interface{}) - encounteredSvcs = make(map[string]interface{}) - pkgNames []string - svcDescs []*desc.ServiceDescriptor - rpcDescs = make(map[string][]*desc.MethodDescriptor) - msgDescs = make(map[string]*desc.MessageDescriptor) - ) - for _, f := range fds { - if _, encountered := encounteredPkgs[f.GetPackage()]; !encountered { - pkgNames = append(pkgNames, f.GetPackage()) - encounteredPkgs[f.GetPackage()] = nil - } - for _, svc := range f.GetServices() { - fqsn := svc.GetFullyQualifiedName() - if _, encountered := encounteredSvcs[fqsn]; !encountered { - svcDescs = append(svcDescs, svc) - encounteredSvcs[fqsn] = nil - } - rpcDescs[fqsn] = append(rpcDescs[fqsn], svc.GetMethods()...) - } - - for _, m := range f.GetMessageTypes() { - msgDescs[m.GetFullyQualifiedName()] = m - } - } - - sort.Slice(pkgNames, func(i, j int) bool { - return pkgNames[i] < pkgNames[j] - }) - - return &spec{ - fileDescs: fds, - pkgNames: pkgNames, - svcDescs: svcDescs, - rpcDescs: rpcDescs, - msgDescs: msgDescs, - } -} - // FullyQualifiedServiceName returns the fully-qualified service name. func FullyQualifiedServiceName(pkg, svc string) string { var s []string diff --git a/idl/proto/proto_test.go b/idl/proto/proto_test.go deleted file mode 100644 index 10232d0a..00000000 --- a/idl/proto/proto_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package proto_test - -import ( - "errors" - "testing" - - "github.com/jhump/protoreflect/desc" - "github.com/ktr0731/evans/grpc/grpcreflection" - "github.com/ktr0731/evans/idl/proto" -) - -func TestLoadFiles(t *testing.T) { - cases := map[string]struct { - fnames []string - hasErr bool - }{ - "normal": {fnames: []string{"message.proto", "api.proto", "other_package.proto"}}, - "invalid proto": {fnames: []string{"invalid.proto"}, hasErr: true}, - } - - for name, c := range cases { - c := c - t.Run(name, func(t *testing.T) { - _, err := proto.LoadFiles([]string{"testdata"}, c.fnames) - if c.hasErr { - if err == nil { - t.Errorf("LoadFiles must return an error, but got nil") - } - return - } - if err != nil { - t.Errorf("LoadFiles must not return an error, but got '%s'", err) - } - }) - } -} - -type reflectionClient struct { - grpcreflection.Client - descs []*desc.FileDescriptor - err error -} - -func (c *reflectionClient) ListPackages() ([]*desc.FileDescriptor, error) { - return c.descs, c.err -} - -func TestLoadByReflection(t *testing.T) { - t.Run("normal", func(t *testing.T) { - refCli := &reflectionClient{} - _, err := proto.LoadByReflection(refCli) - if err != nil { - t.Errorf("must not return an error, but got '%s'", err) - } - }) - - t.Run("reflection client returns an error", func(t *testing.T) { - refCli := &reflectionClient{err: errors.New("an err")} - _, err := proto.LoadByReflection(refCli) - if err == nil { - t.Errorf("must return an error, but got nil") - } - }) -} diff --git a/mode/cli.go b/mode/cli.go index 1a62d381..994c322e 100644 --- a/mode/cli.go +++ b/mode/cli.go @@ -12,11 +12,10 @@ import ( "github.com/ktr0731/evans/format" "github.com/ktr0731/evans/format/curl" fmtjson "github.com/ktr0731/evans/format/json" - "github.com/ktr0731/evans/idl" - "github.com/ktr0731/evans/idl/proto" "github.com/ktr0731/evans/present" "github.com/ktr0731/evans/present/json" "github.com/ktr0731/evans/present/name" + "github.com/ktr0731/evans/proto" "github.com/ktr0731/evans/usecase" "github.com/ktr0731/go-multierror" "github.com/mattn/go-isatty" @@ -140,7 +139,7 @@ func NewListCLIInvoker(ui cui.UI, fqn, format string) CLIInvoker { return "", commonErr // Return commonErr because UsePackage will be deprecated. } - if err := usecase.UseService(svc); err != nil && errors.Is(err, idl.ErrUnknownServiceName) { + if err := usecase.UseService(svc); err != nil && errors.Is(err, usecase.ErrUnknownServiceName) { return "", commonErr } else if err != nil { return "", errors.Wrapf(err, "failed to use service '%s'", svc) @@ -193,7 +192,7 @@ func RunAsCLIMode(cfg *config.Config, invoker CLIInvoker) error { }() } - spec, err := newSpec(cfg, gRPCClient) + descSource, err := newDescSource(cfg, gRPCClient) if err != nil { injectResult = multierror.Append(injectResult, err) } @@ -204,8 +203,8 @@ func RunAsCLIMode(cfg *config.Config, invoker CLIInvoker) error { usecase.InjectPartially( usecase.Dependencies{ - Spec: spec, GRPCClient: gRPCClient, + DescSource: descSource, ResourcePresenter: json.NewPresenter(" "), }, ) diff --git a/mode/common.go b/mode/common.go index 7b1c3a5d..dfc7839e 100644 --- a/mode/common.go +++ b/mode/common.go @@ -7,26 +7,11 @@ import ( "github.com/ktr0731/evans/config" "github.com/ktr0731/evans/grpc" "github.com/ktr0731/evans/grpc/grpcreflection" - "github.com/ktr0731/evans/idl" - "github.com/ktr0731/evans/idl/proto" + "github.com/ktr0731/evans/proto" "github.com/ktr0731/evans/usecase" "github.com/pkg/errors" ) -func newSpec(cfg *config.Config, grpcClient grpcreflection.Client) (spec idl.Spec, err error) { - if cfg.Server.Reflection { - spec, err = proto.LoadByReflection(grpcClient) - } else { - spec, err = proto.LoadFiles(cfg.Default.ProtoPath, cfg.Default.ProtoFile) - } - if errors.Is(err, grpcreflection.ErrTLSHandshakeFailed) { - return nil, errors.New("TLS handshake failed. check whether client or server is misconfigured") - } else if err != nil { - return nil, errors.Wrap(err, "failed to instantiate the spec") - } - return spec, nil -} - func newGRPCClient(cfg *config.Config) (grpc.Client, error) { addr := fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port) if cfg.Request.Web { @@ -64,7 +49,12 @@ func gRPCReflectionPackageFilteredPackages(pkgNames []string) []string { func setDefault(cfg *config.Config) error { // If the spec has only one package, mark it as the default package. if cfg.Default.Package == "" { - pkgs := gRPCReflectionPackageFilteredPackages(usecase.ListPackages()) + got, err := usecase.ListPackages() + if err != nil { + return err + } + + pkgs := gRPCReflectionPackageFilteredPackages(got) if len(pkgs) == 1 { cfg.Default.Package = pkgs[0] } else { @@ -88,7 +78,14 @@ func setDefault(cfg *config.Config) error { // If the spec has only one service, mark it as the default service. if cfg.Default.Service == "" { - svcNames := usecase.ListServicesOld() + svcNames, err := usecase.ListServices() + if err != nil { + return err + } + + // Ignore server reflection name because it's provided imply when reflection is enabled. + svcNames = dropString(svcNames, "grpc.reflection.v1alpha.ServerReflection") + if len(svcNames) != 1 { return nil } @@ -104,3 +101,29 @@ func setDefault(cfg *config.Config) error { } return nil } + +func newDescSource(cfg *config.Config, grpcClient grpcreflection.Client) (descSource proto.DescriptorSource, err error) { + if cfg.Server.Reflection { + descSource = proto.NewDescriptorSourceFromReflection(grpcClient) + } else { + descSource, err = proto.NewDescriptorSourceFromFiles(cfg.Default.ProtoPath, cfg.Default.ProtoFile) + } + if errors.Is(err, grpcreflection.ErrTLSHandshakeFailed) { + return nil, errors.New("TLS handshake failed. check whether client or server is misconfigured") + } else if err != nil { + return nil, errors.Wrap(err, "failed to instantiate the spec") + } + + return +} + +func dropString(slice []string, s string) []string { + newSlice := make([]string, 0, len(slice)) + for _, e := range slice { + if e != s { + newSlice = append(newSlice, e) + } + } + + return newSlice +} diff --git a/mode/repl.go b/mode/repl.go index 9321d668..2e3e378b 100644 --- a/mode/repl.go +++ b/mode/repl.go @@ -23,16 +23,16 @@ func RunAsREPLMode(cfg *config.Config, ui cui.UI, cache *cache.Cache) error { } defer gRPCClient.Close(context.Background()) - spec, err := newSpec(cfg, gRPCClient) + descSource, err := newDescSource(cfg, gRPCClient) if err != nil { - return errors.Wrap(err, "failed to instantiate a new spec") + return errors.Wrap(err, "failed to instantiate a desc source") } usecase.Inject( usecase.Dependencies{ - Spec: spec, InteractiveFiller: proto.NewInteractiveFiller(prompt.New(), cfg.REPL.InputPromptFormat), GRPCClient: gRPCClient, + DescSource: descSource, ResourcePresenter: table.NewPresenter(), }, ) diff --git a/proto/descsource.go b/proto/descsource.go new file mode 100644 index 00000000..27489520 --- /dev/null +++ b/proto/descsource.go @@ -0,0 +1,79 @@ +package proto + +import ( + "context" + + "github.com/bufbuild/protocompile" + "github.com/bufbuild/protocompile/linker" + "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protoreflect" +) + +//go:generate moq -out mock.go . DescriptorSource +type DescriptorSource interface { + ListServices() ([]string, error) + FindSymbol(name string) (protoreflect.Descriptor, error) +} + +type reflection struct { + client interface { + ListServices() ([]string, error) + FindSymbol(name string) (protoreflect.Descriptor, error) + } +} + +func NewDescriptorSourceFromReflection(c interface { + ListServices() ([]string, error) + FindSymbol(name string) (protoreflect.Descriptor, error) +}) DescriptorSource { + return &reflection{c} +} + +func (r *reflection) ListServices() ([]string, error) { + return r.client.ListServices() +} + +func (r *reflection) FindSymbol(name string) (protoreflect.Descriptor, error) { + return r.client.FindSymbol(name) +} + +type files struct { + fds linker.Files +} + +func NewDescriptorSourceFromFiles(importPaths []string, fnames []string) (DescriptorSource, error) { + c := &protocompile.Compiler{ + Resolver: protocompile.WithStandardImports(&protocompile.SourceResolver{ + ImportPaths: importPaths, + }), + } + compiled, err := c.Compile(context.TODO(), fnames...) + if err != nil { + return nil, errors.Wrap(err, "proto: failed to compile proto files") + } + + return &files{fds: compiled}, nil +} + +var errSymbolNotFound = errors.New("proto: symbol not found") + +func (f *files) ListServices() ([]string, error) { + var services []string + for _, fd := range f.fds { + for i := 0; i < fd.Services().Len(); i++ { + services = append(services, string(fd.Services().Get(i).FullName())) + } + } + + return services, nil +} + +func (f *files) FindSymbol(name string) (protoreflect.Descriptor, error) { + for _, fd := range f.fds { + if d := fd.FindDescriptorByName(protoreflect.FullName(name)); d != nil { + return d, nil + } + } + + return nil, errors.Wrapf(errSymbolNotFound, "symbol %s", name) +} diff --git a/proto/helper.go b/proto/helper.go new file mode 100644 index 00000000..2b9e305c --- /dev/null +++ b/proto/helper.go @@ -0,0 +1,34 @@ +package proto + +import "strings" + +// FullyQualifiedServiceName returns the fully-qualified service name. +func FullyQualifiedServiceName(pkg, svc string) string { + var s []string + if pkg != "" { + s = []string{pkg, svc} + } else { + s = []string{svc} + } + return strings.Join(s, ".") +} + +// FullyQualifiedMessageName returns the fully-qualified message name. +func FullyQualifiedMessageName(pkg, msg string) string { + var s []string + if pkg != "" { + s = []string{pkg, msg} + } else { + s = []string{msg} + } + return strings.Join(s, ".") +} + +// ParseFullyQualifiedServiceName returns the package and service name from a fully-qualified service name. +func ParseFullyQualifiedServiceName(fqsn string) (string, string) { + i := strings.LastIndex(fqsn, ".") + if i == -1 { + return "", fqsn + } + return fqsn[:i], fqsn[i+1:] +} diff --git a/proto/mock.go b/proto/mock.go new file mode 100644 index 00000000..3602c961 --- /dev/null +++ b/proto/mock.go @@ -0,0 +1,112 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package proto + +import ( + "google.golang.org/protobuf/reflect/protoreflect" + "sync" +) + +// Ensure, that DescriptorSourceMock does implement DescriptorSource. +// If this is not the case, regenerate this file with moq. +var _ DescriptorSource = &DescriptorSourceMock{} + +// DescriptorSourceMock is a mock implementation of DescriptorSource. +// +// func TestSomethingThatUsesDescriptorSource(t *testing.T) { +// +// // make and configure a mocked DescriptorSource +// mockedDescriptorSource := &DescriptorSourceMock{ +// FindSymbolFunc: func(name string) (protoreflect.Descriptor, error) { +// panic("mock out the FindSymbol method") +// }, +// ListServicesFunc: func() ([]string, error) { +// panic("mock out the ListServices method") +// }, +// } +// +// // use mockedDescriptorSource in code that requires DescriptorSource +// // and then make assertions. +// +// } +type DescriptorSourceMock struct { + // FindSymbolFunc mocks the FindSymbol method. + FindSymbolFunc func(name string) (protoreflect.Descriptor, error) + + // ListServicesFunc mocks the ListServices method. + ListServicesFunc func() ([]string, error) + + // calls tracks calls to the methods. + calls struct { + // FindSymbol holds details about calls to the FindSymbol method. + FindSymbol []struct { + // Name is the name argument value. + Name string + } + // ListServices holds details about calls to the ListServices method. + ListServices []struct { + } + } + lockFindSymbol sync.RWMutex + lockListServices sync.RWMutex +} + +// FindSymbol calls FindSymbolFunc. +func (mock *DescriptorSourceMock) FindSymbol(name string) (protoreflect.Descriptor, error) { + if mock.FindSymbolFunc == nil { + panic("DescriptorSourceMock.FindSymbolFunc: method is nil but DescriptorSource.FindSymbol was just called") + } + callInfo := struct { + Name string + }{ + Name: name, + } + mock.lockFindSymbol.Lock() + mock.calls.FindSymbol = append(mock.calls.FindSymbol, callInfo) + mock.lockFindSymbol.Unlock() + return mock.FindSymbolFunc(name) +} + +// FindSymbolCalls gets all the calls that were made to FindSymbol. +// Check the length with: +// +// len(mockedDescriptorSource.FindSymbolCalls()) +func (mock *DescriptorSourceMock) FindSymbolCalls() []struct { + Name string +} { + var calls []struct { + Name string + } + mock.lockFindSymbol.RLock() + calls = mock.calls.FindSymbol + mock.lockFindSymbol.RUnlock() + return calls +} + +// ListServices calls ListServicesFunc. +func (mock *DescriptorSourceMock) ListServices() ([]string, error) { + if mock.ListServicesFunc == nil { + panic("DescriptorSourceMock.ListServicesFunc: method is nil but DescriptorSource.ListServices was just called") + } + callInfo := struct { + }{} + mock.lockListServices.Lock() + mock.calls.ListServices = append(mock.calls.ListServices, callInfo) + mock.lockListServices.Unlock() + return mock.ListServicesFunc() +} + +// ListServicesCalls gets all the calls that were made to ListServices. +// Check the length with: +// +// len(mockedDescriptorSource.ListServicesCalls()) +func (mock *DescriptorSourceMock) ListServicesCalls() []struct { +} { + var calls []struct { + } + mock.lockListServices.RLock() + calls = mock.calls.ListServices + mock.lockListServices.RUnlock() + return calls +} diff --git a/proto/registry.go b/proto/registry.go new file mode 100644 index 00000000..ff2cc77d --- /dev/null +++ b/proto/registry.go @@ -0,0 +1,59 @@ +package proto + +import ( + "strings" + + "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/types/dynamicpb" +) + +type anyResolver struct { + protoregistry.ExtensionTypeResolver + descSource DescriptorSource +} + +func NewAnyResolver(descSource DescriptorSource) interface { + protoregistry.ExtensionTypeResolver + protoregistry.MessageTypeResolver +} { + return &anyResolver{ + descSource: descSource, + } +} + +func (r *anyResolver) FindMessageByName(m protoreflect.FullName) (protoreflect.MessageType, error) { + d, err := r.descSource.FindSymbol(string(m)) + if err != nil { + return nil, err + } + if errors.Is(err, errSymbolNotFound) { + // Fallback to protoregistry.GlobalTypes. + return protoregistry.GlobalTypes.FindMessageByName(m) + } + + md := d.(protoreflect.MessageDescriptor) // TODO: handle "ok". + + return dynamicpb.NewMessageType(md), nil +} + +func (r *anyResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) { + n := strings.LastIndex(url, "/") + if n != -1 { + url = url[n+1:] + } + + d, err := r.descSource.FindSymbol(url) + if err != nil && !errors.Is(err, errSymbolNotFound) { + return nil, err + } + if errors.Is(err, errSymbolNotFound) { + // Fallback to protoregistry.GlobalTypes. + return protoregistry.GlobalTypes.FindMessageByURL(url) + } + + md := d.(protoreflect.MessageDescriptor) // TODO: handle "ok". + + return dynamicpb.NewMessageType(md), nil +} diff --git a/repl/commands.go b/repl/commands.go index 2109a96a..92a85396 100644 --- a/repl/commands.go +++ b/repl/commands.go @@ -64,7 +64,7 @@ func (c *packageCommand) Validate(args []string) error { func (c *packageCommand) Run(_ io.Writer, args []string) error { pkgName := args[0] err := usecase.UsePackage(pkgName) - if errors.Is(err, idl.ErrUnknownPackageName) { + if errors.Is(err, usecase.ErrUnknownPackageName) { return errors.Errorf("unknown package name '%s'", args[0]) } return err @@ -132,7 +132,7 @@ func (c *showCommand) Run(w io.Writer, args []string) error { case "p", "package", "packages": f = usecase.FormatPackages case "s", "svc", "service", "services": - f = usecase.FormatServicesOld + f = usecase.FormatServices case "m", "msg", "message", "messages": f = usecase.FormatMessages case "a", "r", "rpc", "api": diff --git a/repl/completer.go b/repl/completer.go index 943ef2d4..9a541091 100644 --- a/repl/completer.go +++ b/repl/completer.go @@ -90,7 +90,11 @@ func newCompleter(cmds map[string]commander) *completer { }, "package": func(args []string) (s []*prompt.Suggest) { if len(args) == 1 { - pkgs := usecase.ListPackages() + pkgs, err := usecase.ListPackages() + if err != nil { + return + } + for _, pkg := range pkgs { if pkg == "" { s = append(s, prompt.NewSuggestion(`''`, "default for package name unspecified protos")) @@ -103,7 +107,12 @@ func newCompleter(cmds map[string]commander) *completer { }, "service": func(args []string) (s []*prompt.Suggest) { if len(args) == 1 { - for _, svc := range usecase.ListServicesOld() { + svcs, err := usecase.ListServices() + if err != nil { + return + } + + for _, svc := range svcs { s = append(s, prompt.NewSuggestion(svc, "")) } } @@ -126,8 +135,13 @@ func newCompleter(cmds map[string]commander) *completer { return nil } + svcs, err := usecase.ListServices() + if err != nil { + return + } + encountered := make(map[string]interface{}) - for _, svc := range usecase.ListServicesOld() { + for _, svc := range svcs { rpcs, err := usecase.ListRPCs(svc) if err != nil { panic(fmt.Sprintf("ListRPCs must not return an error, but got '%s'", err)) diff --git a/repl/completer_test.go b/repl/completer_test.go index 320b9bd0..1087cd2e 100644 --- a/repl/completer_test.go +++ b/repl/completer_test.go @@ -4,7 +4,7 @@ import ( "strings" "testing" - "github.com/ktr0731/evans/idl/proto" + "github.com/ktr0731/evans/proto" "github.com/ktr0731/evans/usecase" ) @@ -23,11 +23,12 @@ func (d *dummyDocument) TextBeforeCursor() string { func TestCompleter(t *testing.T) { cmpl := newCompleter(commands) - spec, err := proto.LoadFiles([]string{"testdata"}, []string{"test.proto"}) + descSource, err := proto.NewDescriptorSourceFromFiles([]string{"testdata"}, []string{"test.proto"}) if err != nil { - t.Fatalf("LoadFiles must not return an error, but got '%s'", err) + t.Fatal(err) } - usecase.Inject(usecase.Dependencies{Spec: spec}) + usecase.Inject(usecase.Dependencies{DescSource: descSource}) + err = usecase.UsePackage("api") if err != nil { t.Fatalf("UsePackage must not return an error, but got '%s'", err) diff --git a/repl/proto_commands.go b/repl/proto_commands.go index 35ae4733..366f1d63 100644 --- a/repl/proto_commands.go +++ b/repl/proto_commands.go @@ -6,12 +6,11 @@ import ( "sort" "strconv" - "github.com/golang/protobuf/protoc-gen-go/descriptor" - "github.com/jhump/protoreflect/desc" "github.com/ktr0731/evans/usecase" "github.com/olekukonko/tablewriter" "github.com/pkg/errors" "github.com/spf13/pflag" + "google.golang.org/protobuf/reflect/protoreflect" ) type descCommand struct{} @@ -43,13 +42,14 @@ func (c *descCommand) Run(w io.Writer, args []string) error { table := tablewriter.NewWriter(w) table.SetHeader([]string{"field", "type", "repeated"}) - fields := td.(*desc.MessageDescriptor).GetFields() - rows := make([][]string, len(fields)) - for i, field := range fields { + fields := td.(protoreflect.MessageDescriptor).Fields() + rows := make([][]string, fields.Len()) + for i := 0; i < fields.Len(); i++ { + field := fields.Get(i) rows[i] = []string{ - field.GetName(), + string(field.Name()), presentTypeName(field), - strconv.FormatBool(field.IsRepeated() && !field.IsMap()), + strconv.FormatBool(field.IsList() && !field.IsMap()), } } @@ -62,21 +62,21 @@ func (c *descCommand) Run(w io.Writer, args []string) error { return nil } -func presentTypeName(f *desc.FieldDescriptor) string { - typeName := f.GetType().String() +func presentTypeName(f protoreflect.FieldDescriptor) string { + typeName := f.Kind().String() - switch f.GetType() { - case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + switch f.Kind() { + case protoreflect.MessageKind: if f.IsMap() { typeName = fmt.Sprintf( "map<%s, %s>", - presentTypeName(f.GetMapKeyType()), - presentTypeName(f.GetMapValueType())) + presentTypeName(f.MapKey()), + presentTypeName(f.MapValue())) } else { - typeName += fmt.Sprintf(" (%s)", f.GetMessageType().GetName()) + typeName += fmt.Sprintf(" (%s)", f.Message().Name()) } - case descriptor.FieldDescriptorProto_TYPE_ENUM: - typeName += fmt.Sprintf(" (%s)", f.GetEnumType().GetName()) + case protoreflect.EnumKind: + typeName += fmt.Sprintf(" (%s)", f.Enum().Name()) } return typeName } diff --git a/repl/repl_test.go b/repl/repl_test.go index 8c7e1384..a82dc27b 100644 --- a/repl/repl_test.go +++ b/repl/repl_test.go @@ -11,8 +11,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ktr0731/evans/config" "github.com/ktr0731/evans/cui" - "github.com/ktr0731/evans/grpc" "github.com/ktr0731/evans/prompt" + "github.com/ktr0731/evans/proto" "github.com/ktr0731/evans/usecase" ) @@ -87,7 +87,6 @@ func TestREPL_makePrefix(t *testing.T) { cases := map[string]struct { pkgName string svcName string - RPCsErr error hasErr bool expected string @@ -107,16 +106,11 @@ func TestREPL_makePrefix(t *testing.T) { REPL: &config.REPL{}, Server: &config.Server{Host: "127.0.0.1", Port: "50051"}, } - dummySpec := &SpecMock{ - ServiceNamesFunc: func() []string { - return []string{"api.Example"} - }, - RPCsFunc: func(svcName string) ([]*grpc.RPC, error) { - return nil, c.RPCsErr - }, + dummyDescSource := &proto.DescriptorSourceMock{ + ListServicesFunc: func() ([]string, error) { return []string{"api.Example"}, nil }, } t.Run(name, func(t *testing.T) { - usecase.Inject(usecase.Dependencies{Spec: dummySpec}) + usecase.Inject(usecase.Dependencies{DescSource: dummyDescSource}) r, err := New(dummyCfg, prompt.New(), nil, c.pkgName, c.svcName) if c.hasErr { diff --git a/usecase/call_rpc.go b/usecase/call_rpc.go index ebce079e..c0193290 100644 --- a/usecase/call_rpc.go +++ b/usecase/call_rpc.go @@ -2,14 +2,13 @@ package usecase import ( "context" + "fmt" "io" "strings" "sync" "time" - "github.com/jhump/protoreflect/dynamic" - "github.com/ktr0731/evans/grpc" - "github.com/ktr0731/evans/idl/proto" + pb "github.com/ktr0731/evans/proto" "github.com/ktr0731/evans/fill" "github.com/ktr0731/evans/logger" @@ -19,6 +18,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) // ErrorCode represents an application error code. @@ -52,16 +54,19 @@ func CallRPC(ctx context.Context, w io.Writer, rpcName string) error { return dm.CallRPC(ctx, w, rpcName, false, dm.filler) } func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName string, rerunPrevious bool, filler fill.Filler) error { - fqsn := proto.FullyQualifiedServiceName(m.state.selectedPackage, m.state.selectedService) - rpc, err := m.spec.RPC(fqsn, rpcName) + fqsn := pb.FullyQualifiedServiceName(m.state.selectedPackage, m.state.selectedService) + d, err := m.descSource.FindSymbol(fmt.Sprintf("%s.%s", fqsn, rpcName)) if err != nil { return errors.Wrapf(err, "failed to get the RPC descriptor for: %s", rpcName) } - if rerunPrevious && rpc.IsClientStreaming { + + rpc := d.(protoreflect.MethodDescriptor) // TODO: handle "ok". + + if rerunPrevious && rpc.IsStreamingClient() { return errors.New("cannot rerun previous RPC as client/bidi streaming RPCs are not supported") } - newRequest := func() (interface{}, error) { - req := rpc.RequestType.New() + newRequest := func() (*dynamicpb.Message, error) { + req := dynamicpb.NewMessage(rpc.Input()) if !rerunPrevious { err = filler.Fill(req) if errors.Is(err, io.EOF) { @@ -70,18 +75,18 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st if err != nil { return nil, err } - if err := m.updateRPCCallState(rpcName, req); err != nil { + if err := m.updateMethodCallState(rpcName, req); err != nil { return nil, err } return req, nil } - if err = m.getPreviousRPCRequest(rpc, req.(*dynamic.Message)); err != nil { + if err = m.getPreviousRPCRequest(rpc, req); err != nil { return nil, err } return req, nil } newResponse := func() interface{} { - return rpc.ResponseType.New() + return dynamicpb.NewMessage(rpc.Output()) } flushHeader := func(header metadata.MD) { m.responseFormatter.FormatHeader(header) @@ -139,20 +144,20 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st } streamDesc := &gogrpc.StreamDesc{ - StreamName: rpc.Name, - ServerStreams: rpc.IsServerStreaming, - ClientStreams: rpc.IsClientStreaming, + StreamName: string(rpc.Name()), + ServerStreams: rpc.IsStreamingServer(), + ClientStreams: rpc.IsStreamingClient(), } switch { - case rpc.IsClientStreaming && rpc.IsServerStreaming: + case rpc.IsStreamingClient() && rpc.IsStreamingServer(): ctx, cancel, err := enhanceContext(ctx) if err != nil { cancel() return errors.Wrap(err, "failed to enhance context with metadata") } - stream, err := m.gRPCClient.NewBidiStream(ctx, streamDesc, rpc.FullyQualifiedName) + stream, err := m.gRPCClient.NewBidiStream(ctx, streamDesc, string(rpc.FullName())) if err != nil { cancel() return errors.Wrapf(err, "failed to create a bidi stream for RPC '%s'", streamDesc.StreamName) @@ -256,14 +261,14 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st // 5. Send the close message and receive the response. // 6. Format the response and output it. // - case rpc.IsClientStreaming: + case rpc.IsStreamingClient(): ctx, cancel, err := enhanceContext(ctx) if err != nil { cancel() return errors.Wrap(err, "failed to enhance context with metadata") } - stream, err := m.gRPCClient.NewClientStream(ctx, streamDesc, rpc.FullyQualifiedName) + stream, err := m.gRPCClient.NewClientStream(ctx, streamDesc, string(rpc.FullName())) if err != nil { cancel() return errors.Wrapf(err, "failed to create a new client stream for RPC '%s'", streamDesc.StreamName) @@ -317,7 +322,7 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st // 4. Format a received response and output it. // 5. If io.EOF received, finish the RPC connection. // - case rpc.IsServerStreaming: + case rpc.IsStreamingServer(): req, err := newRequest() if err != nil { return err @@ -329,7 +334,7 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st return errors.Wrap(err, "failed to enhance context with metadata") } - stream, err := m.gRPCClient.NewServerStream(ctx, streamDesc, rpc.FullyQualifiedName) + stream, err := m.gRPCClient.NewServerStream(ctx, streamDesc, string(rpc.FullName())) if err != nil { cancel() return errors.Wrapf(err, "failed to create a new server stream for RPC '%s'", streamDesc.StreamName) @@ -387,7 +392,7 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st } } - // If both of rpc.IsServerStreaming and rpc.IsClientStreaming are nil, it means its RPC is an unary RPC. + // If both of rpc.IsStreamingClient() and rpc.IsStreamingServer() are false, it means its RPC is an unary RPC. // Unary RPCs are processed by the following instruction. // // 1. Create a new request and fill input to it. @@ -408,7 +413,7 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st } res := newResponse() - header, trailer, err := m.gRPCClient.Invoke(ctx, rpc.FullyQualifiedName, req, res) + header, trailer, err := m.gRPCClient.Invoke(ctx, string(rpc.FullName()), req, res) stat, err := handleGRPCResponseError(err) if err != nil { cancel() @@ -431,33 +436,32 @@ func (m *dependencyManager) CallRPC(ctx context.Context, w io.Writer, rpcName st } } -// Gets a request with the body containing the payload of its previous RPC +// Gets a request with the body containing the payload of its previous method // Only RPCs that are repeatable by definition will have their previous requests returned. -func (m *dependencyManager) getPreviousRPCRequest(rpc *grpc.RPC, req *dynamic.Message) error { - id := rpcIdentifier(rpc.Name) +func (m *dependencyManager) getPreviousRPCRequest(method protoreflect.MethodDescriptor, req proto.Message) error { + id := rpcIdentifier(string(method.FullName())) if _, ok := m.state.rpcCallState[id]; !ok { - return errors.Errorf("no previous request exists for RPC: %s, please issue a normal request", id) + return errors.Errorf("no previous request exists for method: %s, please issue a normal request", id) } - if rpc.IsClientStreaming { - return errors.Errorf("cannot rerun previous RPC: %s as client/bidi streaming RPCs are not supported", id) + if method.IsStreamingClient() { + return errors.Errorf("cannot rerun previous method: %s as client/bidi streaming RPCs are not supported", id) } previousReqBytes := m.state.rpcCallState[id].requestPayload if previousReqBytes == nil { - return errors.Errorf("no previous request body exists for RPC: %s, please issue a normal request", id) + return errors.Errorf("no previous request body exists for method: %s, please issue a normal request", id) } - err := req.Unmarshal(previousReqBytes) + err := proto.Unmarshal(previousReqBytes, req) // TODO: Custom resolver. if err != nil { - return errors.Wrapf(err, "error while unmarshalling request for RPC: %s, please run without the --repeat option", id) + return errors.Wrapf(err, "error while unmarshalling request for method: %s, please run without the --repeat option", id) } return nil } -// Updates the last call state for the given RPC. This is done by serializing +// Updates the last call state for the given method. This is done by serializing // the request payload and store it into the state buffer indexed by the rpcName -// The RPC is repeatable only if it is not a client streaming rpc. -func (m *dependencyManager) updateRPCCallState(rpcName string, req interface{}) error { - message := req.(*dynamic.Message) - reqBytes, err := message.Marshal() +// The method is repeatable only if it is not a client streaming method. +func (m *dependencyManager) updateMethodCallState(rpcName string, req *dynamicpb.Message) error { + reqBytes, err := proto.Marshal(req) if err != nil { return err } @@ -471,10 +475,10 @@ func (m *dependencyManager) updateRPCCallState(rpcName string, req interface{}) } type interactiveFiller struct { - fillFunc func(v interface{}) error + fillFunc func(v *dynamicpb.Message) error } -func (f *interactiveFiller) Fill(v interface{}) error { +func (f *interactiveFiller) Fill(v *dynamicpb.Message) error { return f.fillFunc(v) } @@ -484,7 +488,7 @@ func CallRPCInteractively(ctx context.Context, w io.Writer, rpcName string, digM func (m *dependencyManager) CallRPCInteractively(ctx context.Context, w io.Writer, rpcName string, digManually, bytesAsBase64, bytesAsQuotedLiterals, bytesFromFile, rerunPrevious, addRepeatedManually bool) error { return m.CallRPC(ctx, w, rpcName, rerunPrevious, &interactiveFiller{ - fillFunc: func(v interface{}) error { + fillFunc: func(v *dynamicpb.Message) error { return m.interactiveFiller.Fill(v, fill.InteractiveFillerOpts{ DigManually: digManually, BytesAsBase64: bytesAsBase64, diff --git a/usecase/call_rpc_test.go b/usecase/call_rpc_test.go index 4a0de972..848eb65c 100644 --- a/usecase/call_rpc_test.go +++ b/usecase/call_rpc_test.go @@ -3,30 +3,29 @@ package usecase import ( "testing" - "github.com/ktr0731/evans/grpc" - - "github.com/jhump/protoreflect/dynamic" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" ) func TestGetPreviousRPCRequest(t *testing.T) { cases := map[string]struct { expectedError string - rpc *grpc.RPC + method protoreflect.MethodDescriptor rpcCallState map[rpcIdentifier]callState }{ "no previous request exists": { - rpc: getStubRPC(false), - expectedError: "no previous request exists for RPC: TestRPC, please issue a normal request", + method: getStubMethod(false), + expectedError: "no previous request exists for method: TestRPC, please issue a normal request", }, "previous request is client streaming": { rpcCallState: map[rpcIdentifier]callState{"TestRPC": {}}, - rpc: getStubRPC(true), - expectedError: "cannot rerun previous RPC: TestRPC as client/bidi streaming RPCs are not supported", + method: getStubMethod(true), + expectedError: "cannot rerun previous method: TestRPC as client/bidi streaming RPCs are not supported", }, "previous request bytes are nil": { rpcCallState: map[rpcIdentifier]callState{"TestRPC": {}}, - rpc: getStubRPC(false), - expectedError: "no previous request body exists for RPC: TestRPC, please issue a normal request", + method: getStubMethod(false), + expectedError: "no previous request body exists for method: TestRPC, please issue a normal request", }, } for name, c := range cases { @@ -37,8 +36,8 @@ func TestGetPreviousRPCRequest(t *testing.T) { rpcCallState: c.rpcCallState, }, } - var req *dynamic.Message - err := d.getPreviousRPCRequest(c.rpc, req) + var req proto.Message + err := d.getPreviousRPCRequest(c.method, req) if err == nil || err.Error() != c.expectedError { t.Errorf("expected error %s, but got %s", c.expectedError, err) } @@ -46,10 +45,18 @@ func TestGetPreviousRPCRequest(t *testing.T) { } } -func getStubRPC(clientStreaming bool) *grpc.RPC { - return &grpc.RPC{ - Name: "TestRPC", - IsServerStreaming: true, - IsClientStreaming: clientStreaming, +type stubMethod struct { + protoreflect.MethodDescriptor + + isStreamingClient bool +} + +func (m *stubMethod) FullName() protoreflect.FullName { return protoreflect.FullName("TestRPC") } +func (m *stubMethod) IsStreamingClient() bool { return m.isStreamingClient } +func (m *stubMethod) IsStreamingServer() bool { return true } + +func getStubMethod(clientStreaming bool) *stubMethod { + return &stubMethod{ + isStreamingClient: clientStreaming, } } diff --git a/usecase/format_descriptor.go b/usecase/format_descriptor.go index f01012bb..b8565d1f 100644 --- a/usecase/format_descriptor.go +++ b/usecase/format_descriptor.go @@ -2,8 +2,12 @@ package usecase import ( "fmt" + "strings" + "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/desc/protoprint" "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protodesc" ) // FormatDescriptor formats the descriptor of the passed symbol. @@ -11,13 +15,32 @@ func FormatDescriptor(symbol string) (string, error) { return dm.FormatDescriptor(symbol) } func (m *dependencyManager) FormatDescriptor(symbol string) (string, error) { - v, err := m.spec.ResolveSymbol(symbol) + d, err := m.descSource.FindSymbol(symbol) if err != nil { return "", errors.Wrapf(err, "failed to resolve symbol '%s'", symbol) } - out, err := m.spec.FormatDescriptor(v) + + fd, err := desc.CreateFileDescriptor(protodesc.ToFileDescriptorProto(d.ParentFile())) if err != nil { - return "", errors.Wrapf(err, "failed to format the descriptor of symbol '%s'", symbol) + return "", err + } + + jd := fd.FindSymbol(symbol) + if jd == nil { + return "", errors.New("symbol should be found") + } + + p := &protoprint.Printer{ + Compact: true, + ForceFullyQualifiedNames: true, + SortElements: true, } + str, err := p.PrintProtoToString(jd) + if err != nil { + return "", errors.Wrap(err, "failed to convert the descriptor to string") + } + + out := strings.TrimSpace(str) + return fmt.Sprintf("%s:\n%s", symbol, out), nil } diff --git a/usecase/format_messages.go b/usecase/format_messages.go index 1240ca89..6d37d790 100644 --- a/usecase/format_messages.go +++ b/usecase/format_messages.go @@ -11,7 +11,11 @@ func FormatMessages() (string, error) { return dm.FormatMessages() } func (m *dependencyManager) FormatMessages() (string, error) { - svcs := m.ListServicesOld() + svcs, err := m.ListServices() + if err != nil { + return "", err + } + type message struct { Message string `json:"message"` } diff --git a/usecase/format_method.go b/usecase/format_method.go index 13632133..6eb9417b 100644 --- a/usecase/format_method.go +++ b/usecase/format_method.go @@ -1,6 +1,8 @@ package usecase import ( + "fmt" + "github.com/pkg/errors" ) @@ -27,5 +29,5 @@ func (m *dependencyManager) FormatMethod(fqmn string) (string, error) { } return out, nil } - return "", errors.New("method is not found") + return "", fmt.Errorf("method is not found: %s", fqmn) } diff --git a/usecase/format_methods.go b/usecase/format_methods.go index 241f6894..67ba2199 100644 --- a/usecase/format_methods.go +++ b/usecase/format_methods.go @@ -3,7 +3,7 @@ package usecase import ( "sort" - "github.com/ktr0731/evans/idl/proto" + "github.com/ktr0731/evans/proto" "github.com/pkg/errors" ) diff --git a/usecase/format_packages.go b/usecase/format_packages.go index 5556d312..2a82fb59 100644 --- a/usecase/format_packages.go +++ b/usecase/format_packages.go @@ -11,7 +11,11 @@ func FormatPackages() (string, error) { return dm.FormatPackages() } func (m *dependencyManager) FormatPackages() (string, error) { - pkgs := m.ListPackages() + pkgs, err := m.ListPackages() + if err != nil { + return "", err + } + type pkg struct { Package string `json:"package"` } diff --git a/usecase/format_service_descriptors.go b/usecase/format_service_descriptors.go index 61f011b1..0bb59c4a 100644 --- a/usecase/format_service_descriptors.go +++ b/usecase/format_service_descriptors.go @@ -11,7 +11,11 @@ func FormatServiceDescriptors() (string, error) { return dm.FormatServiceDescriptors() } func (m *dependencyManager) FormatServiceDescriptors() (string, error) { - svcs := ListServices() + svcs, err := m.ListServices() + if err != nil { + return "", err + } + out := make([]string, 0, len(svcs)) for _, s := range svcs { o, err := FormatDescriptor(s) diff --git a/usecase/format_services.go b/usecase/format_services.go index 7c40998f..c74e2f34 100644 --- a/usecase/format_services.go +++ b/usecase/format_services.go @@ -11,7 +11,11 @@ func FormatServices() (string, error) { return dm.FormatServices() } func (m *dependencyManager) FormatServices() (string, error) { - fqsns := m.ListServices() + fqsns, err := m.ListServices() + if err != nil { + return "", err + } + type svc struct { Name string `json:"name" name:"target"` } diff --git a/usecase/format_services_old.go b/usecase/format_services_old.go deleted file mode 100644 index ea56cd36..00000000 --- a/usecase/format_services_old.go +++ /dev/null @@ -1,47 +0,0 @@ -package usecase - -import ( - "sort" - - "github.com/pkg/errors" -) - -// FormatServicesOld formats all package names. -// Deprecated: dropped in the next major release. -func FormatServicesOld() (string, error) { - return dm.FormatServicesOld() -} -func (m *dependencyManager) FormatServicesOld() (string, error) { - svcs := m.ListServicesOld() - type service struct { - Service string `json:"service"` - RPC string `json:"rpc"` - RequestType string `json:"request type" table:"request type"` - ResponseType string `json:"response type" table:"response type"` - } - var v struct { - Services []service `json:"services"` - } - for _, svc := range svcs { - rpcs, err := m.ListRPCs(svc) - if err != nil { - return "", errors.Wrapf(err, "failed to list RPCs associated with '%s'", svc) - } - for _, rpc := range rpcs { - v.Services = append(v.Services, service{ - svc, - rpc.Name, - rpc.RequestType.Name, - rpc.ResponseType.Name, - }) - } - } - sort.Slice(v.Services, func(i, j int) bool { - return v.Services[i].Service < v.Services[j].Service - }) - out, err := m.resourcePresenter.Format(v) - if err != nil { - return "", errors.Wrap(err, "failed to format service names by presenter") - } - return out, nil -} diff --git a/usecase/get_type_descriptor.go b/usecase/get_type_descriptor.go index 0f561909..211cd585 100644 --- a/usecase/get_type_descriptor.go +++ b/usecase/get_type_descriptor.go @@ -1,18 +1,26 @@ package usecase import ( - "github.com/ktr0731/evans/idl/proto" + "strings" + + "github.com/ktr0731/evans/proto" "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protoreflect" ) // GetTypeDescriptor gets the descriptor of a type which belongs to the currently selected package. -func GetTypeDescriptor(typeName string) (interface{}, error) { +func GetTypeDescriptor(typeName string) (protoreflect.Descriptor, error) { return dm.GetTypeDescriptor(typeName) } -func (m *dependencyManager) GetTypeDescriptor(typeName string) (interface{}, error) { +func (m *dependencyManager) GetTypeDescriptor(typeName string) (protoreflect.Descriptor, error) { pkgName := m.state.selectedPackage - fqmn := proto.FullyQualifiedMessageName(pkgName, typeName) - d, err := m.spec.ResolveSymbol(fqmn) + + fqmn := typeName + if !strings.HasPrefix(typeName, pkgName+".") { + fqmn = proto.FullyQualifiedMessageName(pkgName, typeName) + } + + d, err := m.descSource.FindSymbol(fqmn) if err != nil { return nil, errors.Wrapf(err, "failed to get the type descriptor of '%s'", typeName) } diff --git a/usecase/list_packages.go b/usecase/list_packages.go index 46976270..6ef404f1 100644 --- a/usecase/list_packages.go +++ b/usecase/list_packages.go @@ -3,30 +3,33 @@ package usecase import ( "sort" - "github.com/ktr0731/evans/idl/proto" + "github.com/ktr0731/evans/proto" ) // ListPackages lists all package names. -func ListPackages() []string { +func ListPackages() ([]string, error) { return dm.ListPackages() } -func (m *dependencyManager) ListPackages() []string { - svcNames := m.spec.ServiceNames() - encountered := make(map[string]interface{}) - toPackageName := func(svcName string) string { - pkg, _ := proto.ParseFullyQualifiedServiceName(svcName) - return pkg +func (m *dependencyManager) ListPackages() ([]string, error) { + pkgMap := map[string]struct{}{} + svcs, err := m.descSource.ListServices() + if err != nil { + return nil, err } - for _, svc := range svcNames { - encountered[toPackageName(svc)] = nil + + for _, s := range svcs { + pkg, _ := proto.ParseFullyQualifiedServiceName(s) + pkgMap[pkg] = struct{}{} } - pkgs := make([]string, 0, len(svcNames)) - for pkg := range encountered { + + pkgs := make([]string, 0, len(pkgMap)) + for pkg := range pkgMap { pkgs = append(pkgs, pkg) } sort.Slice(pkgs, func(i, j int) bool { return pkgs[i] < pkgs[j] }) - return pkgs + + return pkgs, nil } diff --git a/usecase/list_rpcs.go b/usecase/list_rpcs.go index e3d0330e..4bedcbaf 100644 --- a/usecase/list_rpcs.go +++ b/usecase/list_rpcs.go @@ -2,8 +2,10 @@ package usecase import ( "github.com/ktr0731/evans/grpc" - "github.com/ktr0731/evans/idl/proto" + "github.com/ktr0731/evans/proto" "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) // ListRPCs lists all RPC belong to the selected service. @@ -21,9 +23,43 @@ func (m *dependencyManager) ListRPCs(svcName string) ([]*grpc.RPC, error) { } func (m *dependencyManager) listRPCs(fqsn string) ([]*grpc.RPC, error) { - rpcs, err := m.spec.RPCs(fqsn) + var rpcs []*grpc.RPC + svcs, err := m.descSource.ListServices() if err != nil { - return nil, errors.Wrap(err, "failed to list RPCs") + return nil, err } + + for _, service := range svcs { + d, err := m.descSource.FindSymbol(service) + if err != nil { + return nil, errors.Wrapf(err, "failed to resolve service %s", service) + } + + sd := d.(protoreflect.ServiceDescriptor) // TODO: handle "ok". + for i := 0; i < sd.Methods().Len(); i++ { + md := sd.Methods().Get(i) + rpcs = append(rpcs, &grpc.RPC{ + Name: string(md.Name()), + FullyQualifiedName: string(md.FullName()), + RequestType: &grpc.Type{ + Name: string(md.Input().Name()), + FullyQualifiedName: string(md.Input().FullName()), + New: func() interface{} { + return dynamicpb.NewMessageType(md.Input()) + }, + }, + ResponseType: &grpc.Type{ + Name: string(md.Output().Name()), + FullyQualifiedName: string(md.Output().FullName()), + New: func() interface{} { + return dynamicpb.NewMessageType(md.Output()) + }, + }, + IsServerStreaming: md.IsStreamingServer(), + IsClientStreaming: md.IsStreamingClient(), + }) + } + } + return rpcs, nil } diff --git a/usecase/list_services.go b/usecase/list_services.go index 34c8d3e3..f4c965c8 100644 --- a/usecase/list_services.go +++ b/usecase/list_services.go @@ -1,13 +1,13 @@ package usecase // ListServices returns the loaded fully-qualified service names. -func ListServices() []string { +func ListServices() ([]string, error) { return dm.ListServices() } -func (m *dependencyManager) ListServices() []string { +func (m *dependencyManager) ListServices() ([]string, error) { return m.listServices() } -func (m *dependencyManager) listServices() []string { - return m.spec.ServiceNames() +func (m *dependencyManager) listServices() ([]string, error) { + return m.descSource.ListServices() } diff --git a/usecase/list_services_old.go b/usecase/list_services_old.go deleted file mode 100644 index 1d736435..00000000 --- a/usecase/list_services_old.go +++ /dev/null @@ -1,26 +0,0 @@ -package usecase - -import ( - "github.com/ktr0731/evans/idl/proto" -) - -// ListServicesOld returns the services belong to the selected package. -// The returned service names are NOT fully-qualified. -func ListServicesOld() []string { - return dm.ListServicesOld() -} -func (m *dependencyManager) ListServicesOld() []string { - return m.listServicesOld(m.state.selectedPackage) -} - -func (m *dependencyManager) listServicesOld(pkgName string) []string { - var svcs []string - svcNames := m.spec.ServiceNames() - for i := range svcNames { - pkg, svc := proto.ParseFullyQualifiedServiceName(svcNames[i]) - if pkg == pkgName { - svcs = append(svcs, svc) - } - } - return svcs -} diff --git a/usecase/parse_method.go b/usecase/parse_method.go index de0553d6..3ff8146c 100644 --- a/usecase/parse_method.go +++ b/usecase/parse_method.go @@ -1,8 +1,10 @@ package usecase import ( - "errors" "strings" + + "github.com/pkg/errors" + "google.golang.org/protobuf/reflect/protoreflect" ) // ParseFullyQualifiedMethodName parses the passed fully-qualified method as fully-qualified service name and method name. @@ -19,7 +21,14 @@ func (m *dependencyManager) ParseFullyQualifiedMethodName(fqmn string) (string, if i == -1 { return "", "", errors.New("invalid fully-qualified method name") } + v, err := m.descSource.FindSymbol(fqmn) + if err != nil { + return "", "", errors.Wrap(err, "failed to find the symbol") + } + if _, ok := v.(protoreflect.MethodDescriptor); !ok { + return "", "", errors.New("symbol is not method descriptor") + } + svc, mtd := fqmn[:i], fqmn[i+1:] - _, err := m.spec.RPC(svc, mtd) - return svc, mtd, err + return svc, mtd, nil } diff --git a/usecase/use_package.go b/usecase/use_package.go index cb6c8211..ef82b2aa 100644 --- a/usecase/use_package.go +++ b/usecase/use_package.go @@ -1,22 +1,25 @@ package usecase -import "github.com/ktr0731/evans/idl" - // UsePackage modifies pkgName as the currently selected package. // UsePackage may return these errors: // -// - idl.ErrUnknownPackageName: pkgName is not in loaded packages. +// - ErrUnknownPackageName: pkgName is not in loaded packages. // func UsePackage(pkgName string) error { return dm.UsePackage(pkgName) } func (m *dependencyManager) UsePackage(pkgName string) error { - for _, pkg := range ListPackages() { + pkgs, err := ListPackages() + if err != nil { + return err + } + + for _, pkg := range pkgs { if pkg == pkgName { m.state.selectedPackage = pkgName m.state.selectedService = "" return nil } } - return idl.ErrUnknownPackageName + return ErrUnknownPackageName } diff --git a/usecase/use_service.go b/usecase/use_service.go index 920b25cd..61a21f5e 100644 --- a/usecase/use_service.go +++ b/usecase/use_service.go @@ -1,8 +1,7 @@ package usecase import ( - "github.com/ktr0731/evans/idl" - "github.com/ktr0731/evans/idl/proto" + "github.com/ktr0731/evans/proto" "github.com/pkg/errors" ) @@ -19,20 +18,28 @@ func (m *dependencyManager) UseService(svcName string) error { if svcName == "" { return errors.Errorf("invalid service name '%s'", svcName) } + + fqsns, err := m.descSource.ListServices() + if err != nil { + return err + } + var hasPackage bool - for _, fqsn := range m.spec.ServiceNames() { + for _, fqsn := range fqsns { pkg, svc := proto.ParseFullyQualifiedServiceName(fqsn) if m.state.selectedPackage == pkg { hasPackage = true - if svcName == svc { - m.state.selectedService = svcName + // Keep backward-compatibility. + // TODO: Delete package related code after releasing v1.0.0. + if svcName == svc || svcName == fqsn { + m.state.selectedService = svc return nil } } } if hasPackage { - return idl.ErrUnknownServiceName + return ErrUnknownServiceName } // In the case of empty package. - return idl.ErrPackageUnselected + return ErrPackageUnselected } diff --git a/usecase/usecase.go b/usecase/usecase.go index c8a60393..8fcdb6cb 100644 --- a/usecase/usecase.go +++ b/usecase/usecase.go @@ -6,8 +6,19 @@ import ( "github.com/ktr0731/evans/fill" "github.com/ktr0731/evans/format" "github.com/ktr0731/evans/grpc" - "github.com/ktr0731/evans/idl" "github.com/ktr0731/evans/present" + "github.com/ktr0731/evans/proto" + "github.com/pkg/errors" +) + +var ( + ErrPackageUnselected = errors.New("package unselected") + ErrServiceUnselected = errors.New("service unselected") + + ErrUnknownPackageName = errors.New("unknown package name") + ErrUnknownServiceName = errors.New("unknown service name") + ErrUnknownRPCName = errors.New("unknown RPC name") + ErrUnknownSymbol = errors.New("unknown symbol") ) var ( @@ -16,7 +27,7 @@ var ( ) type dependencyManager struct { - spec idl.Spec + descSource proto.DescriptorSource filler fill.Filler interactiveFiller fill.InteractiveFiller gRPCClient grpc.Client @@ -39,7 +50,7 @@ type callState struct { } type Dependencies struct { - Spec idl.Spec + DescSource proto.DescriptorSource Filler fill.Filler InteractiveFiller fill.InteractiveFiller GRPCClient grpc.Client @@ -54,7 +65,7 @@ func Inject(deps Dependencies) { func (m *dependencyManager) Inject(d Dependencies) { dm = &dependencyManager{ - spec: d.Spec, + descSource: d.DescSource, filler: d.Filler, interactiveFiller: d.InteractiveFiller, gRPCClient: d.GRPCClient, @@ -71,8 +82,8 @@ func InjectPartially(deps Dependencies) { } func (m *dependencyManager) InjectPartially(d Dependencies) { - if d.Spec != nil { - m.spec = d.Spec + if d.DescSource != nil { + m.descSource = d.DescSource } if d.Filler != nil { m.filler = d.Filler