diff --git a/issue84_test.go b/issue84_test.go new file mode 100644 index 0000000..aa60526 --- /dev/null +++ b/issue84_test.go @@ -0,0 +1,82 @@ +package mergo + +import ( + "testing" +) + +type DstStructIssue84 struct { + A int + B int + C int +} + +type DstNestedStructIssue84 struct { + A struct { + A int + B int + C int + } + B int + C int +} + +func TestIssue84MergeMapWithNilValueToStructWithOverride(t *testing.T) { + p1 := DstStructIssue84{ + A: 0, B: 1, C: 2, + } + p2 := map[string]interface{}{ + "A": 3, "B": 4, "C": 0, + } + if err := Map(&p1, p2, WithOverride); err != nil { + t.Fatalf("Error during the merge: %v", err) + } + if p1.C != 0 { + t.Error("C field should become '0'") + } +} + +func TestIssue84MergeMapWithoutKeyExistsToStructWithOverride(t *testing.T) { + p1 := DstStructIssue84{ + A: 0, B: 1, C: 2, + } + p2 := map[string]interface{}{ + "A": 3, "B": 4, + } + if err := Map(&p1, p2, WithOverride); err != nil { + t.Fatalf("Error during the merge: %v", err) + } + if p1.C != 2 { + t.Error("C field should be '2'") + } +} + +func TestIssue84MergeNestedMapWithNilValueToStructWithOverride(t *testing.T) { + p1 := DstNestedStructIssue84{ + A: struct { + A int + B int + C int + }{A: 1, B: 2, C: 0}, + B: 0, + C: 2, + } + p2 := map[string]interface{}{ + "A": map[string]interface{}{ + "A": 0, "B": 0, "C": 5, + }, "B": 4, "C": 0, + } + if err := Map(&p1, p2, WithOverride); err != nil { + t.Fatalf("Error during the merge: %v", err) + } + if p1.B != 4 { + t.Error("A.C field should become '4'") + } + + if p1.A.C != 5 { + t.Error("A.C field should become '5'") + } + + if p1.A.B != 0 || p1.A.A != 0 { + t.Error("A.A and A.B field should become '0'") + } +} diff --git a/map.go b/map.go index 6ea38e6..3f5afa8 100644 --- a/map.go +++ b/map.go @@ -72,6 +72,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, conf case reflect.Struct: srcMap := src.Interface().(map[string]interface{}) for key := range srcMap { + config.overwriteWithEmptyValue = true srcValue := srcMap[key] fieldName := changeInitialCase(key, unicode.ToUpper) dstElement := dst.FieldByName(fieldName) diff --git a/merge.go b/merge.go index dd7e714..f8de6c5 100644 --- a/merge.go +++ b/merge.go @@ -26,9 +26,10 @@ func hasExportedField(dst reflect.Value) (exported bool) { } type Config struct { - Overwrite bool - AppendSlice bool - Transformers Transformers + Overwrite bool + AppendSlice bool + Transformers Transformers + overwriteWithEmptyValue bool } type Transformers interface { @@ -40,6 +41,8 @@ type Transformers interface { // short circuiting on recursive types. func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, config *Config) (err error) { overwrite := config.Overwrite + overwriteWithEmptySrc := config.overwriteWithEmptyValue + config.overwriteWithEmptyValue = false if !src.IsValid() { return @@ -74,7 +77,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co } } } else { - if dst.CanSet() && !isEmptyValue(src) && (overwrite || isEmptyValue(dst)) { + if dst.CanSet() && (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) { dst.Set(src) } } @@ -125,7 +128,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co dstSlice = reflect.ValueOf(dstElement.Interface()) } - if !isEmptyValue(src) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { + if (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { dstSlice = srcSlice } else if config.AppendSlice { if srcSlice.Type() != dstSlice.Type() { @@ -151,7 +154,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co if !dst.CanSet() { break } - if !isEmptyValue(src) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { + if (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) && !config.AppendSlice { dst.Set(src) } else if config.AppendSlice { if src.Type() != dst.Type() { @@ -191,7 +194,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co return } default: - if dst.CanSet() && !isEmptyValue(src) && (overwrite || isEmptyValue(dst)) { + if dst.CanSet() && (!isEmptyValue(src) || overwriteWithEmptySrc) && (overwrite || isEmptyValue(dst)) { dst.Set(src) } }