-
Notifications
You must be signed in to change notification settings - Fork 5
/
binder.go
95 lines (88 loc) · 2.4 KB
/
binder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package orm
import (
"database/sql"
"fmt"
"reflect"
"unsafe"
)
// ptrsFor does for each field in struct:
// if field is primitive just allocate and add pointer
// if field is struct call recursively and add all pointers
func (o *schema) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} {
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
t = t.Elem()
}
tableName := o.Table
var scanInto []interface{}
for index := 0; index < len(cts); index++ {
ct := cts[index]
for i := 0; i < t.NumField(); i++ {
if o.fields[i].Virtual {
continue
}
fieldName := o.fields[i].Name
if ct.Name() == fieldName || ct.Name() == tableName+"."+fieldName {
ptr := reflect.NewAt(t.Field(i).Type, unsafe.Pointer(v.Field(i).UnsafeAddr()))
actualPtr := ptr.Elem().Addr().Interface()
scanInto = append(scanInto, actualPtr)
newcts := append(cts[:index], cts[index+1:]...)
return append(scanInto, o.ptrsFor(v, newcts)...)
}
}
}
return scanInto
}
// bind binds given rows to the given object at obj. obj should be a pointer
func (o *schema) bind(rows *sql.Rows, obj interface{}) error {
cts, err := rows.ColumnTypes()
if err != nil {
return err
}
t := reflect.TypeOf(obj)
v := reflect.ValueOf(obj)
if t.Kind() != reflect.Ptr {
return fmt.Errorf("obj should be a ptr")
}
// since passed input is always a pointer one deref is necessary
t = t.Elem()
v = v.Elem()
if t.Kind() == reflect.Slice {
// getting slice elemnt type -> slice[t]
t = t.Elem()
for rows.Next() {
var rowValue reflect.Value
// Since reflect.Initialize returns a pointer to the type, we need to unwrap it to get actual
rowValue = reflect.New(t).Elem()
// till we reach a not pointer type continue newing the underlying type.
for rowValue.IsZero() && rowValue.Type().Kind() == reflect.Ptr {
rowValue = reflect.New(rowValue.Type().Elem()).Elem()
}
newCts := make([]*sql.ColumnType, len(cts))
copy(newCts, cts)
ptrs := o.ptrsFor(rowValue, newCts)
err = rows.Scan(ptrs...)
if err != nil {
return err
}
for rowValue.Type() != t {
tmp := reflect.New(rowValue.Type())
tmp.Elem().Set(rowValue)
rowValue = tmp
}
v = reflect.Append(v, rowValue)
}
} else {
for rows.Next() {
ptrs := o.ptrsFor(v, cts)
err = rows.Scan(ptrs...)
if err != nil {
return err
}
}
}
// v is either struct or slice
reflect.ValueOf(obj).Elem().Set(v)
return nil
}