Skip to content

Commit

Permalink
Fix a bug in StructToMap, add more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
mennanov committed May 17, 2019
1 parent f2a0c0d commit ce90040
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 50 deletions.
100 changes: 51 additions & 49 deletions copy.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package fieldmask_utils

import (
"fmt"
"reflect"

"github.com/pkg/errors"
Expand All @@ -15,12 +14,18 @@ func StructToStruct(filter FieldFilter, src, dst interface{}) error {
if dstVal.Kind() != reflect.Ptr {
return errors.Errorf("dst must be a pointer, %s given", dstVal.Kind())
}
srcVal := reflect.ValueOf(src).Elem()
dstVal = dstVal.Elem()
return copyWithFilter(filter, &srcVal, &dstVal)
srcVal := indirect(reflect.ValueOf(src))
if srcVal.Kind() != reflect.Struct {
return errors.Errorf("src kind must be a struct, %s given", srcVal.Kind())
}
dstVal = indirect(dstVal)
if dstVal.Kind() != reflect.Struct {
return errors.Errorf("dst kind must be a struct, %s given", dstVal.Kind())
}
return structToStruct(filter, &srcVal, &dstVal)
}

func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {
func structToStruct(filter FieldFilter, src, dst *reflect.Value) error {
if src.Kind() != dst.Kind() {
return errors.Errorf("src kind %s differs from dst kind %s", src.Kind(), dst.Kind())
}
Expand All @@ -29,18 +34,20 @@ func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {
case reflect.Struct:
for i := 0; i < src.NumField(); i++ {
fieldName := src.Type().Field(i).Name
srcField := src.FieldByName(fieldName)
dstField := dst.FieldByName(fieldName)

subFilter, ok := filter.Filter(fieldName)
if !ok {
// Skip this field.
continue
}

dstField := dst.FieldByName(fieldName)
if !dstField.CanSet() {
return errors.Errorf("Can't set a value on a destination field %s", fieldName)
}
if err := copyWithFilter(subFilter, &srcField, &dstField); err != nil {

srcField := src.FieldByName(fieldName)
if err := structToStruct(subFilter, &srcField, &dstField); err != nil {
return err
}
}
Expand All @@ -57,7 +64,7 @@ func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {
}

srcElem, dstElem := src.Elem(), dst.Elem()
if err := copyWithFilter(filter, &srcElem, &dstElem); err != nil {
if err := structToStruct(filter, &srcElem, &dstElem); err != nil {
return err
}

Expand All @@ -67,9 +74,6 @@ func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {
dst.Set(reflect.Zero(dst.Type()))
break
}
if !dst.Type().Implements(src.Type()) {
return errors.Errorf("dst %s does not implement src %s", dst.Type(), src.Type())
}
if dst.IsNil() {
if src.Elem().Kind() != reflect.Ptr {
// Non-pointer interface implementations are not addressable.
Expand All @@ -79,7 +83,7 @@ func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {
}

srcElem, dstElem := src.Elem(), dst.Elem()
if err := copyWithFilter(filter, &srcElem, &dstElem); err != nil {
if err := structToStruct(filter, &srcElem, &dstElem); err != nil {
return err
}

Expand All @@ -96,7 +100,7 @@ func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {
dstItem = reflect.New(dst.Type().Elem()).Elem()
}

if err := copyWithFilter(filter, &srcItem, &dstItem); err != nil {
if err := structToStruct(filter, &srcItem, &dstItem); err != nil {
return err
}

Expand All @@ -108,13 +112,13 @@ func copyWithFilter(filter FieldFilter, src, dst *reflect.Value) error {

case reflect.Array:
dstLen := dst.Len()
if dstLen != src.Len() {
return errors.Errorf("dst array size %d differs from src size %d", dstLen, src.Len())
if dstLen < src.Len() {
return errors.Errorf("dst array size %d is less than src size %d", dstLen, src.Len())
}
for i := 0; i < src.Len(); i++ {
srcItem := src.Index(i)
dstItem := dst.Index(i)
if err := copyWithFilter(filter, &srcItem, &dstItem); err != nil {
if err := structToStruct(filter, &srcItem, &dstItem); err != nil {
return errors.WithStack(err)
}
}
Expand All @@ -141,10 +145,8 @@ func StructToMap(filter FieldFilter, src interface{}, dst map[string]interface{}
// Skip this field.
continue
}
srcField, err := getField(src, fieldName)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("failed to get the field %s from %T", fieldName, src))
}
srcField := srcVal.FieldByName(fieldName)

switch srcField.Kind() {
case reflect.Ptr, reflect.Interface:
if srcField.IsNil() {
Expand All @@ -165,8 +167,9 @@ func StructToMap(filter FieldFilter, src interface{}, dst map[string]interface{}
dst[fieldName] = newValue

case reflect.Array, reflect.Slice:
// Check if it is an array of values (non-pointers).
if srcField.Type().Elem().Kind() != reflect.Ptr {
// Check if it is a slice of primitive values.
itemKind := srcField.Type().Elem().Kind()
if itemKind != reflect.Ptr && itemKind != reflect.Struct && itemKind != reflect.Interface {
// Handle this array/slice as a regular non-nested data structure: copy it entirely to dst.
if srcField.Len() > 0 {
dst[fieldName] = srcField.Interface()
Expand All @@ -175,17 +178,37 @@ func StructToMap(filter FieldFilter, src interface{}, dst map[string]interface{}
}
continue
}
v := make([]map[string]interface{}, 0)

var newValue []map[string]interface{}
existingValue, ok := dst[fieldName]
if ok {
newValue = existingValue.([]map[string]interface{})
} else {
newValue = make([]map[string]interface{}, srcField.Len())
}

// Iterate over items of the slice/array.
for i := 0; i < srcField.Len(); i++ {
dstLen := len(newValue)
srcLen := srcField.Len()
if dstLen < srcLen {
return errors.Errorf("dst slice len %d is less than src slice len %d", dstLen, srcLen)
}
for i := 0; i < srcLen; i++ {
subValue := srcField.Index(i)
newDst := make(map[string]interface{})
if newValue[i] == nil {
newValue[i] = make(map[string]interface{})
}
newDst := newValue[i]
if err := StructToMap(subFilter, subValue.Interface(), newDst); err != nil {
return err
}
v = append(v, newDst)
if i < dstLen {
newValue[i] = newDst
} else {
newValue = append(newValue, newDst)
}
}
dst[fieldName] = v
dst[fieldName] = newValue

case reflect.Struct:
var newValue map[string]interface{}
Expand All @@ -208,27 +231,6 @@ func StructToMap(filter FieldFilter, src interface{}, dst map[string]interface{}
return nil
}

func getField(obj interface{}, name string) (reflect.Value, error) {
objValue := reflectValue(obj)
field := objValue.FieldByName(name)
if !field.IsValid() {
return reflect.ValueOf(nil), errors.Errorf("no such field: %s in obj %T", name, obj)
}
return field, nil
}

func reflectValue(obj interface{}) reflect.Value {
var val reflect.Value

if reflect.TypeOf(obj).Kind() == reflect.Ptr {
val = reflect.ValueOf(obj).Elem()
} else {
val = reflect.ValueOf(obj)
}

return val
}

func indirect(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
v = v.Elem()
Expand Down
Loading

0 comments on commit ce90040

Please sign in to comment.