-
Notifications
You must be signed in to change notification settings - Fork 0
/
field_column.go
111 lines (95 loc) · 3.12 KB
/
field_column.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package do
import (
"database/sql"
"fmt"
"reflect"
"strings"
)
// ObjectAndFieldsHelper if *T is interface{ ValuePtrs() []any } then use its results as fields, otherwise use reflect to get fields by ColumnType
func ObjectAndFieldsHelper[T any](fieldMappers ...func(string) string) func(colTypes []*sql.ColumnType) (r *T, fields []any) {
var fieldMapper func(string) string
if len(fieldMappers) != 0 {
fieldMapper = fieldMappers[0]
}
return func(colTypes []*sql.ColumnType) (r *T, fields []any) {
r = new(T)
switch vi := any(r).(type) {
case interface{ ValuePtrs() []any }:
fields = append(fields, vi.ValuePtrs()...)
// Scan 将从数据库读取的列转换为以下常见的 Go 类型和 sql 包提供的特殊类型:
//
// *string
// *[]byte
// *int, *int8, *int16, *int32, *int64
// *uint, *uint8, *uint16, *uint32, *uint64
// *bool
// *float32, *float64
// *interface{}
// *RawBytes
// *Rows (cursor value)
// any type implementing Scanner (see Scanner docs)
//
case *string,
*[]byte,
*bool,
*float32, *float64,
*int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64:
fields = append(fields, r)
default:
fields = FieldsByColumnType(r, colTypes, fieldMapper)
}
return
}
}
// FieldsByColumnType t is a struct pointer, and use it's field match column name to receive scan value. It will use db tag to get column name first, or lower case field name. You can specify fieldMapper to control column name with field name
func FieldsByColumnType(t any, colTypes []*sql.ColumnType, fieldMapper func(string) string) (fields []any) {
validName := make(map[string]struct{})
for _, ct := range colTypes {
validName[ct.Name()] = struct{}{}
}
nameValues := fieldsByColumnName(t, validName, fieldMapper)
for _, ct := range colTypes {
fields = append(fields, nameValues[ct.Name()])
}
return
}
func fieldsByColumnName(t any, validName map[string]struct{}, fieldMapper func(string) string) (nameValues map[string]any) {
val := reflect.ValueOf(t)
typ := val.Type()
if typ.Kind() != reflect.Ptr {
panic(fmt.Errorf("t must be a struct pointer, but t's type is %v", typ))
}
val = val.Elem()
typ = typ.Elem()
if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("t must be a struct pointer, but t's type is %v", typ))
}
return fieldsByColumnNameInner(typ, val, validName, fieldMapper)
}
func fieldsByColumnNameInner(typ reflect.Type, val reflect.Value, validName map[string]struct{}, fieldMapper func(string) string) (nameValues map[string]any) {
nameValues = make(map[string]any)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
value := val.Field(i)
if field.Anonymous {
nv := fieldsByColumnNameInner(value.Type(), value, validName, fieldMapper)
nameValues = MergeKeyValue(nameValues, nv)
} else {
fieldName := ""
if fieldMapper == nil {
fieldName = field.Tag.Get("db")
if fieldName == "" {
fieldName = strings.ToLower(field.Name)
}
} else {
fieldName = fieldMapper(field.Name)
}
if _, ok := validName[fieldName]; !ok {
continue
}
nameValues[fieldName] = value.Addr().Interface()
}
}
return
}