Skip to content

Commit

Permalink
Prohibit concurrent operations on the same transaction object (#430)
Browse files Browse the repository at this point in the history
Co-authored-by: Elvis Pranskevichus <elvis@edgedb.com>
  • Loading branch information
fantix and elprans committed May 26, 2023
1 parent 2de7e3f commit f1fa612
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 13 deletions.
30 changes: 27 additions & 3 deletions edgedb/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


import asyncio
import contextlib
import logging
import socket
import ssl
Expand Down Expand Up @@ -273,11 +274,12 @@ def _warn_on_long_close(self):

class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor):

__slots__ = ("_managed",)
__slots__ = ("_managed", "_locked")

def __init__(self, retry, client, iteration):
super().__init__(retry, client, iteration)
self._managed = False
self._locked = False

async def __aenter__(self):
if self._managed:
Expand All @@ -287,8 +289,9 @@ async def __aenter__(self):
return self

async def __aexit__(self, extype, ex, tb):
self._managed = False
return await self._exit(extype, ex)
with self._exclusive():
self._managed = False
return await self._exit(extype, ex)

async def _ensure_transaction(self):
if not self._managed:
Expand All @@ -298,6 +301,27 @@ async def _ensure_transaction(self):
)
await super()._ensure_transaction()

async def _query(self, query_context: abstract.QueryContext):
with self._exclusive():
return await super()._query(query_context)

async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
with self._exclusive():
await super()._execute(execute_context)

@contextlib.contextmanager
def _exclusive(self):
if self._locked:
raise errors.InterfaceError(
"concurrent queries within the same transaction "
"are not allowed"
)
self._locked = True
try:
yield
finally:
self._locked = False


class AsyncIORetry(transaction.BaseRetry):

Expand Down
38 changes: 28 additions & 10 deletions edgedb/blocking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#


import contextlib
import datetime
import queue
import socket
Expand Down Expand Up @@ -271,22 +272,25 @@ async def close(self, timeout=None):

class Iteration(transaction.BaseTransaction, abstract.Executor):

__slots__ = ("_managed",)
__slots__ = ("_managed", "_lock")

def __init__(self, retry, client, iteration):
super().__init__(retry, client, iteration)
self._managed = False
self._lock = threading.Lock()

def __enter__(self):
if self._managed:
raise errors.InterfaceError(
'cannot enter context: already in a `with` block')
self._managed = True
return self
with self._exclusive():
if self._managed:
raise errors.InterfaceError(
'cannot enter context: already in a `with` block')
self._managed = True
return self

def __exit__(self, extype, ex, tb):
self._managed = False
return self._client._iter_coroutine(self._exit(extype, ex))
with self._exclusive():
self._managed = False
return self._client._iter_coroutine(self._exit(extype, ex))

async def _ensure_transaction(self):
if not self._managed:
Expand All @@ -297,10 +301,24 @@ async def _ensure_transaction(self):
await super()._ensure_transaction()

def _query(self, query_context: abstract.QueryContext):
return self._client._iter_coroutine(super()._query(query_context))
with self._exclusive():
return self._client._iter_coroutine(super()._query(query_context))

def _execute(self, execute_context: abstract.ExecuteContext) -> None:
self._client._iter_coroutine(super()._execute(execute_context))
with self._exclusive():
self._client._iter_coroutine(super()._execute(execute_context))

@contextlib.contextmanager
def _exclusive(self):
if not self._lock.acquire(blocking=False):
raise errors.InterfaceError(
"concurrent queries within the same transaction "
"are not allowed"
)
try:
yield
finally:
self._lock.release()


class Retry(transaction.BaseRetry):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_async_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#

import asyncio
import itertools

import edgedb
Expand Down Expand Up @@ -89,3 +90,17 @@ async def test_async_transaction_commit_failure(self):
async with tx:
await tx.execute("start migration to {};")
self.assertEqual(await self.client.query_single("select 42"), 42)

async def test_async_transaction_exclusive(self):
async for tx in self.client.transaction():
async with tx:
query = "select sys::_sleep(0.01)"
f1 = self.loop.create_task(tx.execute(query))
f2 = self.loop.create_task(tx.execute(query))
with self.assertRaisesRegex(
edgedb.InterfaceError,
"concurrent queries within the same transaction "
"are not allowed"
):
await asyncio.wait_for(f1, timeout=5)
await asyncio.wait_for(f2, timeout=5)
16 changes: 16 additions & 0 deletions tests/test_sync_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#

import itertools
from concurrent.futures import ThreadPoolExecutor

import edgedb

Expand Down Expand Up @@ -97,3 +98,18 @@ def test_sync_transaction_commit_failure(self):
with tx:
tx.execute("start migration to {};")
self.assertEqual(self.client.query_single("select 42"), 42)

def test_sync_transaction_exclusive(self):
for tx in self.client.transaction():
with tx:
query = "select sys::_sleep(0.01)"
with ThreadPoolExecutor(max_workers=2) as executor:
f1 = executor.submit(tx.execute, query)
f2 = executor.submit(tx.execute, query)
with self.assertRaisesRegex(
edgedb.InterfaceError,
"concurrent queries within the same transaction "
"are not allowed"
):
f1.result(timeout=5)
f2.result(timeout=5)

0 comments on commit f1fa612

Please sign in to comment.