/
message.go
116 lines (107 loc) · 3.34 KB
/
message.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package protomock
import (
"github.com/nlm/protoc-gen-mock/pkg/pb/mockpb"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
const mockMaxDepth = 100
const mockDefaultFieldRepeat = 3
// getRepetitions inspects the field to check if
func getRepetitions(field protoreflect.FieldDescriptor) int {
mr := proto.GetExtension(field.Options(), mockpb.E_Rules).(*mockpb.MockRules)
if field.IsMap() && mr.GetMap().GetRepeat() > 0 {
return int(mr.GetMap().GetRepeat())
}
if field.IsList() && mr.GetList().GetRepeat() > 0 {
return int(mr.GetList().GetRepeat())
}
return mockDefaultFieldRepeat
}
// newMessage creates a new protobuf message from a MessageDescsriptor
func newMessage(desc protoreflect.MessageDescriptor) protoreflect.ProtoMessage {
mt, err := protoregistry.GlobalTypes.FindMessageByName(desc.FullName())
if err != nil {
// We should never panic, as if the field is present in the message,
// it should be present in the registry
panic(err)
}
return mt.New().Interface()
}
// mockList mocks a list field
func mockList(msg proto.Message, field protoreflect.FieldDescriptor, depth int) {
lst := msg.ProtoReflect().Mutable(field).List()
// check tags for repetition
switch field.Kind() {
// FIXME: other kinds
case protoreflect.MessageKind:
for i := 0; i < getRepetitions(field); i++ {
sm := lst.AppendMutable()
mockMessage(sm.Message().Interface(), 0)
}
default:
for i := 0; i < getRepetitions(field); i++ {
lst.Append(mockScalar(field, field.Options()))
}
}
}
// mockMap mocks a map field
func mockMap(msg proto.Message, field protoreflect.FieldDescriptor, depth int) {
mp := msg.ProtoReflect().Mutable(field).Map()
for i := 0; i < getRepetitions(field); i++ {
// Key
var mapKey protoreflect.MapKey
switch field.MapKey().Kind() {
case protoreflect.MessageKind:
// actually cannot happen
m := newMessage(field.Message())
mockMessage(m, depth)
mapKey = protoreflect.MapKey(protoreflect.ValueOfMessage(m.ProtoReflect()))
default:
mapKey = protoreflect.MapKey(mockScalar(field.MapKey(), field.Options()))
}
// Value
switch field.MapValue().Kind() {
case protoreflect.MessageKind:
mapValue := mp.Mutable(mapKey)
mockMessage(mapValue.Message().Interface(), depth)
// TODO: list / map ?
default:
mapValue := mockScalar(field.MapValue(), field.Options())
mp.Set(mapKey, mapValue)
}
}
}
// mockUnary mocks an unary message type
func mockUnary(msg proto.Message, field protoreflect.FieldDescriptor, depth int) {
switch field.Kind() {
case protoreflect.MessageKind:
sm := newMessage(field.Message())
mockMessage(sm, depth)
msg.ProtoReflect().Set(field, protoreflect.ValueOf(sm.ProtoReflect()))
default:
msg.ProtoReflect().Set(field, mockScalar(field, field.Options()))
}
}
// mockField mocks a field
func mockField(msg proto.Message, field protoreflect.FieldDescriptor, depth int) {
switch {
case field.IsList():
mockList(msg, field, depth)
case field.IsMap():
mockMap(msg, field, depth)
default:
mockUnary(msg, field, depth)
}
}
// mockMessage mocks a Message field
func mockMessage(msg proto.Message, depth int) {
if depth >= mockMaxDepth {
return
}
fields := msg.ProtoReflect().Descriptor().Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
mockField(msg, field, depth+1)
}
}