From aa09cd1d27d2deb3747cbc1ed485cb6c1a0dee27 Mon Sep 17 00:00:00 2001 From: Rangel Reale Date: Wed, 20 Dec 2017 12:47:48 -0200 Subject: [PATCH] Form reader support for encoding.TextUnmarshaler This allows Context.Read to read components that support the text unmarshaler interface, like UUID libraries --- reader.go | 29 +++++++++++++++++++++++++++++ reader_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/reader.go b/reader.go index b2be94b..7ac3c76 100644 --- a/reader.go +++ b/reader.go @@ -1,6 +1,7 @@ package routing import ( + "encoding" "encoding/json" "encoding/xml" "errors" @@ -19,6 +20,10 @@ const ( MIME_MULTIPART_FORM = "multipart/form-data" ) +var ( + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + // DataReader is used by Context.Read() to read data from an HTTP request. type DataReader interface { // Read reads from the given HTTP request and populate the specified data. @@ -107,6 +112,13 @@ func readForm(form map[string][]string, prefix string, rv reflect.Value) error { name = prefix + "." + name } + // check if type implements a known type, like encoding.TextUnmarshaler + if ok, err := readFormFieldKnownType(form, name, rv.Field(i)); err != nil { + return err + } else if ok { + continue + } + if ft.Kind() != reflect.Struct { if err := readFormField(form, name, rv.Field(i)); err != nil { return err @@ -124,6 +136,23 @@ func readForm(form map[string][]string, prefix string, rv reflect.Value) error { return nil } +func readFormFieldKnownType(form map[string][]string, name string, rv reflect.Value) (bool, error) { + value, ok := form[name] + if !ok { + return false, nil + } + rv = indirect(rv) + rt := rv.Type() + + // check if type implements encoding.TextUnmarshaler + if rt.Implements(textUnmarshalerType) { + return true, rv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value[0])) + } else if reflect.PtrTo(rt).Implements(textUnmarshalerType) { + return true, rv.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value[0])) + } + return false, nil +} + func readFormField(form map[string][]string, name string, rv reflect.Value) error { value, ok := form[name] if !ok { diff --git a/reader_test.go b/reader_test.go index 5c3c937..c121803 100644 --- a/reader_test.go +++ b/reader_test.go @@ -84,3 +84,27 @@ func TestDefaultDataReader(t *testing.T) { assert.Equal(t, expected, data, test.tag) } } + +type TU struct { + UValue string +} + +func (tu *TU) UnmarshalText(text []byte) error { + tu.UValue = "TU_" + string(text[:]) + return nil +} + +func TestTextUnmarshaler(t *testing.T) { + var a struct { + ATU TU `form:"atu"` + NTU string `form:"ntu"` + } + values := map[string][]string{ + "atu": []string{"ORIGINAL"}, + "ntu": []string{"ORIGINAL"}, + } + err := ReadFormData(values, &a) + assert.Nil(t, err) + assert.Equal(t, "TU_ORIGINAL", a.ATU.UValue) + assert.Equal(t, "ORIGINAL", a.NTU) +}