Skip to content

Commit

Permalink
Merge pull request #69 from Xhy-5000/batch
Browse files Browse the repository at this point in the history
Add batch function
  • Loading branch information
leeqvip committed Aug 19, 2020
2 parents 2bdfefa + 22e4254 commit 76472f1
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 0 deletions.
55 changes: 55 additions & 0 deletions casbin/internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ def _add_policy(self, sec, ptype, rule):

return rule_added

def _add_policies(self, sec, ptype, rules):
"""adds rules to the current policy."""

if self.adapter and self.auto_save and not self.model.has_policies(sec, ptype, rules):
try:
rules_added = self.adapter.add_policies(sec, ptype, rules)
return rules_added
except:
print("Batch function haven't been implemented in this adapter")

for rule in rules:
rule_added = self.model.add_policy(sec, ptype, rule)
if not rule_added:
continue

if self.adapter and self.auto_save:
if self.adapter.add_policy(sec, ptype, rule) is False:
continue

if self.watcher:
self.watcher.update()

return True

else:
return False


def _remove_policy(self, sec, ptype, rule):
"""removes a rule from the current policy."""
rule_removed = self.model.remove_policy(sec, ptype, rule)
Expand All @@ -37,6 +65,33 @@ def _remove_policy(self, sec, ptype, rule):

return rule_removed

def _remove_policies(self, sec, ptype, rules):
"""removes rules from the current policy."""
if self.adapter and self.auto_save:
try:
rules_added = self.adapter.remove_policies(sec, ptype, rules)
return rules_added
except:
print("Batch function haven't been implemented in this adapter")

for rule in rules:
rule_removed = self.model.remove_policy(sec, ptype, rule)
if not rule_removed:
continue

if self.adapter and self.auto_save:
if self.adapter.remove_policy(sec, ptype, rule) is False:
continue

if self.watcher:
self.watcher.update()

return True

else:
return False


def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes rules based on field filters from the current policy."""
rule_removed = self.model.remove_filtered_policy(sec, ptype, field_index, *field_values)
Expand Down
61 changes: 61 additions & 0 deletions casbin/management_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,30 @@ def add_named_policy(self, ptype, *params):

return rule_added

def add_policies(self, rules):
"""adds authorization rules to the current policy.
If one of the rules already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rules.
"""
return self.add_named_policies('p', rules)

def add_named_policies(self, ptype, rules):
"""adds authorization rules to the current policy.
If one of the rules already exists, the function returns false and the rule will not be added.
Otherwise the function returns true by adding the new rules.
"""
return self._add_policies('p', ptype, rules)

def remove_policy(self, *params):
"""removes an authorization rule from the current policy."""
return self.remove_named_policy('p', *params)

def remove_policies(self, rules):
'''removes authorization rules from the current policy.'''
return self.remove_named_policies('p', rules)

def remove_filtered_policy(self, field_index, *field_values):
"""removes an authorization rule from the current policy, field filters can be specified."""
return self.remove_filtered_named_policy('p', field_index, *field_values)
Expand All @@ -139,6 +159,10 @@ def remove_named_policy(self, ptype, *params):

return rule_removed

def remove_named_policies(self, ptype, rules):
'''removes authorization rules from the current named policy.'''
return self._remove_policies('p', ptype, rules)

def remove_filtered_named_policy(self, ptype, field_index, *field_values):
"""removes an authorization rule from the current named policy, field filters can be specified."""
return self._remove_filtered_policy('p', ptype, field_index, *field_values)
Expand Down Expand Up @@ -170,6 +194,15 @@ def add_grouping_policy(self, *params):
"""
return self.add_named_grouping_policy('g', *params)

def add_grouping_policies(self, rules):
"""adds role inheritance rules to the current policy.
If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
Otherwise the function returns true for the corresponding policy rule by adding the new rule.
"""

return self.add_named_grouping_policies('g', rules)

def add_named_grouping_policy(self, ptype, *params):
"""adds a named role inheritance rule to the current policy.
Expand All @@ -192,10 +225,34 @@ def add_named_grouping_policy(self, ptype, *params):
self.build_role_links()
return rule_added

def add_named_grouping_policies(self, ptype, rules):
"""adds role inheritance rules to the current policy.
If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added.
Otherwise the function returns true for the corresponding policy rule by adding the new rule.
"""

# return self._add_policies('g', ptype, rules)
for params in rules:
if len(params) == 1:
str_slice = params[0]
rule_added = self._add_policy('g', ptype, str_slice)
else:
rule_added = self._add_policy('g', ptype, params.copy())

if self.auto_build_role_links:
self.build_role_links()

return True

def remove_grouping_policy(self, *params):
"""removes a role inheritance rule from the current policy."""
return self.remove_named_grouping_policy('g', *params)

def remove_grouping_policies(self, rules):
'''removes role inheritance rules from the current policy.'''
return self.remove_named_grouping_policies('g', rules)

def remove_filtered_grouping_policy(self, field_index, *field_values):
"""removes a role inheritance rule from the current policy, field filters can be specified."""
return self.remove_filtered_named_grouping_policy('g', field_index, *field_values)
Expand All @@ -218,6 +275,10 @@ def remove_named_grouping_policy(self, ptype, *params):
self.build_role_links()
return rule_removed

def remove_named_grouping_policies(self, ptype, rules):
'''removes role inheritance rules from the current named policy.'''
return self._remove_policies('g', ptype, rules)

def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_values):
"""removes a role inheritance rule from the current named policy, field filters can be specified."""
rule_removed = self._remove_filtered_policy('g', ptype, field_index, *field_values)
Expand Down
41 changes: 41 additions & 0 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def has_policy(self, sec, ptype, rule):

return rule in self.model[sec][ptype].policy

def has_policies(self, sec, ptype, rules):
'''determines whether a model has any of the specified policies. If one is found we return false.'''
if sec not in self.model.keys():
return False
if ptype not in self.model[sec]:
return False

for rule in rules:
if rule not in self.model[sec][ptype].policy:
return False

return True

def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the model."""

