From c57fcb3746c4bfdab1b65363aa9e9edc7b6cab28 Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 11 Mar 2024 22:49:58 +0200 Subject: [PATCH] Default binder can use `UnmarshalParams(params []string) error` interface to bind multiple input values at one go. (#2607) --- bind.go | 64 +++++++---- bind_test.go | 302 +++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 301 insertions(+), 65 deletions(-) diff --git a/bind.go b/bind.go index 5e29be8e5..b6146e8ca 100644 --- a/bind.go +++ b/bind.go @@ -30,6 +30,13 @@ type BindUnmarshaler interface { UnmarshalParam(param string) error } +// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to +// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case +// for `a` following slice `["1", "2"] will be passed to unmarshaller. +type bindMultipleUnmarshaler interface { + UnmarshalParams(params []string) error +} + // BindPathParams binds path params to bindable object func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { names := c.ParamNames() @@ -217,8 +224,15 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri continue } + if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok { + if err != nil { + return err + } + continue + } + // Call this first, in case we're dealing with an alias to an array type - if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { + if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField); ok { if err != nil { return err } @@ -245,7 +259,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { // But also call it here, in case we're dealing with an array of BindUnmarshalers - if ok, err := unmarshalField(valueKind, val, structField); ok { + if ok, err := unmarshalInputToField(valueKind, val, structField); ok { return err } @@ -286,33 +300,39 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V return nil } -func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { - switch valueKind { - case reflect.Ptr: - return unmarshalFieldPtr(val, field) - default: - return unmarshalFieldNonPtr(val, field) +func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) { + if valueKind == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() } -} -func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { fieldIValue := field.Addr().Interface() - if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok { - return true, unmarshaler.UnmarshalParam(value) - } - if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok { - return true, unmarshaler.UnmarshalText([]byte(value)) + unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler) + if !ok { + return false, nil } - - return false, nil + return true, unmarshaler.UnmarshalParams(values) } -func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { - if field.IsNil() { - // Initialize the pointer to a nil value - field.Set(reflect.New(field.Type().Elem())) +func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { + if valueKind == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() } - return unmarshalFieldNonPtr(value, field.Elem()) + + fieldIValue := field.Addr().Interface() + switch unmarshaler := fieldIValue.(type) { + case BindUnmarshaler: + return true, unmarshaler.UnmarshalParam(val) + case encoding.TextUnmarshaler: + return true, unmarshaler.UnmarshalText([]byte(val)) + } + + return false, nil } func setIntField(value string, bitSize int, field reflect.Value) error { diff --git a/bind_test.go b/bind_test.go index 05f8ef43c..9647d6566 100644 --- a/bind_test.go +++ b/bind_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" "io" "mime/multipart" "net/http" @@ -653,49 +654,6 @@ func TestBindSetWithProperType(t *testing.T) { assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - - ts := new(bindTestStruct) - val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(t, setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(t, 5, ts.I) - } - if assert.NoError(t, setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(t, 0, ts.I) - } - - // Uint - if assert.NoError(t, setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(t, uint(10), ts.UI) - } - if assert.NoError(t, setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(t, uint(0), ts.UI) - } - - // Float - if assert.NoError(t, setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(t, float32(15.5), ts.F32) - } - if assert.NoError(t, setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(t, float32(0.0), ts.F32) - } - - // Bool - if assert.NoError(t, setBoolField("true", val.FieldByName("B"))) { - assert.Equal(t, true, ts.B) - } - if assert.NoError(t, setBoolField("", val.FieldByName("B"))) { - assert.Equal(t, false, ts.B) - } - - ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(t, err) { - assert.Equal(t, ok, true) - assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) - } -} - func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() ts := new(bindTestStructWithTags) @@ -1138,3 +1096,261 @@ func TestDefaultBinder_BindBody(t *testing.T) { }) } } + +type unixTimestamp struct { + Time time.Time +} + +func (t *unixTimestamp) UnmarshalParam(param string) error { + n, err := strconv.ParseInt(param, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", param) + } + *t = unixTimestamp{Time: time.Unix(n, 0)} + return err +} + +type IntArrayA []int + +// UnmarshalParam converts value to *Int64Slice. This allows the API to accept +// a comma-separated list of integers as a query parameter. +func (i *IntArrayA) UnmarshalParam(value string) error { + var values = strings.Split(value, ",") + var numbers = make([]int, 0, len(values)) + + for _, v := range values { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", v) + } + + numbers = append(numbers, int(n)) + } + + *i = append(*i, numbers...) + return nil +} + +func TestBindUnmarshalParamExtras(t *testing.T) { + // this test documents how bind handles `BindUnmarshaler` interface: + // NOTE: BindUnmarshaler chooses first input value to be bound. + + t.Run("nok, unmarshalling fails", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?t=xxxx", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V unixTimestamp `query:"t"` + }{} + err := c.Bind(&result) + + assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") + }) + + t.Run("ok, target is struct", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?t=1710095540&t=1710095541", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V unixTimestamp `query:"t"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + expect := unixTimestamp{ + Time: time.Unix(1710095540, 0), + } + assert.Equal(t, expect, result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1,2,3&a=4,5,6", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V IntArrayA `query:"a"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1,2", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V IntArrayA `query:"a"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayA([]int{1, 2}), result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V *IntArrayA `query:"a"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + var expected = IntArrayA([]int{1}) + assert.Equal(t, &expected, result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V *IntArrayA `query:"a"` + }{} + result.V = new(IntArrayA) // NOT nil + + err := c.Bind(&result) + + assert.NoError(t, err) + var expected = IntArrayA([]int{1}) + assert.Equal(t, &expected, result.V) + }) +} + +type unixTimestampLast struct { + Time time.Time +} + +// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling +func (t *unixTimestampLast) UnmarshalParams(params []string) error { + lastInput := params[len(params)-1] + n, err := strconv.ParseInt(lastInput, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", lastInput) + } + *t = unixTimestampLast{Time: time.Unix(n, 0)} + return err +} + +type IntArrayB []int + +func (i *IntArrayB) UnmarshalParams(params []string) error { + var numbers = make([]int, 0, len(params)) + + for _, param := range params { + var values = strings.Split(param, ",") + for _, v := range values { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", v) + } + numbers = append(numbers, int(n)) + } + } + + *i = append(*i, numbers...) + return nil +} + +func TestBindUnmarshalParams(t *testing.T) { + // this test documents how bind handles `bindMultipleUnmarshaler` interface: + + t.Run("nok, unmarshalling fails", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?t=xxxx", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V unixTimestampLast `query:"t"` + }{} + err := c.Bind(&result) + + assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") + }) + + t.Run("ok, target is struct", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?t=1710095540&t=1710095541", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V unixTimestampLast `query:"t"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + expect := unixTimestampLast{ + Time: time.Unix(1710095541, 0), + } + assert.Equal(t, expect, result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1,2,3&a=4,5,6", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V IntArrayB `query:"a"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1,2", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V IntArrayB `query:"a"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayB([]int{1, 2}), result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V *IntArrayB `query:"a"` + }{} + err := c.Bind(&result) + + assert.NoError(t, err) + var expected = IntArrayB([]int{1, 4, 5, 6}) + assert.Equal(t, &expected, result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + V *IntArrayB `query:"a"` + }{} + result.V = new(IntArrayB) // NOT nil + + err := c.Bind(&result) + + assert.NoError(t, err) + var expected = IntArrayB([]int{1, 4, 5, 6}) + assert.Equal(t, &expected, result.V) + }) +}