Skip to content

Commit

Permalink
Merge pull request #120 from pradeepp28/dont-overwrite-pointers
Browse files Browse the repository at this point in the history
should not overwrite pointers directly, instead check embedded values…
  • Loading branch information
vdemeester committed Jul 5, 2019
2 parents f757d86 + eb76876 commit 27e96fb
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 3 deletions.
5 changes: 3 additions & 2 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
continue
}

if srcElement.IsValid() && (overwrite || (!dstElement.IsValid() || isEmptyValue(dstElement))) {
if srcElement.IsValid() && ((srcElement.Kind() != reflect.Ptr && overwrite) || !dstElement.IsValid() || isEmptyValue(dstElement)) {
if dst.IsNil() {
dst.Set(reflect.MakeMap(dst.Type()))
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
}

if src.Kind() != reflect.Interface {
if dst.IsNil() || overwrite {
if dst.IsNil() || (src.Kind() != reflect.Ptr && overwrite) {
if dst.CanSet() && (overwrite || isEmptyValue(dst)) {
dst.Set(src)
}
Expand Down Expand Up @@ -213,6 +213,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
dst.Set(src)
}
}

return
}

Expand Down
165 changes: 164 additions & 1 deletion mergo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"gopkg.in/yaml.v2"
)

Expand Down Expand Up @@ -347,7 +349,7 @@ func TestEmptyToNotEmptyMaps(t *testing.T) {
func TestMapsWithOverwrite(t *testing.T) {
m := map[string]simpleTest{
"a": {}, // overwritten by 16
"b": {42}, // not overwritten by empty value
"b": {42}, // overwritten by 0, as map Value is not addressable and it doesn't check for b is set or not set in `n`
"c": {13}, // overwritten by 12
"d": {61},
}
Expand All @@ -374,6 +376,167 @@ func TestMapsWithOverwrite(t *testing.T) {
}
}

func TestMapWithEmbeddedStructPointer(t *testing.T) {
m := map[string]*simpleTest{
"a": {}, // overwritten by 16
"b": {42}, // not overwritten by empty value
"c": {13}, // overwritten by 12
"d": {61},
}
n := map[string]*simpleTest{
"a": {16},
"b": {},
"c": {12},
"e": {14},
}
expect := map[string]*simpleTest{
"a": {16},
"b": {42},
"c": {12},
"d": {61},
"e": {14},
}

if err := Merge(&m, n, WithOverride); err != nil {
t.Fatalf(err.Error())
}

assert.Equalf(t, expect, m, "Test Failed")
if !reflect.DeepEqual(m, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect)
}
}

func TestMergeUsingStructAndMap(t *testing.T) {
type multiPtr struct {
Text string
Number int
}
type final struct {
Msg1 string
Msg2 string
}
type params struct {
Name string
Multi *multiPtr
Final *final
}
type config struct {
Foo string
Bar string
Params *params
}

cases := []struct {
name string
overwrite bool
changes *config
target *config
output *config
}{
{
name: "Should overwrite values in target for non-nil values in source",
overwrite: true,
changes: &config{
Bar: "from changes",
Params: &params{
Final: &final{
Msg1: "from changes",
Msg2: "from changes",
},
},
},
target: &config{
Foo: "from target",
Params: &params{
Name: "from target",
Multi: &multiPtr{
Text: "from target",
Number: 5,
},
Final: &final{
Msg1: "from target",
Msg2: "",
},
},
},
output: &config{
Foo: "from target",
Bar: "from changes",
Params: &params{
Name: "from target",
Multi: &multiPtr{
Text: "from target",
Number: 5,
},
Final: &final{
Msg1: "from changes",
Msg2: "from changes",
},
},
},
},
{
name: "Should not overwrite values in target for non-nil values in source",
overwrite: false,
changes: &config{
Bar: "from changes",
Params: &params{
Final: &final{
Msg1: "from changes",
Msg2: "from changes",
},
},
},
target: &config{
Foo: "from target",
Params: &params{
Name: "from target",
Multi: &multiPtr{
Text: "from target",
Number: 5,
},
Final: &final{
Msg1: "from target",
Msg2: "",
},
},
},
output: &config{
Foo: "from target",
Bar: "from changes",
Params: &params{
Name: "from target",
Multi: &multiPtr{
Text: "from target",
Number: 5,
},
Final: &final{
Msg1: "from target",
Msg2: "from changes",
},
},
},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var err error
if tc.overwrite {
err = Merge(tc.target, *tc.changes, WithOverride)
} else {
err = Merge(tc.target, *tc.changes)
}
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(tc.target, tc.output) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", tc.target, tc.output)
}
})
}
}
func TestMaps(t *testing.T) {
m := map[string]simpleTest{
"a": {},
Expand Down

0 comments on commit 27e96fb

Please sign in to comment.