Skip to content

Commit

Permalink
Convert reconnect method to task.
Browse files Browse the repository at this point in the history
This allows us to cancel reconnect attempt when closing the connection and eliminates the need for boolean flag.
  • Loading branch information
denpamusic committed Mar 4, 2024
1 parent a4ca65b commit c7e25a6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 30 deletions.
32 changes: 10 additions & 22 deletions pyplumio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from serial import EIGHTBITS, PARITY_NONE, STOPBITS_ONE, SerialException

from pyplumio.exceptions import ConnectionFailedError
from pyplumio.helpers.task_manager import TaskManager
from pyplumio.helpers.timeout import timeout
from pyplumio.protocol import AsyncProtocol, Protocol

Expand All @@ -26,13 +27,12 @@
import serial_asyncio as pyserial_asyncio


class Connection(ABC):
class Connection(ABC, TaskManager):
"""Represents a connection.
All specific connection classes MUST be inherited from this class.
"""

_closing: bool
_protocol: Protocol
_reconnect_on_failure: bool
_kwargs: MutableMapping[str, Any]
Expand All @@ -44,13 +44,13 @@ def __init__(
**kwargs: Any,
) -> None:
"""Initialize a new connection."""
super().__init__()
if protocol is None:
protocol = AsyncProtocol()

if reconnect_on_failure:
protocol.on_connection_lost.add(self._connection_lost)
protocol.on_connection_lost.add(self._reconnect)

self._closing = False
self._reconnect_on_failure = reconnect_on_failure
self._protocol = protocol
self._kwargs = kwargs
Expand All @@ -76,44 +76,32 @@ async def _connect(self) -> None:
await self._open_connection(),
)
self.protocol.connection_established(reader, writer)
except (
OSError,
SerialException,
asyncio.TimeoutError,
) as connection_error:
raise ConnectionFailedError from connection_error
except (OSError, SerialException, asyncio.TimeoutError) as err:
raise ConnectionFailedError from err

async def _reconnect(self) -> None:
"""Try to connect and reconnect on failure."""
try:
return await self._connect()
await self._connect()
except ConnectionFailedError:
await self._connection_lost()

async def _connection_lost(self) -> None:
"""Resume connection on the connection loss."""
if not self._closing:
_LOGGER.error(
"Can't connect to the device, retrying in %.1f seconds",
RECONNECT_TIMEOUT,
)
await asyncio.sleep(RECONNECT_TIMEOUT)
await self._reconnect()
self.create_task(self._reconnect())

async def connect(self) -> None:
"""Open the connection.
Initialize a connection via connect or reconnect
routines, depending on '_reconnect_on_failure' property.
"""
if self._reconnect_on_failure:
await self._reconnect()
else:
await self._connect()
await (self._reconnect if self._reconnect_on_failure else self._connect)()

async def close(self) -> None:
"""Close the connection."""
self._closing = True
self.cancel_tasks()
await self.protocol.shutdown()

@property
Expand Down
13 changes: 5 additions & 8 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def fixture_serial_connection(mock_protocol) -> pyplumio.connection.SerialConnec

async def test_tcp_connect(mock_protocol, asyncio_open_connection) -> None:
"""Test TCP connection logic."""
with patch(
"pyplumio.connection.Connection._connection_lost"
) as mock_connection_lost:
with patch("pyplumio.connection.Connection._reconnect") as mock_reconnect:
tcp_connection = pyplumio.connection.TcpConnection(
host=HOST,
port=PORT,
Expand All @@ -107,9 +105,7 @@ async def test_tcp_connect(mock_protocol, asyncio_open_connection) -> None:
)

assert tcp_connection.protocol == mock_protocol
mock_protocol.on_connection_lost.add.assert_called_once_with(
mock_connection_lost
)
mock_protocol.on_connection_lost.add.assert_called_once_with(mock_reconnect)

await tcp_connection.connect()
asyncio_open_connection.assert_called_once_with(host=HOST, port=PORT, timeout=10)
Expand Down Expand Up @@ -162,6 +158,7 @@ async def test_reconnect(
side_effect=(ConnectionFailedError, None),
) as mock_connect:
await tcp_connection.connect()
await tcp_connection.wait_until_done()

assert "Can't connect to the device" in caplog.text
assert mock_connect.call_count == 2
Expand All @@ -174,9 +171,9 @@ async def test_connection_lost(
"""Test that connection lost callback calls reconnect."""
await tcp_connection.connect()
on_connection_lost = mock_protocol.on_connection_lost.add.call_args.args[0]
with patch("pyplumio.connection.Connection._reconnect") as mock_reconnect:
with patch("pyplumio.connection.Connection._connect") as mock_connect:
await on_connection_lost()
mock_reconnect.assert_called_once()
mock_connect.assert_called_once()


async def test_reconnect_logic_selection() -> None:
Expand Down

0 comments on commit c7e25a6

Please sign in to comment.