From cf3eb5b33ef36c9a8b2eecc06246292ce14b0c1c Mon Sep 17 00:00:00 2001 From: Gord Allott Date: Thu, 3 Sep 2020 11:12:56 +0100 Subject: [PATCH] allows more types with the Pointable interface Signed-off-by: Gord Allott --- pointer.go | 60 ++++++++++++++++++++++++------------------------- pointer_test.go | 24 ++++++++++++++++++++ 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/pointer.go b/pointer.go index b284eb7..7df9853 100644 --- a/pointer.go +++ b/pointer.go @@ -114,16 +114,16 @@ func getSingleImpl(node interface{}, decodedToken string, nameProvider *swag.Nam rValue := reflect.Indirect(reflect.ValueOf(node)) kind := rValue.Kind() - switch kind { + if rValue.Type().Implements(jsonPointableType) { + r, err := node.(JSONPointable).JSONLookup(decodedToken) + if err != nil { + return nil, kind, err + } + return r, kind, nil + } + switch kind { case reflect.Struct: - if rValue.Type().Implements(jsonPointableType) { - r, err := node.(JSONPointable).JSONLookup(decodedToken) - if err != nil { - return nil, kind, err - } - return r, kind, nil - } nm, ok := nameProvider.GetGoNameForType(rValue.Type(), decodedToken) if !ok { return nil, kind, fmt.Errorf("object has no field %q", decodedToken) @@ -161,17 +161,17 @@ func getSingleImpl(node interface{}, decodedToken string, nameProvider *swag.Nam func setSingleImpl(node, data interface{}, decodedToken string, nameProvider *swag.NameProvider) error { rValue := reflect.Indirect(reflect.ValueOf(node)) - switch rValue.Kind() { - case reflect.Struct: - if ns, ok := node.(JSONSetable); ok { // pointer impl - return ns.JSONSet(decodedToken, data) - } + if ns, ok := node.(JSONSetable); ok { // pointer impl + return ns.JSONSet(decodedToken, data) + } - if rValue.Type().Implements(jsonSetableType) { - return node.(JSONSetable).JSONSet(decodedToken, data) - } + if rValue.Type().Implements(jsonSetableType) { + return node.(JSONSetable).JSONSet(decodedToken, data) + } + switch rValue.Kind() { + case reflect.Struct: nm, ok := nameProvider.GetGoNameForType(rValue.Type(), decodedToken) if !ok { return fmt.Errorf("object has no field %q", decodedToken) @@ -270,22 +270,22 @@ func (p *Pointer) set(node, data interface{}, nameProvider *swag.NameProvider) e rValue := reflect.Indirect(reflect.ValueOf(node)) kind := rValue.Kind() - switch kind { - - case reflect.Struct: - if rValue.Type().Implements(jsonPointableType) { - r, err := node.(JSONPointable).JSONLookup(decodedToken) - if err != nil { - return err - } - fld := reflect.ValueOf(r) - if fld.CanAddr() && fld.Kind() != reflect.Interface && fld.Kind() != reflect.Map && fld.Kind() != reflect.Slice && fld.Kind() != reflect.Ptr { - node = fld.Addr().Interface() - continue - } - node = r + if rValue.Type().Implements(jsonPointableType) { + r, err := node.(JSONPointable).JSONLookup(decodedToken) + if err != nil { + return err + } + fld := reflect.ValueOf(r) + if fld.CanAddr() && fld.Kind() != reflect.Interface && fld.Kind() != reflect.Map && fld.Kind() != reflect.Slice && fld.Kind() != reflect.Ptr { + node = fld.Addr().Interface() continue } + node = r + continue + } + + switch kind { + case reflect.Struct: nm, ok := nameProvider.GetGoNameForType(rValue.Type(), decodedToken) if !ok { return fmt.Errorf("object has no field %q", decodedToken) diff --git a/pointer_test.go b/pointer_test.go index eabd586..020b19d 100644 --- a/pointer_test.go +++ b/pointer_test.go @@ -167,6 +167,21 @@ func (p pointableImpl) JSONLookup(token string) (interface{}, error) { return nil, fmt.Errorf("object has no field %q", token) } +type pointableMap map[string]string + +func (p pointableMap) JSONLookup(token string) (interface{}, error) { + if token == "swap" { + return p["swapped"], nil + } + + v, ok := p[token] + if ok { + return v, nil + } + + return nil, fmt.Errorf("object has no key %q", token) +} + func TestPointableInterface(t *testing.T) { p := &pointableImpl{"hello"} @@ -177,6 +192,15 @@ func TestPointableInterface(t *testing.T) { result, _, err = GetForToken(p, "something") assert.Error(t, err) assert.Nil(t, result) + + pm := pointableMap{"swapped": "hello", "a": "world"} + result, _, err = GetForToken(pm, "swap") + assert.NoError(t, err) + assert.Equal(t, pm["swapped"], result) + + result, _, err = GetForToken(pm, "a") + assert.NoError(t, err) + assert.Equal(t, pm["a"], result) } func TestGetNode(t *testing.T) {