From 0ae25e1b9adb80693fc56456f33324c0cad8ad6e Mon Sep 17 00:00:00 2001 From: Brett Jones Date: Sun, 30 May 2021 17:38:16 -0500 Subject: [PATCH] Implement nested structs (#29) Shout out to @schmath for the [implementation](https://github.com/blockloop/scan/issues/16#issuecomment-748884402) --- fakes_test.go | 40 +++++++++++++++------------------------- scanner.go | 22 +++++++++++++--------- scanner_test.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 34 deletions(-) diff --git a/fakes_test.go b/fakes_test.go index b9ece7b..bbd1e13 100644 --- a/fakes_test.go +++ b/fakes_test.go @@ -47,9 +47,8 @@ func fakeRowsWithRecords(t testing.TB, cols []string, rows ...[]interface{}) *Fa type FakeRowsScanner struct { CloseStub func() error closeMutex sync.RWMutex - closeArgsForCall []struct { - } - closeReturns struct { + closeArgsForCall []struct{} + closeReturns struct { result1 error } closeReturnsOnCall map[int]struct { @@ -57,9 +56,8 @@ type FakeRowsScanner struct { } ColumnTypesStub func() ([]*sql.ColumnType, error) columnTypesMutex sync.RWMutex - columnTypesArgsForCall []struct { - } - columnTypesReturns struct { + columnTypesArgsForCall []struct{} + columnTypesReturns struct { result1 []*sql.ColumnType result2 error } @@ -69,9 +67,8 @@ type FakeRowsScanner struct { } ColumnsStub func() ([]string, error) columnsMutex sync.RWMutex - columnsArgsForCall []struct { - } - columnsReturns struct { + columnsArgsForCall []struct{} + columnsReturns struct { result1 []string result2 error } @@ -81,9 +78,8 @@ type FakeRowsScanner struct { } ErrStub func() error errMutex sync.RWMutex - errArgsForCall []struct { - } - errReturns struct { + errArgsForCall []struct{} + errReturns struct { result1 error } errReturnsOnCall map[int]struct { @@ -91,9 +87,8 @@ type FakeRowsScanner struct { } NextStub func() bool nextMutex sync.RWMutex - nextArgsForCall []struct { - } - nextReturns struct { + nextArgsForCall []struct{} + nextReturns struct { result1 bool } nextReturnsOnCall map[int]struct { @@ -117,8 +112,7 @@ type FakeRowsScanner struct { func (fake *FakeRowsScanner) Close() error { fake.closeMutex.Lock() ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] - fake.closeArgsForCall = append(fake.closeArgsForCall, struct { - }{}) + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) fake.recordInvocation("Close", []interface{}{}) fake.closeMutex.Unlock() if fake.CloseStub != nil { @@ -169,8 +163,7 @@ func (fake *FakeRowsScanner) CloseReturnsOnCall(i int, result1 error) { func (fake *FakeRowsScanner) ColumnTypes() ([]*sql.ColumnType, error) { fake.columnTypesMutex.Lock() ret, specificReturn := fake.columnTypesReturnsOnCall[len(fake.columnTypesArgsForCall)] - fake.columnTypesArgsForCall = append(fake.columnTypesArgsForCall, struct { - }{}) + fake.columnTypesArgsForCall = append(fake.columnTypesArgsForCall, struct{}{}) fake.recordInvocation("ColumnTypes", []interface{}{}) fake.columnTypesMutex.Unlock() if fake.ColumnTypesStub != nil { @@ -224,8 +217,7 @@ func (fake *FakeRowsScanner) ColumnTypesReturnsOnCall(i int, result1 []*sql.Colu func (fake *FakeRowsScanner) Columns() ([]string, error) { fake.columnsMutex.Lock() ret, specificReturn := fake.columnsReturnsOnCall[len(fake.columnsArgsForCall)] - fake.columnsArgsForCall = append(fake.columnsArgsForCall, struct { - }{}) + fake.columnsArgsForCall = append(fake.columnsArgsForCall, struct{}{}) fake.recordInvocation("Columns", []interface{}{}) fake.columnsMutex.Unlock() if fake.ColumnsStub != nil { @@ -279,8 +271,7 @@ func (fake *FakeRowsScanner) ColumnsReturnsOnCall(i int, result1 []string, resul func (fake *FakeRowsScanner) Err() error { fake.errMutex.Lock() ret, specificReturn := fake.errReturnsOnCall[len(fake.errArgsForCall)] - fake.errArgsForCall = append(fake.errArgsForCall, struct { - }{}) + fake.errArgsForCall = append(fake.errArgsForCall, struct{}{}) fake.recordInvocation("Err", []interface{}{}) fake.errMutex.Unlock() if fake.ErrStub != nil { @@ -331,8 +322,7 @@ func (fake *FakeRowsScanner) ErrReturnsOnCall(i int, result1 error) { func (fake *FakeRowsScanner) Next() bool { fake.nextMutex.Lock() ret, specificReturn := fake.nextReturnsOnCall[len(fake.nextArgsForCall)] - fake.nextArgsForCall = append(fake.nextArgsForCall, struct { - }{}) + fake.nextArgsForCall = append(fake.nextArgsForCall, struct{}{}) fake.recordInvocation("Next", []interface{}{}) fake.nextMutex.Unlock() if fake.NextStub != nil { diff --git a/scanner.go b/scanner.go index e446eae..6f28964 100644 --- a/scanner.go +++ b/scanner.go @@ -132,22 +132,26 @@ func rows(v interface{}, r RowsScanner, strict bool) (outerr error) { } // Initialization the tags from struct. -func initFieldTag(v reflect.Value, len int) map[string]reflect.Value { - fieldTagMap := make(map[string]reflect.Value, len) - typ := v.Type() - for i := 0; i < v.NumField(); i++ { +func initFieldTag(sliceItem reflect.Value, fieldTagMap *map[string]reflect.Value) { + typ := sliceItem.Type() + for i := 0; i < sliceItem.NumField(); i++ { + if typ.Field(i).Anonymous || typ.Field(i).Type.Kind() == reflect.Struct { + // found an embedded struct + sliceItemOfAnonymous := sliceItem.Field(i) + initFieldTag(sliceItemOfAnonymous, fieldTagMap) + } tag, ok := typ.Field(i).Tag.Lookup("db") if ok && tag != "" { - fieldTagMap[tag] = v.Field(i) + (*fieldTagMap)[tag] = sliceItem.Field(i) } } - return fieldTagMap } -func structPointers(stct reflect.Value, cols []string, strict bool) []interface{} { +func structPointers(sliceItem reflect.Value, cols []string, strict bool) []interface{} { pointers := make([]interface{}, 0, len(cols)) + fieldTag := make(map[string]reflect.Value, len(cols)) + initFieldTag(sliceItem, &fieldTag) - fieldTag := initFieldTag(stct, len(cols)) for _, colName := range cols { var fieldVal reflect.Value if v, ok := fieldTag[colName]; ok { @@ -156,7 +160,7 @@ func structPointers(stct reflect.Value, cols []string, strict bool) []interface{ if strict { fieldVal = reflect.ValueOf(nil) } else { - fieldVal = stct.FieldByName(strings.Title(colName)) + fieldVal = sliceItem.FieldByName(strings.Title(colName)) } } if !fieldVal.IsValid() || !fieldVal.CanSet() { diff --git a/scanner_test.go b/scanner_test.go index 2a11c5f..fa3a130 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -278,6 +278,39 @@ func TestRowStrictIgnoresFieldsWithoutDBTag(t *testing.T) { assert.Equal(t, "", item.Last) } +func TestRowScansNestedFields(t *testing.T) { + rows := fakeRowsWithRecords(t, []string{"p.First", "p.Last"}, + []interface{}{"Brett", "Jones"}, + ) + + var res struct { + Item struct { + First string `db:"p.First"` + Last string `db:"p.Last"` + } + } + + require.NoError(t, scan.Row(&res, rows)) + assert.Equal(t, "Brett", res.Item.First) + assert.Equal(t, "Jones", res.Item.Last) +} + +func TestRowStrictScansNestedFields(t *testing.T) { + rows := fakeRowsWithRecords(t, []string{"p.First", "p.Last"}, + []interface{}{"Brett", "Jones"}, + ) + + var res struct { + Item struct { + First string `db:"p.First"` + Last string `db:"p.Last"` + } + } + + require.NoError(t, scan.RowStrict(&res, rows)) + assert.Equal(t, "Brett", res.Item.First) + assert.Equal(t, "Jones", res.Item.Last) +} func TestRowsStrictIgnoresFieldsWithoutDBTag(t *testing.T) { rows := fakeRowsWithRecords(t, []string{"First", "Last"}, []interface{}{"Brett", "Jones"},