-
Notifications
You must be signed in to change notification settings - Fork 3
/
c.go
106 lines (89 loc) · 2.38 KB
/
c.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package db
import (
"context"
"encoding/json"
"errors"
"strings"
)
// U 更新
type C struct {
TableName string `json:"-"`
Data string `json:"data"`
OnDuplicateKeyUpdate []string `json:"on_duplicate_key_update"`
}
func (m C) BuildSQL(ctx context.Context) (sql string, err error) {
if m.TableName == "" {
return "", errors.New("param TableName is null")
}
switch len(m.Data) {
case 0:
return "", errors.New("param Data is null")
}
// json 字符串转map
var values []map[string]json.RawMessage
if err = json.Unmarshal([]byte(m.Data), &values); err != nil {
return "", err
}
// 获取字段
fields := m.getFields(values)
var sqlStr strings.Builder
sqlStr.WriteString("INSERT INTO ")
sqlStr.WriteString(m.TableName)
sqlStr.WriteString(" (")
sqlStr.WriteString(strings.Join(fields, ","))
sqlStr.WriteString(") VALUES ")
// 获取第一个
value := m.getValue(fields, values[0])
sqlStr.WriteString(value)
for _, valueMap := range values[1:] {
sqlStr.WriteString(",")
// 获取值
value := m.getValue(fields, valueMap)
sqlStr.WriteString(value)
}
// 更新条件
if len(m.OnDuplicateKeyUpdate) > 0 {
updateFields := m.getOnDuplicateKeyUpdate()
sqlStr.WriteString(updateFields)
}
return strings.Trim(sqlStr.String(), " "), nil
}
func (m C) getOnDuplicateKeyUpdate() string {
var sqlStr strings.Builder
sqlStr.WriteString(" on duplicate key update ")
sqlStr.WriteString(m.OnDuplicateKeyUpdate[0])
sqlStr.WriteString(" = values(")
sqlStr.WriteString(m.OnDuplicateKeyUpdate[0])
sqlStr.WriteString(")")
for _, i2 := range m.OnDuplicateKeyUpdate[1:] {
sqlStr.WriteString(",")
sqlStr.WriteString(i2)
sqlStr.WriteString(" = values(")
sqlStr.WriteString(i2)
sqlStr.WriteString(")")
}
return sqlStr.String()
}
// getFields 获取字段
func (m C) getFields(values []map[string]json.RawMessage) []string {
var fields []string
for k, _ := range values[0] {
fields = append(fields, k)
}
return fields
}
// getValue 获取值
func (m C) getValue(fields []string, valueMap map[string]json.RawMessage) string {
var valStr strings.Builder
valStr.WriteString("(")
valStr.Write(valueMap[fields[0]])
for _, fieldName := range fields[1:] {
/* 按顺序从MAP中取值输出 */
if Value, ok := valueMap[fieldName]; ok {
valStr.WriteString(",")
valStr.Write(Value)
}
}
valStr.WriteString(")")
return valStr.String()
}