Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
qiangxue committed Jul 13, 2016
1 parent 52c82a3 commit 094c37e
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 226 deletions.
7 changes: 7 additions & 0 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ type Builder interface {
// The parameters to this method should be the list column names to be selected.
// A column name may have an optional alias name. For example, Select("id", "my_name AS name").
Select(...string) *SelectQuery
// ModelQuery returns a new ModelQuery object that can be used to perform model insertion, update, and deletion.
// The parameter to this method should be a pointer to the model struct that needs to be inserted, updated, or deleted.
Model(interface{}) *ModelQuery

// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
GeneratePlaceholder(int) string
Expand Down Expand Up @@ -136,6 +139,10 @@ func (b *BaseBuilder) Select(cols ...string) *SelectQuery {
return NewSelectQuery(b.db.Builder, b.executor).Select(cols...)
}

func (b *BaseBuilder) Model(model interface{}) *ModelQuery {
return newModelQuery(model)
}

// GeneratePlaceholder generates an anonymous parameter placeholder with the given parameter ID.
func (b *BaseBuilder) GeneratePlaceholder(int) string {
return "?"
Expand Down
192 changes: 137 additions & 55 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,10 @@ import (
// FieldMapFunc converts a struct field name into a DB column name.
type FieldMapFunc func(string) string

type fieldMapKey struct {
t reflect.Type
m reflect.Value
}

type fieldInfo struct {
Name string
ColName string
IsPK bool
Path []int
}

type fieldMap map[string]fieldInfo

var (
// DbTag is the name of the struct tag used to specify the column name for the associated struct field
DbTag = "db"

muFieldMap sync.Mutex
fieldMaps = make(map[fieldMapKey]fieldMap)
fieldRegex = regexp.MustCompile(`([^A-Z_])([A-Z])`)
)

Expand All @@ -46,31 +30,115 @@ func DefaultFieldMapFunc(f string) string {
return strings.ToLower(fieldRegex.ReplaceAllString(f, "${1}_$2"))
}

// getFieldMap builds a field map for a struct.
// The map returned will have field names as keys and field positions as values.
// Only exported fields are considered. For anonymous fields that are structs,
// their exported fields will be included in the map recursively.
// See TestGetFieldMap() for an example.
func getFieldMap(a reflect.Type, mapper FieldMapFunc) fieldMap {
muFieldMap.Lock()
defer muFieldMap.Unlock()

key := fieldMapKey{a, reflect.ValueOf(mapper)}
if m, ok := fieldMaps[key]; ok {
return m
type fieldInfo struct {
name string
dbName string
isPK bool
path []int
}

type structInfo struct {
nameMap map[string]*fieldInfo
dbNameMap map[string]*fieldInfo
pkNames []string
}

type structValue struct {
*structInfo
value reflect.Value
tableName string
}

func newStructValue(model interface{}, mapper FieldMapFunc) *structValue {
value := reflect.ValueOf(model)
if value.Kind() != reflect.Ptr || value.Elem().Kind() != reflect.Struct || value.IsNil() {
return nil
}
t := reflect.TypeOf(model)
var tableName string
if tm, ok := model.(TableModel); ok {
tableName = tm.TableName()
} else {
tableName = DefaultFieldMapFunc(t.Name())
}

fm := fieldMap{}
buildFieldMap(a, make([]int, 0), "", "", fm, mapper)
fieldMaps[key] = fm
si := getStructInfo(t, mapper)
return &structValue{
structInfo: si,
value: value,
tableName: tableName,
}
}

return fm
func (s *structValue) pk() map[string]interface{} {
return s.fields(s.pkNames...)
}

var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func (s *structValue) fields(attrs ...string) map[string]interface{} {
v := map[string]interface{}{}
if len(attrs) == 0 {
for _, fi := range s.nameMap {
v[fi.dbName] = fi.getValue(s.value)
}
} else {
for _, attr := range attrs {
if fi, ok := s.nameMap[attr]; ok {
v[fi.dbName] = fi.getValue(s.value)
}
}
}
return v
}

func (fi *fieldInfo) getValue(a reflect.Value) interface{} {
for _, i := range fi.path {
a = a.Field(i)
if a.Kind() == reflect.Ptr {
if a.IsNil() {
return nil
}
a = a.Elem()
}
}
return a.Interface()
}

// buildFieldMap is called by getFieldMap recursively to build field map for a struct.
func buildFieldMap(a reflect.Type, path []int, namePrefix, colPrefix string, fm fieldMap, mapper FieldMapFunc) {
// getStructField returns the reflection value of the field specified by its field map path.
func (fi *fieldInfo) getField(a reflect.Value) reflect.Value {
for _, i := range fi.path {
a = indirect(a.Field(i))
}
return a
}

type structInfoMapKey struct {
t reflect.Type
m reflect.Value
}

var structInfoMap = make(map[structInfoMapKey]*structInfo)
var muStructInfoMap sync.Mutex

func getStructInfo(a reflect.Type, mapper FieldMapFunc) *structInfo {
muStructInfoMap.Lock()
defer muStructInfoMap.Unlock()

key := structInfoMapKey{a, reflect.ValueOf(mapper)}
if si, ok := structInfoMap[key]; ok {
return si
}

si := &structInfo{
nameMap: map[string]*fieldInfo{},
dbNameMap: map[string]*fieldInfo{},
}
buildStructInfo(si, a, make([]int, 0), "", "", mapper)
structInfoMap[key] = si

return si
}

func buildStructInfo(si *structInfo, a reflect.Type, path []int, namePrefix, dbNamePrefix string, mapper FieldMapFunc) {
n := a.NumField()
for i := 0; i < n; i++ {
field := a.Field(i)
Expand All @@ -90,12 +158,13 @@ func buildFieldMap(a reflect.Type, path []int, namePrefix, colPrefix string, fm
ft = ft.Elem()
}

colName := tag
name := field.Name
if colName == "" && !field.Anonymous {
colName = field.Name
dbName, isPK := parseTag(tag)
if dbName == "" && !field.Anonymous {
if mapper != nil {
colName = mapper(colName)
dbName = mapper(field.Name)
} else {
dbName = field.Name
}
}
if field.Anonymous {
Expand All @@ -104,20 +173,41 @@ func buildFieldMap(a reflect.Type, path []int, namePrefix, colPrefix string, fm

if ft.Kind() == reflect.Struct && !reflect.PtrTo(ft).Implements(scannerType) {
// dive into non-scanner struct
buildFieldMap(ft, path2, concat(namePrefix, name), concat(colPrefix, colName), fm, mapper)
} else if colName != "" {
buildStructInfo(si, ft, path2, concat(namePrefix, name), concat(dbNamePrefix, dbName), mapper)
} else if dbName != "" {
// non-anonymous scanner or struct field
colName = concat(colPrefix, colName)
fm[colName] = fieldInfo{
Name: concat(namePrefix, name),
ColName: colName,
IsPK: false,
Path: path2,
fi := &fieldInfo{
name: concat(namePrefix, name),
dbName: concat(dbNamePrefix, dbName),
isPK: false,
path: path2,
}
si.nameMap[fi.name] = fi
si.dbNameMap[fi.dbName] = fi
if isPK {
si.pkNames = append(si.pkNames, fi.name)
}
}
}
if len(si.pkNames) == 0 {
if _, ok := si.nameMap["ID"]; ok {
si.pkNames = append(si.pkNames, "ID")
}
}
}

func parseTag(tag string) (string, bool) {
if tag == "pk" {
return "", true
}
if strings.HasPrefix(tag, "pk,") {
return tag[3:], true
}
return tag, false
}

var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()

func concat(s1, s2 string) string {
if s1 == "" {
return s2
Expand All @@ -128,14 +218,6 @@ func concat(s1, s2 string) string {
}
}

// getStructField returns the reflection value of the field specified by its field map path.
func (fi fieldInfo) getStructField(a reflect.Value) reflect.Value {
for _, i := range fi.Path {
a = indirect(a.Field(i))
}
return a
}

// indirect dereferences pointers and returns the actual value it points to.
// If a pointer is nil, it will be initialized with a new value.
func indirect(v reflect.Value) reflect.Value {
Expand Down
33 changes: 16 additions & 17 deletions field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
package dbx

import (
"encoding/json"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -45,18 +43,19 @@ type FB struct {
B1 string
}

func TestGetFieldMap(t *testing.T) {
var a struct {
X1 string
FA
X2 int
B *FB
FB `db:"c"`
c int
}
ta := reflect.TypeOf(a)
r := getFieldMap(ta, DefaultFieldMapFunc)

v, _ := json.Marshal(r)
assert.Equal(t, `{"a1":[1,0],"a2":[1,1],"b.b1":[3,0],"c.b1":[4,0],"x1":[0],"x2":[2]}`, string(v))
}
//
//func TestGetFieldMap(t *testing.T) {
// var a struct {
// X1 string
// FA
// X2 int
// B *FB
// FB `db:"c"`
// c int
// }
// ta := reflect.TypeOf(a)
// r := getFieldMap(ta, DefaultFieldMapFunc)
//
// v, _ := json.Marshal(r)
// assert.Equal(t, `{"a1":[1,0],"a2":[1,1],"b.b1":[3,0],"c.b1":[4,0],"x1":[0],"x2":[2]}`, string(v))
//}
Loading

0 comments on commit 094c37e

Please sign in to comment.