Skip to content

Commit

Permalink
Correct edgedb.Client.close() timeout behavior
Browse files Browse the repository at this point in the history
Also added tests in test_sync_query.py, and use only sync client in
sync_* tests.
  • Loading branch information
fantix committed Oct 20, 2022
1 parent 6d0d6ab commit 33a912c
Show file tree
Hide file tree
Showing 13 changed files with 628 additions and 110 deletions.
81 changes: 46 additions & 35 deletions edgedb/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,13 @@ def connection(self):
def is_proto_lt_1_0(self):
return self.connection._protocol.is_legacy

@property
def dbname(self):
return self._impl._working_params.database


class ConnectedTestCaseMixin:
is_client_async = True

@classmethod
def make_test_client(
Expand All @@ -362,11 +367,17 @@ def make_test_client(
database='edgedb',
user='edgedb',
password='test',
connection_class=asyncio_client.AsyncIOConnection,
connection_class=...,
):
conargs = cls.get_connect_args(
cluster=cluster, database=database, user=user, password=password)
return TestAsyncIOClient(
if connection_class is ...:
connection_class = (
asyncio_client.AsyncIOConnection
if cls.is_client_async
else blocking_client.BlockingIOConnection
)
return (TestAsyncIOClient if cls.is_client_async else TestClient)(
connection_class=connection_class,
max_concurrency=1,
**conargs,
Expand All @@ -384,6 +395,10 @@ def get_connect_args(cls, *,
database=database))
return conargs

@classmethod
def adapt_call(cls, coro):
return cls.loop.run_until_complete(coro)


class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin):
SETUP = None
Expand All @@ -398,15 +413,15 @@ class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin):

