forked from zhufuyi/sponge
/
main.go
157 lines (129 loc) · 4.36 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Package main is to generate *.go(tmpl), *_client_test.go, *_rpc.go files.
package main
import (
"bytes"
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/github-tree/sponge/cmd/protoc-gen-go-rpc-tmpl/internal/generate/service"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/pluginpb"
)
const (
helpInfo = `
# generate *.go file
protoc --proto_path=. --proto_path=./third_party --go-rpc-tmpl_out=. --go-rpc-tmpl_opt=paths=source_relative \
--go-rpc-tmpl_opt=moduleName=yourModuleName --go-rpc-tmpl_opt=serverName=yourServerName *.proto
Note:
If you want to merge the code, after generating the code, execute the command "sponge merge rpc-pb",
you don't worry about it affecting the logic code you have already written, in case of accidents,
you can find the pre-merge code in the directory /tmp/sponge_merge_backup_code.
`
optErrFormat = `--go-rpc-tmpl_opt error, '%s' cannot be empty.
Usage example:
protoc --proto_path=. --proto_path=./third_party \
--go-rpc-tmpl_out=. --go-rpc-tmpl_opt=paths=source_relative \
--go-rpc-tmpl_opt=moduleName=yourModuleName --go-rpc-tmpl_opt=serverName=yourServerName \
*.proto
`
)
func main() {
var h bool
flag.BoolVar(&h, "h", false, "help information")
flag.Parse()
if h {
fmt.Printf("%s", helpInfo)
return
}
var flags flag.FlagSet
var moduleName, serverName, tmplDir, ecodeOut string
flags.StringVar(&moduleName, "moduleName", "", "module name")
flags.StringVar(&serverName, "serverName", "", "server name")
flags.StringVar(&tmplDir, "tmplDir", "internal/service", "rpc template file directory, the default value is internal/service")
flags.StringVar(&ecodeOut, "ecodeOut", "internal/ecode", "rpc error code file directory, the default value is internal/ecode")
options := protogen.Options{
ParamFunc: flags.Set,
}
options.Run(func(gen *protogen.Plugin) error {
gen.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
for _, f := range gen.Files {
if !f.Generate {
continue
}
err := saveRPCTmplFiles(f, moduleName, serverName, tmplDir, ecodeOut)
if err != nil {
continue // skip error, process the next protobuf file
}
}
return nil
})
}
func saveRPCTmplFiles(f *protogen.File, moduleName string, serverName string, tmplOut string, ecodeOut string) error {
filenamePrefix := f.GeneratedFilenamePrefix
tmplFileContent, testTmplFileContent, ecodeFileContent := service.GenerateFiles(filenamePrefix, f)
filePath := filenamePrefix + ".go"
err := saveFile(moduleName, serverName, tmplOut, filePath, tmplFileContent, false)
if err != nil {
return err
}
filePath = filenamePrefix + "_client_test.go"
err = saveFile(moduleName, serverName, tmplOut, filePath, testTmplFileContent, true)
if err != nil {
return err
}
filePath = filenamePrefix + "_rpc.go"
err = saveFileSimple(ecodeOut, filePath, ecodeFileContent, false)
if err != nil {
return err
}
return nil
}
func saveFile(moduleName string, serverName string, out string, filePath string, content []byte, isNeedCovered bool) error {
if len(content) == 0 {
return nil
}
if moduleName == "" {
panic(fmt.Sprintf(optErrFormat, "moduleName"))
}
if serverName == "" {
panic(fmt.Sprintf(optErrFormat, "serverName"))
}
_ = os.MkdirAll(out, 0766)
_, name := filepath.Split(filePath)
file := out + "/" + name
if !isNeedCovered && isExists(file) {
file += ".gen" + time.Now().Format("20060102T150405")
}
content = bytes.ReplaceAll(content, []byte("moduleNameExample"), []byte(moduleName))
content = bytes.ReplaceAll(content, []byte("serverNameExample"), []byte(serverName))
content = bytes.ReplaceAll(content, firstLetterToUpper("serverNameExample"), firstLetterToUpper(serverName))
return os.WriteFile(file, content, 0666)
}
func saveFileSimple(out string, filePath string, content []byte, isNeedCovered bool) error {
if len(content) == 0 {
return nil
}
_ = os.MkdirAll(out, 0766)
_, name := filepath.Split(filePath)
file := out + "/" + name
if !isNeedCovered && isExists(file) {
file += ".gen" + time.Now().Format("20060102T150405")
}
return os.WriteFile(file, content, 0666)
}
func isExists(f string) bool {
_, err := os.Stat(f)
if err != nil {
return !os.IsNotExist(err)
}
return true
}
func firstLetterToUpper(s string) []byte {
if s == "" {
return []byte{}
}
return []byte(strings.ToUpper(s[:1]) + s[1:])
}