Skip to content

Commit

Permalink
Merge pull request #510 from GopherJ/improve-locks
Browse files Browse the repository at this point in the history
fix: remove unnecessary locks
  • Loading branch information
hsluoyz committed Jul 3, 2020
2 parents a5d65be + 9a4650c commit 7841fd3
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 69 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ jobs:
- name: Test
run: go test -v .

- name: Test Race
run: go test -race -v .

semantic-release:
runs-on: ubuntu-latest
steps:
Expand Down
7 changes: 0 additions & 7 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"os"
"strconv"
"strings"
"sync"
)

var (
Expand All @@ -50,8 +49,6 @@ type ConfigInterface interface {

// Config represents an implementation of the ConfigInterface
type Config struct {
// map is not safe.
sync.RWMutex
// Section:key=value
data map[string]map[string]string
}
Expand Down Expand Up @@ -91,12 +88,10 @@ func (c *Config) AddConfig(section string, option string, value string) bool {
}

func (c *Config) parse(fname string) (err error) {
c.Lock()
f, err := os.Open(fname)
if err != nil {
return err
}
defer c.Unlock()
defer f.Close()

buf := bufio.NewReader(f)
Expand Down Expand Up @@ -227,8 +222,6 @@ func (c *Config) Strings(key string) []string {

// Set sets the value for the specific key in the Config
func (c *Config) Set(key string, value string) error {
c.Lock()
defer c.Unlock()
if len(key) == 0 {
return errors.New("key is empty")
}
Expand Down
5 changes: 1 addition & 4 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,7 @@ func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interfac
return true, nil
}

functions := model.FunctionMap{}
for k, v := range e.fm {
functions[k] = v
}
functions := e.fm.GetFunctions()
if _, ok := e.model["g"]; ok {
for key, ast := range e.model["g"] {
rm := ast.RM
Expand Down
17 changes: 12 additions & 5 deletions enforcer_synced.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package casbin

import (
"sync"
"sync/atomic"
"time"

"github.com/Knetic/govaluate"
Expand All @@ -28,7 +29,7 @@ type SyncedEnforcer struct {
*Enforcer
m sync.RWMutex
stopAutoLoad chan struct{}
autoLoadRunning bool
autoLoadRunning int32
}

// NewSyncedEnforcer creates a synchronized enforcer via file or DB.
Expand All @@ -41,21 +42,27 @@ func NewSyncedEnforcer(params ...interface{}) (*SyncedEnforcer, error) {
}

e.stopAutoLoad = make(chan struct{}, 1)
e.autoLoadRunning = 0
return e, nil
}

// IsAudoLoadingRunning check if SyncedEnforcer is auto loading policies
func (e *SyncedEnforcer) IsAudoLoadingRunning() bool {
return atomic.LoadInt32(&(e.autoLoadRunning)) != 0
}

// StartAutoLoadPolicy starts a go routine that will every specified duration call LoadPolicy
func (e *SyncedEnforcer) StartAutoLoadPolicy(d time.Duration) {
// Don't start another goroutine if there is already one running
if e.autoLoadRunning {
if e.IsAudoLoadingRunning() {
return
}
e.autoLoadRunning = true
atomic.StoreInt32(&(e.autoLoadRunning), int32(1))
ticker := time.NewTicker(d)
go func() {
defer func() {
ticker.Stop()
e.autoLoadRunning = false
atomic.StoreInt32(&(e.autoLoadRunning), int32(0))
}()
n := 1
log.LogPrintf("Start automatically load policy")
Expand All @@ -77,7 +84,7 @@ func (e *SyncedEnforcer) StartAutoLoadPolicy(d time.Duration) {

// StopAutoLoadPolicy causes the go routine to exit.
func (e *SyncedEnforcer) StopAutoLoadPolicy() {
if e.autoLoadRunning {
if e.IsAudoLoadingRunning() {
e.stopAutoLoad <- struct{}{}
}
}
Expand Down
4 changes: 2 additions & 2 deletions enforcer_synced_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ func TestSync(t *testing.T) {
func TestStopAutoLoadPolicy(t *testing.T) {
e, _ := NewSyncedEnforcer("examples/basic_model.conf", "examples/basic_policy.csv")
e.StartAutoLoadPolicy(5 * time.Millisecond)
if !e.autoLoadRunning {
if !e.IsAudoLoadingRunning() {
t.Error("auto load is not running")
}
e.StopAutoLoadPolicy()
// Need a moment, to exit goroutine
time.Sleep(10 * time.Millisecond)
if e.autoLoadRunning {
if e.IsAudoLoadingRunning() {
t.Error("auto load is still running")
}
}
19 changes: 13 additions & 6 deletions log/default_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,36 @@

package log

import "log"
import (
"log"
"sync/atomic"
)

// DefaultLogger is the implementation for a Logger using golang log.
type DefaultLogger struct {
enable bool
enable int32
}

func (l *DefaultLogger) EnableLog(enable bool) {
l.enable = enable
i := 0
if enable {
i = 1
}
atomic.StoreInt32(&(l.enable), int32(i))
}

func (l *DefaultLogger) IsEnabled() bool {
return l.enable
return atomic.LoadInt32(&(l.enable)) != 0
}

func (l *DefaultLogger) Print(v ...interface{}) {
if l.enable {
if l.IsEnabled() {
log.Print(v...)
}
}

func (l *DefaultLogger) Printf(format string, v ...interface{}) {
if l.enable {
if l.IsEnabled() {
log.Printf(format, v...)
}
}
4 changes: 2 additions & 2 deletions management_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (e *Enforcer) AddNamedPolicy(ptype string, params ...interface{}) (bool, er
// AddNamedPolicies adds authorization rules to the current named policy.
// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added.
// Otherwise the function returns true for the corresponding by adding the new rule.
func (e *Enforcer) AddNamedPolicies(ptype string, rules [][] string) (bool, error) {
func (e *Enforcer) AddNamedPolicies(ptype string, rules [][]string) (bool, error) {
return e.addPolicies("p", ptype, rules)
}

Expand Down Expand Up @@ -180,7 +180,7 @@ func (e *Enforcer) RemoveNamedPolicy(ptype string, params ...interface{}) (bool,
}

// RemoveNamedPolicies removes authorization rules from the current named policy.
func (e *Enforcer) RemoveNamedPolicies(ptype string, rules [][] string) (bool, error) {
func (e *Enforcer) RemoveNamedPolicies(ptype string, rules [][]string) (bool, error) {
return e.removePolicies("p", ptype, rules)
}

Expand Down
6 changes: 0 additions & 6 deletions model/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package model
import (
"errors"
"strings"
"sync"

"github.com/casbin/casbin/v2/log"
"github.com/casbin/casbin/v2/rbac"
Expand All @@ -32,12 +31,9 @@ type Assertion struct {
Policy [][]string
PolicyMap map[string]int
RM rbac.RoleManager
Mutex sync.RWMutex
}

func (ast *Assertion) buildIncrementalRoleLinks(rm rbac.RoleManager, op PolicyOp, rules [][]string) error {
ast.Mutex.RLock()
defer ast.Mutex.RUnlock()
ast.RM = rm
count := strings.Count(ast.Value, "_")
if count < 2 {
Expand Down Expand Up @@ -69,8 +65,6 @@ func (ast *Assertion) buildIncrementalRoleLinks(rm rbac.RoleManager, op PolicyOp
}

func (ast *Assertion) buildRoleLinks(rm rbac.RoleManager) error {
ast.Mutex.RLock()
defer ast.Mutex.RUnlock()
ast.RM = rm
count := strings.Count(ast.Value, "_")
if count < 2 {
Expand Down
28 changes: 23 additions & 5 deletions model/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@
package model

import (
"sync"

"github.com/Knetic/govaluate"
"github.com/casbin/casbin/v2/util"
)

// FunctionMap represents the collection of Function.
type FunctionMap map[string]govaluate.ExpressionFunction
type FunctionMap struct {
fns *sync.Map
}
// [string]govaluate.ExpressionFunction

// AddFunction adds an expression function.
func (fm FunctionMap) AddFunction(name string, function govaluate.ExpressionFunction) {
fm[name] = function
func (fm *FunctionMap) AddFunction(name string, function govaluate.ExpressionFunction) {
fm.fns.LoadOrStore(name, function)
}

// LoadFunctionMap loads an initial function map.
func LoadFunctionMap() FunctionMap {
fm := make(FunctionMap)
fm := &FunctionMap{}
fm.fns = &sync.Map{}

fm.AddFunction("keyMatch", util.KeyMatchFunc)
fm.AddFunction("keyMatch2", util.KeyMatch2Func)
Expand All @@ -39,5 +45,17 @@ func LoadFunctionMap() FunctionMap {
fm.AddFunction("ipMatch", util.IPMatchFunc)
fm.AddFunction("globMatch", util.GlobMatchFunc)

return fm
return *fm
}

// GetFunctions return a map with all the functions
func (fm *FunctionMap) GetFunctions()(map[string]govaluate.ExpressionFunction) {
ret := make(map[string]govaluate.ExpressionFunction)

fm.fns.Range(func(k interface{}, v interface{})bool {
ret[k.(string)] = v.(govaluate.ExpressionFunction)
return true
})

return ret
}
30 changes: 0 additions & 30 deletions model/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,47 +57,34 @@ func (model Model) BuildRoleLinks(rm rbac.RoleManager) error {
func (model Model) PrintPolicy() {
log.LogPrint("Policy:")
for key, ast := range model["p"] {
ast.Mutex.RLock()
defer ast.Mutex.RUnlock()
log.LogPrint(key, ": ", ast.Value, ": ", ast.Policy)
}

for key, ast := range model["g"] {
ast.Mutex.RLock()
defer ast.Mutex.RUnlock()
log.LogPrint(key, ": ", ast.Value, ": ", ast.Policy)
}
}

// ClearPolicy clears all current policy.
func (model Model) ClearPolicy() {
for _, ast := range model["p"] {
ast.Mutex.Lock()
defer ast.Mutex.Unlock()
ast.Policy = nil
ast.PolicyMap = map[string]int{}
}

for _, ast := range model["g"] {
ast.Mutex.Lock()
defer ast.Mutex.Unlock()
ast.Policy = nil
ast.PolicyMap = map[string]int{}
}
}

// GetPolicy gets all rules in a policy.
func (model Model) GetPolicy(sec string, ptype string) [][]string {
model[sec][ptype].Mutex.RLock()
defer model[sec][ptype].Mutex.RUnlock()
return model[sec][ptype].Policy
}

// GetFilteredPolicy gets rules based on field filters from a policy.
func (model Model) GetFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) [][]string {
model[sec][ptype].Mutex.RLock()
defer model[sec][ptype].Mutex.RUnlock()

res := [][]string{}

for _, rule := range model[sec][ptype].Policy {
Expand All @@ -119,16 +106,12 @@ func (model Model) GetFilteredPolicy(sec string, ptype string, fieldIndex int, f

// HasPolicy determines whether a model has the specified policy rule.
func (model Model) HasPolicy(sec string, ptype string, rule []string) bool {
model[sec][ptype].Mutex.RLock()
defer model[sec][ptype].Mutex.RUnlock()
_, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)]
return ok
}

// HasPolicies determines whether a model has any of the specified policies. If one is found we return false.
func (model Model) HasPolicies(sec string, ptype string, rules [][]string) bool {
model[sec][ptype].Mutex.RLock()
defer model[sec][ptype].Mutex.RUnlock()
for i := 0; i < len(rules); i++ {
if model.HasPolicy(sec, ptype, rules[i]) {
return true
Expand All @@ -140,16 +123,12 @@ func (model Model) HasPolicies(sec string, ptype string, rules [][]string) bool

// AddPolicy adds a policy rule to the model.
func (model Model) AddPolicy(sec string, ptype string, rule []string) {
model[sec][ptype].Mutex.Lock()
defer model[sec][ptype].Mutex.Unlock()
model[sec][ptype].Policy = append(model[sec][ptype].Policy, rule)
model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1
}

// AddPolicies adds policy rules to the model.
func (model Model) AddPolicies(sec string, ptype string, rules [][]string) {
model[sec][ptype].Mutex.Lock()
defer model[sec][ptype].Mutex.Unlock()
for _, rule := range rules {
model[sec][ptype].Policy = append(model[sec][ptype].Policy, rule)
model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1
Expand All @@ -158,8 +137,6 @@ func (model Model) AddPolicies(sec string, ptype string, rules [][]string) {

// RemovePolicy removes a policy rule from the model.
func (model Model) RemovePolicy(sec string, ptype string, rule []string) bool {
model[sec][ptype].Mutex.Lock()
defer model[sec][ptype].Mutex.Unlock()
index, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)]
if !ok {
return false
Expand All @@ -176,8 +153,6 @@ func (model Model) RemovePolicy(sec string, ptype string, rule []string) bool {

// RemovePolicies removes policy rules from the model.
func (model Model) RemovePolicies(sec string, ptype string, rules [][]string) bool {
model[sec][ptype].Mutex.Lock()
defer model[sec][ptype].Mutex.Unlock()
for _, rule := range rules {
index, ok := model[sec][ptype].PolicyMap[strings.Join(rule, DefaultSep)]
if !ok {
Expand All @@ -195,8 +170,6 @@ func (model Model) RemovePolicies(sec string, ptype string, rules [][]string) bo

// RemoveFilteredPolicy removes policy rules based on field filters from the model.
func (model Model) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (bool, [][]string) {
model[sec][ptype].Mutex.Lock()
defer model[sec][ptype].Mutex.Unlock()
var tmp [][]string
var effects [][]string
res := false
Expand Down Expand Up @@ -234,9 +207,6 @@ func (model Model) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int

// GetValuesForFieldInPolicy gets all values for a field for all rules in a policy, duplicated values are removed.
func (model Model) GetValuesForFieldInPolicy(sec string, ptype string, fieldIndex int) []string {
model[sec][ptype].Mutex.RLock()
defer model[sec][ptype].Mutex.RUnlock()

values := []string{}

for _, rule := range model[sec][ptype].Policy {
Expand Down
Loading

0 comments on commit 7841fd3

Please sign in to comment.