From a183878ad2e6e523fe9311380ce99b5b47cccf1e Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Wed, 7 Aug 2019 10:47:57 -0700 Subject: [PATCH] auth/casbin: update to accommodate change in dep --- auth/casbin/middleware.go | 16 +++++++++++----- auth/casbin/middleware_test.go | 3 ++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/auth/casbin/middleware.go b/auth/casbin/middleware.go index fdbf6dc31..43f2f9daa 100644 --- a/auth/casbin/middleware.go +++ b/auth/casbin/middleware.go @@ -48,17 +48,23 @@ func NewEnforcer( subject string, object interface{}, action string, ) endpoint.Middleware { return func(next endpoint.Endpoint) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) ( - response interface{}, err error, - ) { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { casbinModel := ctx.Value(CasbinModelContextKey) casbinPolicy := ctx.Value(CasbinPolicyContextKey) + enforcer, err := stdcasbin.NewEnforcer(casbinModel, casbinPolicy) + if err != nil { + return nil, err + } - enforcer := stdcasbin.NewEnforcer(casbinModel, casbinPolicy) ctx = context.WithValue(ctx, CasbinEnforcerContextKey, enforcer) - if !enforcer.Enforce(subject, object, action) { + ok, err := enforcer.Enforce(subject, object, action) + if err != nil { + return nil, err + } + if !ok { return nil, ErrUnauthorized } + return next(ctx, request) } } diff --git a/auth/casbin/middleware_test.go b/auth/casbin/middleware_test.go index 5dbb003df..922158657 100644 --- a/auth/casbin/middleware_test.go +++ b/auth/casbin/middleware_test.go @@ -5,13 +5,14 @@ import ( "testing" stdcasbin "github.com/casbin/casbin" + "github.com/casbin/casbin/model" fileadapter "github.com/casbin/casbin/persist/file-adapter" ) func TestStructBaseContext(t *testing.T) { e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil } - m := stdcasbin.NewModel() + m := model.NewModel() m.AddDef("r", "r", "sub, obj, act") m.AddDef("p", "p", "sub, obj, act") m.AddDef("e", "e", "some(where (p.eft == allow))")