forked from bufbuild/protocompile
-
Notifications
You must be signed in to change notification settings - Fork 0
/
walk.go
176 lines (159 loc) · 4.55 KB
/
walk.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package ast
import (
"fmt"
"google.golang.org/protobuf/reflect/protopath"
"google.golang.org/protobuf/reflect/protorange"
"google.golang.org/protobuf/reflect/protoreflect"
)
// WalkOption represents an option used with the Walk function. These
// allow optional before and after hooks to be invoked as each node in
// the tree is visited.
type WalkOption func(*walkOptions)
type walkOptions struct {
before func(protopath.Values) error
after func(protopath.Values) error
hasRangeRequirement bool
start, end Token
depthLimit int
hasIntersectionRequirement bool
intersects Token
}
// WithBefore returns a WalkOption that will cause the given function to be
// invoked before a node is visited during a walk operation. If this hook
// returns an error, the node is not visited and the walk operation is aborted.
func WithBefore(fn func(protopath.Values) error) WalkOption {
return func(options *walkOptions) {
options.before = fn
}
}
// WithAfter returns a WalkOption that will cause the given function to be
// invoked after a node (as well as any descendants) is visited during a walk
// operation. If this hook returns an error, the node is not visited and the
// walk operation is aborted.
//
// If the walk is aborted due to some other visitor or before hook returning an
// error, the after hook is still called for all nodes that have been visited.
// However, the walk operation fails with the first error it encountered, so any
// error returned from an after hook is effectively ignored.
func WithAfter(fn func(protopath.Values) error) WalkOption {
return func(options *walkOptions) {
options.after = fn
}
}
func WithRange(start, end Token) WalkOption {
return func(options *walkOptions) {
options.hasRangeRequirement = true
options.start = start
options.end = end
}
}
func WithIntersection(intersects Token) WalkOption {
return func(options *walkOptions) {
options.hasIntersectionRequirement = true
options.intersects = intersects
}
}
func WithDepthLimit(limit int) WalkOption {
return func(options *walkOptions) {
options.depthLimit = limit
}
}
// Inspect traverses an AST in depth-first order: It starts by calling
// f(node); node must not be nil. If f returns true, Inspect invokes f
// recursively for each of the non-nil children of node.
func Inspect(node Node, visit func(Node) bool, opts ...WalkOption) {
wOpts := walkOptions{
depthLimit: 32,
}
for _, opt := range opts {
opt(&wOpts)
}
check := func(v protopath.Values) (kind protoreflect.Kind, isList bool, err error) {
top := v.Index(-1)
switch top.Step.Kind() {
case protopath.RootStep:
return protoreflect.MessageKind, false, nil
case protopath.FieldAccessStep:
fd := top.Step.FieldDescriptor()
if fd.IsExtension() {
return 0, false, protorange.Break
}
return fd.Kind(), fd.IsList(), nil
case protopath.ListIndexStep:
// for list indexes, visit only if the list type is concrete and
// not an extension
prev := v.Index(-2)
switch prev.Step.Kind() {
case protopath.FieldAccessStep:
fd := prev.Step.FieldDescriptor()
return fd.Kind(), false, nil
}
default:
panic(fmt.Sprintf("ast.Inspect: invalid step kind %q in path: %s"+top.Step.Kind().String(), v.Path))
}
return
}
shouldBreak := false
protorange.Options{
Stable: true,
}.Range(
node.ProtoReflect(),
func(v protopath.Values) error {
if len(v.Path) > wOpts.depthLimit {
return protorange.Break
}
if shouldBreak {
return protorange.Break
}
kind, isList, err := check(v)
if err != nil {
return err
}
if kind != protoreflect.MessageKind {
return nil
}
if wOpts.before != nil {
if err := wOpts.before(v); err != nil {
if err == protorange.Break {
return nil
}
return err
}
}
if isList {
return nil
}
node := v.Index(-1).Value.Message().Interface().(Node)
canVisit := true
if wOpts.hasRangeRequirement {
if node.Start() > wOpts.end || node.End() < wOpts.start {
canVisit = false
}
}
if canVisit && wOpts.hasIntersectionRequirement {
if node.Start() > wOpts.intersects || node.End() < wOpts.intersects {
canVisit = false
}
}
if canVisit {
if !visit(node) {
shouldBreak = true
}
}
return nil
},
func(v protopath.Values) error {
if shouldBreak {
shouldBreak = false
return nil // intentional - don't override error returned from push()
}
if wOpts.after != nil {
kind, _, err := check(v)
if err == nil && kind == protoreflect.MessageKind {
wOpts.after(v)
}
}
return nil
},
)
}