Skip to content

Commit

Permalink
Return error instead of causing panic for non-API functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
hsluoyz committed Jul 12, 2019
1 parent 9b590d6 commit ce253cf
Show file tree
Hide file tree
Showing 21 changed files with 293 additions and 315 deletions.
147 changes: 99 additions & 48 deletions enforcer.go
Expand Up @@ -51,7 +51,7 @@ type Enforcer struct {
// MySQL DB:
// a := mysqladapter.NewDBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/")
// e := casbin.NewEnforcer("path/to/basic_model.conf", a)
func NewEnforcer(params ...interface{}) *Enforcer {
func NewEnforcer(params ...interface{}) (*Enforcer, error) {
e := &Enforcer{}

parsedParamLen := 0
Expand All @@ -69,50 +69,76 @@ func NewEnforcer(params ...interface{}) *Enforcer {
case string:
switch p1 := params[1].(type) {
case string:
e.InitWithFile(p0, p1)
err := e.InitWithFile(p0, p1)
if err != nil {
return nil, err
}
default:
e.InitWithAdapter(p0, p1.(persist.Adapter))
err := e.InitWithAdapter(p0, p1.(persist.Adapter))
if err != nil {
return nil, err
}
}
default:
switch params[1].(type) {
case string:
panic("Invalid parameters for enforcer.")
return nil, errors.New("invalid parameters for enforcer")
default:
e.InitWithModelAndAdapter(p0.(model.Model), params[1].(persist.Adapter))
err := e.InitWithModelAndAdapter(p0.(model.Model), params[1].(persist.Adapter))
if err != nil {
return nil, err
}
}
}
} else if len(params)-parsedParamLen == 1 {
switch p0 := params[0].(type) {
case string:
e.InitWithFile(p0, "")
err := e.InitWithFile(p0, "")
if err != nil {
return nil, err
}
default:
e.InitWithModelAndAdapter(p0.(model.Model), nil)
err := e.InitWithModelAndAdapter(p0.(model.Model), nil)
if err != nil {
return nil, err
}
}
} else if len(params)-parsedParamLen == 0 {
e.InitWithFile("", "")
err := e.InitWithFile("", "")
if err != nil {
return nil, err
}
} else {
panic("Invalid parameters for enforcer.")
return nil, errors.New("invalid parameters for enforcer")
}

return e
return e, nil
}

// InitWithFile initializes an enforcer with a model file and a policy file.
func (e *Enforcer) InitWithFile(modelPath string, policyPath string) {
func (e *Enforcer) InitWithFile(modelPath string, policyPath string) error {
a := fileadapter.NewAdapter(policyPath)
e.InitWithAdapter(modelPath, a)
return e.InitWithAdapter(modelPath, a)
}

// InitWithAdapter initializes an enforcer with a database adapter.
func (e *Enforcer) InitWithAdapter(modelPath string, adapter persist.Adapter) {
m := NewModel(modelPath, "")
e.InitWithModelAndAdapter(m, adapter)
func (e *Enforcer) InitWithAdapter(modelPath string, adapter persist.Adapter) error {
m, err := NewModel(modelPath, "")
if err != nil {
return err
}

err = e.InitWithModelAndAdapter(m, adapter)
if err != nil {
return err
}

e.modelPath = modelPath
return nil
}

// InitWithModelAndAdapter initializes an enforcer with a model and a database adapter.
func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) {
func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapter) error {
e.adapter = adapter

e.model = m
Expand All @@ -124,9 +150,13 @@ func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapte
// Do not initialize the full policy when using a filtered adapter
fa, ok := e.adapter.(persist.FilteredAdapter)
if e.adapter != nil && (!ok || ok && !fa.IsFiltered()) {
// error intentionally ignored
e.LoadPolicy()
err := e.LoadPolicy()
if err != nil {
return err
}
}

return nil
}

func (e *Enforcer) initialize() {
Expand All @@ -140,29 +170,44 @@ func (e *Enforcer) initialize() {
}

// NewModel creates a model.
func NewModel(text ...string) model.Model {
func NewModel(text ...string) (model.Model, error) {
m := make(model.Model)

if len(text) == 2 {
if text[0] != "" {
m.LoadModel(text[0])
err := m.LoadModel(text[0])
if err != nil {
return nil, err
}
}
} else if len(text) == 1 {
m.LoadModelFromText(text[0])
} else if len(text) != 0 {
panic("Invalid parameters for model.")
return nil, errors.New("invalid parameters for model")
}

return m
return m, nil
}

// LoadModel reloads the model from the model CONF file.
// Because the policy is attached to a model, so the policy is invalidated and needs to be reloaded by calling LoadPolicy().
func (e *Enforcer) LoadModel() {
e.model = NewModel()
e.model.LoadModel(e.modelPath)
func (e *Enforcer) LoadModel() error {
var err error

e.model, err = NewModel()
if err != nil {
return err
}

err = e.model.LoadModel(e.modelPath)
if err != nil {
return err
}

e.model.PrintModel()
e.fm = model.LoadFunctionMap()

return nil
}

// GetModel gets the current model.
Expand All @@ -187,10 +232,9 @@ func (e *Enforcer) SetAdapter(adapter persist.Adapter) {
}

// SetWatcher sets the current watcher.
func (e *Enforcer) SetWatcher(watcher persist.Watcher) {
func (e *Enforcer) SetWatcher(watcher persist.Watcher) error {
e.watcher = watcher
// error intentionally ignored
watcher.SetUpdateCallback(func(string) { e.LoadPolicy() })
return watcher.SetUpdateCallback(func(string) { e.LoadPolicy() })
}

// SetRoleManager sets the current role manager.
Expand All @@ -217,7 +261,10 @@ func (e *Enforcer) LoadPolicy() error {

e.model.PrintPolicy()
if e.autoBuildRoleLinks {
e.BuildRoleLinks()
err := e.BuildRoleLinks()
if err != nil {
return err
}
}
return nil
}
Expand All @@ -241,7 +288,10 @@ func (e *Enforcer) LoadFilteredPolicy(filter interface{}) error {

e.model.PrintPolicy()
if e.autoBuildRoleLinks {
e.BuildRoleLinks()
err := e.BuildRoleLinks()
if err != nil {
return err
}
}
return nil
}
Expand Down Expand Up @@ -290,16 +340,19 @@ func (e *Enforcer) EnableAutoBuildRoleLinks(autoBuildRoleLinks bool) {
}

// BuildRoleLinks manually rebuild the role inheritance relations.
func (e *Enforcer) BuildRoleLinks() {
// error intentionally ignored
e.rm.Clear()
e.model.BuildRoleLinks(e.rm)
func (e *Enforcer) BuildRoleLinks() error {
err := e.rm.Clear()
if err != nil {
return err
}

return e.model.BuildRoleLinks(e.rm)
}

// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act).
func (e *Enforcer) Enforce(rvals ...interface{}) bool {
func (e *Enforcer) Enforce(rvals ...interface{}) (bool, error) {
if !e.enabled {
return true
return true, nil
}

functions := make(map[string]govaluate.ExpressionFunction)
Expand All @@ -316,7 +369,7 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {
expString := e.model["m"]["m"].Value
expression, err := govaluate.NewEvaluableExpressionWithFunctions(expString, functions)
if err != nil {
panic(err)
return false, err
}

rTokens := make(map[string]int, len(e.model["r"]["r"].Tokens))
Expand All @@ -341,19 +394,19 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {
policyEffects = make([]effect.Effect, policyLen)
matcherResults = make([]float64, policyLen)
if len(e.model["r"]["r"].Tokens) != len(rvals) {
panic(
return false, errors.New(
fmt.Sprintf(
"Invalid Request Definition size: expected %d got %d rvals: %v",
"invalid request size: expected %d, got %d, rvals: %v",
len(e.model["r"]["r"].Tokens),
len(rvals),
rvals))
}
for i, pvals := range e.model["p"]["p"].Policy {
// log.LogPrint("Policy Rule: ", pvals)
if len(e.model["p"]["p"].Tokens) != len(pvals) {
panic(
return false, errors.New(
fmt.Sprintf(
"Invalid Policy Rule size: expected %d got %d pvals: %v",
"invalid policy size: expected %d, got %d, pvals: %v",
len(e.model["p"]["p"].Tokens),
len(pvals),
pvals))
Expand All @@ -365,8 +418,7 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {
// log.LogPrint("Result: ", result)

if err != nil {
policyEffects[i] = effect.Indeterminate
panic(err)
return false, err
}

switch result := result.(type) {
Expand All @@ -383,7 +435,7 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {
matcherResults[i] = result
}
default:
panic(errors.New("matcher result should be bool, int or float"))
return false, errors.New("matcher result should be bool, int or float")
}

if j, ok := parameters.pTokens["p_eft"]; ok {
Expand Down Expand Up @@ -414,8 +466,7 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {
// log.LogPrint("Result: ", result)

if err != nil {
policyEffects[0] = effect.Indeterminate
panic(err)
return false, err
}

if result.(bool) {
Expand All @@ -429,7 +480,7 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {

result, err := e.eft.MergeEffects(e.model["e"]["e"].Value, policyEffects, matcherResults)
if err != nil {
panic(err)
return false, err
}

// Log request.
Expand All @@ -446,7 +497,7 @@ func (e *Enforcer) Enforce(rvals ...interface{}) bool {
log.LogPrint(reqStr)
}

return result
return result, nil
}

// assumes bounds have already been checked
Expand Down
23 changes: 16 additions & 7 deletions enforcer_cached.go
Expand Up @@ -27,13 +27,18 @@ type CachedEnforcer struct {
}

// NewCachedEnforcer creates a cached enforcer via file or DB.
func NewCachedEnforcer(params ...interface{}) *CachedEnforcer {
func NewCachedEnforcer(params ...interface{}) (*CachedEnforcer, error) {
e := &CachedEnforcer{}
e.Enforcer = NewEnforcer(params...)
var err error
e.Enforcer, err = NewEnforcer(params...)
if err != nil {
return nil, err
}

e.enableCache = true
e.m = make(map[string]bool)
e.locker = new(sync.RWMutex)
return e
return e, nil
}

// EnableCache determines whether to enable cache on Enforce(). When enableCache is enabled, cached result (true | false) will be returned for previous decisions.
Expand All @@ -43,7 +48,7 @@ func (e *CachedEnforcer) EnableCache(enableCache bool) {

// Enforce decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (sub, obj, act).
// if rvals is not string , ingore the cache
func (e *CachedEnforcer) Enforce(rvals ...interface{}) bool {
func (e *CachedEnforcer) Enforce(rvals ...interface{}) (bool, error) {
if !e.enableCache {
return e.Enforcer.Enforce(rvals...)
}
Expand All @@ -58,11 +63,15 @@ func (e *CachedEnforcer) Enforce(rvals ...interface{}) bool {
}

if res, ok := e.getCachedResult(key); ok {
return res
return res, nil
} else {
res := e.Enforcer.Enforce(rvals...)
res, err := e.Enforcer.Enforce(rvals...)
if err != nil {
return false, err
}

e.setCachedResult(key, res)
return res
return res, nil
}
}

Expand Down

0 comments on commit ce253cf

Please sign in to comment.