Skip to content

Commit

Permalink
Make registered converters take priority over TextUnmarshaler. (#85)
Browse files Browse the repository at this point in the history
Fixes a bug in #84
  • Loading branch information
kisielk committed Jun 12, 2017
1 parent ca549c5 commit 8b11008
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 17 deletions.
18 changes: 7 additions & 11 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@ var invalidPath = errors.New("schema: invalid path")
func newCache() *cache {
c := cache{
m: make(map[reflect.Type]*structInfo),
conv: make(map[reflect.Kind]Converter),
regconv: make(map[reflect.Type]Converter),
tag: "schema",
}
for k, v := range converters {
c.conv[k] = v
}
return &c
}

// cache caches meta-data about a struct.
type cache struct {
l sync.RWMutex
m map[reflect.Type]*structInfo
conv map[reflect.Kind]Converter
regconv map[reflect.Type]Converter
tag string
}

// registerConverter registers a converter function for a custom type.
func (c *cache) registerConverter(value interface{}, converterFunc Converter) {
c.regconv[reflect.TypeOf(value)] = converterFunc
}

// parsePath parses a path in dotted notation verifying that it is a valid
// path to a struct field.
//
Expand Down Expand Up @@ -178,7 +178,7 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) {
}
}
if isStruct = ft.Kind() == reflect.Struct; !isStruct {
if conv := c.converter(ft); conv == nil {
if c.converter(ft) == nil && builtinConverters[ft.Kind()] == nil {
// Type is not supported.
return
}
Expand All @@ -196,11 +196,7 @@ func (c *cache) createField(field reflect.StructField, info *structInfo) {

// converter returns the converter for a type.
func (c *cache) converter(t reflect.Type) Converter {
conv := c.regconv[t]
if conv == nil {
conv = c.conv[t.Kind()]
}
return conv
return c.regconv[t]
}

// ----------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var (
)

// Default converters for basic types.
var converters = map[reflect.Kind]Converter{
var builtinConverters = map[reflect.Kind]Converter{
boolType: convertBool,
float32Type: convertFloat32,
float64Type: convertFloat64,
Expand Down
23 changes: 18 additions & 5 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (d *Decoder) IgnoreUnknownKeys(i bool) {

// RegisterConverter registers a converter function for a custom type.
func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) {
d.cache.regconv[reflect.TypeOf(value)] = converterFunc
d.cache.registerConverter(value, converterFunc)
}

// Decode decodes a map[string][]string to a struct.
Expand Down Expand Up @@ -196,9 +196,12 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
// Try to get a converter for the element type.
conv := d.cache.converter(elemT)
if conv == nil {
// As we are not dealing with slice of structs here, we don't need to check if the type
// implements TextUnmarshaler interface
return fmt.Errorf("schema: converter not found for %v", elemT)
conv = builtinConverters[elemT.Kind()]
if conv == nil {
// As we are not dealing with slice of structs here, we don't need to check if the type
// implements TextUnmarshaler interface
return fmt.Errorf("schema: converter not found for %v", elemT)
}
}

for key, value := range values {
Expand Down Expand Up @@ -284,6 +287,16 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
if d.zeroEmpty {
v.Set(reflect.Zero(t))
}
} else if conv != nil {
if value := conv(val); value.IsValid() {
v.Set(value.Convert(t))
} else {
return ConversionError{
Key: path,
Type: t,
Index: -1,
}
}
} else if m := isTextUnmarshaler(v); m.IsValid {
// If the value implements the encoding.TextUnmarshaler interface
// apply UnmarshalText as the converter
Expand All @@ -295,7 +308,7 @@ func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values
Err: err,
}
}
} else if conv != nil {
} else if conv := builtinConverters[t.Kind()]; conv != nil {
if value := conv(val); value.IsValid() {
v.Set(value.Convert(t))
} else {
Expand Down
19 changes: 19 additions & 0 deletions decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1638,3 +1638,22 @@ func TestAnonymousStructField(t *testing.T) {
}
}
}

// Test to ensure that a registered converter overrides the default text unmarshaler.
func TestRegisterConverterOverridesTextUnmarshaler(t *testing.T) {
type MyTime time.Time
s1 := &struct {
MyTime
}{}
decoder := NewDecoder()

ts := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
decoder.RegisterConverter(s1.MyTime, func(s string) reflect.Value { return reflect.ValueOf(ts) })

v1 := map[string][]string{"MyTime": {"4"}, "Bb": {"5"}}
decoder.Decode(s1, v1)

if s1.MyTime != MyTime(ts) {
t.Errorf("s1.Aa: expected %v, got %v", ts, s1.MyTime)
}
}

0 comments on commit 8b11008

Please sign in to comment.