Skip to content

Commit

Permalink
reflect: permit DeepCopyOption
Browse files Browse the repository at this point in the history
  • Loading branch information
changkun committed Mar 28, 2022
1 parent 1f2c918 commit 0b7e31d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 29 deletions.
4 changes: 3 additions & 1 deletion README.md
Expand Up @@ -40,7 +40,9 @@ Package reflect implements the proposal [go.dev/issue/51520](https://go.dev/issu
// memory representations of the source value but may result in unexpected
// consequences in follow-up usage, the caller should clear these values
// depending on their usage context.
func DeepCopy[T any](src T) (dst T)
//
// To change these predefined behaviors, use provided DeepCopyOption.
func DeepCopy[T any](src T, opts ...DeepCopyOption) (dst T)
```

_Warning_: Not largely tested. Use it with care.
Expand Down
129 changes: 101 additions & 28 deletions deepcopy.go
Expand Up @@ -14,10 +14,52 @@ import (
"reflect"
"strings"
"unsafe"

_ "unsafe" // for go:linkname
)

// DeepCopyOption represents an option to customize deep copied results.
type DeepCopyOption func(opt *copyConfig)

type copyConfig struct {
disallowCopyUnexported bool
disallowCopyCircular bool
disallowCopyBidirectionalChan bool
disallowCopyTypes []reflect.Type
}

// DisallowCopyUnexported returns a DeepCopyOption that disables the behavior
// of copying unexported fields.
func DisallowCopyUnexported() DeepCopyOption {
return func(opt *copyConfig) {
opt.disallowCopyUnexported = true
}
}

// DisallowCopyCircular returns a DeepCopyOption that disables the behavior
// of copying circular structures.
func DisallowCopyCircular() DeepCopyOption {
return func(opt *copyConfig) {
opt.disallowCopyCircular = true
}
}

// DisallowCopyBidirectionalChan returns a DeepCopyOption that disables
// the behavior of producing new channel when a bidirectional channel is copied.
func DisallowCopyBidirectionalChan() DeepCopyOption {
return func(opt *copyConfig) {
opt.disallowCopyBidirectionalChan = true
}
}

// DisallowTypes returns a DeepCopyOption that disallows copying any types
// that are in given values.
func DisallowTypes(val ...any) DeepCopyOption {
return func(opt *copyConfig) {
for i := range val {
opt.disallowCopyTypes = append(opt.disallowCopyTypes, reflect.TypeOf(val[i]))
}
}
}

// DeepCopy copies src to dst recursively.
//
// Two values of identical type are deeply copied if one of the following
Expand Down Expand Up @@ -55,17 +97,33 @@ import (
// memory representations of the source value but may result in unexpected
// consequences in follow-up usage, the caller should clear these values
// depending on their usage context.
func DeepCopy[T any](src T) (dst T) {
//
// To change these predefined behaviors, use provided DeepCopyOption.
func DeepCopy[T any](src T, opts ...DeepCopyOption) (dst T) {
ptrs := map[uintptr]any{}
ret := copyAny(src, ptrs)
conf := &copyConfig{}
for _, opt := range opts {
opt(conf)
}

ret := copyAny(src, ptrs, conf)
if v, ok := ret.(T); ok {
dst = v
return
}
panic(fmt.Sprintf("reflect: internal error: copied value is not typed in %T, got %T", src, ret))
}

func copyAny(src any, ptrs map[uintptr]any) (dst any) {
func copyAny(src any, ptrs map[uintptr]any, copyConf *copyConfig) (dst any) {

if len(copyConf.disallowCopyTypes) != 0 {
for i := range copyConf.disallowCopyTypes {
if reflect.TypeOf(src) == copyConf.disallowCopyTypes[i] {
panic(fmt.Sprintf("reflect: deep copying type %T is disallowed", src))
}
}
}

v := reflect.ValueOf(src)
if !v.IsValid() {
return src
Expand All @@ -77,30 +135,30 @@ func copyAny(src any, ptrs map[uintptr]any) (dst any) {
reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64,
reflect.Complex64, reflect.Complex128, reflect.Func:
dst = copyPremitive(src, ptrs)
dst = copyPremitive(src, ptrs, copyConf)
case reflect.String:
dst = strings.Clone(src.(string))
case reflect.Slice:
dst = copySlice(src, ptrs)
dst = copySlice(src, ptrs, copyConf)
case reflect.Array:
dst = copyArray(src, ptrs)
dst = copyArray(src, ptrs, copyConf)
case reflect.Map:
dst = copyMap(src, ptrs)
dst = copyMap(src, ptrs, copyConf)
case reflect.Ptr, reflect.UnsafePointer:
dst = copyPointer(src, ptrs)
dst = copyPointer(src, ptrs, copyConf)
case reflect.Struct:
dst = copyStruct(src, ptrs)
dst = copyStruct(src, ptrs, copyConf)
case reflect.Interface:
dst = copyAny(src, ptrs)
dst = copyAny(src, ptrs, copyConf)
case reflect.Chan:
dst = copyChan(src, ptrs)
dst = copyChan(src, ptrs, copyConf)
default:
panic(fmt.Sprintf("reflect: internal error: unknown type %v", v.Kind()))
}
return
}

func copyPremitive(src any, ptr map[uintptr]any) (dst any) {
func copyPremitive(src any, ptr map[uintptr]any, copyConf *copyConfig) (dst any) {
kind := reflect.ValueOf(src).Kind()
switch kind {
case reflect.Array, reflect.Chan, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.Struct, reflect.UnsafePointer:
Expand All @@ -110,7 +168,7 @@ func copyPremitive(src any, ptr map[uintptr]any) (dst any) {
return
}

func copySlice(x any, ptrs map[uintptr]any) any {
func copySlice(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
v := reflect.ValueOf(x)
kind := v.Kind()
if kind != reflect.Slice {
Expand All @@ -121,15 +179,15 @@ func copySlice(x any, ptrs map[uintptr]any) any {
t := reflect.TypeOf(x)
dc := reflect.MakeSlice(t, size, size)
for i := 0; i < size; i++ {
iv := reflect.ValueOf(copyAny(v.Index(i).Interface(), ptrs))
iv := reflect.ValueOf(copyAny(v.Index(i).Interface(), ptrs, copyConf))
if iv.IsValid() {
dc.Index(i).Set(iv)
}
}
return dc.Interface()
}

func copyArray(x any, ptrs map[uintptr]any) any {
func copyArray(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
v := reflect.ValueOf(x)
if v.Kind() != reflect.Array {
panic(fmt.Errorf("reflect: internal error: must be an Array; got %v", v.Kind()))
Expand All @@ -138,13 +196,13 @@ func copyArray(x any, ptrs map[uintptr]any) any {
size := t.Len()
dc := reflect.New(reflect.ArrayOf(size, t.Elem())).Elem()
for i := 0; i < size; i++ {
item := copyAny(v.Index(i).Interface(), ptrs)
item := copyAny(v.Index(i).Interface(), ptrs, copyConf)
dc.Index(i).Set(reflect.ValueOf(item))
}
return dc.Interface()
}

func copyMap(x any, ptrs map[uintptr]any) any {
func copyMap(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
v := reflect.ValueOf(x)
if v.Kind() != reflect.Map {
panic(fmt.Errorf("reflect: internal error: must be a Map; got %v", v.Kind()))
Expand All @@ -153,27 +211,30 @@ func copyMap(x any, ptrs map[uintptr]any) any {
dc := reflect.MakeMapWithSize(t, v.Len())
iter := v.MapRange()
for iter.Next() {
item := copyAny(iter.Value().Interface(), ptrs)
k := copyAny(iter.Key().Interface(), ptrs)
item := copyAny(iter.Value().Interface(), ptrs, copyConf)
k := copyAny(iter.Key().Interface(), ptrs, copyConf)
dc.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(item))
}
return dc.Interface()
}

func copyPointer(x any, ptrs map[uintptr]any) any {
func copyPointer(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
v := reflect.ValueOf(x)
if v.Kind() != reflect.Pointer {
panic(fmt.Errorf("reflect: internal error: must be a Pointer or Ptr; got %v", v.Kind()))
}
addr := uintptr(v.UnsafePointer())
if dc, ok := ptrs[addr]; ok {
if copyConf.disallowCopyCircular {
panic("reflect: deep copy dircular value is disallowed")
}
return dc
}
t := reflect.TypeOf(x)
dc := reflect.New(t.Elem())
ptrs[addr] = dc.Interface()
if !v.IsNil() {
item := copyAny(v.Elem().Interface(), ptrs)
item := copyAny(v.Elem().Interface(), ptrs, copyConf)
iv := reflect.ValueOf(item)
if iv.IsValid() {
dc.Elem().Set(reflect.ValueOf(item))
Expand All @@ -182,21 +243,30 @@ func copyPointer(x any, ptrs map[uintptr]any) any {
return dc.Interface()
}

func copyStruct(x any, ptrs map[uintptr]any) any {
func copyStruct(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
v := reflect.ValueOf(x)
if v.Kind() != reflect.Struct {
panic(fmt.Errorf("reflect: internal error: must be a Struct; got %v", v.Kind()))
}
t := reflect.TypeOf(x)
dc := reflect.New(t)
for i := 0; i < t.NumField(); i++ {
item := copyAny(valueInterfaceUnsafe(v.Field(i)), ptrs)
setField(dc.Elem().Field(i), reflect.ValueOf(item))
if copyConf.disallowCopyUnexported {
f := t.Field(i)
if f.PkgPath != "" {
continue
}
item := copyAny(v.Field(i).Interface(), ptrs, copyConf)
dc.Elem().Field(i).Set(reflect.ValueOf(item))
} else {
item := copyAny(valueInterfaceUnsafe(v.Field(i)), ptrs, copyConf)
setField(dc.Elem().Field(i), reflect.ValueOf(item))
}
}
return dc.Elem().Interface()
}

func copyChan(x any, ptrs map[uintptr]any) any {
func copyChan(x any, ptrs map[uintptr]any, copyConf *copyConfig) any {
v := reflect.ValueOf(x)
if v.Kind() != reflect.Chan {
panic(fmt.Errorf("reflect: internal error: must be a Chan; got %v", v.Kind()))
Expand All @@ -206,7 +276,10 @@ func copyChan(x any, ptrs map[uintptr]any) any {
var dc any
switch dir {
case reflect.BothDir:
dc = reflect.MakeChan(t, v.Cap()).Interface()
if !copyConf.disallowCopyBidirectionalChan {
dc = reflect.MakeChan(t, v.Cap()).Interface()
}
fallthrough
case reflect.SendDir, reflect.RecvDir:
dc = x
}
Expand Down

0 comments on commit 0b7e31d

Please sign in to comment.