/
helpers.go
129 lines (111 loc) · 2.73 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package asthelpers
import (
"go/ast"
"go/types"
"github.com/pkg/errors"
"golang.org/x/tools/go/packages"
)
func GetPackage(pkgPath string) (*packages.Package, error) {
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo,
}
pkgs, err := packages.Load(cfg, pkgPath)
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, errors.Errorf("could not find package %s", pkgPath)
}
return pkgs[0], nil
}
func FindInterface(name string, files []*ast.File) (*ast.InterfaceType, error) {
iface, _, err := FindInterfaceWithIdent(name, files)
return iface, err
}
func FindInterfaceWithIdent(name string, files []*ast.File) (*ast.InterfaceType, *ast.Ident, error) {
var (
ident *ast.Ident
iface *ast.InterfaceType
)
for _, f := range files {
ast.Inspect(f, func(n ast.Node) bool {
if t, ok := n.(*ast.TypeSpec); ok {
if iface != nil {
return false
}
if i, ok := t.Type.(*ast.InterfaceType); ok && t.Name.Name == name {
ident = t.Name
iface = i
return false
}
}
return true
})
if iface != nil {
return iface, ident, nil
}
}
return nil, nil, errors.Errorf("could not find %s interface", name)
}
func FindMethodsCalledOnType(info *types.Info, typ types.Type, caller *ast.FuncDecl) []string {
var methods []string
ast.Inspect(caller, func(n ast.Node) bool {
if s, ok := n.(*ast.SelectorExpr); ok {
var receiver *ast.Ident
switch r := s.X.(type) {
case *ast.Ident:
// Left-hand side of the selector is an identifier, eg:
//
// a := p.API
// a.GetTeams()
//
receiver = r
case *ast.SelectorExpr:
// Left-hand side of the selector is a selector, eg:
//
// p.API.GetTeams()
//
receiver = r.Sel
}
if receiver != nil {
obj := info.ObjectOf(receiver)
if obj != nil && types.Identical(obj.Type(), typ) {
methods = append(methods, s.Sel.Name)
}
return false
}
}
return true
})
return methods
}
func FindReceiverMethods(receiverName string, files []*ast.File) []*ast.FuncDecl {
var fns []*ast.FuncDecl
for _, f := range files {
ast.Inspect(f, func(n ast.Node) bool {
if fn, ok := n.(*ast.FuncDecl); ok {
r := extractReceiverTypeName(fn)
if r == receiverName {
fns = append(fns, fn)
}
}
return true
})
}
return fns
}
func extractReceiverTypeName(fn *ast.FuncDecl) string {
if fn.Recv != nil {
t := fn.Recv.List[0].Type
// Unwrap the pointer type (a star expression)
if se, ok := t.(*ast.StarExpr); ok {
t = se.X
}
if id, ok := t.(*ast.Ident); ok {
return id.Name
}
}
return ""
}