/
extractor.go
153 lines (129 loc) · 3.38 KB
/
extractor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package pkgextract
import (
"bufio"
"container/list"
"go/ast"
"go/build"
"go/parser"
"go/printer"
"go/token"
"os"
"path/filepath"
"strconv"
)
type rewriterFunc func(importPath string) string
type filterFunc func(importPath string) bool
// NewPackageExtractor builds a new package extractor, given a set of options
func NewPackageExtractor(opts PackageExtractorOptions) *PackageExtractor {
return &PackageExtractor{
PackageExtractorOptions: opts,
importMap: make(map[string]string),
}
}
// PackageExtractor extracts a set of dependent packages
type PackageExtractor struct {
PackageExtractorOptions
importMap map[string]string
}
// PackageExtractorOptions let you configure the package extractor
type PackageExtractorOptions struct {
InitialPkg string
ImportRewriter rewriterFunc
OutputRoot string
Filter filterFunc
}
// Run executes the package extraction process
func (pe *PackageExtractor) Run() error {
pe.importMap[pe.InitialPkg] = pe.ImportRewriter(pe.InitialPkg)
pe.scanForPackages()
for originalPath := range pe.importMap {
if err := pe.extractPackage(originalPath); err != nil {
return err
}
}
return nil
}
// scanForPackages finds packages to be extracted, and generates new import
// paths for them
func (pe *PackageExtractor) scanForPackages() error {
buildCtx := build.Default
toScan := list.New()
toScan.PushBack(pe.InitialPkg)
seen := map[string]bool{pe.InitialPkg: true}
for toScan.Len() > 0 {
pkgPath := toScan.Front()
toScan.Remove(pkgPath)
pkg, err := buildCtx.Import(pkgPath.Value.(string), "", 0)
if err != nil {
return err
}
for _, impPath := range pkg.Imports {
if _, ok := seen[impPath]; ok {
continue
}
if pe.Filter(impPath) {
pe.importMap[impPath] = pe.ImportRewriter(impPath)
toScan.PushBack(impPath)
}
seen[impPath] = true
}
}
return nil
}
// extractPackage extracts a given package to the output directory
func (pe *PackageExtractor) extractPackage(importPath string) error {
ctx := build.Default
pkg, err := ctx.Import(importPath, "", 0)
if err != nil {
return err
}
newImportPath := pe.ImportRewriter(importPath)
os.MkdirAll(filepath.Join(pe.OutputRoot, newImportPath), 0755)
srcDir := filepath.Join(pkg.SrcRoot, importPath)
for _, f := range pkg.GoFiles {
srcPath := filepath.Join(srcDir, f)
dstPath := filepath.Join(pe.OutputRoot, newImportPath, f)
err := pe.extractFile(srcPath, dstPath)
if err != nil {
return err
}
}
return nil
}
// extractFile extracts a file to the output directory, rewriting any imports
// that reference other extracted packages
func (pe *PackageExtractor) extractFile(srcPath, dstPath string) error {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, srcPath, nil, parser.ParseComments)
if err != nil {
return err
}
if err := pe.updateImports(f); err != nil {
return err
}
out, err := os.Create(dstPath)
if err != nil {
return err
}
defer out.Close()
w := bufio.NewWriter(out)
if err := printer.Fprint(w, fset, f); err != nil {
return err
}
if err := w.Flush(); err != nil {
return err
}
return nil
}
func (pe *PackageExtractor) updateImports(f *ast.File) error {
for _, impSpec := range f.Imports {
impPath, err := strconv.Unquote(impSpec.Path.Value)
if err != nil {
return err
}
if newPath, ok := pe.importMap[impPath]; ok {
impSpec.Path.Value = strconv.Quote(newPath)
}
}
return nil
}