Skip to content

Commit

Permalink
feat: add Unmarshalers using generics to replace Decoder.Register
Browse files Browse the repository at this point in the history
This greatly improves performance of custom unmarshaler functions by removing the need to call reflect.Call.

benchmark                      old ns/op     new ns/op     delta
BenchmarkDecode/register-8     302           51.7          -82.91%

benchmark                      old allocs     new allocs     delta
BenchmarkDecode/register-8     5              2              -60.00%

benchmark                      old bytes     new bytes     delta
BenchmarkDecode/register-8     80            16            -80.00%

Decoder.Register is now deprecated in favor of Decoder.WithUnmarshalers.

This change requires Go1.18
  • Loading branch information
jszwec committed Feb 14, 2023
1 parent 2b7b86b commit 46f1e77
Show file tree
Hide file tree
Showing 5 changed files with 514 additions and 276 deletions.
44 changes: 14 additions & 30 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,16 @@ var (

type decodeFunc func(s string, v reflect.Value) error

func decodeFuncValue(f reflect.Value) decodeFunc {
isIface := f.Type().In(1).Kind() == reflect.Interface

func decodeFuncValue(f func([]byte, any) error) decodeFunc {
return func(s string, v reflect.Value) error {
if isIface && v.Type().Kind() == reflect.Interface && v.IsNil() {
return &UnmarshalTypeError{Value: s, Type: v.Type()}
}

out := f.Call([]reflect.Value{
reflect.ValueOf([]byte(s)),
v,
})
err, _ := out[0].Interface().(error)
return err
return f([]byte(s), v.Interface())
}
}

func decodeFuncValuePtr(f reflect.Value) decodeFunc {
func decodeFuncValuePtr(f func([]byte, any) error) decodeFunc {
return func(s string, v reflect.Value) error {
out := f.Call([]reflect.Value{
reflect.ValueOf([]byte(s)),
v.Addr(),
})
err, _ := out[0].Interface().(error)
return err
v = v.Addr()
return f([]byte(s), v.Interface())
}
}

Expand Down Expand Up @@ -124,7 +109,7 @@ func decodeFieldUnmarshaler(s string, v reflect.Value) error {
return v.Interface().(Unmarshaler).UnmarshalCSV([]byte(s))
}

func decodePtr(typ reflect.Type, funcMap map[reflect.Type]reflect.Value, ifaceFuncs []reflect.Value) (decodeFunc, error) {
func decodePtr(typ reflect.Type, funcMap map[reflect.Type]func([]byte, any) error, ifaceFuncs []ifaceDecodeFunc) (decodeFunc, error) {
next, err := decodeFn(typ.Elem(), funcMap, ifaceFuncs)
if err != nil {
return nil, err
Expand All @@ -138,7 +123,7 @@ func decodePtr(typ reflect.Type, funcMap map[reflect.Type]reflect.Value, ifaceFu
}, nil
}

func decodeInterface(funcMap map[reflect.Type]reflect.Value, ifaceFuncs []reflect.Value) decodeFunc {
func decodeInterface(funcMap map[reflect.Type]func([]byte, any) error, ifaceFuncs []ifaceDecodeFunc) decodeFunc {
return func(s string, v reflect.Value) error {
if v.NumMethod() != 0 {
return &UnmarshalTypeError{
Expand All @@ -163,8 +148,8 @@ func decodeInterface(funcMap map[reflect.Type]reflect.Value, ifaceFuncs []reflec
return decodeFuncValue(f)(s, el)
}
for _, f := range ifaceFuncs {
if typ.AssignableTo(f.Type().In(1)) {
return decodeFuncValue(f)(s, el)
if typ.AssignableTo(f.argType) {
return decodeFuncValue(f.f)(s, el)
}
}
if typ.Implements(csvUnmarshaler) {
Expand Down Expand Up @@ -195,7 +180,7 @@ func decodeBytes(s string, v reflect.Value) error {
return nil
}

func decodeFn(typ reflect.Type, funcMap map[reflect.Type]reflect.Value, ifaceFuncs []reflect.Value) (decodeFunc, error) {
func decodeFn(typ reflect.Type, funcMap map[reflect.Type]func([]byte, any) error, ifaceFuncs []ifaceDecodeFunc) (decodeFunc, error) {
if f, ok := funcMap[typ]; ok {
return decodeFuncValue(f), nil
}
Expand All @@ -204,12 +189,11 @@ func decodeFn(typ reflect.Type, funcMap map[reflect.Type]reflect.Value, ifaceFun
}

for _, f := range ifaceFuncs {
argType := f.Type().In(1)
if typ.AssignableTo(argType) {
return decodeFuncValue(f), nil
if typ.AssignableTo(f.argType) {
return decodeFuncValue(f.f), nil
}
if reflect.PtrTo(typ).AssignableTo(argType) {
return decodeFuncValuePtr(f), nil
if reflect.PtrTo(typ).AssignableTo(f.argType) {
return decodeFuncValuePtr(f.f), nil
}
}

Expand Down
159 changes: 137 additions & 22 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ type Decoder struct {
record []string
cache []decField
unused []int
funcMap map[reflect.Type]reflect.Value
ifaceFuncs []reflect.Value
funcMap map[reflect.Type]func([]byte, any) error
ifaceFuncs []ifaceDecodeFunc
}

type ifaceDecodeFunc struct {
f func([]byte, any) error
argType reflect.Type
}

// NewDecoder returns a new decoder that reads from r.
Expand Down Expand Up @@ -104,22 +109,23 @@ func NewDecoder(r Reader, header ...string) (dec *Decoder, err error) {
// the decoding if record's field is an empty string.
//
// Examples of struct field tags and their meanings:
// // Decode matches this field with "myName" header column.
// Field int `csv:"myName"`
//
// // Decode matches this field with "Field" header column.
// Field int
// // Decode matches this field with "myName" header column.
// Field int `csv:"myName"`
//
// // Decode matches this field with "myName" header column and decoding is not
// // Decode matches this field with "Field" header column.
// Field int
//
// // Decode matches this field with "myName" header column and decoding is not
// // called if record's field is an empty string.
// Field int `csv:"myName,omitempty"`
// Field int `csv:"myName,omitempty"`
//
// // Decode matches this field with "Field" header column and decoding is not
// // Decode matches this field with "Field" header column and decoding is not
// // called if record's field is an empty string.
// Field int `csv:",omitempty"`
// Field int `csv:",omitempty"`
//
// // Decode ignores this field.
// Field int `csv:"-"`
// // Decode ignores this field.
// Field int `csv:"-"`
//
// // Decode treats this field exactly as if it was an embedded field and
// // matches header columns that start with "my_prefix_" to all fields of this
Expand Down Expand Up @@ -238,7 +244,8 @@ func (d *Decoder) Unused() []int {

// Register registers a custom decoding function for a concrete type or interface.
// The argument f must be of type:
// func([]byte, T) error
//
// func([]byte, T) error
//
// T must be a concrete type such as *time.Time, or interface that has at least one
// method.
Expand All @@ -248,12 +255,15 @@ func (d *Decoder) Unused() []int {
// in order they were registered.
//
// Register panics if:
// - f does not match the right signature
// - f is an empty interface
// - f was already registered
// - f does not match the right signature
// - f is an empty interface
// - f was already registered
//
// Register is based on the encoding/json proposal:
// https://github.com/golang/go/issues/5901.
//
// Deprecated: use UnmarshalFunc function with type parameter instead. The benefits
// are type safety and much better performance.
func (d *Decoder) Register(f interface{}) {
v := reflect.ValueOf(f)
typ := v.Type()
Expand All @@ -271,20 +281,51 @@ func (d *Decoder) Register(f interface{}) {
}

if d.funcMap == nil {
d.funcMap = make(map[reflect.Type]reflect.Value)
d.funcMap = make(map[reflect.Type]func([]byte, any) error)
}

if _, ok := d.funcMap[argType]; ok {
panic("csvutil: func " + typ.String() + " already registered")
}

d.funcMap[argType] = v
isIface := argType.Kind() == reflect.Interface
isArgPtr := v.Type().In(1).Kind() == reflect.Ptr

fn := func(data []byte, in any) error {
dst := reflect.ValueOf(in)

if isIface && !dst.IsValid() {
return &UnmarshalTypeError{Value: string(data), Type: argType}
}

if !isIface && isArgPtr && dst.Kind() != reflect.Pointer {
dst = dst.Addr()
}

out := v.Call([]reflect.Value{
reflect.ValueOf(data),
dst,
})
err, _ := out[0].Interface().(error)
return err
}

d.funcMap[argType] = fn

if argType.Kind() == reflect.Interface {
d.ifaceFuncs = append(d.ifaceFuncs, v)
d.ifaceFuncs = append(d.ifaceFuncs, ifaceDecodeFunc{
f: fn,
argType: argType,
})
}
}

// WithUnmarshalers sets the provided Unmarshalers for the decoder.
func (d *Decoder) WithUnmarshalers(u *Unmarshalers) {
d.funcMap = u.funcMap
d.ifaceFuncs = u.ifaceFuncs
}

func (d *Decoder) decodeSlice(slice reflect.Value) error {
typ := slice.Type().Elem()
if walkType(typ).Kind() != reflect.Struct {
Expand Down Expand Up @@ -417,9 +458,9 @@ fieldLoop:
}

// wrapDecodeError provides the given error with more context such as:
// - column name (field)
// - line number
// - column within record
// - column name (field)
// - line number
// - column within record
//
// Line and Column info is available only if the used Reader supports 'FieldPos'
// that is available e.g. in csv.Reader (since Go1.17).
Expand Down Expand Up @@ -537,3 +578,77 @@ func indirect(v reflect.Value) reflect.Value {
}
}
}

// Unmarshalers stores custom unmarshal functions. Unmarshalers is immutable.
type Unmarshalers struct {
funcMap map[reflect.Type]func([]byte, any) error
ifaceFuncs []ifaceDecodeFunc
}

// NewUnmarshalers merges the provided Unmarshalers into one and returns it.
// If Unmarshalers contain duplicate function signatures, the one that was
// provided first wins.
func NewUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
out := &Unmarshalers{
funcMap: make(map[reflect.Type]func([]byte, any) error),
}

for _, u := range us {
for k, v := range u.funcMap {
if _, ok := out.funcMap[k]; ok {
continue
}
out.funcMap[k] = v
}
out.ifaceFuncs = append(out.ifaceFuncs, u.ifaceFuncs...)
}

return out
}

// UnmarshalFunc stores the provided function in Unmarshaler and returns it.
//
// Type Parameter T must be a concrete type such as *time.Time, or interface
// that has at least one method.
//
// During decoding, fields are matched by the concrete type first. If match is not
// found then Decoder looks if field implements any of the registered interfaces
// in order they were registered.
//
// UnmarshalFunc panics if T is an empty interface.
func UnmarshalFunc[T any](f func([]byte, T) error) *Unmarshalers {
var (
funcMap = make(map[reflect.Type]func([]byte, any) error)
ifaceFuncs []ifaceDecodeFunc
argType = reflect.TypeOf(f).In(1)
isIface = argType.Kind() == reflect.Interface
)

fn := func(data []byte, v any) error {
if !isIface {
return f(data, v.(T))
}
if _, ok := v.(T); !ok {
return &UnmarshalTypeError{Value: string(data), Type: argType}
}
return f(data, v.(T))
}

funcMap[argType] = fn

if argType.Kind() == reflect.Interface {
if argType.NumMethod() == 0 {
panic("csvutil: func argument type must not be an empty interface")
}

ifaceFuncs = append(ifaceFuncs, ifaceDecodeFunc{
f: fn,
argType: argType,
})
}

return &Unmarshalers{
funcMap: funcMap,
ifaceFuncs: ifaceFuncs,
}
}

0 comments on commit 46f1e77

Please sign in to comment.