/
importer.go
126 lines (116 loc) · 2.61 KB
/
importer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package parser
import (
"errors"
"go/ast"
"go/parser"
"go/token"
"go/types"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"sync"
)
// CREDIT
// Most of this was stolen from https://github.com/ernesto-jimenez/gogen/blob/master/importer/importer.go - Copyright (c) 2015 Ernesto Jiménez
// which has an MIT license
var gopathlistOnce sync.Once
var gopathlistCache []string
func gopathlist() []string {
gopathlistOnce.Do(func() {
gopathlistCache = filepath.SplitList(os.Getenv("GOPATH"))
})
return gopathlistCache
}
// smartImporter is an importer.Importer that looks in the usual places
// for packages.
type smartImporter struct {
base types.Importer
packages map[string]*types.Package
}
func newSmartImporter(base types.Importer) *smartImporter {
return &smartImporter{
base: base,
packages: make(map[string]*types.Package),
}
}
func (i *smartImporter) Import(path string) (*types.Package, error) {
var err error
if path == "" || path[0] == '.' {
path, err = filepath.Abs(filepath.Clean(path))
if err != nil {
return nil, err
}
path = stripGopath(path)
}
if pkg, ok := i.packages[path]; ok {
return pkg, nil
}
pkg, err := i.doimport(path)
if err != nil {
return nil, err
}
i.packages[path] = pkg
return pkg, nil
}
func (i *smartImporter) doimport(p string) (*types.Package, error) {
dir, err := lookupImport(p)
if err != nil {
return i.base.Import(p)
}
dirFiles, err := ioutil.ReadDir(dir)
if err != nil {
return i.base.Import(p)
}
fset := token.NewFileSet()
var files []*ast.File
for _, fileInfo := range dirFiles {
if fileInfo.IsDir() {
continue
}
n := fileInfo.Name()
if path.Ext(fileInfo.Name()) != ".go" {
continue
}
// if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") {
// continue
// }
file := path.Join(dir, n)
src, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
f, err := parser.ParseFile(fset, file, src, 0)
if err != nil {
return nil, err
}
files = append(files, f)
}
conf := types.Config{
Importer: i,
}
pkg, err := conf.Check(p, fset, files, nil)
if err != nil {
return i.base.Import(p)
}
return pkg, nil
}
func lookupImport(p string) (string, error) {
for _, gopath := range gopathlist() {
absPath, err := filepath.Abs(path.Join(gopath, "src", p))
if err != nil {
return "", err
}
if dir, err := os.Stat(absPath); err == nil && dir.IsDir() {
return absPath, nil
}
}
return "", errors.New("not in GOPATH: " + p)
}
func stripGopath(p string) string {
for _, gopath := range gopathlist() {
p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1)
}
return p
}