def setUp(self):
if self.SETUP_METHOD:
self.loop.run_until_complete(
self.adapt_call(
self.client.execute(self.SETUP_METHOD))

super().setUp()

def tearDown(self):
try:
if self.TEARDOWN_METHOD:
self.loop.run_until_complete(
self.adapt_call(
self.client.execute(self.TEARDOWN_METHOD))
finally:
try:
Expand All @@ -431,7 +446,7 @@ def setUpClass(cls):
if not class_set_up:
script = f'CREATE DATABASE {dbname};'
cls.admin_client = cls.make_test_client()
cls.loop.run_until_complete(cls.admin_client.execute(script))
cls.adapt_call(cls.admin_client.execute(script))

cls.client = cls.make_test_client(database=dbname)

Expand All @@ -440,11 +455,17 @@ def setUpClass(cls):
if script:
# The setup is expected to contain a CREATE MIGRATION,
# which needs to be wrapped in a transaction.
async def execute():
async for tr in cls.client.transaction():
async with tr:
await tr.execute(script)
cls.loop.run_until_complete(execute())
if cls.is_client_async:
async def execute():
async for tr in cls.client.transaction():
async with tr:
await tr.execute(script)
else:
def execute():
for tr in cls.client.transaction():
with tr:
tr.execute(script)
cls.adapt_call(execute())

@classmethod
def get_database_name(cls):
Expand Down Expand Up @@ -507,19 +528,22 @@ def tearDownClass(cls):

try:
if script:
cls.loop.run_until_complete(
cls.adapt_call(
cls.client.execute(script))
finally:
try:
cls.loop.run_until_complete(cls.client.aclose())
if cls.is_client_async:
cls.adapt_call(cls.client.aclose())
else:
cls.client.close()

dbname = cls.get_database_name()
script = f'DROP DATABASE {dbname};'

retry = cls.TEARDOWN_RETRY_DROP_DB
for i in range(retry):
try:
cls.loop.run_until_complete(
cls.adapt_call(
cls.admin_client.execute(script))
except edgedb.errors.ExecutionError:
if i < retry - 1:
Expand All @@ -536,8 +560,11 @@ def tearDownClass(cls):
finally:
try:
if cls.admin_client is not None:
cls.loop.run_until_complete(
cls.admin_client.aclose())
if cls.is_client_async:
cls.adapt_call(
cls.admin_client.aclose())
else:
cls.admin_client.close()
finally:
super().tearDownClass()

Expand All @@ -549,27 +576,11 @@ class AsyncQueryTestCase(DatabaseTestCase):
class SyncQueryTestCase(DatabaseTestCase):
BASE_TEST_CLASS = True
TEARDOWN_RETRY_DROP_DB = 5
is_client_async = False

def setUp(self):
super().setUp()

cls = type(self)
cls.async_client = cls.client

conargs = cls.get_connect_args().copy()
conargs.update(dict(database=cls.async_client.dbname))

cls.client = TestClient(
connection_class=blocking_client.BlockingIOConnection,
max_concurrency=1,
**conargs
)

def tearDown(self):
cls = type(self)
cls.client.close()
cls.client = cls.async_client
del cls.async_client
@classmethod
def adapt_call(cls, result):
return result


_lock_cnt = 0
Expand Down
19 changes: 15 additions & 4 deletions edgedb/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

class AsyncIOConnection(base_client.BaseConnection):
__slots__ = ("_loop",)
_close_exceptions = (Exception, asyncio.CancelledError)

def __init__(self, loop, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -61,6 +60,18 @@ async def connect_addr(self, addr, timeout):
async def sleep(self, seconds):
await asyncio.sleep(seconds)

async def aclose(self):
"""Send graceful termination message wait for connection to drop."""
if not self.is_closed():
try:
self._protocol.terminate()
await self._protocol.wait_for_disconnect()
except (Exception, asyncio.CancelledError):
self.terminate()
raise
finally:
self._cleanup()

def _protocol_factory(self):
return asyncio_proto.AsyncIOProtocol(self._params, self._loop)

Expand Down Expand Up @@ -104,7 +115,7 @@ async def _connect_addr(self, addr):
if tr is not None:
tr.close()
raise con_utils.wrap_error(e) from e
except Exception:
except BaseException:
if tr is not None:
tr.close()
raise
Expand All @@ -125,9 +136,9 @@ async def close(self, *, wait=True):
if self._con is None:
return
if wait:
await self._con.close()
await self._con.aclose()
else:
self._pool._loop.create_task(self._con.close())
self._pool._loop.create_task(self._con.aclose())

async def wait_until_released(self, timeout=None):
await self._release_event.wait()
Expand Down
13 changes: 0 additions & 13 deletions edgedb/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class BaseConnection(metaclass=abc.ABCMeta):
_log_listeners: typing.Set[
typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None]
]
_close_exceptions = (Exception,)
__slots__ = (
"__weakref__",
"_protocol",
Expand Down Expand Up @@ -313,18 +312,6 @@ def terminate(self):
finally:
self._cleanup()

async def close(self):
"""Send graceful termination message wait for connection to drop."""
if not self.is_closed():
try:
self._protocol.terminate()
await self._protocol.wait_for_disconnect()
except self._close_exceptions:
self.terminate()
raise
finally:
self._cleanup()

def __repr__(self):
if self.is_closed():
return '<{classname} [closed] {id:#x}>'.format(
Expand Down
51 changes: 39 additions & 12 deletions edgedb/blocking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ def is_closed(self):
return not (proto and proto.sock is not None and
proto.sock.fileno() >= 0 and proto.connected)

async def close(self, timeout=None):
"""Send graceful termination message wait for connection to drop."""
if not self.is_closed():
try:
self._protocol.terminate()
if timeout is None:
await self._protocol.wait_for_disconnect()
else:
await self._protocol.wait_for(
self._protocol.wait_for_disconnect(), timeout
)
except Exception:
self.terminate()
raise
finally:
self._cleanup()

def _dispatch_log_message(self, msg):
for cb in self._log_listeners:
cb(self, msg)
Expand All @@ -119,13 +136,13 @@ class _PoolConnectionHolder(base_client.PoolConnectionHolder):
__slots__ = ()
_event_class = threading.Event

async def close(self, *, wait=True):
async def close(self, *, wait=True, timeout=None):
if self._con is None:
return
await self._con.close()
await self._con.close(timeout=timeout)

async def wait_until_released(self, timeout=None):
self._release_event.wait(timeout)
return self._release_event.wait(timeout)


class _PoolImpl(base_client.BasePoolImpl):
Expand Down Expand Up @@ -200,17 +217,27 @@ async def close(self, timeout=None):
if timeout is None:
for ch in self._holders:
await ch.wait_until_released()
for ch in self._holders:
await ch.close()
else:
remaining = timeout
deadline = time.monotonic() + timeout
for ch in self._holders:
secs = deadline - time.monotonic()
if secs <= 0:
raise TimeoutError
if not await ch.wait_until_released(secs):
raise TimeoutError
for ch in self._holders:
start = time.monotonic()
await ch.wait_until_released(remaining)
remaining -= time.monotonic() - start
if remaining <= 0:
self.terminate()
return
for ch in self._holders:
await ch.close()
secs = deadline - time.monotonic()
if secs <= 0:
raise TimeoutError
await ch.close(timeout=secs)
except TimeoutError as e:
self.terminate()
raise errors.InterfaceError(
"client is not fully closed in {} seconds; "
"terminating now.".format(timeout)
) from e
except Exception:
self.terminate()
raise
Expand Down
1 change: 1 addition & 0 deletions edgedb/protocol/asyncio_proto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import asyncio

from edgedb import errors
from edgedb import compat
from edgedb.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
Expand Down
1 change: 1 addition & 0 deletions edgedb/protocol/blocking_proto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocolBackwardsCompatible):

cdef:
readonly object sock
float deadline

cdef _disconnect(self)
Loading

0 comments on commit 33a912c

Please sign in to comment.