Skip to content

Commit

Permalink
fix: lost model data if LoadPolicy failed
Browse files Browse the repository at this point in the history
Signed-off-by: closetool <c299999999@qq.com>
  • Loading branch information
kilosonc committed Jul 1, 2021
1 parent d3ac22c commit eaf1fc1
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
25 changes: 7 additions & 18 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,20 +271,21 @@ func (e *Enforcer) ClearPolicy() {

// LoadPolicy reloads the policy from file/database.
func (e *Enforcer) LoadPolicy() error {
needToRebuild := false
oldModel := e.model
e.model = model.NewModel()
e.model.SetLogger(oldModel.GetLogger())
e.model = oldModel.Copy()
e.ClearPolicy()

var err error
defer func() {
if err != nil {
e.model = oldModel
if e.autoBuildRoleLinks && needToRebuild {
_ = e.BuildRoleLinks()
}
}
}()

if err = e.model.LoadModelFromText(oldModel.ToText()); err != nil {
return err
}
if err = e.adapter.LoadPolicy(e.model); err != nil && err.Error() != "invalid file path, file path cannot be empty" {
return err
}
Expand All @@ -293,11 +294,8 @@ func (e *Enforcer) LoadPolicy() error {
return err
}

if err = e.clearRmMap(); err != nil {
return err
}

if e.autoBuildRoleLinks {
needToRebuild = true
err = e.BuildRoleLinks()
if err != nil {
return err
Expand Down Expand Up @@ -386,15 +384,6 @@ func (e *Enforcer) initRmMap() {
}
}

func (e *Enforcer) clearRmMap() error {
for ptype := range e.model["g"] {
if err := e.rmMap[ptype].Clear(); err != nil {
return err
}
}
return nil
}

// EnableEnforce changes the enforcing state of Casbin, when Casbin is disabled, all access will be allowed by the Enforce() function.
func (e *Enforcer) EnableEnforce(enable bool) {
e.enabled = enable
Expand Down
21 changes: 21 additions & 0 deletions enforcer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,24 @@ func TestPriorityExplicit(t *testing.T) {
true, true, false, false, false, false, true, true,
})
}

func TestFailedToLoadPolicy(t *testing.T) {
e, _ := NewEnforcer("examples/rbac_with_pattern_model.conf", "examples/rbac_with_pattern_policy.csv")
e.AddNamedMatchingFunc("g2", "matchingFunc", util.KeyMatch2)
testEnforce(t, e, "alice", "/book/1", "GET", true)
testEnforce(t, e, "bob", "/pen/3", "GET", true)
e.SetAdapter(fileadapter.NewAdapter("not found"))
_ = e.LoadPolicy()
testEnforce(t, e, "alice", "/book/1", "GET", true)
testEnforce(t, e, "bob", "/pen/3", "GET", true)
}

func TestReloadPolicyWithFunc(t *testing.T) {
e, _ := NewEnforcer("examples/rbac_with_pattern_model.conf", "examples/rbac_with_pattern_policy.csv")
e.AddNamedMatchingFunc("g2", "matchingFunc", util.KeyMatch2)
testEnforce(t, e, "alice", "/book/1", "GET", true)
testEnforce(t, e, "bob", "/pen/3", "GET", true)
_ = e.LoadPolicy()
testEnforce(t, e, "alice", "/book/1", "GET", true)
testEnforce(t, e, "bob", "/pen/3", "GET", true)
}
24 changes: 24 additions & 0 deletions model/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,27 @@ func (ast *Assertion) setLogger(logger log.Logger) {
func (ast *Assertion) initPriorityIndex() {
ast.priorityIndex = -1
}

func (ast *Assertion) copy() *Assertion {
tokens := append([]string(nil), ast.Tokens...)
policy := make([][]string, len(ast.Policy))

for i, p := range ast.Policy {
policy[i] = append(policy[i], p...)
}
policyMap := make(map[string]int)
for k, v := range ast.PolicyMap {
policyMap[k] = v
}

newAst := &Assertion{
Key: ast.Key,
Value: ast.Value,
PolicyMap: policyMap,
Tokens: tokens,
Policy: policy,
priorityIndex: ast.priorityIndex,
}

return newAst
}
15 changes: 15 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,18 @@ func (model Model) ToText() string {
writeString("m")
return s.String()
}

func (model Model) Copy() Model {
newModel := NewModel()

for sec, m := range model {
newAstMap := make(AssertionMap)
for ptype, ast := range m {
newAstMap[ptype] = ast.copy()
}
newModel[sec] = newAstMap
}

newModel.SetLogger(model.GetLogger())
return newModel
}

0 comments on commit eaf1fc1

Please sign in to comment.