Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,75 @@ async with async_session() as session:
await session.commit()
```

## Soft Deletion Support

The adapter supports soft deletion, which marks records as deleted instead of physically removing them from the database. This is useful for:

- Maintaining audit trails
- Implementing undo functionality
- Preserving historical data
- Debugging and compliance requirements

### Basic Usage with Soft Deletion

To enable soft deletion, you need to:

1. Create a custom database model with a boolean `is_deleted` column
2. Pass the soft delete attribute to the adapter

```python
import casbin_async_sqlalchemy_adapter
import casbin
from sqlalchemy import Column, Boolean, Integer, String
from sqlalchemy.ext.asyncio import create_async_engine

# Define a custom model with soft delete support
class CasbinRuleSoftDelete(casbin_async_sqlalchemy_adapter.Base):
__tablename__ = "casbin_rule"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
v0 = Column(String(255))
v1 = Column(String(255))
v2 = Column(String(255))
v3 = Column(String(255))
v4 = Column(String(255))
v5 = Column(String(255))

# Add the soft delete column
is_deleted = Column(Boolean, default=False, index=True, nullable=False)

# Create adapter with soft delete support
engine = create_async_engine('sqlite+aiosqlite:///test.db')
adapter = casbin_async_sqlalchemy_adapter.Adapter(
engine,
db_class=CasbinRuleSoftDelete,
db_class_softdelete_attribute=CasbinRuleSoftDelete.is_deleted
)

# Create the table
await adapter.create_table()

e = casbin.AsyncEnforcer('path/to/model.conf', adapter)

# When you delete a policy, it will be soft-deleted (marked as deleted)
await e.delete_permission_for_user("alice", "data1", "read")