Expand All @@ -74,6 +87,17 @@ def add_policy(self, sec, ptype, rule):

return False

def add_policies(self, sec, ptype, rules):
"""adds rules to the current policy."""

for rule in rules:
if not self.has_policy(sec, ptype, rule):
self.model[sec][ptype].policy.append(rule)
else:
return False

return True

def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the model."""

Expand All @@ -89,6 +113,23 @@ def remove_policy(self, sec, ptype, rule):

return rule not in self.model[sec][ptype].policy

def remove_policies(self, sec, ptype, rules):
'''removes policy rules from the model.'''

if sec not in self.model.keys():
return False
if ptype not in self.model[sec]:
return False

for rule in rules:
if not self.has_policy(sec, ptype, rule):
return False

self.model[sec][ptype].policy.remove(rule)

return rules not in self.model[sec][ptype].policy


def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules based on field filters from the model."""
tmp = []
Expand Down
12 changes: 12 additions & 0 deletions casbin/persist/batch_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class BatchAdapter:
'''the interface for Casbin adapters with multiple add and remove policy functions'''

def add_policies(self, sec, ptype, rules):
'''AddPolicies adds policy rules to the storage.
This is part of the Auto-Save feature.'''
pass

def remove_policies(self, sec, ptype, rules):
'''RemovePolicies removes policy rules from the storage.
This is part of the Auto-Save feature.'''
pass
43 changes: 43 additions & 0 deletions tests/model/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ def test_has_policy(self):

self.assertTrue(m.has_policy('p', 'p', rule))

def test_has_policies(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rules = [
['p', 'alice', 'data1', 'read'],
['p', 'bob', 'data2', 'write'],
]
m.add_policies('p', 'p', rules)

self.assertTrue(m.has_policies('p', 'p', rules))

def test_add_policy(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))
Expand All @@ -34,6 +46,20 @@ def test_add_policy(self):
m.add_policy('p', 'p', rule)
self.assertTrue(m.has_policy('p', 'p', rule))

def test_add_policies(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rules = [
['alice', 'data1', 'read'],
['bob', 'data2', 'write'],
]

self.assertFalse(m.has_policies('p', 'p', rules))

m.add_policies('p', 'p', rules)
self.assertTrue(m.has_policies('p', 'p', rules))

def test_add_role_policy(self):
m = Model()
m.load_model(get_examples("rbac_model.conf"))
Expand Down Expand Up @@ -65,6 +91,23 @@ def test_remove_policy(self):
self.assertFalse(m.has_policy('p', 'p', rule))
self.assertFalse(m.remove_policy('p', 'p', rule))

def test_remove_policies(self):
m = Model()
m.load_model(get_examples("basic_model.conf"))

rules = [
['alice', 'data1', 'read'],
['bob', 'data2', 'write'],
]

m.add_policies('p', 'p', rules)
self.assertTrue(m.has_policies('p', 'p', rules))

m.remove_policies('p', 'p', rules)
self.assertFalse(m.has_policies('p', 'p', rules))



def test_remove_filtered_policy(self):
m = Model()
m.load_model(get_examples("rbac_with_domains_model.conf"))
Expand Down
56 changes: 56 additions & 0 deletions tests/test_management_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,59 @@ def test_modify_policy_api(self):
['eve', 'data3', 'read'],
['eve', 'data3', 'write'],
])

rules = [
["jack", "data4", "read"],
["jack", "data4", "read"],
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["katy", "data4", "write"],
["katy", "data4", "write"],
["ham", "data4", "write"]
]

e.add_policies(rules)
e.add_named_policies('p', rules)
self.assertEqual(e.get_policy(),[
['alice', 'data1', 'read'],
['bob', 'data2', 'write'],
['data2_admin', 'data2', 'read'],
['data2_admin', 'data2', 'write'],
['eve', 'data3', 'read'],
['eve', 'data3', 'write'],
["jack", "data4", "read"],
["katy", "data4", "write"],
["leyo", "data4", "read"],
["ham", "data4", "write"],
])

e.remove_policies(rules)
e.remove_named_policies('p', rules)
self.assertEqual(e.get_policy(), [
['alice', 'data1', 'read'],
['bob', 'data2', 'write'],
['data2_admin', 'data2', 'read'],
['data2_admin', 'data2', 'write'],
['eve', 'data3', 'read'],
['eve', 'data3', 'write'],
])

grouping_rules = [
["ham", "data4_admin"],
["jack", "data5_admin"],
]

e.add_grouping_policies(grouping_rules)
e.add_named_grouping_policies('g', grouping_rules)
self.assertEqual(e.get_grouping_policy(),[
["alice", "data2_admin"],
["ham", "data4_admin"],
["jack", "data5_admin"],
])

e.remove_policies(grouping_rules)
e.remove_named_grouping_policies('g', grouping_rules)
self.assertEqual(e.get_grouping_policy(), [
["alice", "data2_admin"],
])

0 comments on commit 76472f1

Please sign in to comment.