This repository has been archived by the owner on Jul 10, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
96 lines (84 loc) · 2.27 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//go:generate esc -o templates.go templates
package main
import (
"flag"
"log"
"net/http"
"path"
"strings"
"github.com/huangjunwen/sqlw/datasrc"
_ "github.com/huangjunwen/sqlw/datasrc/drivers/mysql"
"github.com/huangjunwen/sqlw/render"
)
type commaSeperatd []string
func (cs *commaSeperatd) String() string {
return strings.Join(*cs, ",")
}
func (cs *commaSeperatd) Set(s string) error {
*cs = strings.Split(s, ",")
return nil
}
var (
driverName string
dataSourceName string
outputDir string
outputPkg string
stmtDir string
tmplDir string
whitelist commaSeperatd
blacklist commaSeperatd
)
func main() {
// Parse flags.
flag.StringVar(&driverName, "driver", "mysql", "Driver name. (e.g. 'mysql')")
flag.StringVar(&dataSourceName, "dsn", "root:123456@tcp(localhost:3306)/dev?parseTime=true", "Data source name. ")
flag.StringVar(&outputDir, "out", "models", "Output directory for generated code.")
flag.StringVar(&outputPkg, "pkg", "", "Alternative package name of the generated code.")
flag.StringVar(&stmtDir, "stmt", "", "Statement xml directory.")
flag.StringVar(&tmplDir, "tmpl", "", "Custom templates directory.")
flag.Var(&whitelist, "whitelist", "Comma seperated table names to render.")
flag.Var(&blacklist, "blacklist", "Comma seperated table names not to render.")
flag.Parse()
if driverName == "" {
log.Fatalf("Missing -driver")
}
if dataSourceName == "" {
log.Fatalf("Missing -dsn")
}
// Create loader.
loader, err := datasrc.NewLoader(driverName, dataSourceName)
if err != nil {
log.Fatal(err)
}
defer loader.Close()
// Choose template.
fs := http.FileSystem(nil)
if tmplDir != "" && tmplDir[0] != '@' {
fs = http.Dir(tmplDir)
} else {
prefix := loader.DriverName()
if tmplDir != "" {
// tmplDir is starts with '@'
prefix = tmplDir[1:]
}
fs = newPrefixFS(path.Join("/templates", prefix), FS(false))
}
// Create Renderer.
renderer, err := render.NewRenderer(
render.Loader(loader),
render.OutputDir(outputDir),
render.OutputPkg(outputPkg),
render.StmtDir(stmtDir),
render.TmplFS(fs),
render.Whitelist([]string(whitelist)),
render.Blacklist([]string(blacklist)),
)
if err != nil {
log.Fatal(err)
}
// Run!
if err := renderer.Run(); err != nil {
log.Fatal(err)
}
return
}