diff --git a/go.mod b/go.mod index 2689d0e3091..003acc90a4e 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/google/go-cmp v0.5.9 github.com/gorilla/mux v1.8.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/huandu/go-clone v1.4.0 github.com/intel-go/cpuid v0.0.0-20220614022739-219e067757cb github.com/lima-vm/sshocker v0.3.0 github.com/mattn/go-isatty v0.0.16 diff --git a/go.sum b/go.sum index 28510db5e2d..e32e26f2e03 100644 --- a/go.sum +++ b/go.sum @@ -115,6 +115,10 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= +github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c= +github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= +github.com/huandu/go-clone v1.4.0 h1:NlnghW4lsmMoz+3N4yb4Ouff86ArRYPo/1aCsqQKKF4= +github.com/huandu/go-clone v1.4.0/go.mod h1:ReGivhG6op3GYr+UY3lS6mxjKp7MIGTknuU5TbTVaXE= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/intel-go/cpuid v0.0.0-20220614022739-219e067757cb h1:Fg0Y/RDZ6UPwl3o7/IzPbneDq8g9+gH6DPs42KFUsy8= diff --git a/pkg/limayaml/defaults.go b/pkg/limayaml/defaults.go index bfb8054a8e6..77b95c4d5c4 100644 --- a/pkg/limayaml/defaults.go +++ b/pkg/limayaml/defaults.go @@ -14,6 +14,7 @@ import ( "github.com/lima-vm/lima/pkg/guestagent/api" "github.com/lima-vm/lima/pkg/osutil" + "github.com/lima-vm/lima/pkg/reflectutil" "github.com/lima-vm/lima/pkg/store/dirnames" "github.com/lima-vm/lima/pkg/store/filenames" "github.com/sirupsen/logrus" @@ -61,6 +62,54 @@ func MACAddress(uniqueID string) string { return hw.String() } +// builtinDefault defines the built-in default values. +var builtinDefault = &LimaYAML{ + Arch: nil, // Resolved in FillDefault() + Images: nil, + CPUType: defaultCPUType(), + CPUs: pointer.Int(4), + Memory: pointer.String("4GiB"), + Disk: pointer.String("100GiB"), + Mounts: nil, + MountType: pointer.String(REVSSHFS), + Video: Video{ + Display: pointer.String("none"), + }, + Firmware: Firmware{ + LegacyBIOS: pointer.Bool(false), + }, + SSH: SSH{ + LocalPort: pointer.Int(0), // Resolved by the hostagent + LoadDotSSHPubKeys: pointer.Bool(true), + ForwardAgent: pointer.Bool(false), + ForwardX11: pointer.Bool(false), + ForwardX11Trusted: pointer.Bool(false), + }, + Provision: nil, + Containerd: Containerd{ + System: pointer.Bool(false), + User: pointer.Bool(true), + Archives: defaultContainerdArchives(), + }, + Probes: nil, + PortForwards: nil, + Message: "", + Networks: nil, + Env: nil, + DNS: nil, + HostResolver: HostResolver{ + Enabled: pointer.Bool(true), + IPv6: pointer.Bool(false), + Hosts: nil, + }, + PropagateProxyEnv: pointer.Bool(true), + CACertificates: CACertificates{ + RemoveDefaults: pointer.Bool(false), + Files: nil, + Certs: nil, + }, +} + // FillDefault updates undefined fields in y with defaults from d (or built-in default), and overwrites with values from o. // Both d and o may be empty. // @@ -76,318 +125,24 @@ func MACAddress(uniqueID string) string { // - DNS are picked from the highest priority where DNS is not empty. // - CACertificates Files and Certs are uniquely appended in d, y, o order func FillDefault(y, d, o *LimaYAML, filePath string) { - if y.Arch == nil { - y.Arch = d.Arch - } - if o.Arch != nil { - y.Arch = o.Arch - } - y.Arch = pointer.String(ResolveArch(y.Arch)) - - y.Images = append(append(o.Images, y.Images...), d.Images...) - for i := range y.Images { - img := &y.Images[i] - if img.Arch == "" { - img.Arch = *y.Arch - } - if img.Kernel != nil && img.Kernel.Arch == "" { - img.Kernel.Arch = img.Arch - } - if img.Initrd != nil && img.Initrd.Arch == "" { - img.Initrd.Arch = img.Arch - } - } - - cpuType := map[Arch]string{ - AARCH64: "cortex-a72", - // Since https://github.com/lima-vm/lima/pull/494, we use qemu64 cpu for better emulation of x86_64. - X8664: "qemu64", - RISCV64: "rv64", // FIXME: what is the right choice for riscv64? - } - for arch := range cpuType { - if IsNativeArch(arch) && IsAccelOS() { - if HasHostCPU() { - cpuType[arch] = "host" - } else if HasMaxCPU() { - cpuType[arch] = "max" - } - } - } - for k, v := range d.CPUType { - if len(v) > 0 { - cpuType[k] = v - } - } - for k, v := range y.CPUType { - if len(v) > 0 { - cpuType[k] = v - } - } - for k, v := range o.CPUType { - if len(v) > 0 { - cpuType[k] = v - } - } - y.CPUType = cpuType - - if y.CPUs == nil { - y.CPUs = d.CPUs - } - if o.CPUs != nil { - y.CPUs = o.CPUs - } - if y.CPUs == nil || *y.CPUs == 0 { - y.CPUs = pointer.Int(4) - } - - if y.Memory == nil { - y.Memory = d.Memory - } - if o.Memory != nil { - y.Memory = o.Memory - } - if y.Memory == nil || *y.Memory == "" { - y.Memory = pointer.String("4GiB") - } - - if y.Disk == nil { - y.Disk = d.Disk - } - if o.Disk != nil { - y.Disk = o.Disk - } - if y.Disk == nil || *y.Disk == "" { - y.Disk = pointer.String("100GiB") - } - - if y.Video.Display == nil { - y.Video.Display = d.Video.Display - } - if o.Video.Display != nil { - y.Video.Display = o.Video.Display - } - if y.Video.Display == nil || *y.Video.Display == "" { - y.Video.Display = pointer.String("none") - } - - if y.Firmware.LegacyBIOS == nil { - y.Firmware.LegacyBIOS = d.Firmware.LegacyBIOS - } - if o.Firmware.LegacyBIOS != nil { - y.Firmware.LegacyBIOS = o.Firmware.LegacyBIOS - } - if y.Firmware.LegacyBIOS == nil { - y.Firmware.LegacyBIOS = pointer.Bool(false) - } - - if y.SSH.LocalPort == nil { - y.SSH.LocalPort = d.SSH.LocalPort - } - if o.SSH.LocalPort != nil { - y.SSH.LocalPort = o.SSH.LocalPort - } - if y.SSH.LocalPort == nil { - // y.SSH.LocalPort value is not filled here (filled by the hostagent) - y.SSH.LocalPort = pointer.Int(0) - } - if y.SSH.LoadDotSSHPubKeys == nil { - y.SSH.LoadDotSSHPubKeys = d.SSH.LoadDotSSHPubKeys - } - if o.SSH.LoadDotSSHPubKeys != nil { - y.SSH.LoadDotSSHPubKeys = o.SSH.LoadDotSSHPubKeys - } - if y.SSH.LoadDotSSHPubKeys == nil { - y.SSH.LoadDotSSHPubKeys = pointer.Bool(true) - } - - if y.SSH.ForwardAgent == nil { - y.SSH.ForwardAgent = d.SSH.ForwardAgent - } - if o.SSH.ForwardAgent != nil { - y.SSH.ForwardAgent = o.SSH.ForwardAgent - } - if y.SSH.ForwardAgent == nil { - y.SSH.ForwardAgent = pointer.Bool(false) - } - - if y.SSH.ForwardX11 == nil { - y.SSH.ForwardX11 = d.SSH.ForwardX11 - } - if o.SSH.ForwardX11 != nil { - y.SSH.ForwardX11 = o.SSH.ForwardX11 - } - if y.SSH.ForwardX11 == nil { - y.SSH.ForwardX11 = pointer.Bool(false) - } - - if y.SSH.ForwardX11Trusted == nil { - y.SSH.ForwardX11Trusted = d.SSH.ForwardX11Trusted - } - if o.SSH.ForwardX11Trusted != nil { - y.SSH.ForwardX11Trusted = o.SSH.ForwardX11Trusted - } - if y.SSH.ForwardX11Trusted == nil { - y.SSH.ForwardX11Trusted = pointer.Bool(false) - } - - hosts := make(map[string]string) - // Values can be either names or IP addresses. Name values are canonicalized in the hostResolver. - for k, v := range d.HostResolver.Hosts { - hosts[Cname(k)] = v - } - for k, v := range y.HostResolver.Hosts { - hosts[Cname(k)] = v - } - for k, v := range o.HostResolver.Hosts { - hosts[Cname(k)] = v - } - y.HostResolver.Hosts = hosts - - y.Provision = append(append(o.Provision, y.Provision...), d.Provision...) - for i := range y.Provision { - provision := &y.Provision[i] - if provision.Mode == "" { - provision.Mode = ProvisionModeSystem - } - } - - if y.Containerd.System == nil { - y.Containerd.System = d.Containerd.System - } - if o.Containerd.System != nil { - y.Containerd.System = o.Containerd.System - } - if y.Containerd.System == nil { - y.Containerd.System = pointer.Bool(false) - } - if y.Containerd.User == nil { - y.Containerd.User = d.Containerd.User - } - if o.Containerd.User != nil { - y.Containerd.User = o.Containerd.User - } - if y.Containerd.User == nil { - y.Containerd.User = pointer.Bool(true) - } - - y.Containerd.Archives = append(append(o.Containerd.Archives, y.Containerd.Archives...), d.Containerd.Archives...) - if len(y.Containerd.Archives) == 0 { - y.Containerd.Archives = defaultContainerdArchives() - } - for i := range y.Containerd.Archives { - f := &y.Containerd.Archives[i] - if f.Arch == "" { - f.Arch = *y.Arch - } - } - - y.Probes = append(append(o.Probes, y.Probes...), d.Probes...) - for i := range y.Probes { - probe := &y.Probes[i] - if probe.Mode == "" { - probe.Mode = ProbeModeReadiness - } - if probe.Description == "" { - probe.Description = fmt.Sprintf("user probe %d/%d", i+1, len(y.Probes)) - } - } - - y.PortForwards = append(append(o.PortForwards, y.PortForwards...), d.PortForwards...) instDir := filepath.Dir(filePath) - for i := range y.PortForwards { - FillPortForwardDefaults(&y.PortForwards[i], instDir) - // After defaults processing the singular HostPort and GuestPort values should not be used again. - } + bd := builtinDefault - if y.HostResolver.Enabled == nil { - y.HostResolver.Enabled = d.HostResolver.Enabled - } - if o.HostResolver.Enabled != nil { - y.HostResolver.Enabled = o.HostResolver.Enabled - } - if y.HostResolver.Enabled == nil { - y.HostResolver.Enabled = pointer.Bool(true) + // *EXCEPTION*: Remove built-in containerd archives when the custom values are specified. + if len(d.Containerd.Archives)+len(y.Containerd.Archives)+len(o.Containerd.Archives) > 0 { + bd.Containerd.Archives = nil } - if y.HostResolver.IPv6 == nil { - y.HostResolver.IPv6 = d.HostResolver.IPv6 - } - if o.HostResolver.IPv6 != nil { - y.HostResolver.IPv6 = o.HostResolver.IPv6 - } - if y.HostResolver.IPv6 == nil { - y.HostResolver.IPv6 = pointer.Bool(false) - } - - if y.PropagateProxyEnv == nil { - y.PropagateProxyEnv = d.PropagateProxyEnv - } - if o.PropagateProxyEnv != nil { - y.PropagateProxyEnv = o.PropagateProxyEnv - } - if y.PropagateProxyEnv == nil { - y.PropagateProxyEnv = pointer.Bool(true) + // Merge bd, d, y, and o, into x. + // y is not altered yet, and is used later for exceptional rules. + xx, err := reflectutil.MergeMany(bd, d, y, o) + if err != nil { + panic(err) } + x := xx.(*LimaYAML) - networks := make([]Network, 0, len(d.Networks)+len(y.Networks)+len(o.Networks)) - iface := make(map[string]int) - for _, nw := range append(append(d.Networks, y.Networks...), o.Networks...) { - if i, ok := iface[nw.Interface]; ok { - if nw.VNLDeprecated != "" { - networks[i].VNLDeprecated = nw.VNLDeprecated - networks[i].SwitchPortDeprecated = nw.SwitchPortDeprecated - networks[i].Socket = "" - networks[i].Lima = "" - } - if nw.Socket != "" { - if nw.VNLDeprecated != "" { - // We can't return an error, so just log it, and prefer `socket` over `vnl` - logrus.Errorf("Network %q has both vnl=%q and socket=%q fields; ignoring vnl", - nw.Interface, nw.VNLDeprecated, nw.Socket) - } - networks[i].Socket = nw.Socket - networks[i].VNLDeprecated = "" - networks[i].SwitchPortDeprecated = 0 - networks[i].Lima = "" - } - if nw.Lima != "" { - if nw.VNLDeprecated != "" { - // We can't return an error, so just log it, and prefer `lima` over `vnl` - logrus.Errorf("Network %q has both vnl=%q and lima=%q fields; ignoring vnl", - nw.Interface, nw.VNLDeprecated, nw.Lima) - } - if nw.Socket != "" { - // We can't return an error, so just log it, and prefer `lima` over `socket` - logrus.Errorf("Network %q has both socket=%q and lima=%q fields; ignoring socket", - nw.Interface, nw.Socket, nw.Lima) - } - networks[i].Lima = nw.Lima - networks[i].Socket = "" - networks[i].VNLDeprecated = "" - networks[i].SwitchPortDeprecated = 0 - } - if nw.MACAddress != "" { - networks[i].MACAddress = nw.MACAddress - } - } else { - // unnamed network definitions are not combined/overwritten - if nw.Interface != "" { - iface[nw.Interface] = len(networks) - } - networks = append(networks, nw) - } - } - y.Networks = networks - for i := range y.Networks { - nw := &y.Networks[i] - if nw.MACAddress == "" { - // every interface in every limayaml file must get its own unique MAC address - nw.MACAddress = MACAddress(fmt.Sprintf("%s#%d", filePath, i)) - } - if nw.Interface == "" { - nw.Interface = "lima" + strconv.Itoa(i) - } - } + // *EXCEPTION*: Mounts are appended in d, y, o order, but "merged" when the Location matches a previous entry; + // the highest priority Writable setting wins. // Combine all mounts; highest priority entry determines writable status. // Only works for exact matches; does not normalize case or resolve symlinks. @@ -427,10 +182,10 @@ func FillDefault(y, d, o *LimaYAML, filePath string) { mounts = append(mounts, mount) } } - y.Mounts = mounts + x.Mounts = mounts - for i := range y.Mounts { - mount := &y.Mounts[i] + for i := range x.Mounts { + mount := &x.Mounts[i] if mount.SSHFS.Cache == nil { mount.SSHFS.Cache = pointer.Bool(true) } @@ -464,51 +219,143 @@ func FillDefault(y, d, o *LimaYAML, filePath string) { } } - if y.MountType == nil { - y.MountType = d.MountType - } - if o.MountType != nil { - y.MountType = o.MountType + // *EXCEPTION*: Networks are appended in d, y, o order + networks := make([]Network, 0, len(d.Networks)+len(y.Networks)+len(o.Networks)) + iface := make(map[string]int) + for _, nw := range append(append(d.Networks, y.Networks...), o.Networks...) { + if i, ok := iface[nw.Interface]; ok { + if nw.VNLDeprecated != "" { + networks[i].VNLDeprecated = nw.VNLDeprecated + networks[i].SwitchPortDeprecated = nw.SwitchPortDeprecated + networks[i].Socket = "" + networks[i].Lima = "" + } + if nw.Socket != "" { + if nw.VNLDeprecated != "" { + // We can't return an error, so just log it, and prefer `socket` over `vnl` + logrus.Errorf("Network %q has both vnl=%q and socket=%q fields; ignoring vnl", + nw.Interface, nw.VNLDeprecated, nw.Socket) + } + networks[i].Socket = nw.Socket + networks[i].VNLDeprecated = "" + networks[i].SwitchPortDeprecated = 0 + networks[i].Lima = "" + } + if nw.Lima != "" { + if nw.VNLDeprecated != "" { + // We can't return an error, so just log it, and prefer `lima` over `vnl` + logrus.Errorf("Network %q has both vnl=%q and lima=%q fields; ignoring vnl", + nw.Interface, nw.VNLDeprecated, nw.Lima) + } + if nw.Socket != "" { + // We can't return an error, so just log it, and prefer `lima` over `socket` + logrus.Errorf("Network %q has both socket=%q and lima=%q fields; ignoring socket", + nw.Interface, nw.Socket, nw.Lima) + } + networks[i].Lima = nw.Lima + networks[i].Socket = "" + networks[i].VNLDeprecated = "" + networks[i].SwitchPortDeprecated = 0 + } + if nw.MACAddress != "" { + networks[i].MACAddress = nw.MACAddress + } + } else { + // unnamed network definitions are not combined/overwritten + if nw.Interface != "" { + iface[nw.Interface] = len(networks) + } + networks = append(networks, nw) + } } - if y.MountType == nil || *y.MountType == "" { - y.MountType = pointer.String(REVSSHFS) + x.Networks = networks + for i := range x.Networks { + nw := &x.Networks[i] + if nw.MACAddress == "" { + // every interface in every limayaml file must get its own unique MAC address + nw.MACAddress = MACAddress(fmt.Sprintf("%s#%d", filePath, i)) + } + if nw.Interface == "" { + nw.Interface = "lima" + strconv.Itoa(i) + } } + // *EXCEPTION*: DNS are picked from the highest priority where DNS is not empty. // Note: DNS lists are not combined; highest priority setting is picked - if len(y.DNS) == 0 { - y.DNS = d.DNS + dns := y.DNS + if len(dns) == 0 { + dns = d.DNS } if len(o.DNS) > 0 { - y.DNS = o.DNS + dns = o.DNS } + x.DNS = dns - env := make(map[string]string) - for k, v := range d.Env { - env[k] = v - } - for k, v := range y.Env { - env[k] = v + // *EXCEPTION*: CACertificates Files and Certs are uniquely appended in d, y, o order + x.CACertificates.Files = unique(append(append(d.CACertificates.Files, y.CACertificates.Files...), o.CACertificates.Files...)) + x.CACertificates.Certs = unique(append(append(d.CACertificates.Certs, y.CACertificates.Certs...), o.CACertificates.Certs...)) + + // Fix up other fields + fixUp(x, instDir) + + // Return the result x as y + *y = *x +} + +func fixUp(x *LimaYAML, instDir string) { + // Resolve the default arch + x.Arch = pointer.String(ResolveArch(x.Arch)) + for i := range x.Images { + img := &x.Images[i] + if img.Arch == "" { + img.Arch = *x.Arch + } + if img.Kernel != nil && img.Kernel.Arch == "" { + img.Kernel.Arch = img.Arch + } + if img.Initrd != nil && img.Initrd.Arch == "" { + img.Initrd.Arch = img.Arch + } } - for k, v := range o.Env { - env[k] = v + for i := range x.Containerd.Archives { + f := &x.Containerd.Archives[i] + if f.Arch == "" { + f.Arch = *x.Arch + } } - y.Env = env - if y.CACertificates.RemoveDefaults == nil { - y.CACertificates.RemoveDefaults = d.CACertificates.RemoveDefaults - } - if o.CACertificates.RemoveDefaults != nil { - y.CACertificates.RemoveDefaults = o.CACertificates.RemoveDefaults + // Resolve the default provision mode + for i := range x.Provision { + provision := &x.Provision[i] + if provision.Mode == "" { + provision.Mode = ProvisionModeSystem + } } - if y.CACertificates.RemoveDefaults == nil { - y.CACertificates.RemoveDefaults = pointer.Bool(false) + + // Resolve the default probe mode + for i := range x.Probes { + probe := &x.Probes[i] + if probe.Mode == "" { + probe.Mode = ProbeModeReadiness + } + if probe.Description == "" { + probe.Description = fmt.Sprintf("user probe %d/%d", i+1, len(x.Probes)) + } } - caFiles := unique(append(append(d.CACertificates.Files, y.CACertificates.Files...), o.CACertificates.Files...)) - y.CACertificates.Files = caFiles + // Fill port forward defaults + for i := range x.PortForwards { + FillPortForwardDefaults(&x.PortForwards[i], instDir) + // After defaults processing the singular HostPort and GuestPort values should not be used again. + } - caCerts := unique(append(append(d.CACertificates.Certs, y.CACertificates.Certs...), o.CACertificates.Certs...)) - y.CACertificates.Certs = caCerts + // Fix up the host resolver. + // Values can be either names or IP addresses. Name values are canonicalized in the hostResolver. + hosts := make(map[string]string) + for k, v := range x.HostResolver.Hosts { + hosts[Cname(k)] = v + } + x.HostResolver.Hosts = hosts } func FillPortForwardDefaults(rule *PortForward, instDir string) { @@ -661,3 +508,22 @@ func unique(s []string) []string { } return list } + +func defaultCPUType() map[Arch]string { + cpuType := map[Arch]string{ + AARCH64: "cortex-a72", + // Since https://github.com/lima-vm/lima/pull/494, we use qemu64 cpu for better emulation of x86_64. + X8664: "qemu64", + RISCV64: "rv64", + } + for arch := range cpuType { + if IsNativeArch(arch) && IsAccelOS() { + if HasHostCPU() { + cpuType[arch] = "host" + } else if HasMaxCPU() { + cpuType[arch] = "max" + } + } + } + return cpuType +} diff --git a/pkg/reflectutil/reflectutil.go b/pkg/reflectutil/reflectutil.go new file mode 100644 index 00000000000..f72bfedbb52 --- /dev/null +++ b/pkg/reflectutil/reflectutil.go @@ -0,0 +1,99 @@ +package reflectutil + +import ( + "fmt" + "net" + "reflect" + + "github.com/huandu/go-clone" +) + +// NonAppendableSliceTypes are non-appendable slice types, such as net.IP . +var NonAppendableSliceTypes = map[reflect.Type]struct{}{ + reflect.TypeOf(net.IP{}): struct{}{}, // []byte +} + +// MergeMany merges vv and returns the new value. +// MergeMany does not alter vv. +// +// Maps are merged in "vv[0], vv[1], .., vv[N-1]" order. +// Slices are appended in the "vv[N-1], .., vv[1], vv[0]" order. +func MergeMany(vv ...interface{}) (interface{}, error) { + if l := len(vv); l < 2 { + return nil, fmt.Errorf("expected len(vv) >= 2, got %d", l) + } + x := vv[0] + for _, v := range vv[1:] { + var err error + x, err = Merge(x, v) + if err != nil { + return x, err + } + } + return x, nil +} + +// Merge merges o (override) into d (default) and returns the new value. +// Merge does not alter o and d. +// +// Maps are merged in the "d, o" order. +// Slices are appended in the "o, d" order. +func Merge(d, o interface{}) (interface{}, error) { + if o == nil { + return d, nil + } + dVal, oVal := reflect.ValueOf(d), reflect.ValueOf(o) + if dVal.Type() != oVal.Type() { + return nil, fmt.Errorf("type mismatch: %T vs %T", d, o) + } + x := clone.Clone(d) + xVal := reflect.ValueOf(x) + merge(xVal, oVal) + return x, nil +} + +func merge(xVal, oVal reflect.Value) { + switch k := xVal.Kind(); k { + case reflect.Pointer: + if !oVal.IsNil() { + if xVal.IsNil() { + xVal.Set(cloneVal(oVal)) + } else { + merge(xVal.Elem(), oVal.Elem()) + } + } + case reflect.Struct: + numField := xVal.NumField() + for i := 0; i < numField; i++ { + merge(xVal.Field(i), oVal.Field(i)) + } + case reflect.Map: + if xVal.IsNil() { + xVal.Set(reflect.MakeMap(oVal.Type())) + } + oValIter := oVal.MapRange() + for oValIter.Next() { + xVal.SetMapIndex(cloneVal(oValIter.Key()), cloneVal(oValIter.Value())) + } + case reflect.Array: + xVal.Set(cloneVal(oVal)) + case reflect.Slice: + if _, ok := NonAppendableSliceTypes[xVal.Type()]; ok { + xVal.Set(cloneVal(oVal)) + } else { + // o comes first + xVal.Set(reflect.AppendSlice(cloneVal(oVal), xVal)) + } + case reflect.Chan, reflect.Func, reflect.Interface: + panic(fmt.Errorf("unexpected kind %+v", k)) + default: + if xVal.CanSet() { + // oVal is not a pointer(-ish), so no need to clone oVal + xVal.Set(oVal) + } + } +} + +func cloneVal(v reflect.Value) reflect.Value { + return reflect.ValueOf(clone.Clone(v.Interface())) +} diff --git a/pkg/reflectutil/reflectutil_test.go b/pkg/reflectutil/reflectutil_test.go new file mode 100644 index 00000000000..fb84acc782e --- /dev/null +++ b/pkg/reflectutil/reflectutil_test.go @@ -0,0 +1,132 @@ +package reflectutil + +import ( + "encoding/json" + "net" + "testing" + + "github.com/xorcare/pointer" + "gotest.tools/v3/assert" +) + +type Foo struct { + BoolPtr, BoolPtr2, BoolPtr3, BoolPtr4 *bool + IntPtr, IntPtr2, IntPtr3, IntPtr4 *int + StrPtr, StrPtr2 *string + StrStrMap, StrStrMap2 map[string]string + StrSlice []string + StrArray [2]string + Struct FooChild +} + +type FooChild struct { + Int, Int2 int + Str, Str2 string + IP net.IP + IPSlice []net.IP +} + +func TestMerge(t *testing.T) { + d := &Foo{ + BoolPtr: pointer.Bool(true), + BoolPtr2: pointer.Bool(false), + BoolPtr3: nil, + BoolPtr4: pointer.Bool(true), + IntPtr: pointer.Int(42), + IntPtr2: pointer.Int(200), + IntPtr3: nil, + IntPtr4: pointer.Int(400), + StrPtr: pointer.String("hello"), + StrPtr2: pointer.String("world"), + StrStrMap: map[string]string{ + "a": "apple", + "b": "banana", + }, + StrStrMap2: nil, + StrSlice: []string{"alpha", "beta"}, + StrArray: [2]string{"alabama", "alaska"}, + Struct: FooChild{ + Int: -42, + Int2: -100, + Str: "bonjour", + Str2: "le monde", + IP: net.ParseIP("192.168.10.1"), + IPSlice: []net.IP{net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2")}, + }, + } + + o := &Foo{ + BoolPtr: pointer.Bool(false), + BoolPtr2: pointer.Bool(true), + BoolPtr3: pointer.Bool(true), + BoolPtr4: nil, + IntPtr: pointer.Int(43), + IntPtr2: nil, + IntPtr3: pointer.Int(300), + IntPtr4: pointer.Int(0), + StrPtr: pointer.String("olleh"), + StrPtr2: nil, + StrStrMap: map[string]string{ + "b": "beer", + "c": "cider", + }, + StrSlice: []string{"gamma", "delta"}, + StrStrMap2: map[string]string{ + "a": "america", + "b": "brazil", + }, + StrArray: [2]string{"california", "colorado"}, + Struct: FooChild{ + Int: -43, + Str: "ruojnob", + IP: net.ParseIP("192.168.11.1"), + IPSlice: []net.IP{net.ParseIP("10.0.0.3")}, + }, + } + + expected := &Foo{ + BoolPtr: pointer.Bool(false), // overridden + BoolPtr2: pointer.Bool(true), // overridden + BoolPtr3: pointer.Bool(true), // overridden (d=nil) + BoolPtr4: pointer.Bool(true), // Not overridden (o=nil) + IntPtr: pointer.Int(43), // overridden + IntPtr2: pointer.Int(200), // Not overridden (o=nil) + IntPtr3: pointer.Int(300), // overridden (d=nil) + IntPtr4: pointer.Int(0), // overridden + StrPtr: pointer.String("olleh"), // overridden + StrPtr2: pointer.String("world"), // Not overridden (o=nil) + StrStrMap: map[string]string{ // merged (d, o) + "a": "apple", + "b": "beer", + "c": "cider", + }, + StrStrMap2: map[string]string{ // merged (d=nil, o) + "a": "america", + "b": "brazil", + }, + StrSlice: []string{"gamma", "delta", "alpha", "beta"}, // appended (o, d) + StrArray: [2]string{"california", "colorado"}, // overridden + Struct: FooChild{ + Int: -43, // overridden + Int2: 0, // overridden (o=zero) + Str: "ruojnob", // overridden + Str2: "", // overridden (o=empty) + IP: net.ParseIP("192.168.11.1"), // overridden + IPSlice: []net.IP{net.ParseIP("10.0.0.3"), net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2")}, // appended (o, d) + }, + } + + x, err := Merge(d, o) + assert.NilError(t, err) + logX(t, d, "d") + logX(t, o, "o") + logX(t, x, "x") + assert.DeepEqual(t, expected, x) +} + +func logX(t testing.TB, x interface{}, format string, args ...interface{}) { + // Print in JSON for human readability + j, err := json.Marshal(x) + assert.NilError(t, err) + t.Logf(format+": %s", append(args, string(j))...) +}