Skip to content

Commit

Permalink
fix: revert commit #819 for #820 (#822)
Browse files Browse the repository at this point in the history
Signed-off-by: closetool <c299999999@qq.com>
  • Loading branch information
kilosonc committed Jun 24, 2021
1 parent 2c4ba4a commit d3ac22c
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 98 deletions.
84 changes: 43 additions & 41 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package casbin
import (
"errors"
"fmt"
"os"

"github.com/Knetic/govaluate"
"github.com/casbin/casbin/v2/effector"
Expand All @@ -30,8 +29,6 @@ import (
"github.com/casbin/casbin/v2/util"
)

type RoleManagerMap map[string]rbac.RoleManager

// Enforcer is the main interface for authorization enforcement and policy management.
type Enforcer struct {
modelPath string
Expand All @@ -42,7 +39,7 @@ type Enforcer struct {
adapter persist.Adapter
watcher persist.Watcher
dispatcher persist.Dispatcher
rmMap RoleManagerMap
rmMap map[string]rbac.RoleManager

enabled bool
autoSave bool
Expand Down Expand Up @@ -135,21 +132,7 @@ func NewEnforcer(params ...interface{}) (*Enforcer, error) {

// InitWithFile initializes an enforcer with a model file and a policy file.
func (e *Enforcer) InitWithFile(modelPath string, policyPath string) error {
var a persist.Adapter

// When the policy path is empty, the user only passes the model path.
// If the policy path is not empty, we have to make sure the file exists.
if policyPath != "" {
exists, err := util.PathExists(policyPath)
if err != nil {
return err
}
if !exists {
return os.ErrNotExist
}
a = fileadapter.NewAdapter(policyPath)
}

a := fileadapter.NewAdapter(policyPath)
return e.InitWithAdapter(modelPath, a)
}

Expand Down Expand Up @@ -202,6 +185,7 @@ func (e *Enforcer) SetLogger(logger log.Logger) {
}

func (e *Enforcer) initialize() {
e.rmMap = map[string]rbac.RoleManager{}
e.eft = effector.NewDefaultEffector()
e.watcher = nil

Expand All @@ -210,9 +194,7 @@ func (e *Enforcer) initialize() {
e.autoBuildRoleLinks = true
e.autoNotifyWatcher = true
e.autoNotifyDispatcher = true

e.rmMap = RoleManagerMap{}
e.initRmMap(e.model, e.rmMap)
e.initRmMap()
}

// LoadModel reloads the model from the model CONF file.
Expand Down Expand Up @@ -289,34 +271,38 @@ func (e *Enforcer) ClearPolicy() {

// LoadPolicy reloads the policy from file/database.
func (e *Enforcer) LoadPolicy() error {
newModel := model.NewModel()
e.model.CopyTo(&newModel)
newModel.ClearPolicy()
oldModel := e.model
e.model = model.NewModel()
e.model.SetLogger(oldModel.GetLogger())

var err error
rmMap := RoleManagerMap{}
defer func() {
if err == nil {
e.model = newModel
e.rmMap = rmMap
if err != nil {
e.model = oldModel
}
}()

if err = e.adapter.LoadPolicy(newModel); err != nil {
if err = e.model.LoadModelFromText(oldModel.ToText()); err != nil {
return err
}
if err = e.adapter.LoadPolicy(e.model); err != nil && err.Error() != "invalid file path, file path cannot be empty" {
return err
}

if err = e.model.SortPoliciesByPriority(); err != nil {
return err
}

if err = newModel.SortPoliciesByPriority(); err != nil {
if err = e.clearRmMap(); err != nil {
return err
}

if e.autoBuildRoleLinks {
err = e.buildRoleLinks(newModel, rmMap)
err = e.BuildRoleLinks()
if err != nil {
return err
}
}

return nil
}

Expand All @@ -338,6 +324,7 @@ func (e *Enforcer) loadFilteredPolicy(filter interface{}) error {
return err
}

e.initRmMap()
e.model.PrintPolicy()
if e.autoBuildRoleLinks {
err := e.BuildRoleLinks()
Expand Down Expand Up @@ -389,10 +376,23 @@ func (e *Enforcer) SavePolicy() error {
return nil
}

func (e *Enforcer) initRmMap(model model.Model, rmMap RoleManagerMap) {
for ptype := range model["g"] {
rmMap[ptype] = defaultrolemanager.NewRoleManager(10)
func (e *Enforcer) initRmMap() {
for ptype := range e.model["g"] {
if rm, ok := e.rmMap[ptype]; ok {
_ = rm.Clear()
} else {
e.rmMap[ptype] = defaultrolemanager.NewRoleManager(10)
}
}
}

func (e *Enforcer) clearRmMap() error {
for ptype := range e.model["g"] {
if err := e.rmMap[ptype].Clear(); err != nil {
return err
}
}
return nil
}

// EnableEnforce changes the enforcing state of Casbin, when Casbin is disabled, all access will be allowed by the Enforce() function.
Expand Down Expand Up @@ -432,12 +432,14 @@ func (e *Enforcer) EnableAutoBuildRoleLinks(autoBuildRoleLinks bool) {

// BuildRoleLinks manually rebuild the role inheritance relations.
func (e *Enforcer) BuildRoleLinks() error {
return e.buildRoleLinks(e.model, e.rmMap)
}
for _, rm := range e.rmMap {
err := rm.Clear()
if err != nil {
return err
}
}

func (e *Enforcer) buildRoleLinks(model model.Model, rmMap RoleManagerMap) error {
e.initRmMap(model, rmMap)
return model.BuildRoleLinks(rmMap)
return e.model.BuildRoleLinks(e.rmMap)
}

// BuildIncrementalRoleLinks provides incremental build the role inheritance relations.
Expand Down
4 changes: 0 additions & 4 deletions model/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ type Assertion struct {
priorityIndex int
}

func (ast *Assertion) copyTo(dest *Assertion) {
*dest = *ast
}

func (ast *Assertion) buildIncrementalRoleLinks(rm rbac.RoleManager, op PolicyOp, rules [][]string) error {
ast.RM = rm
count := strings.Count(ast.Value, "_")
Expand Down
16 changes: 3 additions & 13 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package model

import (
"fmt"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -205,18 +206,6 @@ func (model Model) PrintModel() {
model.GetLogger().LogModel(modelInfo)
}

func (model Model) CopyTo(dest *Model) {
for modelKey, modelValue := range model {
astMap := make(AssertionMap)
(*dest)[modelKey] = astMap
for key, value := range modelValue {
ast := new(Assertion)
value.copyTo(ast)
astMap[key] = ast
}
}
}

func (model Model) SortPoliciesByPriority() error {
for ptype, assertion := range model["p"] {
for index, token := range assertion.Tokens {
Expand Down Expand Up @@ -250,9 +239,10 @@ func (model Model) SortPoliciesByPriority() error {
func (model Model) ToText() string {
tokenPatterns := make(map[string]string)

pPattern, rPattern := regexp.MustCompile("^p_"), regexp.MustCompile("^r_")
for _, ptype := range []string{"r", "p"} {
for _, token := range model[ptype][ptype].Tokens {
tokenPatterns[token] = strings.Replace(token, "_", ".", -1)
tokenPatterns[token] = rPattern.ReplaceAllString(pPattern.ReplaceAllString(token, "p."), "r.")
}
}
s := strings.Builder{}
Expand Down
28 changes: 1 addition & 27 deletions model/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
package model

import (
"fmt"
"github.com/casbin/casbin/v2/config"
"io/ioutil"
"path/filepath"
"reflect"
"strings"
"testing"

"github.com/casbin/casbin/v2/config"
)

var (
Expand Down Expand Up @@ -126,29 +123,6 @@ func TestModel_AddDef(t *testing.T) {
}
}

func TestModel_CopyTo(t *testing.T) {
a := NewModel()
a["p"] = make(AssertionMap)
a["p"]["p"] = new(Assertion)
a["p"]["p"].Policy = [][]string{{"1"}, {"2"}}

b := NewModel()
a.CopyTo(&b)

if fmt.Sprintf("%p", a["p"]) == fmt.Sprintf("%p", b["p"]) {
t.Fatal(`the memory address of a["p"] and b["p"] should not be equal`)
}

if fmt.Sprintf("%p", a["p"]["p"]) == fmt.Sprintf("%p", b["p"]["p"]) {
t.Fatal(`the memory address of a["p"]["p"] and b["p"]["p"] should not be equal`)
}

a["p"]["p"].Policy = nil
if reflect.DeepEqual(a["p"]["p"], b["p"]["p"]) {
t.Fatal(`the a["p"]["p"] and b["p"]["p"] should not be equal`)
}
}

func TestModelToTest(t *testing.T) {
testModelToText(t, "r.sub == p.sub && r.obj == p.obj && r_func(r.act, p.act) && testr_func(r.act, p.act)", "r_sub == p_sub && r_obj == p_obj && r_func(r_act, p_act) && testr_func(r_act, p_act)")
testModelToText(t, "r.sub == p.sub && r.obj == p.obj && p_func(r.act, p.act) && testp_func(r.act, p.act)", "r_sub == p_sub && r_obj == p_obj && p_func(r_act, p_act) && testp_func(r_act, p_act)")
Expand Down
8 changes: 8 additions & 0 deletions model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package casbin

import (
"fmt"
"testing"

"github.com/casbin/casbin/v2/log"
Expand Down Expand Up @@ -547,6 +548,13 @@ func newTestSubject(name string, age int) testSub {

func TestABACPolicy(t *testing.T) {
e, _ := NewEnforcer("examples/abac_rule_model.conf", "examples/abac_rule_policy.csv")
m := e.GetModel()
for sec, ast := range m {
fmt.Println(sec)
for ptype, p := range ast {
fmt.Println(ptype, p)
}
}
sub1 := newTestSubject("alice", 16)
sub2 := newTestSubject("alice", 20)
sub3 := newTestSubject("alice", 65)
Expand Down
13 changes: 0 additions & 13 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package util

import (
"os"
"regexp"
"sort"
"strings"
Expand Down Expand Up @@ -198,15 +197,3 @@ func RemoveDuplicateElement(s []string) []string {
}
return result
}

func PathExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsExist(err) {
return true, nil
} else {
return false, err
}
}

0 comments on commit d3ac22c

Please sign in to comment.