/
binder.go
216 lines (192 loc) · 5.36 KB
/
binder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
package sqlplus
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"reflect"
"strings"
)
type binder struct {
rows *sql.Rows
// ats 列表类型
ats reflect.Type
// avs 列表值
avs reflect.Value
// item 从类型新创建的值
item reflect.Value
// keys 从item上的键
keys map[string]reflect.Value
// fields 可以放入Scan中的值指针
fields []interface{}
}
func (b *binder) analysisSlice(list interface{}) (err error) {
b.ats = reflect.TypeOf(list)
b.avs = reflect.ValueOf(list)
if b.ats.Kind() != reflect.Ptr {
return errors.New("传入的list必须是指针")
}
if b.ats.Elem().Kind() != reflect.Slice {
return errors.New("传入的list必须是个slice")
}
if b.ats.Elem().Elem().Kind() != reflect.Struct {
return errors.New("传入的list必须是struct类型的slice")
}
b.item = reflect.New(b.ats.Elem().Elem())
b.keys = make(map[string]reflect.Value)
return
}
func (b *binder) parseSlideAll() (err error) {
cts, err := b.rows.ColumnTypes()
if err != nil {
return
}
for b.rows.Next() {
// 清空重构
b.fields = []interface{}{}
b.item = reflect.New(b.ats.Elem().Elem())
b.keys = make(map[string]reflect.Value)
// 将新创建的对象上的数据项根据tag映射到key
b.decode(b.item.Elem())
// 将key上的指针按column类型顺序整理进`b.fields`数组中
err = b.merge(cts)
if err != nil {
return
}
// 读入数据至`b.fields`指针中。
err = b.rows.Scan(b.fields...)
// 记下错误,同时也赋值,不因为个别字段问题丧失所有数据
// 将`b.fields`里指针映射的数据:`b.item`合并到`b.avs` slice数组中
b.avs.Elem().Set(reflect.Append(b.avs.Elem(), b.item.Elem()))
}
return
}
func (b *binder) analysisStruct(obj interface{}) (err error) {
b.ats = reflect.TypeOf(obj)
b.avs = reflect.ValueOf(obj)
if b.ats.Kind() != reflect.Ptr {
return errors.New("传入的 obj 必须是指针")
}
if b.ats.Elem().Kind() != reflect.Struct {
return fmt.Errorf("传入的 obj %v 必须是个 struct", b.ats.Elem().Kind())
}
b.item = b.avs
b.keys = make(map[string]reflect.Value)
return
}
func (b *binder) parseStruct() (err error) {
cts, err := b.rows.ColumnTypes()
if err != nil {
return
}
b.decode(b.item.Elem())
err = b.merge(cts)
if err != nil {
return
}
b.rows.Next()
err = b.rows.Scan(b.fields...)
return
}
func (b *binder) mustLimit1(query string) string {
query = strings.TrimSpace(query)
//if !strings.Contains(strings.ToLower(query), "limit") && query[len(query)-1] != 42 {
// query += " limit 1"
//}
return query
}
type jsonField struct {
Field interface{}
}
func (jf *jsonField) Scan(src interface{}) (err error) {
switch src.(type) {
case json.RawMessage:
err = json.Unmarshal(src.(json.RawMessage), jf.Field)
case string:
err = json.Unmarshal([]byte(src.(string)), jf.Field)
case []byte:
err = json.Unmarshal(src.([]byte), jf.Field)
}
return
}
func (b *binder) merge(cts []*sql.ColumnType) (err error) {
for _, v := range cts {
if f := b.keys[v.Name()]; f.CanAddr() && f.Addr().CanInterface() {
// 要先检查类型是否匹配
if b.canScan(v, f.Type()) {
b.fields = append(b.fields, f.Addr().Interface())
} else {
if v.DatabaseTypeName() == "PgTypeJsonb" || v.DatabaseTypeName() == "PgTypeJson" {
b.fields = append(b.fields, &jsonField{f.Addr().Interface()})
} else {
log.Println("ParseRows type not pare -> ", v.Name(), v.DatabaseTypeName(), v.ScanType(), f.Type())
b.fields = append(b.fields, reflect.New(v.ScanType()).Interface())
}
}
} else {
/*
如果查询出的字段,不在struct有标记的field中,会导致Scan时数量对不上的问题
为了补齐,需创建一个对应字段类型的变量指针
*/
f := reflect.New(v.ScanType()).Interface()
b.fields = append(b.fields, &f)
}
}
return
}
func (b *binder) canScan(t1 *sql.ColumnType, t2 reflect.Type) bool {
if t1.ScanType() == t2 || "*"+t1.ScanType().String() == t2.String() {
return true
} else {
if t1.ScanType().String() == "time.Time" && t2.String() == "json_data.JsonDate" {
return true
}
if len(t1.DatabaseTypeName()) > 2 && t1.DatabaseTypeName()[0:3] == "INT" {
return t1.ScanType().String()[0:3] == "int" && t2.String()[0:3] == "int"
} else if t1.ScanType().String() == "time.Time" && t2.String() == "pq.NullTime" {
return true
} else if t1.DatabaseTypeName() == "_INT4" && t2.String() == "pq.Int64Array" {
return true
} else if t1.DatabaseTypeName() == "_VARCHAR" && t2.String() == "pq.StringArray" {
return true
} else if t1.DatabaseTypeName() == "TEXT" && t2.String() == "sql.NullString" {
return true
} else {
return false
}
}
}
func (b *binder) decode(v reflect.Value) {
if !v.IsValid() {
return
}
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
tag := b.getTag(v.Type().Field(i).Tag)
if tag == "" {
if f.Kind() == reflect.Struct {
// 没有tag的类型引用
b.decode(f.Addr().Elem())
}
} else {
/*
只要有tag,视为解析的终点
因为一条记录是一个线形的一维数组,不是树形结构
*/
if f.CanInterface() && f.CanAddr() {
// 忽略得到了类型,也无法赋值的私有类型
b.keys[tag] = f
}
}
}
return
}
func (b *binder) getTag(t reflect.StructTag) (tag string) {
if tag = t.Get("sql"); tag == "" {
if tag = t.Get("json"); tag == "" {
tag = t.Get("xml")
}
}
return
}