/
render.go
221 lines (198 loc) · 5.77 KB
/
render.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
// emit.go 主要负责写文件
package gen
import (
"fmt"
"go/format"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"strings"
"unicode/utf8"
"github.com/i2eco/generator/internal/model"
"github.com/i2eco/generator/pkg/arg"
"github.com/flosch/pongo2"
"github.com/pkg/errors"
"github.com/smartwalle/pongo2render"
)
var (
// tpls 存放tmpl文件,key是相对路径
tpls map[string]string
// tplDirs 存放tmpl文件夹,key是文件夹相对路径
tplDirs map[string]bool
)
func init() {
pongo2.RegisterFilter("lowerfirst", lwfirst)
pongo2.RegisterFilter("upperfirst", upperfirst)
tpls = make(map[string]string)
tplDirs = make(map[string]bool)
}
// lwfirst 首字母小写,注意不要和go关键字冲突
func lwfirst(in *pongo2.Value, param *pongo2.Value) (*pongo2.Value, *pongo2.Error) {
if in.Len() <= 0 {
return pongo2.AsValue(""), nil
}
t := in.String()
r, size := utf8.DecodeRuneInString(t)
return pongo2.AsValue(strings.ToLower(string(r)) + t[size:]), nil
}
func upperfirst(in *pongo2.Value, param *pongo2.Value) (*pongo2.Value, *pongo2.Error) {
if in.Len() <= 0 {
return pongo2.AsValue(""), nil
}
t := in.String()
return pongo2.AsValue(strings.Replace(t, string(t[0]), strings.ToUpper(string(t[0])), 1)), nil
}
// loadTmpl 递归地将tmpl文件加载到内存中
func loadTmpl() {
tmplRepoDir := arg.TmplDir
err := filepath.Walk(tmplRepoDir,
func(path string, info os.FileInfo, err error) error {
if info.IsDir() && info.Name() == ".git" {
return filepath.SkipDir
}
if info.IsDir() && path != tmplRepoDir {
relPath, _ := filepath.Rel(tmplRepoDir, path)
tplDirs[relPath] = true
}
if err != nil {
return err
}
b, e := ioutil.ReadFile(path)
if e != nil {
return nil
}
relPath, e := filepath.Rel(tmplRepoDir, path)
if e != nil {
return nil
}
tpls[relPath] = string(b)
return nil
})
if err != nil {
log.Println(err)
}
}
// render 渲染模板
func Render(schemaTpls map[string]model.Table) {
loadTmpl()
camelTableNames := make(map[string]struct{})
for tableName, _ := range schemaTpls {
camelTableNames[snakeToCamel(tableName)] = struct{}{}
}
ctx := pongo2.Context{
"camelTableNames": camelTableNames,
}
render(ctx, schemaTpls)
}
// getPath 替换path中的特殊变量,返回最终生成的go文件路径
func getPath(path string, tableName string) string {
path = strings.ReplaceAll(path, "TABLE_NAME", tableName)
path = strings.ReplaceAll(path, ".go.tmpl", ".gen.go")
return path
}
// loadImports 载入默认imports表
func loadImports() map[string]struct{} {
imports := make(map[string]struct{})
for relPath := range tplDirs {
// 过滤掉dao
if relPath == "dao" {
continue
}
imports[arg.Module+"/"+relPath] = struct{}{}
}
imports[arg.Module+"/pkg/mus"] = struct{}{}
imports[arg.Model+"/mysql"] = struct{}{}
imports[arg.Dao] = struct{}{}
imports[arg.Model+"/trans"] = struct{}{}
imports["go.uber.org/zap"] = struct{}{}
imports["github.com/jinzhu/gorm"] = struct{}{}
imports["strings"] = struct{}{}
imports["time"] = struct{}{}
if arg.Debug == "true" {
fmt.Println("imports module: ", imports)
}
return imports
}
// render 渲染dist文件
func render(ctx pongo2.Context, schemas map[string]model.Table) {
var render = pongo2render.NewRender(arg.TmplDir)
for path, content := range tpls {
var globalImports = loadImports()
// 剔除自己所在的包,防止循环引用
delete(globalImports, arg.Module+"/"+filepath.Dir(path))
for tableName, schema := range schemas {
schema.Imports = globalImports
var hasOpenId, hasDeleteTime bool
for _, value := range schema.Columns {
if value.CamelName == "DeleteTime" {
schema.Imports["time"] = struct{}{}
hasDeleteTime = true
}
}
// todo 莫名其毛有时候有time,有时候没有,估计数据串了
schema.Imports["time"] = struct{}{}
for _, value := range schema.Columns {
if value.CamelName == "OpenId" && value.ColumnKey != "PRI" {
hasOpenId = true
}
}
ctx["tableName"] = tableName
ctx["camelTableName"] = snakeToCamel(tableName)
ctx["lcamelTableName"] = lowerFirst(snakeToCamel(tableName))
ctx["hasOpenId"] = hasOpenId
ctx["hasDeleteTime"] = hasDeleteTime
ctx["imports"] = schema.Imports
ctx["columns"] = schema.Columns
ctx["hasPrimaryKey"] = schema.HasPrimaryKey
ctx["camelPrimaryKey"] = schema.CamelPrimaryKey
ctx["primaryKey"] = schema.PrimaryKey
ctx["primaryKeyType"] = schema.PrimaryKeyType
// 如果包含到dao,就换个目录输出
if strings.Contains(path, "dao/") {
buf, err := render.TemplateFromString(content).Execute(ctx)
if err = write(filepath.Join(arg.OutDao, getPath(strings.Replace(path, "dao/", "", 1), tableName)), buf); err != nil {
log.Panicln("[render] write err: ", err.Error(), path, tableName, buf)
return
}
} else {
buf, err := render.TemplateFromString(content).Execute(ctx)
if err = write(filepath.Join(arg.Out, getPath(path, tableName)), buf); err != nil {
log.Panicln("[render] write err: ", err.Error(), path, tableName, buf)
return
}
}
}
}
}
// write 写bytes到文件
func write(filename string, buf string) (err error) {
filePath := path.Dir(filename)
err = createPath(filePath)
if err != nil {
err = errors.New("write create path " + err.Error())
return
}
file, err := os.Create(filename)
defer file.Close()
if err != nil {
err = errors.New("write create file " + err.Error())
return
}
// 格式化代码
bts, err := format.Source([]byte(buf))
if err != nil {
err = errors.New("format buf error " + err.Error())
return
}
err = ioutil.WriteFile(filename, bts, 0644)
if err != nil {
err = errors.New("write write file " + err.Error())
return
}
if arg.Debug == "true" {
fmt.Println("write file success, file name: ", filename)
}
return
}