/
paramutil.go
89 lines (86 loc) · 2.28 KB
/
paramutil.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package utils
import (
"reflect"
"strings"
)
// FormDataCopyFields copy form data fields into db model fields
func FormDataCopyFields(formData interface{}, dest interface{}, tagName string) int {
valueFormData := reflect.ValueOf(formData)
valueDest := reflect.ValueOf(dest)
if valueFormData.Type().Kind() == reflect.Ptr {
valueFormData = valueFormData.Elem()
}
if valueDest.Type().Kind() == reflect.Ptr {
valueDest = valueDest.Elem()
}
fields := valueFormData.NumField()
fieldValues := map[string]reflect.Value{}
for i := 0; i < fields; i++ {
fieldType := valueFormData.Type().Field(i)
if "ID" == fieldType.Name {
continue
}
fieldTagName := strings.SplitN(fieldType.Tag.Get(tagName), ",", 2)[0]
if "" != fieldTagName {
fieldValues[fieldTagName] = valueFormData.Field(i)
} else if IsCapital(fieldType.Name) {
fieldValues[fieldType.Name] = valueFormData.Field(i)
}
}
return copyFields(fieldValues, valueDest, tagName)
}
func copyFields(fieldValues map[string]reflect.Value, dst reflect.Value, tagName string) int {
affected := 0
fields := dst.NumField()
for i := 0; i < fields; i++ {
fieldType := dst.Type().Field(i)
fieldTagName := strings.SplitN(fieldType.Tag.Get(tagName), ",", 2)[0]
if "" == fieldTagName {
if IsCapital(fieldType.Name) {
fieldTagName = fieldType.Name
} else {
continue
}
}
srcValue, exists := fieldValues[fieldTagName]
if !exists {
continue
}
fieldValue := dst.Field(i)
switch fieldValue.Kind() {
case reflect.String:
if fieldValue.String() != srcValue.String() {
affected++
fieldValue.Set(srcValue)
}
break
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int8:
if fieldValue.Int() != srcValue.Int() && srcValue.Int() != 0 {
affected++
fieldValue.Set(srcValue)
}
break
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint8:
if fieldValue.Uint() != srcValue.Uint() {
affected++
fieldValue.Set(srcValue)
}
break
case reflect.Float32, reflect.Float64:
if fieldValue.Float() != srcValue.Float() {
affected++
fieldValue.Set(srcValue)
}
break
case reflect.Bool:
if fieldValue.Bool() != srcValue.Bool() {
affected++
fieldValue.Set(srcValue)
}
break
default:
break
}
}
return affected
}