forked from go-kratos/kratos
/
common.go
151 lines (139 loc) · 3.4 KB
/
common.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package pkg
import (
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"log"
"os"
"regexp"
"strings"
)
// Source source
type Source struct {
Fset *token.FileSet
Src string
F *ast.File
}
// NewSource new source
func NewSource(src string) *Source {
s := &Source{
Fset: token.NewFileSet(),
Src: src,
}
f, err := parser.ParseFile(s.Fset, "", src, 0)
if err != nil {
log.Fatal("无法解析源文件")
}
s.F = f
return s
}
// ExprString expr string
func (s *Source) ExprString(typ ast.Expr) string {
fset := s.Fset
s1 := fset.Position(typ.Pos()).Offset
s2 := fset.Position(typ.End()).Offset
return s.Src[s1:s2]
}
// pkgPath package path
func (s *Source) pkgPath(name string) (res string) {
for _, im := range s.F.Imports {
if im.Name != nil && im.Name.Name == name {
return im.Path.Value
}
}
for _, im := range s.F.Imports {
if strings.HasSuffix(im.Path.Value, name+"\"") {
return im.Path.Value
}
}
return
}
// GetDef get define code
func (s *Source) GetDef(name string) string {
c := s.F.Scope.Lookup(name).Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType)
s1 := s.Fset.Position(c.Pos()).Offset
s2 := s.Fset.Position(c.End()).Offset
line := s.Fset.Position(c.Pos()).Line
lines := []string{strings.Split(s.Src, "\n")[line-1]}
for _, l := range strings.Split(s.Src[s1:s2], "\n")[1:] {
lines = append(lines, "\t"+l)
}
return strings.Join(lines, "\n")
}
// RegexpReplace replace regexp
func RegexpReplace(reg, src, temp string) string {
result := []byte{}
pattern := regexp.MustCompile(reg)
for _, submatches := range pattern.FindAllStringSubmatchIndex(src, -1) {
result = pattern.ExpandString(result, temp, src, submatches)
}
return string(result)
}
// formatPackage format package
func formatPackage(name, path string) (res string) {
if path != "" {
if strings.HasSuffix(path, name+"\"") {
res = path
return
}
res = fmt.Sprintf("%s %s", name, path)
}
return
}
// SourceText get source file text
func SourceText() string {
file := os.Getenv("GOFILE")
data, err := ioutil.ReadFile(file)
if err != nil {
log.Fatal("请使用go generate执行", file)
}
return string(data)
}
// FormatCode format code
func FormatCode(source string) string {
src, err := format.Source([]byte(source))
if err != nil {
// Should never happen, but can arise when developing this code.
// The user can compile the output to see the error.
log.Printf("warning: 输出文件不合法: %s", err)
log.Printf("warning: 详细错误请编译查看")
return source
}
return string(src)
}
// Packages get import packages
func (s *Source) Packages(f *ast.Field) (res []string) {
fs := f.Type.(*ast.FuncType).Params.List
if f.Type.(*ast.FuncType).Results != nil {
fs = append(fs, f.Type.(*ast.FuncType).Results.List...)
}
var types []string
resMap := make(map[string]bool)
for _, field := range fs {
if p, ok := field.Type.(*ast.MapType); ok {
types = append(types, s.ExprString(p.Key))
types = append(types, s.ExprString(p.Value))
} else if p, ok := field.Type.(*ast.ArrayType); ok {
types = append(types, s.ExprString(p.Elt))
} else {
types = append(types, s.ExprString(field.Type))
}
}
for _, t := range types {
name := RegexpReplace(`(?P<pkg>\w+)\.\w+`, t, "$pkg")
if name == "" {
continue
}
pkg := formatPackage(name, s.pkgPath(name))
if !resMap[pkg] {
resMap[pkg] = true
}
}
for pkg := range resMap {
res = append(res, pkg)
}
return
}