forked from utrack/pontoon
/
reghttp.go
127 lines (104 loc) · 2.65 KB
/
reghttp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package main
import (
"fmt"
"go/ast"
"go/types"
"reflect"
"strings"
"github.com/pkg/errors"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
)
// getHandlerNames scans RegisterHTTP's function body and
// returns registered ops/paths/handler function AST.
func (b builder) getHandlerNames(sel *types.Func) ([]hdlPathPtr, error) {
af, err := b.astFindFile(sel.Pos())
if err != nil {
return nil, err
}
srcPkg := sel.Pkg()
reg, exact := astutil.PathEnclosingInterval(af, sel.Scope().Pos(), sel.Scope().End())
if !exact {
return nil, errors.New("cannot find exact func path")
}
funcBody := reg[0].(*ast.BlockStmt)
funcRouterParam := reg[1].(*ast.FuncDecl).Type.Params.List[0]
vis := &visRegHTTP{
muxDecl: funcRouterParam,
pkg: srcPkg,
pkgReg: b.pkg,
}
ast.Walk(vis, funcBody)
return vis.hits, nil
}
type visRegHTTP struct {
muxDecl *ast.Field
hits []hdlPathPtr
pkg *types.Package
pkgReg *packages.Package
}
func (vr *visRegHTTP) Visit(node ast.Node) ast.Visitor {
if node == nil {
return nil
}
cv, ok := node.(*ast.CallExpr)
if !ok {
return vr
}
se, ok := cv.Fun.(*ast.SelectorExpr)
if !ok {
return vr
}
muxIdent, ok := se.X.(*ast.Ident)
if !ok {
return vr
}
if muxIdent.Obj.Decl != vr.muxDecl {
return vr
}
argOp,
argPath,
argHandlerFunc :=
vr.litFromExpr(cv.Args[0]),
vr.litFromExpr(cv.Args[1]),
cv.Args[2].(*ast.SelectorExpr)
vr.hits = append(vr.hits, hdlPathPtr{
op: strings.Trim(argOp.Value, `"`),
path: strings.Trim(argPath.Value, `"`),
fn: argHandlerFunc})
return nil
}
func (vr *visRegHTTP) litFromExpr(ex ast.Node) *ast.BasicLit {
switch v := ex.(type) {
case *ast.BasicLit:
return v
case *ast.Ident:
vv := v.Obj.Decl.(*ast.ValueSpec).Values[0]
return vr.litFromExpr(vv)
case *ast.SelectorExpr: // usually reference from another package
pkg := findImportedPackage(vr.pkg, v.X.(*ast.Ident).Name)
obj := pkg.Scope().Lookup(v.Sel.Name)
f, _ := astFindFile(vr.pkgReg.Imports[pkg.Path()], obj.Pos())
ecl, _ := astutil.PathEnclosingInterval(f, obj.Pos()-1, obj.Pos())
specs := ecl[0].(*ast.GenDecl).Specs
for i, s := range specs {
specName := s.(*ast.ValueSpec).Names[0].Name
if specName == v.Sel.Name {
return vr.litFromExpr(specs[i])
}
}
return vr.litFromExpr(ecl[0].(*ast.GenDecl).Specs[0])
case *ast.ValueSpec:
return vr.litFromExpr(v.Names[0])
default:
panic(fmt.Sprintf("litFromExpr: cannot convert %v (%v) to BasicLit", ex, reflect.TypeOf(ex).String()))
}
}
func findImportedPackage(pkg *types.Package, name string) *types.Package {
for _, p := range pkg.Imports() {
if p.Name() == name {
return p
}
}
return nil
}