-
Notifications
You must be signed in to change notification settings - Fork 18
/
update.go
104 lines (98 loc) · 2.91 KB
/
update.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package fieldmask
import (
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
// Update updates fields in dst with values from src according to the provided field mask.
// Nested messages are recursively updated in the same manner.
// Repeated fields and maps are copied by reference from src to dst.
// Field mask paths referring to Individual entries in maps or
// repeated fields are ignored.
//
// If no update mask is provided, only non-zero values of src are copied to dst.
// If the special value "*" is provided as the field mask, a full replacement of all fields in dst is done.
//
// See: https://google.aip.dev/134 (Standard methods: Update).
func Update(mask *fieldmaskpb.FieldMask, dst, src proto.Message) {
dstReflect := dst.ProtoReflect()
srcReflect := src.ProtoReflect()
if dstReflect.Descriptor() != srcReflect.Descriptor() {
panic(fmt.Sprintf(
"dst (%s) and src (%s) messages have different types",
dstReflect.Descriptor().FullName(),
srcReflect.Descriptor().FullName(),
))
}
switch {
// Special-case: No update mask.
// Update all fields of src that are set on the wire.
case len(mask.GetPaths()) == 0:
updateWireSetFields(dstReflect, srcReflect)
// Special-case: Update mask is [*].
// Do a full replacement of all fields.
case IsFullReplacement(mask):
proto.Reset(dst)
proto.Merge(dst, src)
default:
for _, path := range mask.GetPaths() {
segments := strings.Split(path, ".")
updateNamedField(dstReflect, srcReflect, segments)
}
}
}
func updateWireSetFields(dst, src protoreflect.Message) {
src.Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool {
switch {
case field.IsList():
dst.Set(field, value)
case field.IsMap():
dst.Set(field, value)
case field.Message() != nil && !dst.Has(field):
dst.Set(field, value)
case field.Message() != nil:
updateWireSetFields(dst.Get(field).Message(), value.Message())
default:
dst.Set(field, value)
}
return true
})
}
func updateNamedField(dst, src protoreflect.Message, segments []string) {
if len(segments) == 0 {
return
}
field := src.Descriptor().Fields().ByName(protoreflect.Name(segments[0]))
if field == nil {
// no known field by that name
return
}
// a named field in this message
if len(segments) == 1 {
if !src.Has(field) {
dst.Clear(field)
} else {
dst.Set(field, src.Get(field))
}
return
}
// a named field in a nested message
switch {
case field.IsList(), field.IsMap():
// nested fields in repeated or map not supported
return
case field.Message() != nil:
// if message field is not set, allocate an empty value
if !dst.Has(field) {
dst.Set(field, dst.NewField(field))
}
if !src.Has(field) {
src.Set(field, src.NewField(field))
}
updateNamedField(dst.Get(field).Message(), src.Get(field).Message(), segments[1:])
default:
return
}
}