Skip to content

Commit be79ffe

Browse files
author
xxx
committed
add_policies bulk insert
1 parent 44738cb commit be79ffe

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

casbin_async_sqlalchemy_adapter/adapter.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from casbin import persist
1919
from casbin.persist.adapters.asyncio import AsyncAdapter
20-
from sqlalchemy import Column, Integer, String, delete
20+
from sqlalchemy import Column, Integer, String, delete, insert
2121
from sqlalchemy import or_
2222
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
2323
from sqlalchemy.future import select
@@ -183,14 +183,20 @@ async def add_policy(self, sec, ptype, rule):
183183

184184
async def add_policies(self, sec, ptype, rules):
185185
"""adds a policy rules to the storage."""
186-
if self._external_session is not None:
187-
# Use external session to add all rules in the same transaction
188-
for rule in rules:
189-
await self._save_policy_line(ptype, rule, self._external_session)
190-
else:
191-
# Use individual sessions for each rule (original behavior)
192-
for rule in rules:
193-
await self._save_policy_line(ptype, rule)
186+
if not rules:
187+
return
188+
189+
# Build rows for executemany bulk insert
190+
rows = []
191+
for rule in rules:
192+
row = {"ptype": ptype}
193+
for i, v in enumerate(rule):
194+
row[f"v{i}"] = v
195+
rows.append(row)
196+
197+
async with self._session_scope() as session:
198+
stmt = insert(self._db_class)
199+
await session.execute(stmt, rows)
194200

195201
async def remove_policy(self, sec, ptype, rule):
196202
"""removes a policy rule from the storage."""

tests/test_adapter.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,5 +393,30 @@ async def test_update_filtered_policies(self):
393393
self.assertTrue(e.enforce("bob", "data2", "read"))
394394

395395

396+
class TestBulkInsert(IsolatedAsyncioTestCase):
397+
async def test_add_policies_bulk_internal_session(self):
398+
engine = create_async_engine("sqlite+aiosqlite://", future=True)
399+
adapter = Adapter(engine)
400+
await adapter.create_table()
401+
402+
rules = [
403+
("u1", "obj1", "read"),
404+
("u2", "obj2", "write"),
405+
("u3", "obj3", "read"),
406+
]
407+
await adapter.add_policies("p", "p", rules)
408+
409+
async_session = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
410+
async with async_session() as s:
411+
# count inserted rows
412+
from sqlalchemy import select, func
413+
cnt = await s.execute(select(func.count()).select_from(CasbinRule).where(CasbinRule.ptype == "p"))
414+
assert cnt.scalar_one() == len(rules)
415+
416+
rows = (await s.execute(select(CasbinRule).order_by(CasbinRule.id))).scalars().all()
417+
tuples = [(r.v0, r.v1, r.v2) for r in rows]
418+
for r in rules:
419+
assert r in tuples
420+
396421
if __name__ == "__main__":
397422
unittest.main()

0 commit comments

Comments
 (0)