Skip to content

Commit

Permalink
Fix hasDeepCopy to use values of --pointer-receiver option
Browse files Browse the repository at this point in the history
  • Loading branch information
egawata committed Jul 9, 2021
1 parent bfdfaf5 commit 02bcfef
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 13 deletions.
26 changes: 13 additions & 13 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (o %s%s) DeepCopy() %s%s {
var cp %s = %s%s
`, ptr, kind, ptr, kind, ptr, kind, kind, ptr, source)

walkType(source, "cp", p.Name, obj, &buf, imports, skips, generating, 0)
walkType(source, "cp", p.Name, obj, &buf, imports, skips, generating, 0, pointer)

if pointer {
buf.WriteString("return &cp\n}")
Expand Down Expand Up @@ -317,7 +317,7 @@ func exprFilter(t types.Type, sel string, x string) object {
return m
}

func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[string]string, skips skips, generating []object, depth int) {
func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[string]string, skips skips, generating []object, depth int, isPtrRecv bool) {
initial := depth == 0
if m == nil {
return
Expand All @@ -331,7 +331,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str
}
}

if v, ok := m.(methoder); ok && !initial && reuseDeepCopy(source, sink, v, false, generating, w) {
if v, ok := m.(methoder); ok && !initial && reuseDeepCopy(source, sink, v, false, generating, w, isPtrRecv) {
return
}

Expand All @@ -350,7 +350,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str
if _, ok := skips[sel]; ok {
continue
}
walkType(source+"."+fname, sink+"."+fname, x, field.Type(), w, imports, skips, generating, depth)
walkType(source+"."+fname, sink+"."+fname, x, field.Type(), w, imports, skips, generating, depth, isPtrRecv)
}
case *types.Slice:
kind := getElemType(v.Elem(), x, imports)
Expand Down Expand Up @@ -383,7 +383,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str

if !skipSlice {
baseSel := "[" + idx + "]"
walkType(source+baseSel, sink+baseSel, x, v.Elem(), &b, imports, skips, generating, depth)
walkType(source+baseSel, sink+baseSel, x, v.Elem(), &b, imports, skips, generating, depth, isPtrRecv)
}

if b.Len() > 0 {
Expand All @@ -399,14 +399,14 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str
case *types.Pointer:
fmt.Fprintf(w, "if %s != nil {\n", source)

if e, ok := v.Elem().(methoder); !ok || initial || !reuseDeepCopy(source, sink, e, true, generating, w) {
if e, ok := v.Elem().(methoder); !ok || initial || !reuseDeepCopy(source, sink, e, true, generating, w, isPtrRecv) {
kind := getElemType(v.Elem(), x, imports)

fmt.Fprintf(w, `%s = new(%s)
*%s = *%s
`, sink, kind, sink, source)

walkType(source, sink, x, v.Elem(), w, imports, skips, generating, depth)
walkType(source, sink, x, v.Elem(), w, imports, skips, generating, depth, isPtrRecv)
}

fmt.Fprintf(w, "}\n")
Expand Down Expand Up @@ -451,7 +451,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str

if !skipKey {
copyKSink := selToIdent(sink) + "_" + key
walkType(key, copyKSink, x, v.Key(), &b, imports, skips, generating, depth)
walkType(key, copyKSink, x, v.Key(), &b, imports, skips, generating, depth, isPtrRecv)

if b.Len() > 0 {
ksink = copyKSink
Expand All @@ -464,7 +464,7 @@ func walkType(source, sink, x string, m types.Type, w io.Writer, imports map[str

if !skipValue {
copyVSink := selToIdent(sink) + "_" + val
walkType(val, copyVSink, x, v.Elem(), &b, imports, skips, generating, depth)
walkType(val, copyVSink, x, v.Elem(), &b, imports, skips, generating, depth, isPtrRecv)

if b.Len() > 0 {
vsink = copyVSink
Expand Down Expand Up @@ -496,10 +496,10 @@ func getElemType(t types.Type, x string, imports map[string]string) string {
return kind
}

func hasDeepCopy(v methoder, generating []object, pointer bool) (hasMethod, isPointer bool) {
func hasDeepCopy(v methoder, generating []object, isPtrRecv bool) (hasMethod, isPointer bool) {
for _, t := range generating {
if types.Identical(v, t) {
return true, pointer
return true, isPtrRecv
}
}

Expand Down Expand Up @@ -532,8 +532,8 @@ func hasDeepCopy(v methoder, generating []object, pointer bool) (hasMethod, isPo
return false, false
}

func reuseDeepCopy(source, sink string, v methoder, pointer bool, generating []object, w io.Writer) bool {
hasMethod, isPointer := hasDeepCopy(v, generating, pointer)
func reuseDeepCopy(source, sink string, v methoder, pointer bool, generating []object, w io.Writer, isPtrRecv bool) bool {
hasMethod, isPointer := hasDeepCopy(v, generating, isPtrRecv)

if hasMethod {
if pointer == isPointer {
Expand Down
80 changes: 80 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ func Test_run(t *testing.T) {
{name: "issue 10, slice with element that contains pointer and value", types: typesVal{"StructCH"}, path: "./testdata", want: []byte(Issue10StructCH)},
{name: "issue 12, nested slices", types: typesVal{"I12NestedSlices"}, path: "./testdata", want: []byte(Issue12NestedSlices)},
{name: "issue 12, map with slice value", types: typesVal{"I12StructWithMapOfSlices"}, path: "./testdata", want: []byte(Issue12MapWithSliceValues)},
{name: "issue 15, parent has child value, value receiver", types: typesVal{"ParentHasChildValue", "Child"}, path: "./testdata", want: []byte(I15ParentHasChildValueValueRecv)},
{name: "issue 15, parent has child pointer, value receiver", types: typesVal{"ParentHasChildPointer", "Child"}, path: "./testdata", want: []byte(I15ParentHasChildPointerValueRecv)},
{name: "issue 15, parent has child value, pointer receiver", pointer: true, types: typesVal{"ParentHasChildValue", "Child"}, path: "./testdata", want: []byte(I15ParentHasChildValuePointerRecv)},
{name: "issue 15, parent has child pointer, pointer receiver", pointer: true, types: typesVal{"ParentHasChildPointer", "Child"}, path: "./testdata", want: []byte(I15ParentHasChildPointerPointerRecv)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -422,4 +426,80 @@ func (o I12StructWithMapOfSlices) DeepCopy() I12StructWithMapOfSlices {
}
return cp
}`

I15ParentHasChildValueValueRecv = `// generated by deep-copy; DO NOT EDIT.
package testdata
// DeepCopy generates a deep copy of ParentHasChildValue
func (o ParentHasChildValue) DeepCopy() ParentHasChildValue {
var cp ParentHasChildValue = o
cp.c = o.c.DeepCopy()
return cp
}
// DeepCopy generates a deep copy of Child
func (o Child) DeepCopy() Child {
var cp Child = o
return cp
}`

I15ParentHasChildPointerValueRecv = `// generated by deep-copy; DO NOT EDIT.
package testdata
// DeepCopy generates a deep copy of ParentHasChildPointer
func (o ParentHasChildPointer) DeepCopy() ParentHasChildPointer {
var cp ParentHasChildPointer = o
if o.c != nil {
retV := o.c.DeepCopy()
cp.c = &retV
}
return cp
}
// DeepCopy generates a deep copy of Child
func (o Child) DeepCopy() Child {
var cp Child = o
return cp
}`

I15ParentHasChildValuePointerRecv = `// generated by deep-copy; DO NOT EDIT.
package testdata
// DeepCopy generates a deep copy of *ParentHasChildValue
func (o *ParentHasChildValue) DeepCopy() *ParentHasChildValue {
var cp ParentHasChildValue = *o
{
retV := o.c.DeepCopy()
cp.c = *retV
}
return &cp
}
// DeepCopy generates a deep copy of *Child
func (o *Child) DeepCopy() *Child {
var cp Child = *o
return &cp
}`

I15ParentHasChildPointerPointerRecv = `// generated by deep-copy; DO NOT EDIT.
package testdata
// DeepCopy generates a deep copy of *ParentHasChildPointer
func (o *ParentHasChildPointer) DeepCopy() *ParentHasChildPointer {
var cp ParentHasChildPointer = *o
if o.c != nil {
cp.c = o.c.DeepCopy()
}
return &cp
}
// DeepCopy generates a deep copy of *Child
func (o *Child) DeepCopy() *Child {
var cp Child = *o
return &cp
}`
)
13 changes: 13 additions & 0 deletions testdata/issue_15_pointer_receiver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package testdata

type ParentHasChildValue struct {
c Child
}

type ParentHasChildPointer struct {
c *Child
}

type Child struct {
s string
}

0 comments on commit 02bcfef

Please sign in to comment.