Skip to content

Commit

Permalink
Similarly to slices, make sure that map keys/values are copied correctly
Browse files Browse the repository at this point in the history
If map keys/values are simple structs, generate the copy code in a
buffer, and check whether the buffer has content before attempting to
generate temporary variables for keys/values.
  • Loading branch information
urandom committed Mar 10, 2020
1 parent 225ef72 commit 8a5655e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 25 deletions.
58 changes: 35 additions & 23 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str
walkType(source+"."+fname, sink+"."+fname, x, field.Type(), w, imports, skips, false)
}
case *types.Slice:
kind, _ := getElemType(v.Elem(), x, imports, false)
kind := getElemType(v.Elem(), x, imports, false)

sel := sink + "[i]"
if initial {
Expand Down Expand Up @@ -355,7 +355,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str

fmt.Fprintf(w, "}\n")
case *types.Pointer:
kind, _ := getElemType(v.Elem(), x, imports, true)
kind := getElemType(v.Elem(), x, imports, true)

fmt.Fprintf(w, "if %s != nil {\n", source)

Expand All @@ -370,23 +370,25 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str

fmt.Fprintf(w, "}\n")
case *types.Chan:
kind, _ := getElemType(v.Elem(), x, imports, false)
kind := getElemType(v.Elem(), x, imports, false)

fmt.Fprintf(w, `if %s != nil {
%s = make(chan %s, cap(%s))
}
`, source, sink, kind, source)
case *types.Map:
kkind, kbasic := getElemType(v.Key(), x, imports, false)
vkind, vbasic := getElemType(v.Elem(), x, imports, false)
kkind := getElemType(v.Key(), x, imports, false)
vkind := getElemType(v.Elem(), x, imports, false)

sel := sink + "[k]"
if initial {
sel = "[k]"
}

var skipKey, skipValue bool
sel = sel[strings.Index(sel, ".")+1:]
if _, ok := skips[sel]; ok {
kbasic, vbasic = true, true
skipKey, skipValue = true, true
}

fmt.Fprintf(w, `if %s != nil {
Expand All @@ -395,15 +397,31 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str
`, source, sink, kkind, vkind, source, source)

ksink, vsink := "k", "v"
if !kbasic {
ksink = "cpk"
fmt.Fprintf(w, "var %s %s\n", ksink, kkind)
walkType("k", ksink, x, v.Key(), w, imports, skips, false)

var b bytes.Buffer

if !skipKey {
copyKSink := "cpk"
walkType("k", copyKSink, x, v.Key(), &b, imports, skips, false)

if b.Len() > 0 {
ksink = copyKSink
fmt.Fprintf(w, "var %s %s\n", ksink, kkind)
b.WriteTo(w)
}
}
if !vbasic {
vsink = "cpv"
fmt.Fprintf(w, "var %s %s\n", vsink, vkind)
walkType("v", vsink, x, v.Elem(), w, imports, skips, false)

b.Reset()

if !skipValue {
copyVSink := "cpv"
walkType("v", copyVSink, x, v.Elem(), &b, imports, skips, false)

if b.Len() > 0 {
vsink = copyVSink
fmt.Fprintf(w, "var %s %s\n", vsink, vkind)
b.WriteTo(w)
}
}

fmt.Fprintf(w, "%s[%s] = %s", sink, ksink, vsink)
Expand All @@ -413,7 +431,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str

}

func getElemType(t types.Type, x string, imports map[string]string, rawkind bool) (string, bool) {
func getElemType(t types.Type, x string, imports map[string]string, rawkind bool) string {
obj := objFromType(t)
var name, kind string
if obj != nil {
Expand All @@ -433,19 +451,13 @@ func getElemType(t types.Type, x string, imports map[string]string, rawkind bool
kind += t.String()
}

var pointer, noncopy bool
switch t.(type) {
case *types.Pointer:
pointer = true
case *types.Basic, *types.Interface:
noncopy = true
}
_, pointer := t.(*types.Pointer)

if !rawkind && pointer && kind[0] != '*' {
kind = "*" + kind
}

return kind, noncopy
return kind
}

func reuseDeepCopy(source, sink string, v methoder, pointer bool, w io.Writer) bool {
Expand Down
38 changes: 36 additions & 2 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ func Test_run(t *testing.T) {
{name: "alpha - with DeepCopy method", types: typesVal{"Alpha"}, path: "./testdata", want: []byte(AlphaPointer)},
{name: "slicepointer, skip slice member", types: typesVal{"SlicePointer"}, skips: skipsVal{{"[i]": struct{}{}}}, path: "./testdata", want: []byte(SlicePointer)},
{name: "foo, alpha, skips", types: typesVal{"Foo", "Alpha"}, skips: skipsVal{{"Map[k]": struct{}{}, "ch": struct{}{}}, {"D": struct{}{}, "E": struct{}{}}}, path: "./testdata", want: []byte(FooAlphaSkips)},
{name: "issue 1, struct with slice of simple structs", types: typesVal{"I3WithSlice"}, pointer: true, path: "./testdata", want: []byte(Issue1SliceSimpleStruct)},
{name: "issue 3, struct with slice of simple structs", types: typesVal{"I3WithSlice"}, pointer: true, path: "./testdata", want: []byte(Issue3SliceSimpleStruct)},
{name: "issue 3, struct with map of simple struct keys", types: typesVal{"I3WithMap"}, pointer: true, path: "./testdata", want: []byte(Issue3MapSimpleStructKey)},
{name: "issue 3, struct with map of simple struct values", types: typesVal{"I3WithMapVal"}, path: "./testdata", want: []byte(Issue3MapSimpleStructVal)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -229,7 +231,7 @@ func (o Alpha) DeepCopy() Alpha {
return cp
}`

Issue1SliceSimpleStruct = `// generated by deep-copy; DO NOT EDIT.
Issue3SliceSimpleStruct = `// generated by deep-copy; DO NOT EDIT.
package testdata
Expand All @@ -242,5 +244,37 @@ func (o *I3WithSlice) DeepCopy() *I3WithSlice {
copy(cp.a, o.a)
}
return &cp
}`
Issue3MapSimpleStructKey = `// generated by deep-copy; DO NOT EDIT.
package testdata
// DeepCopy generates a deep copy of *I3WithMap
func (o *I3WithMap) DeepCopy() *I3WithMap {
var cp I3WithMap
cp = *o
if o.a != nil {
cp.a = make(map[I3SimpleStruct]string, len(o.a))
for k, v := range o.a {
cp.a[k] = v
}
}
return &cp
}`
Issue3MapSimpleStructVal = `// generated by deep-copy; DO NOT EDIT.
package testdata
// DeepCopy generates a deep copy of I3WithMapVal
func (o I3WithMapVal) DeepCopy() I3WithMapVal {
var cp I3WithMapVal
cp = o
if o.a != nil {
cp.a = make(map[string]I3SimpleStruct, len(o.a))
for k, v := range o.a {
cp.a[k] = v
}
}
return cp
}`
)
10 changes: 10 additions & 0 deletions testdata/issue_3_slice_of_simple_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ type I3SimpleStruct struct {
foo string
bar int
}

type I3WithMap struct {
a map[I3SimpleStruct]string
b int
}

type I3WithMapVal struct {
a map[string]I3SimpleStruct
b int
}

0 comments on commit 8a5655e

Please sign in to comment.