mirrored from https://chromium.googlesource.com/infra/luci/luci-go
/
mask.go
464 lines (434 loc) · 16.4 KB
/
mask.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
// Copyright 2020 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mask provides utility functions for google protobuf field mask
//
// Supports advanced field mask semantics:
// - Refer to fields and map keys using . literals:
// - Supported map key types: string, integer, bool. (double, float, enum,
// and bytes keys are not supported by protobuf or this implementation)
// - Fields: "publisher.name" means field "name" of field "publisher"
// - String map keys: "metadata.year" means string key 'year' of map field
// metadata
// - Integer map keys (e.g. int32): 'year_ratings.0' means integer key 0 of
// a map field year_ratings
// - Bool map keys: 'access_text.true' means boolean key true of a map field
// access_text
// - String map keys that cannot be represented as an unquoted string literal,
// must be quoted using backticks: metadata.`year.published`, metadata.`17`,
// metadata.``. Backtick can be escaped with ``: a.`b``c` means map key "b`c"
// of map field a.
// - Refer to all map keys using a * literal: "topics.*.archived" means field
// "archived" of all map values of map field "topic".
// - Refer to all elements of a repeated field using a * literal: authors.*.name
// - Refer to all fields of a message using * literal: publisher.*.
// - Prohibit addressing a single element in repeated fields: authors.0.name
//
// FieldMask.paths string grammar:
// path = segment {'.' segment}
// segment = literal | star | quoted_string;
// literal = string | integer | boolean
// string = (letter | '_') {letter | '_' | digit}
// integer = ['-'] digit {digit};
// boolean = 'true' | 'false';
// quoted_string = '`' { utf8-no-backtick | '``' } '`'
// star = '*'
package mask
import (
"fmt"
"sort"
"strings"
"github.com/golang/protobuf/proto"
"google.golang.org/genproto/protobuf/field_mask"
protoV2 "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"go.chromium.org/luci/common/data/stringset"
)
// Mask is a tree representation of a field Mask. Serves as a tree node too.
// Each node represents a segment of a path, e.g. "bar" in "foo.bar.qux".
// A Field Mask with paths ["a","b.c"] is parsed as
// <root>
// / \
// "a" "b"
// /
// "c"
//
// Zero value is not valid. Use IsEmpty() to check if the mask is zero.
type Mask struct {
// descriptor is the proto descriptor of the message of the field this node
// represents. If the field kind is not a message, then descriptor is nil and
// the node must be a leaf unless isRepeated is true which denotes a repeated
// scalar field.
descriptor protoreflect.MessageDescriptor
// isRepeated indicates whether the segment represents a repeated field or
// not. Children of this node are the field elements.
isRepeated bool
// children maps segments to its node. e.g. children of the root in the
// example above has keys "a" and "b", and values are Mask objects and the
// Mask object "b" maps to will have a single child "c". All types of segment
// (i.e. int, bool, string, star) will be converted to string.
children map[string]Mask
}
// FromFieldMask parses a field mask to a mask.
//
// Trailing stars will be removed, e.g. parses ['a.*'] as ['a'].
// Redundant paths will be removed, e.g. parses ['a', 'a.b'] as ['a'].
//
// If isFieldNameJSON is set to true, json name will be used instead of
// canonical name defined in proto during parsing (e.g. "fooBar" instead of
// "foo_bar"). However, the child field name in return mask will always be
// in canonical form.
//
// If isUpdateMask is set to true, a repeated field is allowed only as the last
// field in a path string.
func FromFieldMask(fieldMask *field_mask.FieldMask, targetMsg proto.Message, isFieldNameJSON bool, isUpdateMask bool) (Mask, error) {
descriptor := proto.MessageReflect(targetMsg).Descriptor()
parsedPaths := make([]path, len(fieldMask.GetPaths()))
for i, p := range fieldMask.GetPaths() {
parsedPath, err := parsePath(p, descriptor, isFieldNameJSON)
if err != nil {
return Mask{}, err
}
parsedPaths[i] = parsedPath
}
return fromParsedPaths(parsedPaths, descriptor, isUpdateMask)
}
// MustFromReadMask is a shortcut FromFieldMask with isFieldNameJSON and
// isUpdateMask as false, that accepts field mask a variadic paths and
// that panics if the mask is invalid.
// It is useful when the mask is hardcoded.
func MustFromReadMask(targetMsg proto.Message, paths ...string) Mask {
ret, err := FromFieldMask(&field_mask.FieldMask{Paths: paths}, targetMsg, false, false)
if err != nil {
panic(err)
}
return ret
}
// All returns a field mask that selects all fields.
func All(targetMsg proto.Message) Mask {
return MustFromReadMask(targetMsg, "*")
}
// fromParsedPaths constructs a mask tree from a slice of parsed paths.
func fromParsedPaths(parsedPaths []path, desc protoreflect.MessageDescriptor, isUpdateMask bool) (Mask, error) {
root := Mask{
descriptor: desc,
children: map[string]Mask{},
}
for _, p := range normalizePaths(parsedPaths) {
curNode := root
curNodeName := ""
for _, seg := range p {
if curNode.isRepeated && isUpdateMask {
return Mask{}, fmt.Errorf("update mask allows a repeated field only at the last position; field: %s is not last", curNodeName)
}
if _, ok := curNode.children[seg]; !ok {
child := Mask{
children: map[string]Mask{},
}
switch curDesc := curNode.descriptor; {
case curDesc.IsMapEntry():
child.descriptor = curDesc.Fields().ByName(protoreflect.Name("value")).Message()
case curNode.isRepeated:
child.descriptor = curDesc
default:
field := curDesc.Fields().ByName(protoreflect.Name(seg))
child.descriptor = field.Message()
child.isRepeated = field.Cardinality() == protoreflect.Repeated
}
curNode.children[seg] = child
}
curNode = curNode.children[seg]
curNodeName = seg
}
}
return root, nil
}
// normalizePaths normalizes parsed paths. Returns a new slice of paths.
//
// Removes trailing stars for all paths, e.g. converts ["a", "*"] to ["a"].
// Removes paths that have a segment prefix already present in paths,
// e.g. removes ["a", "b"] from [["a", "b"], ["a",]].
//
// The result slice is stable and ordered by the number of segments of each
// path. If two paths have same number of segments, break the tie by comparing
// the segments at each index lexicographically.
func normalizePaths(paths []path) []path {
paths = removeTrailingStars(paths)
sort.SliceStable(paths, func(i, j int) bool {
lenI, lenJ := len(paths[i]), len(paths[j])
if lenI == lenJ {
for index, segI := range paths[i] {
if segI == paths[j][index] {
continue
}
return segI < paths[j][index]
}
return true
}
return lenI < lenJ
})
present := stringset.New(len(paths))
delimiter := string(pathDelimiter)
ret := make([]path, 0, len(paths))
PATH_LOOP:
for _, p := range paths {
for i := range p {
if present.Has(strings.Join(p[:i+1], delimiter)) {
continue PATH_LOOP
}
}
ret = append(ret, p)
present.Add(strings.Join(p, delimiter))
}
return ret
}
func removeTrailingStars(paths []path) []path {
ret := make([]path, 0, len(paths))
for _, p := range paths {
if n := len(p); n > 0 && p[n-1] == "*" {
p = p[:n-1]
}
ret = append(ret, p)
}
return ret
}
// Trim clears protobuf message fields that are not in the mask.
//
// If mask is empty, this is a noop. It returns error when the supplied
// message is nil or has a different message descriptor from that of mask.
// It uses Includes to decide what to trim, see its doc.
func (m Mask) Trim(msg proto.Message) error {
if m.IsEmpty() {
return nil
}
reflectMsg := proto.MessageReflect(msg)
if err := checkMsgHaveDesc(reflectMsg, m.descriptor); err != nil {
return err
}
m.trimImpl(reflectMsg)
return nil
}
func (m Mask) trimImpl(reflectMsg protoreflect.Message) {
reflectMsg.Range(func(fieldDesc protoreflect.FieldDescriptor, fieldVal protoreflect.Value) bool {
fieldName := string(fieldDesc.Name())
switch incl, _ := m.includesImpl(path{fieldName}); incl {
case Exclude:
reflectMsg.Clear(fieldDesc)
case IncludePartially:
// child for this field must exist because the path is included partially
switch child := m.children[fieldName]; {
case fieldDesc.IsMap():
child.trimMap(fieldVal.Map(), fieldDesc.MapValue().Kind())
case fieldDesc.Kind() != protoreflect.MessageKind:
// The field is scalar but the mask does not specify to include
// it entirely. Skip it because scalars do not have subfields.
// Note that FromFieldMask would fail on such a mask because a
// scalar field cannot be followed by other fields.
reflectMsg.Clear(fieldDesc)
case fieldDesc.IsList():
// star child is the only possible child for list field
if starChild, ok := child.children["*"]; ok {
for i, list := 0, fieldVal.List(); i < list.Len(); i++ {
starChild.trimImpl(list.Get(i).Message())
}
}
default:
child.trimImpl(fieldVal.Message())
}
}
return true
})
}
func (m Mask) trimMap(protoMap protoreflect.Map, valueKind protoreflect.Kind) {
protoMap.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
keyString := k.String()
switch incl, _ := m.includesImpl(path{keyString}); {
case incl == Exclude:
protoMap.Clear(k)
case incl == IncludePartially && valueKind != protoreflect.MessageKind:
// same reason as comment above that value is scalar
protoMap.Clear(k)
case incl == IncludePartially:
// Mask might not have a child of keyName but it can still partially
// include the key because of star child. So, check both key child
// and star child.
for _, seg := range []string{keyString, "*"} {
if child, ok := m.children[seg]; ok {
child.trimImpl(v.Message())
}
}
}
return true
})
}
// Inclusiveness tells if a field value at the given path is included.
type Inclusiveness int8
const (
// Exclude indicates the field value is excluded.
Exclude Inclusiveness = iota
// IncludePartially indicates some subfields of the field value are included.
IncludePartially
// IncludeEntirely indicates the entire field value is included.
IncludeEntirely
)
// Includes tells the Inclusiveness of a field value at the given path.
//
// The path must have canonical field names, i.e. not JSON names.
// Returns error if path parsing fails.
func (m Mask) Includes(path string) (Inclusiveness, error) {
parsedPath, err := parsePath(path, m.descriptor, false)
if err != nil {
return Exclude, err
}
incl, _ := m.includesImpl(parsedPath)
return incl, nil
}
// includesImpl implements Includes(). It returns the computed inclusiveness
// and the leaf mask that includes the path if IncludeEntirely, or the
// intermediate mask that the last segment of path represents if
// IncludePartially or an empty mask if Exclude.
func (m Mask) includesImpl(p path) (Inclusiveness, Mask) {
if len(m.children) == 0 {
return IncludeEntirely, m
}
if len(p) == 0 {
// This node is intermediate and we've exhausted the path. Some of the
// value's subfields are included, so includes this value partially.
return IncludePartially, m
}
incl, inclMask := Exclude, Mask{}
// star child should also be examined.
// e.g. children are {"a": {"b": {}}, "*": {"c": {}}}
// If seg is 'x', we should check the star child.
for _, seg := range []string{p[0], "*"} {
if child, ok := m.children[seg]; ok {
if cIncl, cInclMask := child.includesImpl(p[1:]); cIncl > incl {
incl, inclMask = cIncl, cInclMask
}
}
}
return incl, inclMask
}
// Merge merges masked fields from src to dest.
//
// If mask is empty, this is a noop. It returns error when one of src or dest
// message is nil or has different message descriptor from that of mask.
// Empty field will be merged as long as they are present in the mask. Repeated
// fields or map fields will be overwritten entirely. Partial updates are not
// supported for such field.
func (m Mask) Merge(src, dest proto.Message) error {
if m.IsEmpty() {
return nil
}
srcReflectMsg := proto.MessageReflect(src)
if err := checkMsgHaveDesc(srcReflectMsg, m.descriptor); err != nil {
return fmt.Errorf("src message: %s", err.Error())
}
destReflectMsg := proto.MessageReflect(dest)
if err := checkMsgHaveDesc(destReflectMsg, m.descriptor); err != nil {
return fmt.Errorf("dest message: %s", err.Error())
}
m.mergeImpl(srcReflectMsg, destReflectMsg)
return nil
}
func (m Mask) mergeImpl(src, dest protoreflect.Message) {
for seg, submask := range m.children {
// star field is not supported for update mask so this won't be nil
fieldDesc := m.descriptor.Fields().ByName(protoreflect.Name(seg))
switch srcVal, kind := src.Get(fieldDesc), fieldDesc.Kind(); {
case fieldDesc.IsList():
newField := dest.NewField(fieldDesc)
srcList, destList := srcVal.List(), newField.List()
for i := 0; i < srcList.Len(); i++ {
destList.Append(cloneValue(srcList.Get(i), kind))
}
dest.Set(fieldDesc, newField)
case fieldDesc.IsMap():
newField := dest.NewField(fieldDesc)
destMap := newField.Map()
mapValKind := fieldDesc.MapValue().Kind()
srcVal.Map().Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
destMap.Set(k, cloneValue(v, mapValKind))
return true
})
dest.Set(fieldDesc, newField)
case len(submask.children) > 0 && fieldDesc.Kind() == protoreflect.MessageKind:
// only singular message field can be merged partially
submask.mergeImpl(srcVal.Message(), dest.Mutable(fieldDesc).Message())
default:
// scalar value
dest.Set(fieldDesc, cloneValue(srcVal, kind))
}
}
}
// cloneValue returns a cloned value for message kind or the same instance of
// input value for all the other kinds (i.e. scalar). List and map value are not
// expected as they have been explicitly handled in mergeImpl.
func cloneValue(v protoreflect.Value, kind protoreflect.Kind) protoreflect.Value {
if kind == protoreflect.MessageKind {
clonedMsg := protoV2.Clone(v.Message().Interface()).ProtoReflect()
return protoreflect.ValueOf(clonedMsg)
}
return v
}
// Submask returns a sub-mask given a path from the received mask to it.
//
// For example, for a mask ["a.b.c"], m.submask("a.b") will return a mask ["c"].
//
// If the received mask includes the path entirely, returns a Mask that includes
// everything. For example, for mask ["a"], m.submask("a.b") returns a mask
// without children.
//
// Returns error if path parsing fails or path is excluded from the received
// mask.
func (m Mask) Submask(path string) (Mask, error) {
ctx := &parseCtx{
curDescriptor: m.descriptor,
isList: m.isRepeated && !(m.descriptor != nil && m.descriptor.IsMapEntry()),
}
parsedPath, err := parsePathWithContext(path, ctx, false)
if err != nil {
return Mask{}, err
}
switch incl, inclMask := m.includesImpl(parsedPath); incl {
case IncludeEntirely:
return Mask{
descriptor: ctx.curDescriptor,
isRepeated: ctx.isList || (ctx.curDescriptor != nil && ctx.curDescriptor.IsMapEntry()),
}, nil
case Exclude:
return Mask{}, fmt.Errorf("the given path %q is excluded from mask", path)
case IncludePartially:
return inclMask, nil
default:
return Mask{}, fmt.Errorf("unknown Inclusiveness: %d", incl)
}
}
// IsEmpty reports whether a mask is of empty value. Such mask implies keeping
// everything when calling Trim, merging nothing when calling Merge and always
// returning IncludeEntirely when calling Includes
func (m Mask) IsEmpty() bool {
return m.descriptor == nil && !m.isRepeated && len(m.children) == 0
}
// checkMsgHaveDesc validates that the descriptor of given proto reflect message
// matches the expected message descriptor. It returns error when the given
// message is nil or descriptor of which doesn't match the expectation.
func checkMsgHaveDesc(msg protoreflect.Message, expectedDesc protoreflect.MessageDescriptor) error {
if msg == nil {
return fmt.Errorf("nil message")
}
if msgDesc := msg.Descriptor(); msgDesc != expectedDesc {
return fmt.Errorf("expected message have descriptor: %s; got descriptor: %s", expectedDesc.FullName(), msgDesc.FullName())
}
return nil
}