Skip to content

Commit

Permalink
Add WithoutDereference config to prevent incorrect bool pointer merges
Browse files Browse the repository at this point in the history
  • Loading branch information
Anonymous committed Oct 28, 2022
1 parent c42713b commit 0e73161
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 13 deletions.
73 changes: 73 additions & 0 deletions issue131_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
type foz struct {
A *bool
B string
C *bool
D *bool
E *bool
}

func TestIssue131MergeWithOverwriteWithEmptyValue(t *testing.T) {
Expand All @@ -30,3 +33,73 @@ func TestIssue131MergeWithOverwriteWithEmptyValue(t *testing.T) {
t.Errorf("dest.B not merged in properly: %v != %v", src.B, dest.B)
}
}

func TestIssue131MergeWithoutDereferenceWithOverride(t *testing.T) {
src := foz{
A: func(v bool) *bool { return &v }(false),
B: "src",
C: nil,
D: func(v bool) *bool { return &v }(false),
E: func(v bool) *bool { return &v }(true),
}
dest := foz{
A: func(v bool) *bool { return &v }(true),
B: "dest",
C: func(v bool) *bool { return &v }(false),
D: nil,
E: func(v bool) *bool { return &v }(false),
}
if err := mergo.Merge(&dest, src, mergo.WithoutDereference, mergo.WithOverride); err != nil {
t.Error(err)
}
if *src.A != *dest.A {
t.Errorf("dest.A not merged in properly: %v != %v", *src.A, *dest.A)
}
if src.B != dest.B {
t.Errorf("dest.B not merged in properly: %v != %v", src.B, dest.B)
}
if *dest.C != false {
t.Errorf("dest.C not merged in properly: %v != %v", *src.C, *dest.C)
}
if *dest.D != false {
t.Errorf("dest.D not merged in properly: %v != %v", src.D, *dest.D)
}
if *dest.E != true {
t.Errorf("dest.E not merged in properly: %v != %v", *src.E, *dest.E)
}
}

