Skip to content

Commit

Permalink
Added the remove_policy function and optimized policy related code.
Browse files Browse the repository at this point in the history
  • Loading branch information
leeqvip committed Jul 15, 2019
1 parent 3b89ade commit 6b8746f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 5 deletions.
27 changes: 22 additions & 5 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ def __init__(self):
self.model = {}

def build_role_links(self, rm):
"""initializes the roles in RBAC."""

if "g" not in self.model.keys():
return

for ast in self.model["g"].values():
ast.build_role_links(rm)

def print_policy(self):
"""prints the policy to log."""

log.log_print("Policy:")
for sec in ["p", "g"]:
if sec not in self.model.keys():
Expand All @@ -22,6 +26,8 @@ def print_policy(self):
log.log_print(key, ": ", ast.value, ": ", ast.policy)

def clear_policy(self):
"""clears all current policy."""

for sec in ["p", "g"]:
if sec not in self.model.keys():
continue
Expand All @@ -30,6 +36,8 @@ def clear_policy(self):
self.model[sec][key].policy = []

def get_policy(self, sec, ptype):
"""gets all rules in a policy."""

return self.model[sec][ptype].policy

def get_filtered_policy(self, sec, ptype, field_index, *field_values):
Expand All @@ -55,11 +63,7 @@ def has_policy(self, sec, ptype, rule):
if ptype not in self.model[sec]:
return False

for r in self.model[sec][ptype].policy:
if rule == r:
return True

return False
return rule in self.model[sec][ptype].policy

def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the model."""
Expand All @@ -70,6 +74,19 @@ def add_policy(self, sec, ptype, rule):

return False

def remove_policy(self, sec, ptype, rule):
"""removePolicy removes a policy rule from the model."""

if sec not in self.model.keys():
return False
if ptype not in self.model[sec]:
return False

if not self.has_policy(sec, ptype, rule):
return False

return self.model[sec][ptype].policy.remove(rule)

def get_values_for_field_in_policy(self, sec, ptype, field_index):
"""gets all values for a field for all rules in a policy, duplicated values are removed."""

Expand Down
47 changes: 47 additions & 0 deletions tests/model/test_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from unittest import TestCase
from casbin.model import Model
from tests.test_enforcer import get_examples


class TestPolicy(TestCase):
def test_get_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rule = ['admin', 'domain1', 'data1', 'read']

m.add_policy('p', 'p', rule)

self.assertTrue(m.get_policy('p', 'p') == [rule])

def test_has_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rule = ['admin', 'domain1', 'data1', 'read']
m.add_policy('p', 'p', rule)

self.assertTrue(m.has_policy('p', 'p', rule))

def test_add_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rule = ['admin', 'domain1', 'data1', 'read']

self.assertFalse(m.has_policy('p', 'p', rule))

m.add_policy('p', 'p', rule)
self.assertTrue(m.has_policy('p', 'p', rule))

def test_remove_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rule = ['admin', 'domain1', 'data1', 'read']
m.add_policy('p', 'p', rule)
self.assertTrue(m.has_policy('p', 'p', rule))

m.remove_policy('p', 'p', rule)
self.assertFalse(m.has_policy('p', 'p', rule))
self.assertFalse(m.remove_policy('p', 'p', rule))

0 comments on commit 6b8746f

Please sign in to comment.