/
visitor.go
191 lines (161 loc) · 3.84 KB
/
visitor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
package defercheck
import (
"go/ast"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/cfg"
)
type visitor struct {
pass *analysis.Pass
fType *ast.FuncType
// Keep track of the captured variables when we visit each block.
// If the captured variables change, we'll need to reanalyze.
visitedBlocks map[int32]visitorState
}
func newVisitor(pass *analysis.Pass, fType *ast.FuncType) *visitor {
return &visitor{
pass: pass,
fType: fType,
visitedBlocks: make(map[int32]visitorState),
}
}
func (v *visitor) Analyze(cfg *cfg.CFG) {
if len(cfg.Blocks) == 0 {
return
}
v.analyzeBlock(cfg.Blocks[0], visitorState{})
}
// Traverse the CFG, analyzing the blocks in order.
func (v *visitor) analyzeBlock(b *cfg.Block, state visitorState) {
lastVisit, ok := v.visitedBlocks[b.Index]
if ok && lastVisit.Equals(state) {
// We already visited this block with the same visitorState, so
// we can skip this.
return
}
v.visitedBlocks[b.Index] = state
for _, n := range b.Nodes {
state = v.analyzeNode(n, state)
}
for _, succ := range b.Succs {
v.analyzeBlock(succ, state)
}
}
// Traverse a node. We want to check for two things:
// 1) If the node has a defer that evals vars, and
// 2) If the node assigns any vars
func (v *visitor) analyzeNode(n ast.Node, state visitorState) visitorState {
ast.Inspect(n, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit, *ast.FuncDecl:
return false
case *ast.DeferStmt:
state = v.analyzeDefer(n, state)
return false
case *ast.ReturnStmt:
if v.fType.Results == nil {
return false
}
for _, retField := range v.fType.Results.List {
if len(retField.Names) != 1 {
continue
}
ident := retField.Names[0]
obj, ok := v.pass.TypesInfo.Defs[ident]
if !ok {
continue
}
evalNode, evaled := state.deferEvaledVars.Use(obj)
if !evaled {
continue
}
v.pass.Reportf(evalNode.Pos(),
"variable %s evaluated by defer, then returned later", ident.Name)
}
return false
case *ast.AssignStmt:
for _, lhs := range n.Lhs {
ident, isIdent := lhs.(*ast.Ident)
if !isIdent {
continue
}
obj, ok := v.pass.TypesInfo.Uses[ident]
if !ok {
continue
}
evalNode, evaled := state.deferEvaledVars.Use(obj)
if !evaled {
continue
}
v.pass.Reportf(evalNode.Pos(),
"variable %s evaluated by defer, then reassigned later", ident.Name)
}
return false
default:
return true
}
})
return state
}
func (v *visitor) analyzeDefer(n ast.Node, state visitorState) visitorState {
ast.Inspect(n, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit, *ast.FuncDecl:
return false
case *ast.Ident:
obj, ok := v.pass.TypesInfo.Uses[n]
if !ok {
return false
}
state = visitorState{
state.deferEvaledVars.Add(obj, n),
}
return false
default:
return true
}
})
return state
}
type visitorState struct {
deferEvaledVars varSet
}
func (a visitorState) Equals(b visitorState) bool {
return a.deferEvaledVars.Equals(b.deferEvaledVars)
}
// Immutable list of vars captured in defers, keyed by Object ID
type varSet map[string]ast.Node
func (vs varSet) Contains(obj types.Object) bool {
_, exists := vs[obj.Id()]
return exists
}
func (vs varSet) Use(obj types.Object) (ast.Node, bool) {
n, exists := vs[obj.Id()]
return n, exists
}
func (vs varSet) Add(obj types.Object, n ast.Node) varSet {
_, exists := vs[obj.Id()]
if exists {
return vs
}
newMap := make(map[string]ast.Node, len(vs)+1)
for k, v := range vs {
newMap[k] = v
}
newMap[obj.Id()] = n
return varSet(newMap)
}
func (a varSet) Equals(b varSet) bool {
if len(a) != len(b) {
return false
}
// It's good enough to compare the keys
// because they are IDs.
for k := range a {
_, exists := b[k]
if !exists {
return false
}
}
return true
}