forked from src-d/go-parse-utils
/
ast.go
67 lines (54 loc) · 1.6 KB
/
ast.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package parseutil
import (
"errors"
"go/ast"
"go/parser"
"go/token"
"strings"
)
// ErrTooManyPackages is returned when there is more than one package in a
// directory where there should only be one Go package.
var ErrTooManyPackages = errors.New("more than one package found in a directory")
// PackageAST returns the AST of the package at the given path.
func PackageAST(path string) (pkg *ast.Package, err error) {
return parseAndFilterPackages(path, func(k string, v *ast.Package) bool {
return !strings.HasSuffix(k, "_test")
})
}
// PackageTestAST returns the AST of the test package at the given path.
func PackageTestAST(path string) (pkg *ast.Package, err error) {
return parseAndFilterPackages(path, func(k string, v *ast.Package) bool {
return strings.HasSuffix(k, "_test")
})
}
type packageFilter func(string, *ast.Package) bool
// filteredPackages filters the parsed packages and then makes sure there is only
// one left.
func parseAndFilterPackages(path string, filter packageFilter) (pkg *ast.Package, err error) {
fset := token.NewFileSet()
srcDir, err := DefaultGoPath.Abs(path)
if err != nil {
return nil, err
}
pkgs, err := parser.ParseDir(fset, srcDir, nil, parser.ParseComments)
if err != nil {
return nil, err
}
pkgs = filterPkgs(pkgs, filter)
if len(pkgs) != 1 {
return nil, ErrTooManyPackages
}
for _, p := range pkgs {
pkg = p
}
return
}
func filterPkgs(pkgs map[string]*ast.Package, filter packageFilter) map[string]*ast.Package {
filtered := make(map[string]*ast.Package)
for k, v := range pkgs {
if filter(k, v) {
filtered[k] = v
}
}
return filtered
}