/
protodesc.go
135 lines (115 loc) · 3.69 KB
/
protodesc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package protodesc
import (
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
)
// GetMethodDescFromProto gets method descritor for the given call symbol from proto file given my path proto
// imports is used for import paths in parsing the proto file
func GetMethodDescFromProto(call, proto string, imports []string) (*desc.MethodDescriptor, error) {
p := &protoparse.Parser{ImportPaths: imports}
filename := proto
if filepath.IsAbs(filename) {
filename = filepath.Base(proto)
}
fds, err := p.ParseFiles(filename)
if err != nil {
return nil, err
}
fileDesc := fds[0]
files := map[string]*desc.FileDescriptor{}
files[fileDesc.GetName()] = fileDesc
return getMethodDesc(call, files)
}
// GetMethodDescFromProtoSet gets method descritor for the given call symbol from protoset file given my path protoset
func GetMethodDescFromProtoSet(call, protoset string) (*desc.MethodDescriptor, error) {
b, err := ioutil.ReadFile(protoset)
if err != nil {
return nil, fmt.Errorf("could not load protoset file %q: %v", protoset, err)
}
var fds descriptor.FileDescriptorSet
err = proto.Unmarshal(b, &fds)
if err != nil {
return nil, fmt.Errorf("could not parse contents of protoset file %q: %v", protoset, err)
}
unresolved := map[string]*descriptor.FileDescriptorProto{}
for _, fd := range fds.File {
unresolved[fd.GetName()] = fd
}
resolved := map[string]*desc.FileDescriptor{}
for _, fd := range fds.File {
_, err := resolveFileDescriptor(unresolved, resolved, fd.GetName())
if err != nil {
return nil, err
}
}
return getMethodDesc(call, resolved)
}
func getMethodDesc(call string, files map[string]*desc.FileDescriptor) (*desc.MethodDescriptor, error) {
svc, mth := parseSymbol(call)
if svc == "" || mth == "" {
return nil, fmt.Errorf("given method name %q is not in expected format: 'service/method' or 'service.method'", call)
}
dsc, err := findServiceSymbol(files, svc)
if err != nil {
return nil, err
}
if dsc == nil {
return nil, fmt.Errorf("cannot find service %q", svc)
}
sd, ok := dsc.(*desc.ServiceDescriptor)
if !ok {
return nil, fmt.Errorf("cannot find service %q", svc)
}
mtd := sd.FindMethodByName(mth)
if mtd == nil {
return nil, fmt.Errorf("service %q does not include a method named %q", svc, mth)
}
return mtd, nil
}
func resolveFileDescriptor(unresolved map[string]*descriptor.FileDescriptorProto, resolved map[string]*desc.FileDescriptor, filename string) (*desc.FileDescriptor, error) {
if r, ok := resolved[filename]; ok {
return r, nil
}
fd, ok := unresolved[filename]
if !ok {
return nil, fmt.Errorf("no descriptor found for %q", filename)
}
deps := make([]*desc.FileDescriptor, 0, len(fd.GetDependency()))
for _, dep := range fd.GetDependency() {
depFd, err := resolveFileDescriptor(unresolved, resolved, dep)
if err != nil {
return nil, err
}
deps = append(deps, depFd)
}
result, err := desc.CreateFileDescriptor(fd, deps...)
if err != nil {
return nil, err
}
resolved[filename] = result
return result, nil
}
func findServiceSymbol(resolved map[string]*desc.FileDescriptor, fullyQualifiedName string) (desc.Descriptor, error) {
for _, fd := range resolved {
if dsc := fd.FindSymbol(fullyQualifiedName); dsc != nil {
return dsc, nil
}
}
return nil, fmt.Errorf("cannot find service %q", fullyQualifiedName)
}
func parseSymbol(svcAndMethod string) (string, string) {
pos := strings.LastIndex(svcAndMethod, "/")
if pos < 0 {
pos = strings.LastIndex(svcAndMethod, ".")
if pos < 0 {
return "", ""
}
}
return svcAndMethod[:pos], svcAndMethod[pos+1:]
}