forked from golang/vuln
/
witness.go
394 lines (349 loc) · 10.1 KB
/
witness.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vulncheck
import (
"container/list"
"fmt"
"go/token"
"sort"
"strings"
"sync"
)
// ImportChain is a slice of packages where each
// subsequent package is imported by its immediate
// predecessor. The chain starts with a client package
// and ends in a package with some known vulnerabilities.
type ImportChain []*PkgNode
// ImportChains returns a slice of representative import chains for
// each vulnerability in res. The returned chains are ordered
// increasingly by their length.
//
// ImportChains performs a breadth-first search of res.RequireGraph starting
// at a vulnerable package and going up until reaching an entry package
// in res.ImportGraph.Entries. During this search, a package is visited
// only once to avoid analyzing every possible import chain. Hence, not
// all import chains are analyzed.
//
// Note that vulnerabilities from the same package will have the same
// slice of identified import chains.
func ImportChains(res *Result) map[*Vuln][]ImportChain {
// Group vulns per package.
vPerPkg := make(map[int][]*Vuln)
for _, v := range res.Vulns {
vPerPkg[v.ImportSink] = append(vPerPkg[v.ImportSink], v)
}
// Collect chains in parallel for every package path.
var wg sync.WaitGroup
var mu sync.Mutex
chains := make(map[*Vuln][]ImportChain)
for pkgID, vulns := range vPerPkg {
pID := pkgID
vs := vulns
wg.Add(1)
go func() {
pChains := importChains(pID, res)
mu.Lock()
for _, v := range vs {
chains[v] = pChains
}
mu.Unlock()
wg.Done()
}()
}
wg.Wait()
return chains
}
// importChains finds representative chains of package imports
// leading to vulnerable package identified with vulnSinkID.
func importChains(vulnSinkID int, res *Result) []ImportChain {
if vulnSinkID == 0 {
return nil
}
// Entry packages, needed for finalizing chains.
entries := make(map[int]bool)
for _, e := range res.Imports.Entries {
entries[e] = true
}
var chains []ImportChain
seen := make(map[int]bool)
queue := list.New()
queue.PushBack(&importChain{pkg: res.Imports.Packages[vulnSinkID]})
for queue.Len() > 0 {
front := queue.Front()
c := front.Value.(*importChain)
queue.Remove(front)
pkg := c.pkg
if seen[pkg.ID] {
continue
}
seen[pkg.ID] = true
for _, impBy := range pkg.ImportedBy {
imp := res.Imports.Packages[impBy]
newC := &importChain{pkg: imp, child: c}
// If the next package is an entry, we have
// a chain to report.
if entries[imp.ID] {
chains = append(chains, newC.ImportChain())
}
queue.PushBack(newC)
}
}
return chains
}
// importChain models an chain of package imports.
type importChain struct {
pkg *PkgNode
child *importChain
}
// ImportChain converts importChain to ImportChain type.
func (r *importChain) ImportChain() ImportChain {
if r == nil {
return nil
}
return append([]*PkgNode{r.pkg}, r.child.ImportChain()...)
}
// CallStack is a call stack starting with a client
// function or method and ending with a call to a
// vulnerable symbol.
type CallStack []StackEntry
// StackEntry is an element of a call stack.
type StackEntry struct {
// Function whose frame is on the stack.
Function *FuncNode
// Call is the call site inducing the next stack frame.
// nil when the frame represents the last frame in the stack.
Call *CallSite
}
// CallStacks returns representative call stacks for each
// vulnerability in res. The returned call stacks are heuristically
// ordered by how seemingly easy is to understand them: shorter
// call stacks with less dynamic call sites appear earlier in the
// returned slices.
//
// CallStacks performs a breadth-first search of res.CallGraph starting
// at the vulnerable symbol and going up until reaching an entry
// function or method in res.CallGraph.Entries. During this search,
// each function is visited at most once to avoid potential
// exponential explosion. Hence, not all call stacks are analyzed.
func CallStacks(res *Result) map[*Vuln][]CallStack {
var (
wg sync.WaitGroup
mu sync.Mutex
)
stacksPerVuln := make(map[*Vuln][]CallStack)
for _, vuln := range res.Vulns {
vuln := vuln
wg.Add(1)
go func() {
cs := callStacks(vuln.CallSink, res)
// sort call stacks by the estimated value to the user
sort.SliceStable(cs, func(i int, j int) bool { return stackLess(cs[i], cs[j]) })
mu.Lock()
stacksPerVuln[vuln] = cs
mu.Unlock()
wg.Done()
}()
}
wg.Wait()
return stacksPerVuln
}
// callStacks finds representative call stacks
// for vulnerable symbol identified with vulnSinkID.
func callStacks(vulnSinkID int, res *Result) []CallStack {
if vulnSinkID == 0 {
return nil
}
entries := make(map[int]bool)
for _, e := range res.Calls.Entries {
entries[e] = true
}
var stacks []CallStack
seen := make(map[int]bool)
queue := list.New()
queue.PushBack(&callChain{f: res.Calls.Functions[vulnSinkID]})
for queue.Len() > 0 {
front := queue.Front()
c := front.Value.(*callChain)
queue.Remove(front)
f := c.f
if seen[f.ID] {
continue
}
seen[f.ID] = true
// Pick a single call site for each function in determinstic order.
// A single call site is sufficient as we visit a function only once.
for _, cs := range callsites(f.CallSites, res, seen) {
caller := res.Calls.Functions[cs.Parent]
nStack := &callChain{f: caller, call: cs, child: c}
if entries[caller.ID] {
stacks = append(stacks, nStack.CallStack())
}
queue.PushBack(nStack)
}
}
return stacks
}
// callsites picks a call site from sites for each non-visited function.
// For each such function, the smallest (posLess) call site is chosen. The
// returned slice is sorted by caller functions (funcLess). Assumes callee
// of each call site is the same.
func callsites(sites []*CallSite, result *Result, visited map[int]bool) []*CallSite {
minCs := make(map[int]*CallSite)
for _, cs := range sites {
if visited[cs.Parent] {
continue
}
if csLess(cs, minCs[cs.Parent]) {
minCs[cs.Parent] = cs
}
}
var fs []*FuncNode
for id := range minCs {
fs = append(fs, result.Calls.Functions[id])
}
sort.SliceStable(fs, func(i, j int) bool { return funcLess(fs[i], fs[j]) })
var css []*CallSite
for _, f := range fs {
css = append(css, minCs[f.ID])
}
return css
}
// callChain models a chain of function calls.
type callChain struct {
call *CallSite // nil for entry points
f *FuncNode
child *callChain
}
// CallStack converts callChain to CallStack type.
func (c *callChain) CallStack() CallStack {
if c == nil {
return nil
}
return append(CallStack{StackEntry{Function: c.f, Call: c.call}}, c.child.CallStack()...)
}
// weight computes an approximate measure of how easy is to understand the call
// stack when presented to the client as a witness. The smaller the value, the more
// understandable the stack is. Currently defined as the number of unresolved
// call sites in the stack.
func weight(stack CallStack) int {
w := 0
for _, e := range stack {
if e.Call != nil && !e.Call.Resolved {
w += 1
}
}
return w
}
func isStdPackage(pkg string) bool {
if pkg == "" {
return false
}
// std packages do not have a "." in their path. For instance, see
// Contains in pkgsite/+/refs/heads/master/internal/stdlbib/stdlib.go.
if i := strings.IndexByte(pkg, '/'); i != -1 {
pkg = pkg[:i]
}
return !strings.Contains(pkg, ".")
}
// confidence computes an approximate measure of whether the stack
// is realizeable in practice. Currently, it equals the number of call
// sites in stack that go through standard libraries. Such call stacks
// have been experimentally shown to often result in false positives.
func confidence(stack CallStack) int {
c := 0
for _, e := range stack {
if isStdPackage(e.Function.PkgPath) {
c += 1
}
}
return c
}
// stackLess compares two call stacks in terms of their estimated
// value to the user. Shorter stacks generally come earlier in the ordering.
//
// Two stacks are lexicographically ordered by:
// 1) their estimated level of confidence in being a real call stack,
// 2) their length, and 3) the number of dynamic call sites in the stack.
func stackLess(s1, s2 CallStack) bool {
if c1, c2 := confidence(s1), confidence(s2); c1 != c2 {
return c1 < c2
}
if len(s1) != len(s2) {
return len(s1) < len(s2)
}
if w1, w2 := weight(s1), weight(s2); w1 != w2 {
return w1 < w2
}
// At this point, the stableness/determinism of
// sorting is guaranteed by the determinism of
// the underlying call graph and the call stack
// search algorithm.
return true
}
// csLess compares two call sites by their locations and, if needed,
// their string representation.
func csLess(cs1, cs2 *CallSite) bool {
if cs2 == nil {
return true
}
// fast code path
if p1, p2 := cs1.Pos, cs2.Pos; p1 != nil && p2 != nil {
if posLess(*p1, *p2) {
return true
}
if posLess(*p2, *p1) {
return false
}
// for sanity, should not occur in practice
return fmt.Sprintf("%v.%v", cs1.RecvType, cs2.Name) < fmt.Sprintf("%v.%v", cs2.RecvType, cs2.Name)
}
// code path rarely exercised
if cs2.Pos == nil {
return true
}
if cs1.Pos == nil {
return false
}
// should very rarely occur in practice
return fmt.Sprintf("%v.%v", cs1.RecvType, cs2.Name) < fmt.Sprintf("%v.%v", cs2.RecvType, cs2.Name)
}
// posLess compares two positions by their line and column number,
// and filename if needed.
func posLess(p1, p2 token.Position) bool {
if p1.Line < p2.Line {
return true
}
if p2.Line < p1.Line {
return false
}
if p1.Column < p2.Column {
return true
}
if p2.Column < p1.Column {
return false
}
return strings.Compare(p1.Filename, p2.Filename) == -1
}
// funcLess compares two function nodes by locations of
// corresponding functions and, if needed, their string representation.
func funcLess(f1, f2 *FuncNode) bool {
if p1, p2 := f1.Pos, f2.Pos; p1 != nil && p2 != nil {
if posLess(*p1, *p2) {
return true
}
if posLess(*p2, *p1) {
return false
}
// for sanity, should not occur in practice
return f1.String() < f2.String()
}
if f2.Pos == nil {
return true
}
if f1.Pos == nil {
return false
}
// should happen only for inits
return f1.String() < f2.String()
}