Skip to content

Commit

Permalink
feat: Added support for async watcher callbacks
Browse files Browse the repository at this point in the history
AsyncEnforcer can now also await Watcher callbacks if they are Coroutines.
  • Loading branch information
tanasecucliciu committed Jan 29, 2024
1 parent e409434 commit eabb1bd
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions casbin/async_internal_enforcer.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect

from casbin.core_enforcer import CoreEnforcer
from casbin.model import Model, FunctionMap
Expand Down Expand Up @@ -105,8 +106,12 @@ async def save_policy(self):
await self.adapter.save_policy(self.model)

if self.watcher:
if callable(getattr(self.watcher, "update_for_save_policy", None)):
self.watcher.update_for_save_policy(self.model)
update_for_save_policy = getattr(self.watcher, "update_for_save_policy", None)
if callable(update_for_save_policy):
if inspect.iscoroutinefunction(update_for_save_policy):
await update_for_save_policy(self.model)
else:
update_for_save_policy(self.model)
else:
self.watcher.update()

Expand All @@ -122,8 +127,12 @@ async def _add_policy(self, sec, ptype, rule):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_add_policy", None)):
self.watcher.update_for_add_policy(sec, ptype, rule)
update_for_add_policy = getattr(self.watcher, "update_for_add_policy", None)
if callable(update_for_add_policy):
if inspect.iscoroutinefunction(update_for_add_policy):
await update_for_add_policy(sec, ptype, rule)
else:
update_for_add_policy(sec, ptype, rule)
else:
self.watcher.update()

Expand All @@ -144,8 +153,12 @@ async def _add_policies(self, sec, ptype, rules):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_add_policies", None)):
self.watcher.update_for_add_policies(sec, ptype, rules)
update_for_add_policies = getattr(self.watcher, "update_for_add_policies", None)
if callable(update_for_add_policies):
if inspect.iscoroutinefunction(update_for_add_policies):
await update_for_add_policies(sec, ptype, rules)
else:
update_for_add_policies(sec, ptype, rules)
else:
self.watcher.update()

Expand Down Expand Up @@ -224,8 +237,12 @@ async def _remove_policy(self, sec, ptype, rule):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_policy", None)):
self.watcher.update_for_remove_policy(sec, ptype, rule)
update_for_remove_policy = getattr(self.watcher, "update_for_remove_policy", None)
if callable(update_for_remove_policy):
if inspect.iscoroutinefunction(update_for_remove_policy):
await update_for_remove_policy(sec, ptype, rule)
else:
update_for_remove_policy(sec, ptype, rule)
else:
self.watcher.update()

Expand All @@ -246,8 +263,12 @@ async def _remove_policies(self, sec, ptype, rules):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_policies", None)):
self.watcher.update_for_remove_policies(sec, ptype, rules)
update_for_remove_policies = getattr(self.watcher, "update_for_remove_policies", None)
if callable(update_for_remove_policies):
if inspect.iscoroutinefunction(update_for_remove_policies):
await update_for_remove_policies(sec, ptype, rules)
else:
update_for_remove_policies(sec, ptype, rules)
else:
self.watcher.update()

Expand All @@ -265,8 +286,12 @@ async def _remove_filtered_policy(self, sec, ptype, field_index, *field_values):
return False

if self.watcher and self.auto_notify_watcher:
if callable(getattr(self.watcher, "update_for_remove_filtered_policy", None)):
self.watcher.update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
update_for_remove_filtered_policy = getattr(self.watcher, "update_for_remove_filtered_policy", None)
if callable(update_for_remove_filtered_policy):
if inspect.iscoroutinefunction(update_for_remove_filtered_policy):
await update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
update_for_remove_filtered_policy(sec, ptype, field_index, *field_values)
else:
self.watcher.update()

Expand Down

0 comments on commit eabb1bd

Please sign in to comment.