-
Notifications
You must be signed in to change notification settings - Fork 0
/
parse_struct_parser.go
143 lines (117 loc) · 3.16 KB
/
parse_struct_parser.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package parser
import (
"bytes"
"go/ast"
"go/printer"
"go/token"
"io"
"log"
"os"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
)
// 为什么要这样写,因为:https://github.com/golang/go/issues/27477
type (
// Parser 解析器
// 解析指定的包导入路径,获取go源码信息
Parser struct {
filter func(os.FileInfo) bool // 过滤器
fset *token.FileSet
useSourceImporter bool // 使用源码importer
replaceImportPath bool // 替换导入路径
fromPath string
toPath string
output io.Writer
op Op // 操作,如生成接口,生成实现等
replaceCallExpr bool
PkgInfo
}
)
func NewParser(opt Option) *Parser {
return &Parser{
op: opt.Op,
filter: opt.Filter,
useSourceImporter: opt.UseSourceImporter,
replaceImportPath: opt.ReplaceImportPath,
fromPath: opt.FromPath,
toPath: opt.ToPath,
output: opt.Output,
replaceCallExpr: opt.ReplaceCallExpr,
}
}
func (p *Parser) GetPkgInfo() PkgInfo {
return p.PkgInfo
}
// ParseByGoPackages 使用x/tools/go/packages解析指定导入路径
func (p *Parser) ParseByGoPackages(patterns ...string) (result Packages, err error) {
cfg := &packages.Config{
Mode: packages.NeedName |
packages.NeedFiles |
packages.NeedCompiledGoFiles |
packages.NeedImports |
packages.NeedDeps |
packages.NeedExportFile |
packages.NeedTypes |
packages.NeedSyntax |
packages.NeedTypesInfo |
packages.NeedTypesSizes |
packages.NeedModule,
}
// pattern可以是文件目录,也可以是包导入路径,如:'~/a/b/c', 'bytes', 'github.com/donnol/tools'...
pkgs, err := packages.Load(cfg, patterns...)
if err != nil {
return
}
result.Patterns = patterns
result.Pkgs = make([]Package, 0, len(pkgs))
inspector := NewInspector(InspectOption{
Parser: p,
})
for _, pkg := range pkgs {
p.fset = pkg.Fset
tmpPkg := inspector.InspectPkg(pkg)
result.Pkgs = append(result.Pkgs, tmpPkg)
}
return
}
func (p *Parser) GetStandardPackages() []string {
pkgs, err := packages.Load(nil, "std")
if err != nil {
panic(err)
}
standardPackages := make([]string, 0, len(pkgs))
for _, p := range pkgs {
standardPackages = append(standardPackages, p.PkgPath)
}
return standardPackages
}
func (p *Parser) replaceFileImportPath(fileName string, file *ast.File) error {
var err error
// 替换import path
for _, fi := range file.Imports {
path := strings.Trim(fi.Path.Value, `"`)
if strings.HasPrefix(path, p.fromPath) {
topath := strings.Replace(path, p.fromPath, p.toPath, 1)
rewrote := astutil.RewriteImport(p.fset, file, path, topath)
log.Printf("From %s to %s, rewrote: %v\n", p.fromPath, p.toPath, rewrote)
}
}
// 获取file的ast内容并格式化
buf := bytes.NewBuffer([]byte{})
printer.Fprint(buf, p.fset, file)
content, err := Format(fileName, buf.String(), true)
if err != nil {
return err
}
// 将内容输出到原文件
output, err := os.OpenFile(fileName, os.O_RDWR|os.O_TRUNC, os.ModePerm)
if err != nil {
return err
}
_, err = output.Write([]byte(content))
if err != nil {
return err
}
return nil
}