func TestIssue131MergeWithoutDereference(t *testing.T) {
src := foz{
A: func(v bool) *bool { return &v }(false),
B: "src",
C: nil,
D: func(v bool) *bool { return &v }(false),
E: func(v bool) *bool { return &v }(true),
}
dest := foz{
A: func(v bool) *bool { return &v }(true),
B: "dest",
C: func(v bool) *bool { return &v }(false),
D: nil,
E: func(v bool) *bool { return &v }(false),
}
if err := mergo.Merge(&dest, src, mergo.WithoutDereference); err != nil {
t.Error(err)
}
if *src.A == *dest.A {
t.Errorf("dest.A should not have been merged: %v == %v", *src.A, *dest.A)
}
if src.B == dest.B {
t.Errorf("dest.B should not have been merged: %v == %v", src.B, dest.B)
}
if *dest.C != false {
t.Errorf("dest.C not merged in properly: %v != %v", *src.C, *dest.C)
}
if *dest.D != false {
t.Errorf("dest.D not merged in properly: %v != %v", src.D, *dest.D)
}
if *dest.E == true {
t.Errorf("dest.Eshould not have been merged: %v == %v", *src.E, *dest.E)
}
}
2 changes: 1 addition & 1 deletion map.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, conf
}
fieldName := field.Name
fieldName = changeInitialCase(fieldName, unicode.ToLower)
if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v)) || overwrite) {
if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v), !config.ShouldNotDereference) || overwrite) {
dstMap[fieldName] = src.Field(i).Interface()
}
}
Expand Down
33 changes: 23 additions & 10 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func isExportedComponent(field *reflect.StructField) bool {
type Config struct {
Transformers Transformers
Overwrite bool
ShouldNotDereference bool
AppendSlice bool
TypeCheck bool
overwriteWithEmptyValue bool
Expand Down Expand Up @@ -95,7 +96,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
}
}
} else {
if dst.CanSet() && (isReflectNil(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc) {
if dst.CanSet() && (isReflectNil(dst) || overwrite) && (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc) {
dst.Set(src)
}
}
Expand Down Expand Up @@ -162,7 +163,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
dstSlice = reflect.ValueOf(dstElement.Interface())
}

if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice && !sliceDeepCopy {
if (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) && !config.AppendSlice && !sliceDeepCopy {
if typeCheck && srcSlice.Type() != dstSlice.Type() {
return fmt.Errorf("cannot override two slices with different type (%s, %s)", srcSlice.Type(), dstSlice.Type())
}
Expand Down Expand Up @@ -194,11 +195,11 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
dst.SetMapIndex(key, dstSlice)
}
}
if dstElement.IsValid() && !isEmptyValue(dstElement) && (reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Map || reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Slice) {
if dstElement.IsValid() && !isEmptyValue(dstElement, !config.ShouldNotDereference) && (reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Map || reflect.TypeOf(srcElement.Interface()).Kind() == reflect.Slice) {
continue
}

if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement)) {
if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement, !config.ShouldNotDereference)) {
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
Expand All @@ -209,7 +210,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
if !dst.CanSet() {
break
}
if (!isEmptyValue(src) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice && !sliceDeepCopy {
if (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc || overwriteSliceWithEmptySrc) && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) && !config.AppendSlice && !sliceDeepCopy {
dst.Set(src)
} else if config.AppendSlice {
if src.Type() != dst.Type() {
Expand Down Expand Up @@ -244,12 +245,18 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co

if src.Kind() != reflect.Interface {
if dst.IsNil() || (src.Kind() != reflect.Ptr && overwrite) {
if dst.CanSet() && (overwrite || isEmptyValue(dst)) {
if dst.CanSet() && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) {
dst.Set(src)
}
} else if src.Kind() == reflect.Ptr {
if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil {
return
if !config.ShouldNotDereference {
if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil {
return
}
} else {
if overwriteWithEmptySrc || (overwrite && !src.IsNil()) || dst.IsNil() {
dst.Set(src)
}
}
} else if dst.Elem().Type() == src.Type() {
if err = deepMerge(dst.Elem(), src, visited, depth+1, config); err != nil {
Expand All @@ -262,7 +269,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
}

if dst.IsNil() || overwrite {
if dst.CanSet() && (overwrite || isEmptyValue(dst)) {
if dst.CanSet() && (overwrite || isEmptyValue(dst, !config.ShouldNotDereference)) {
dst.Set(src)
}
break
Expand All @@ -275,7 +282,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
break
}
default:
mustSet := (isEmptyValue(dst) || overwrite) && (!isEmptyValue(src) || overwriteWithEmptySrc)
mustSet := (isEmptyValue(dst, !config.ShouldNotDereference) || overwrite) && (!isEmptyValue(src, !config.ShouldNotDereference) || overwriteWithEmptySrc)
if mustSet {
if dst.CanSet() {
dst.Set(src)
Expand Down Expand Up @@ -326,6 +333,12 @@ func WithOverrideEmptySlice(config *Config) {
config.overwriteSliceWithEmptyValue = true
}

// WithoutDereference prevents dereferencing pointers when evaluating whether they are empty
// (i.e. a non-nil pointer is never considered empty).
func WithoutDereference(config *Config) {
config.ShouldNotDereference = true
}

// WithAppendSlice will make merge append slices instead of overwriting it.
func WithAppendSlice(config *Config) {
config.AppendSlice = true
Expand Down
7 changes: 5 additions & 2 deletions mergo.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type visit struct {
}

// From src/pkg/encoding/json/encode.go.
func isEmptyValue(v reflect.Value) bool {
func isEmptyValue(v reflect.Value, shouldDereference bool) bool {
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
Expand All @@ -50,7 +50,10 @@ func isEmptyValue(v reflect.Value) bool {
if v.IsNil() {
return true
}
return isEmptyValue(v.Elem())
if shouldDereference {
return isEmptyValue(v.Elem(), shouldDereference)
}
return false
case reflect.Func:
return v.IsNil()
case reflect.Invalid:
Expand Down

0 comments on commit 0e73161

Please sign in to comment.