-
-
Notifications
You must be signed in to change notification settings - Fork 265
/
protodesc.go
202 lines (174 loc) · 5.75 KB
/
protodesc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
package protodesc
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
"github.com/jhump/protoreflect/grpcreflect"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var errNoMethodNameSpecified = errors.New("no method name specified")
// GetMethodDescFromProto gets method descriptor for the given call symbol from proto file given my path proto
// imports is used for import paths in parsing the proto file
func GetMethodDescFromProto(call, proto string, imports []string) (*desc.MethodDescriptor, error) {
p := &protoparse.Parser{ImportPaths: imports}
filename := proto
if filepath.IsAbs(filename) {
filename = filepath.Base(proto)
}
fds, err := p.ParseFiles(filename)
if err != nil {
return nil, err
}
fileDesc := fds[0]
files := map[string]*desc.FileDescriptor{}
files[fileDesc.GetName()] = fileDesc
return getMethodDesc(call, files)
}
// GetMethodDescFromProtoSet gets method descriptor for the given call symbol from protoset file given my path protoset
func GetMethodDescFromProtoSet(call, protoset string) (*desc.MethodDescriptor, error) {
b, err := os.ReadFile(protoset)
if err != nil {
return nil, fmt.Errorf("could not load protoset file %q: %v", protoset, err)
}
res, err := GetMethodDescFromProtoSetBinary(call, b)
if err != nil && strings.Contains(err.Error(), "could not parse contents of protoset binary") {
return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", protoset, err)
}
return res, err
}
// GetMethodDescFromProtoSetBinary gets method descriptor for the given call symbol from protoset binary
func GetMethodDescFromProtoSetBinary(call string, b []byte) (*desc.MethodDescriptor, error) {
var fds descriptor.FileDescriptorSet
err := proto.Unmarshal(b, &fds)
if err != nil {
return nil, fmt.Errorf("could not parse contents of protoset binary: %v", err)
}
unresolved := map[string]*descriptor.FileDescriptorProto{}
for _, fd := range fds.File {
unresolved[fd.GetName()] = fd
}
resolved := map[string]*desc.FileDescriptor{}
for _, fd := range fds.File {
_, err := resolveFileDescriptor(unresolved, resolved, fd.GetName())
if err != nil {
return nil, err
}
}
return getMethodDesc(call, resolved)
}
// GetMethodDescFromReflect gets method descriptor for the call from reflection using client
func GetMethodDescFromReflect(call string, client *grpcreflect.Client) (*desc.MethodDescriptor, error) {
call = strings.Replace(call, "/", ".", -1)
file, err := client.FileContainingSymbol(call)
if err != nil || file == nil {
return nil, reflectionSupport(err)
}
files := map[string]*desc.FileDescriptor{}
files[file.GetName()] = file
return getMethodDesc(call, files)
}
func getMethodDesc(call string, files map[string]*desc.FileDescriptor) (*desc.MethodDescriptor, error) {
svc, mth, err := parseServiceMethod(call)
if err != nil {
return nil, err
}
dsc, err := findServiceSymbol(files, svc)
if err != nil {
return nil, err
}
if dsc == nil {
return nil, fmt.Errorf("cannot find service %q", svc)
}
sd, ok := dsc.(*desc.ServiceDescriptor)
if !ok {
return nil, fmt.Errorf("cannot find service %q", svc)
}
mtd := sd.FindMethodByName(mth)
if mtd == nil {
return nil, fmt.Errorf("service %q does not include a method named %q", svc, mth)
}
return mtd, nil
}
func resolveFileDescriptor(unresolved map[string]*descriptor.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) {
if r, ok := resolved[filename]; ok {
return r, nil
}
fd, ok := unresolved[filename]
if !ok {
return nil, fmt.Errorf("no descriptor found for %q", filename)
}
deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
for _, dep := range fd.GetDependency() {
depFd, err := resolveFileDescriptor(unresolved, resolved, dep)
if err != nil {
return nil, err
}
deps = append(deps, depFd)
}
result, err := desc.CreateFileDescriptor(fd, deps...)
if err != nil {
return nil, err
}
resolved[filename] = result
return result, nil
}
func findServiceSymbol(resolved map[string]*desc.FileDescriptor, fullyQualifiedName string) (desc.Descriptor, error) {
for _, fd := range resolved {
if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
return dsc, nil
}
}
return nil, fmt.Errorf("cannot find service %q", fullyQualifiedName)
}
// parseServiceMethod parses the fully-qualified service name without a leading "."
// and the method name from the input string.
//
// valid inputs:
//
// package.Service.Method
// .package.Service.Method
// package.Service/Method
// .package.Service/Method
func parseServiceMethod(svcAndMethod string) (string, string, error) {
if len(svcAndMethod) == 0 {
return "", "", errNoMethodNameSpecified
}
if svcAndMethod[0] == '.' {
svcAndMethod = svcAndMethod[1:]
}
if len(svcAndMethod) == 0 {
return "", "", errNoMethodNameSpecified
}
switch strings.Count(svcAndMethod, "/") {
case 0:
pos := strings.LastIndex(svcAndMethod, ".")
if pos < 0 {
return "", "", newInvalidMethodNameError(svcAndMethod)
}
return svcAndMethod[:pos], svcAndMethod[pos+1:], nil
case 1:
split := strings.Split(svcAndMethod, "/")
return split[0], split[1], nil
default:
return "", "", newInvalidMethodNameError(svcAndMethod)
}
}
func newInvalidMethodNameError(svcAndMethod string) error {
return fmt.Errorf("method name must be package.Service.Method or package.Service/Method: %q", svcAndMethod)
}
func reflectionSupport(err error) error {
if err == nil {
return nil
}
if stat, ok := status.FromError(err); ok && stat.Code() == codes.Unimplemented {
return errors.New("server does not support the reflection API")
}
return err
}