Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace internal joined error with errors.Join #25

Merged
merged 3 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions error.go

This file was deleted.

11 changes: 11 additions & 0 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package errors

import "errors"

func New(text string) error {
return errors.New(text)
}

func As(err error, target interface{}) bool {
return errors.As(err, target)
}
9 changes: 9 additions & 0 deletions internal/errors/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build go1.20

package errors

import "errors"

func Join(errs ...error) error {
return errors.Join(errs...)
}
61 changes: 61 additions & 0 deletions internal/errors/join_go1_19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//go:build !go1.20

// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package errors

// Join returns an error that wraps the given errors.
// Any nil error values are discarded.
// Join returns nil if every value in errs is nil.
// The error formats as the concatenation of the strings obtained
// by calling the Error method of each element of errs, with a newline
// between each string.
//
// A non-nil error returned by Join implements the Unwrap() []error method.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}

type joinError struct {
errs []error
}

func (e *joinError) Error() string {
// Since Join returns nil if every value in errs is nil,
// e.errs cannot be empty.
if len(e.errs) == 1 {
return e.errs[0].Error()
}

b := []byte(e.errs[0].Error())
for _, err := range e.errs[1:] {
b = append(b, '\n')
b = append(b, err.Error()...)
}
// At this point, b has at least one byte '\n'.
// return unsafe.String(&b[0], len(b))
return string(b)
}

func (e *joinError) Unwrap() []error {
return e.errs
}
66 changes: 30 additions & 36 deletions mapstructure.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,13 @@ package mapstructure

import (
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"

"github.com/go-viper/mapstructure/v2/internal/errors"
)

// DecodeHookFunc is the callback function that can be used for
Expand Down Expand Up @@ -414,7 +415,15 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) {
// Decode decodes the given raw interface to the target pointer specified
// by the configuration.
func (d *Decoder) Decode(input interface{}) error {
return d.decode("", input, reflect.ValueOf(d.config.Result).Elem())
err := d.decode("", input, reflect.ValueOf(d.config.Result).Elem())

// Retain some of the original behavior when multiple errors ocurr
var joinedErr interface{ Unwrap() []error }
if errors.As(err, &joinedErr) {
return fmt.Errorf("decoding failed due to the following error(s):\n\n%w", err)
}

return err
}

// Decodes an unknown data type into a specific reflection value.
Expand Down Expand Up @@ -881,7 +890,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
valElemType := valType.Elem()

// Accumulate errors
errors := make([]string, 0)
var errs []error

// If the input data is empty, then we just match what the input data is.
if dataVal.Len() == 0 {
Expand All @@ -903,15 +912,15 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
// First decode the key into the proper type
currentKey := reflect.Indirect(reflect.New(valKeyType))
if err := d.decode(fieldName, k.Interface(), currentKey); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
continue
}

// Next decode the data into the proper type
v := dataVal.MapIndex(k).Interface()
currentVal := reflect.Indirect(reflect.New(valElemType))
if err := d.decode(fieldName, v, currentVal); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
continue
}

Expand All @@ -921,12 +930,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle
// Set the built up map to the value
val.Set(valMap)

// If we had errors, return those
if len(errors) > 0 {
return &joinedError{errors}
}

return nil
return errors.Join(errs...)
}

func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error {
Expand Down Expand Up @@ -1164,7 +1168,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)
}

// Accumulate any errors
errors := make([]string, 0)
var errs []error

for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
Expand All @@ -1175,19 +1179,14 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value)

fieldName := name + "[" + strconv.Itoa(i) + "]"
if err := d.decode(fieldName, currentData, currentField); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}
}

// Finally, set the value to the slice we built up
val.Set(valSlice)

// If there were errors, we return those
if len(errors) > 0 {
return &joinedError{errors}
}

return nil
return errors.Join(errs...)
}

func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) error {
Expand Down Expand Up @@ -1233,27 +1232,22 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value)
}

// Accumulate any errors
errors := make([]string, 0)
var errs []error

for i := 0; i < dataVal.Len(); i++ {
currentData := dataVal.Index(i).Interface()
currentField := valArray.Index(i)

fieldName := name + "[" + strconv.Itoa(i) + "]"
if err := d.decode(fieldName, currentData, currentField); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}
}

// Finally, set the value to the array we built up
val.Set(valArray)

// If there were errors, we return those
if len(errors) > 0 {
return &joinedError{errors}
}

return nil
return errors.Join(errs...)
}

func (d *Decoder) decodeStruct(name string, data interface{}, val reflect.Value) error {
Expand Down Expand Up @@ -1315,7 +1309,8 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}

targetValKeysUnused := make(map[interface{}]struct{})
errors := make([]string, 0)

var errs []error

// This slice will keep track of all the structs we'll be decoding.
// There can be more than one struct if there are embedded structs
Expand Down Expand Up @@ -1369,8 +1364,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e

if squash {
if fieldVal.Kind() != reflect.Struct {
errors = appendErrors(errors,
fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
errs = append(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind()))
} else {
structs = append(structs, fieldVal)
}
Expand Down Expand Up @@ -1449,7 +1443,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
}

if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}
}

Expand All @@ -1464,7 +1458,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e

// Decode it as-if we were just decoding this map onto our map.
if err := d.decodeMap(name, remain, remainField.val); err != nil {
errors = appendErrors(errors, err)
errs = append(errs, err)
}

// Set the map to nil so we have none so that the next check will
Expand All @@ -1480,7 +1474,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
sort.Strings(keys)

err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", "))
errors = appendErrors(errors, err)
errs = append(errs, err)
}

if d.config.ErrorUnset && len(targetValKeysUnused) > 0 {
Expand All @@ -1491,11 +1485,11 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
sort.Strings(keys)

err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", "))
errors = appendErrors(errors, err)
errs = append(errs, err)
}

if len(errors) > 0 {
return &joinedError{errors}
if err := errors.Join(errs...); err != nil {
return err
}

// Add the unused keys to the list of unused keys if we're tracking metadata
Expand Down
12 changes: 6 additions & 6 deletions mapstructure_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ func ExampleDecode_errors() {

fmt.Println(err.Error())
// Output:
// 5 error(s) decoding:
// decoding failed due to the following error(s):
//
// * 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
// * 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
// * 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
// * 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
// * 'Name' expected type 'string', got unconvertible type 'int', value: '123'
// 'Name' expected type 'string', got unconvertible type 'int', value: '123'
// 'Age' expected type 'int', got unconvertible type 'string', value: 'bad value'
// 'Emails[0]' expected type 'string', got unconvertible type 'int', value: '1'
// 'Emails[1]' expected type 'string', got unconvertible type 'int', value: '2'
// 'Emails[2]' expected type 'string', got unconvertible type 'int', value: '3'
}

func ExampleDecode_metadata() {
Expand Down
Loading
Loading