/
load.go
264 lines (235 loc) · 8.1 KB
/
load.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
package desc
import (
"fmt"
"reflect"
"sync"
"github.com/golang/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"
"github.com/jhump/protoreflect/desc/sourceinfo"
"github.com/jhump/protoreflect/internal"
)
// The global cache is used to store descriptors that wrap items in
// protoregistry.GlobalTypes and protoregistry.GlobalFiles. This prevents
// repeating work to re-wrap underlying global descriptors.
var (
// We put all wrapped file and message descriptors in this cache.
loadedDescriptors = lockingCache{cache: mapCache{}}
// Unfortunately, we need a different mechanism for enums for
// compatibility with old APIs, which required that they were
// registered in a different way :(
loadedEnumsMu sync.RWMutex
loadedEnums = map[reflect.Type]*EnumDescriptor{}
)
// LoadFileDescriptor creates a file descriptor using the bytes returned by
// proto.FileDescriptor. Descriptors are cached so that they do not need to be
// re-processed if the same file is fetched again later.
func LoadFileDescriptor(file string) (*FileDescriptor, error) {
d, err := sourceinfo.GlobalFiles.FindFileByPath(file)
if err == protoregistry.NotFound {
// for backwards compatibility, see if this matches a known old
// alias for the file (older versions of libraries that registered
// the files using incorrect/non-canonical paths)
if alt := internal.StdFileAliases[file]; alt != "" {
d, err = sourceinfo.GlobalFiles.FindFileByPath(alt)
}
}
if err != nil {
if err != protoregistry.NotFound {
return nil, internal.ErrNoSuchFile(file)
}
return nil, err
}
if fd := loadedDescriptors.get(d); fd != nil {
return fd.(*FileDescriptor), nil
}
var fd *FileDescriptor
loadedDescriptors.withLock(func(cache descriptorCache) {
// double-check cache, in case it was concurrently added while
// we were waiting for the lock
f := cache.get(d)
if f != nil {
fd = f.(*FileDescriptor)
return
}
fd, err = wrapFile(d, cache)
})
return fd, err
}
// LoadMessageDescriptor loads descriptor using the encoded descriptor proto returned by
// Message.Descriptor() for the given message type. If the given type is not recognized,
// then a nil descriptor is returned.
func LoadMessageDescriptor(message string) (*MessageDescriptor, error) {
mt, err := sourceinfo.GlobalTypes.FindMessageByName(protoreflect.FullName(message))
if err != nil {
if err == protoregistry.NotFound {
return nil, nil
}
return nil, err
}
return loadMessageDescriptor(mt.Descriptor())
}
func loadMessageDescriptor(md protoreflect.MessageDescriptor) (*MessageDescriptor, error) {
d := loadedDescriptors.get(md)
if d != nil {
return d.(*MessageDescriptor), nil
}
var err error
loadedDescriptors.withLock(func(cache descriptorCache) {
d, err = wrapMessage(md, cache)
})
if err != nil {
return nil, err
}
return d.(*MessageDescriptor), err
}
// LoadMessageDescriptorForType loads descriptor using the encoded descriptor proto returned
// by message.Descriptor() for the given message type. If the given type is not recognized,
// then a nil descriptor is returned.
func LoadMessageDescriptorForType(messageType reflect.Type) (*MessageDescriptor, error) {
m, err := messageFromType(messageType)
if err != nil {
return nil, err
}
return LoadMessageDescriptorForMessage(m)
}
// LoadMessageDescriptorForMessage loads descriptor using the encoded descriptor proto
// returned by message.Descriptor(). If the given type is not recognized, then a nil
// descriptor is returned.
func LoadMessageDescriptorForMessage(message proto.Message) (*MessageDescriptor, error) {
// efficiently handle dynamic messages
type descriptorable interface {
GetMessageDescriptor() *MessageDescriptor
}
if d, ok := message.(descriptorable); ok {
return d.GetMessageDescriptor(), nil
}
var md protoreflect.MessageDescriptor
if m, ok := message.(protoreflect.ProtoMessage); ok {
md = m.ProtoReflect().Descriptor()
} else {
md = proto.MessageReflect(message).Descriptor()
}
return loadMessageDescriptor(sourceinfo.WrapMessage(md))
}
func messageFromType(mt reflect.Type) (proto.Message, error) {
if mt.Kind() != reflect.Ptr {
mt = reflect.PtrTo(mt)
}
m, ok := reflect.Zero(mt).Interface().(proto.Message)
if !ok {
return nil, fmt.Errorf("failed to create message from type: %v", mt)
}
return m, nil
}
// interface implemented by all generated enums
type protoEnum interface {
EnumDescriptor() ([]byte, []int)
}
// NB: There is no LoadEnumDescriptor that takes a fully-qualified enum name because
// it is not useful since protoc-gen-go does not expose the name anywhere in generated
// code or register it in a way that is it accessible for reflection code. This also
// means we have to cache enum descriptors differently -- we can only cache them as
// they are requested, as opposed to caching all enum types whenever a file descriptor
// is cached. This is because we need to know the generated type of the enums, and we
// don't know that at the time of caching file descriptors.
// LoadEnumDescriptorForType loads descriptor using the encoded descriptor proto returned
// by enum.EnumDescriptor() for the given enum type.
func LoadEnumDescriptorForType(enumType reflect.Type) (*EnumDescriptor, error) {
// we cache descriptors using non-pointer type
if enumType.Kind() == reflect.Ptr {
enumType = enumType.Elem()
}
e := getEnumFromCache(enumType)
if e != nil {
return e, nil
}
enum, err := enumFromType(enumType)
if err != nil {
return nil, err
}
return loadEnumDescriptor(enumType, enum)
}
func getEnumFromCache(t reflect.Type) *EnumDescriptor {
loadedEnumsMu.RLock()
defer loadedEnumsMu.RUnlock()
return loadedEnums[t]
}
func putEnumInCache(t reflect.Type, d *EnumDescriptor) {
loadedEnumsMu.Lock()
defer loadedEnumsMu.Unlock()
loadedEnums[t] = d
}
// LoadEnumDescriptorForEnum loads descriptor using the encoded descriptor proto
// returned by enum.EnumDescriptor().
func LoadEnumDescriptorForEnum(enum protoEnum) (*EnumDescriptor, error) {
et := reflect.TypeOf(enum)
// we cache descriptors using non-pointer type
if et.Kind() == reflect.Ptr {
et = et.Elem()
enum = reflect.Zero(et).Interface().(protoEnum)
}
e := getEnumFromCache(et)
if e != nil {
return e, nil
}
return loadEnumDescriptor(et, enum)
}
func enumFromType(et reflect.Type) (protoEnum, error) {
e, ok := reflect.Zero(et).Interface().(protoEnum)
if !ok {
if et.Kind() != reflect.Ptr {
et = et.Elem()
}
e, ok = reflect.Zero(et).Interface().(protoEnum)
}
if !ok {
return nil, fmt.Errorf("failed to create enum from type: %v", et)
}
return e, nil
}
func getDescriptorForEnum(enum protoEnum) (*descriptorpb.FileDescriptorProto, []int, error) {
fdb, path := enum.EnumDescriptor()
name := fmt.Sprintf("%T", enum)
fd, err := internal.DecodeFileDescriptor(name, fdb)
return fd, path, err
}
func loadEnumDescriptor(et reflect.Type, enum protoEnum) (*EnumDescriptor, error) {
fdp, path, err := getDescriptorForEnum(enum)
if err != nil {
return nil, err
}
fd, err := LoadFileDescriptor(fdp.GetName())
if err != nil {
return nil, err
}
ed := findEnum(fd, path)
putEnumInCache(et, ed)
return ed, nil
}
func findEnum(fd *FileDescriptor, path []int) *EnumDescriptor {
if len(path) == 1 {
return fd.GetEnumTypes()[path[0]]
}
md := fd.GetMessageTypes()[path[0]]
for _, i := range path[1 : len(path)-1] {
md = md.GetNestedMessageTypes()[i]
}
return md.GetNestedEnumTypes()[path[len(path)-1]]
}
// LoadFieldDescriptorForExtension loads the field descriptor that corresponds to the given
// extension description.
func LoadFieldDescriptorForExtension(ext *proto.ExtensionDesc) (*FieldDescriptor, error) {
file, err := LoadFileDescriptor(ext.Filename)
if err != nil {
return nil, err
}
field, ok := file.FindSymbol(ext.Name).(*FieldDescriptor)
// make sure descriptor agrees with attributes of the ExtensionDesc
if !ok || !field.IsExtension() || field.GetOwner().GetFullyQualifiedName() != proto.MessageName(ext.ExtendedType) ||
field.GetNumber() != ext.Field {
return nil, fmt.Errorf("file descriptor contained unexpected object with name %s", ext.Name)
}
return field, nil
}