# The record remains in the database with is_deleted=True
# Load policy will automatically filter out soft-deleted records
await e.load_policy()
```

### How Soft Deletion Works

When soft deletion is enabled:

- **Delete operations** set the `is_deleted` flag to `True` instead of removing records
- **Load operations** automatically filter out records where `is_deleted=True`
- **Save policy** marks removed rules as deleted while preserving the records
- **Update operations** only affect non-deleted records

This feature maintains full backward compatibility - when `db_class_softdelete_attribute` is not provided, the adapter functions with hard deletion as before.

### Getting Help

Expand Down
167 changes: 141 additions & 26 deletions casbin_async_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from casbin import persist
from casbin.persist.adapters.asyncio import AsyncAdapter
from sqlalchemy import Column, Integer, String, delete, insert
from sqlalchemy import or_
from sqlalchemy import Column, Integer, String, Boolean, delete, insert
from sqlalchemy import or_, not_
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import declarative_base, sessionmaker
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(
self,
engine,
db_class=None,
db_class_softdelete_attribute=None,
filtered=False,
db_session: Optional[AsyncSession] = None,
):
Expand All @@ -74,9 +75,18 @@ def __init__(
else:
self._engine = engine

self.softdelete_attribute = None

if db_class is None:
db_class = CasbinRule
else:
if db_class_softdelete_attribute is not None and not isinstance(db_class_softdelete_attribute.type, Boolean):
msg = f"The type of db_class_softdelete_attribute needs to be {str(Boolean)!r}. "
msg += f"An attribute of type {str(type(db_class_softdelete_attribute.type))!r} was given."
raise ValueError(msg)
# Softdelete is only supported when using custom class
self.softdelete_attribute = db_class_softdelete_attribute

for attr in (
"id",
"ptype",
Expand Down Expand Up @@ -121,7 +131,9 @@ async def create_table(self):
async def load_policy(self, model):
"""loads all policy rules from the storage."""
async with self._session_scope() as session:
lines = await session.execute(select(self._db_class))
stmt = select(self._db_class)
stmt = self._softdelete_query(stmt)
lines = await session.execute(stmt)
for line in lines.scalars():
persist.load_policy_line(str(line), model)

Expand All @@ -132,6 +144,7 @@ async def load_filtered_policy(self, model, filter) -> None:
"""loads all policy rules from the storage."""
async with self._session_scope() as session:
stmt = select(self._db_class)
stmt = self._softdelete_query(stmt)
stmt = self.filter_query(stmt, filter)
result = await session.execute(stmt)
for line in result.scalars():
Expand All @@ -144,6 +157,12 @@ def filter_query(self, stmt, filter):
stmt = stmt.where(getattr(self._db_class, attr).in_(getattr(filter, attr)))
return stmt.order_by(self._db_class.id)

def _softdelete_query(self, stmt):
"""Filter out soft-deleted records if soft delete is enabled."""
if self.softdelete_attribute is not None:
stmt = stmt.where(not_(self.softdelete_attribute))
return stmt

async def _save_policy_line(self, ptype, rule, session=None):
if session is not None:
# Use provided session
Expand All @@ -161,15 +180,62 @@ async def _save_policy_line(self, ptype, rule, session=None):

async def save_policy(self, model):
"""saves all policy rules to the storage."""
# Use the default strategy when soft delete is not enabled
if self.softdelete_attribute is None:
async with self._session_scope() as session:
stmt = delete(self._db_class)
await session.execute(stmt)
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
await self._save_policy_line(ptype, rule, session)
return True

# Custom strategy for softdelete since it does not make sense to recreate all of the
# entries when using soft delete
async with self._session_scope() as session:
stmt = delete(self._db_class)
await session.execute(stmt)
stmt = select(self._db_class)
stmt = self._softdelete_query(stmt)

# Get entries that are not part of the model anymore
result = await session.execute(stmt)
lines_before_changes = result.scalars().all()

# Create new entries in the database
for sec in ["p", "g"]:
if sec not in model.model.keys():
continue
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
await self._save_policy_line(ptype, rule, session)
# Filter for rule in the database
filter_stmt = select(self._db_class).where(self._db_class.ptype == ptype)
filter_stmt = self._softdelete_query(filter_stmt)
for index, value in enumerate(rule):
v_value = getattr(self._db_class, "v{}".format(index))
filter_stmt = filter_stmt.where(v_value == value)
# If the rule is not present, create an entry in the database
result = await session.execute(filter_stmt)
if result.scalar_one_or_none() is None:
await self._save_policy_line(ptype, rule, session=session)

for line in lines_before_changes:
ptype = line.ptype
sec = ptype[0] # derived from persist.load_policy_line function
fields_with_None = [
line.v0,
line.v1,
line.v2,
line.v3,
line.v4,
line.v5,
]
rule = [element for element in fields_with_None if element is not None]
# If the rule is not part of the model, set the deletion flag to True
if not model.has_policy(sec, ptype, rule):
setattr(line, self.softdelete_attribute.name, True)

return True

async def add_policy(self, sec, ptype, rule):
Expand All @@ -196,42 +262,75 @@ async def add_policies(self, sec, ptype, rules):
async def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the storage."""
async with self._session_scope() as session:
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
for i, v in enumerate(rule):
stmt = stmt.where(getattr(self._db_class, "v{}".format(i)) == v)
r = await session.execute(stmt)

return True if r.rowcount > 0 else False
if self.softdelete_attribute is None:
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
for i, v in enumerate(rule):
stmt = stmt.where(getattr(self._db_class, "v{}".format(i)) == v)
r = await session.execute(stmt)
return True if r.rowcount > 0 else False
else:
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
stmt = self._softdelete_query(stmt)
for i, v in enumerate(rule):
stmt = stmt.where(getattr(self._db_class, "v{}".format(i)) == v)
result = await session.execute(stmt)
lines = result.scalars().all()
for line in lines:
setattr(line, self.softdelete_attribute.name, True)
return True if len(lines) > 0 else False

async def remove_policies(self, sec, ptype, rules):
"""remove policy rules from the storage."""
if not rules:
return
async with self._session_scope() as session:
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
rules = zip(*rules)
for i, rule in enumerate(rules):
stmt = stmt.where(or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule))
await session.execute(stmt)
if self.softdelete_attribute is None:
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
rules_zipped = zip(*rules)
for i, rule in enumerate(rules_zipped):
stmt = stmt.where(or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule))
await session.execute(stmt)
else:
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
stmt = self._softdelete_query(stmt)
rules_zipped = zip(*rules)
for i, rule in enumerate(rules_zipped):
stmt = stmt.where(or_(getattr(self._db_class, "v{}".format(i)) == v for v in rule))
result = await session.execute(stmt)
lines = result.scalars().all()
for line in lines:
setattr(line, self.softdelete_attribute.name, True)

async def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
"""
async with self._session_scope() as session:
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)

if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, v in enumerate(field_values):
if v != "":
v_value = getattr(self._db_class, "v{}".format(field_index + i))
stmt = stmt.where(v_value == v)
r = await session.execute(stmt)

return True if r.rowcount > 0 else False
if self.softdelete_attribute is None:
stmt = delete(self._db_class).where(self._db_class.ptype == ptype)
for i, v in enumerate(field_values):
if v != "":
v_value = getattr(self._db_class, "v{}".format(field_index + i))
stmt = stmt.where(v_value == v)
r = await session.execute(stmt)
return True if r.rowcount > 0 else False
else:
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
stmt = self._softdelete_query(stmt)
for i, v in enumerate(field_values):
if v != "":
v_value = getattr(self._db_class, "v{}".format(field_index + i))
stmt = stmt.where(v_value == v)
result = await session.execute(stmt)
lines = result.scalars().all()
for line in lines:
setattr(line, self.softdelete_attribute.name, True)
return True if len(lines) > 0 else False

async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rule: List[str]) -> None:
"""
Expand All @@ -247,6 +346,7 @@ async def update_policy(self, sec: str, ptype: str, old_rule: List[str], new_rul

async with self._session_scope() as session:
stmt = select(self._db_class).where(self._db_class.ptype == ptype)
stmt = self._softdelete_query(stmt)

# locate the old rule
for index, value in enumerate(old_rule):
Expand Down Expand Up @@ -307,9 +407,24 @@ async def _update_filtered_policies(self, new_rules, filter) -> List[List[str]]:
# Load old policies

stmt = select(self._db_class).where(self._db_class.ptype == filter.ptype)
stmt = self._softdelete_query(stmt)
filtered_stmt = self.filter_query(stmt, filter)
result = await session.execute(filtered_stmt)
old_rules = result.scalars().all()
old_rules_db = result.scalars().all()

# Convert database objects to rule lists
old_rules = []
for line in old_rules_db:
fields_with_None = [
line.v0,
line.v1,
line.v2,
line.v3,
line.v4,
line.v5,
]
rule = [element for element in fields_with_None if element is not None]
old_rules.append(rule)

# Delete old policies

Expand Down
Loading