Skip to content

Commit

Permalink
feat: Fixed the subjectPriority sorting algorithm and support for c…
Browse files Browse the repository at this point in the history
…hecking the subject role link loop (#322)

* fix: Fixed the `subjectPriority` sorting algorithm and support for checking the subject role link loop.

* fix: Run black
  • Loading branch information
amisadmin committed Sep 23, 2023
1 parent 45bcc8b commit f964e2a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 31 deletions.
63 changes: 32 additions & 31 deletions casbin/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,20 @@ def compare_policy(policy):
name = self.get_name_with_domain(domain, policy[sub_index])
return subject_hierarchy_map.get(name, 0)

assertion.policy = sorted(assertion.policy, key=compare_policy, reverse=True)
assertion.policy = sorted(assertion.policy, key=compare_policy)
for i, policy in enumerate(assertion.policy):
assertion.policy_map[",".join(policy)] = i

def get_subject_hierarchy_map(self, policies):
subject_hierarchy_map = {}
# Tree structure of role
policy_map = {}
"""
Get the subject hierarchy from the policy.
Select the lowest level subject in multiple rounds until all subjects are selected.
Return the subject hierarchy dictionary, the subject is the key, and the level is the value.
The level starts from 0 and increases in turn. The smaller the level, the higher the priority.
"""
# Init unsorted policy, and subject
unsorted_policy = []
unsorted_sub = set()
for policy in policies:
if len(policy) < 2:
raise RuntimeError("policy g expect 2 more params")
Expand All @@ -150,33 +156,28 @@ def get_subject_hierarchy_map(self, policies):
domain = policy[2]
child = self.get_name_with_domain(domain, policy[0])
parent = self.get_name_with_domain(domain, policy[1])
if parent not in policy_map.keys():
policy_map[parent] = [child]
else:
policy_map[parent].append(child)
if child not in subject_hierarchy_map.keys():
subject_hierarchy_map[child] = 0
if parent not in subject_hierarchy_map.keys():
subject_hierarchy_map[parent] = 0
subject_hierarchy_map[child] = 1
# Use queues for levelOrder
queue = []
for k, v in subject_hierarchy_map.items():
root = k
if v != 0:
continue
lv = 0
queue.append(root)
while len(queue) != 0:
sz = len(queue)
for _ in range(sz):
node = queue.pop(0)
subject_hierarchy_map[node] = lv
if node in policy_map.keys():
for child in policy_map[node]:
queue.append(child)
lv += 1
return subject_hierarchy_map
unsorted_policy.append([child, parent])
unsorted_sub.add(child)
unsorted_sub.add(parent)
# sort policy,and update sorted_sub_list
sorted_sub_list = []
while len(unsorted_policy) > 0:
# get all parent subject
parent_sub = {p[1] for p in unsorted_policy if p[1] != ""}
# remove parent subject from unsorted_sub
sorted_sub = unsorted_sub - parent_sub
if not sorted_sub:
raise RuntimeError("cycle dependency in subject hierarchy.subjects: {}".format(unsorted_sub))
# update sorted_sub_list
sorted_sub_list.append(sorted_sub)
# remove sorted subject, and update unsorted_policy
unsorted_policy = [p for p in unsorted_policy if p[0] not in sorted_sub]
# update unsorted_sub
unsorted_sub = unsorted_sub - sorted_sub
if len(unsorted_sub) > 0:
sorted_sub_list.append(unsorted_sub)
# Tree structure of subject
return {sub: i for i, subs in enumerate(sorted_sub_list) for sub in subs}

def get_name_with_domain(self, domain, name):
return "{}{}{}".format(domain, DEFAULT_SEPARATOR, name)
Expand Down
47 changes: 47 additions & 0 deletions tests/model/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from unittest import TestCase

from casbin import Model
from casbin.model.model import DEFAULT_DOMAIN


class TestModel(TestCase):
m = Model()

def check_hierarchy(self, policies: list, subject_hierarchy_map: dict):
"""check_hierarchy checks the hierarchy of the subject hierarchy map"""
for policy in policies:
if len(policy) < 2:
raise RuntimeError("policy g expect 2 more params")
domain = DEFAULT_DOMAIN
if len(policy) != 2:
domain = policy[2]
child = self.m.get_name_with_domain(domain, policy[0])
parent = self.m.get_name_with_domain(domain, policy[1])
assert subject_hierarchy_map[child] < subject_hierarchy_map[parent]

def test_get_subject_hierarchy_map(self):
# test 1
policies = [
["A1", "B1"],
["A1", "B2"],
["A2", "B3"],
]
res = self.m.get_subject_hierarchy_map(policies)
self.check_hierarchy(policies, res)
# test 2
policies = [
["A1", "B1"],
["B1", "B2"],
["B2", "B3"],
["B1", "B4"],
["A1", "B2"],
]
res = self.m.get_subject_hierarchy_map(policies)
self.check_hierarchy(policies, res)
# test 3
policies = [
["B1", "B2"],
["B2", "B3"],
["B3", "B1"],
]
self.assertRaises(RuntimeError, self.m.get_subject_hierarchy_map, policies)

0 comments on commit f964e2a

Please sign in to comment.