From cb4c7226abdff72789999ae65cac8746be0465c7 Mon Sep 17 00:00:00 2001 From: Albert Tugushev Date: Wed, 16 Sep 2020 20:05:10 +0700 Subject: [PATCH] Fix a bug where a command would hang due to closed connection (#26) * Fix a bug where a command would hang due to closed connection * Try to fix flaky test * Fix timeouts * Fix test --- ansq/tcp/connection.py | 15 +++++++++++++-- ansq/tcp/exceptions.py | 4 ++++ tests/test_send_commands.py | 35 ++++++++++++++++++++++++++++++++--- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/ansq/tcp/connection.py b/ansq/tcp/connection.py index 5583e95..8b994c5 100644 --- a/ansq/tcp/connection.py +++ b/ansq/tcp/connection.py @@ -7,7 +7,12 @@ from typing import Any, AsyncGenerator, Callable, Optional, Union from ansq.tcp import consts -from ansq.tcp.exceptions import NSQUnauthorized, ProtocolError, get_exception +from ansq.tcp.exceptions import ( + ConnectionClosedError, + NSQUnauthorized, + ProtocolError, + get_exception, +) from ansq.tcp.types import ( ConnectionStatus, NSQCommands, @@ -122,6 +127,11 @@ async def _do_close( finally: pass + for future, callback in self._cmd_waiters: + if not future.cancelled(): + future.set_exception(ConnectionClosedError("Connection is closed")) + callback is not None and callback(None) + if self._message_queue.qsize() > 0: self._message_queue.get_nowait() @@ -163,7 +173,8 @@ async def execute( await self._reconnect_task assert self._reader, "You should call `connect` method first" - assert self._status or command == NSQCommands.CLS, "Connection is closed" + if not self._status and not (command == NSQCommands.CLS): + raise ConnectionClosedError("Connection is closed") future = self._loop.create_future() if command in ( diff --git a/ansq/tcp/exceptions.py b/ansq/tcp/exceptions.py index 9994a95..5c7180c 100644 --- a/ansq/tcp/exceptions.py +++ b/ansq/tcp/exceptions.py @@ -1,6 +1,10 @@ from typing import Union +class ConnectionClosedError(Exception): + pass + + class NSQException(Exception): """XXX""" diff --git a/tests/test_send_commands.py b/tests/test_send_commands.py index 74856e1..9e76014 100644 --- a/tests/test_send_commands.py +++ b/tests/test_send_commands.py @@ -1,9 +1,11 @@ -from time import time +import asyncio +from time import sleep, time import pytest from ansq import open_connection from ansq.tcp.connection import NSQConnection +from ansq.tcp.exceptions import ConnectionClosedError @pytest.mark.asyncio @@ -68,9 +70,10 @@ async def test_command_without_connection(): nsq = NSQConnection() assert nsq.status.is_init - with pytest.raises(AssertionError) as e: + with pytest.raises( + AssertionError, match="^You should call `connect` method first$", + ): await nsq.pub("test_topic", "test_message") - assert str(e.value) == "You should call `connect` method first" await nsq.close() assert nsq.status.is_init @@ -87,3 +90,29 @@ async def test_command_sub(): await nsq.close() assert nsq.is_closed + + +@pytest.mark.asyncio +async def test_command_with_closed_connection(): + nsq = await open_connection() + await nsq.close() + + with pytest.raises(ConnectionClosedError, match="^Connection is closed$"): + await nsq.pub("test_topic", "test_message") + + +@pytest.mark.asyncio +async def test_command_with_concurrently_closed_connection(): + nsq = await open_connection() + + async def close(): + await nsq.close() + + async def blocking_wait_and_pub(): + sleep(0.1) + await nsq.pub("test_topic", "test_message") + + with pytest.raises(ConnectionClosedError, match="^Connection is closed$"): + await asyncio.wait_for( + asyncio.gather(close(), blocking_wait_and_pub()), timeout=1, + )