From 08e5aeae2cdee693438f7e1759f2725230afcc47 Mon Sep 17 00:00:00 2001 From: Austin Taylor Date: Mon, 13 Aug 2018 15:40:23 -0400 Subject: [PATCH] FEATURE: InterfaceLoader for loading data into an interface type (#142) --- load.go | 24 ++++++++++++++++++++++-- select_test.go | 22 ++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/load.go b/load.go index 31c05e39..2bc82a41 100644 --- a/load.go +++ b/load.go @@ -5,6 +5,15 @@ import ( "reflect" ) +type interfaceLoader struct { + v interface{} + typ reflect.Type +} + +func InterfaceLoader(value interface{}, concreteType interface{}) interface{} { + return interfaceLoader{value, reflect.TypeOf(concreteType)} +} + // Load loads any value from sql.Rows. // // value can be: @@ -28,7 +37,16 @@ func Load(rows *sql.Rows, value interface{}) (int, error) { } ptr := make([]interface{}, len(column)) - v := reflect.ValueOf(value) + var v reflect.Value + var elemType reflect.Type + + if il, ok := value.(interfaceLoader); ok { + v = reflect.ValueOf(il.v) + elemType = il.typ + } else { + v = reflect.ValueOf(value) + } + if v.Kind() != reflect.Ptr || v.IsNil() { return 0, ErrInvalidPointer } @@ -46,7 +64,9 @@ func Load(rows *sql.Rows, value interface{}) (int, error) { for rows.Next() { var elem, keyElem reflect.Value - if isMapOfSlices { + if elemType != nil { + elem = reflectAlloc(elemType) + } else if isMapOfSlices { elem = reflectAlloc(v.Type().Elem().Elem()) } else if isSlice || isMap { elem = reflectAlloc(v.Type().Elem()) diff --git a/select_test.go b/select_test.go index fe0c7929..16566d30 100644 --- a/select_test.go +++ b/select_test.go @@ -117,6 +117,28 @@ func TestMaps(t *testing.T) { } } +func TestInterfaceLoader(t *testing.T) { + for _, sess := range testSession { + reset(t, sess) + + _, err := sess.InsertInto("dbr_people"). + Columns("name", "email"). + Values("test1", "test1@test.com"). + Values("test2", "test2@test.com"). + Values("test2", "test3@test.com"). + Exec() + + var m []interface{} + cnt, err := sess.Select("*").From("dbr_people").Load(InterfaceLoader(&m, dbrPerson{})) + require.NoError(t, err) + require.Equal(t, 3, cnt) + require.Len(t, m, 3) + person, ok := m[0].(dbrPerson) + require.True(t, ok) + require.Equal(t, "test1", person.Name) + } +} + func TestPostgresArray(t *testing.T) { sess := postgresSession for _, v := range []string{