Skip to content

Commit

Permalink
Keep connection alive (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickw444 committed Dec 24, 2018
1 parent 91071b5 commit c06d369
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 14 deletions.
2 changes: 1 addition & 1 deletion nessclient/cli/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def on_event_received(event: BaseEvent):
client.update(),
))

client.close()
loop.run_until_complete(client.close())
loop.close()
2 changes: 1 addition & 1 deletion nessclient/cli/send_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ def send_command(host: str, port: int, command: str):
client = Client(host=host, port=port, loop=loop)

loop.run_until_complete(client.send_command(command))
client.close()
loop.run_until_complete(client.close())
loop.close()
47 changes: 42 additions & 5 deletions nessclient/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import datetime
import logging
from asyncio import sleep
from typing import Optional, Callable
Expand All @@ -18,7 +19,8 @@ def __init__(self,
connection: Optional[Connection] = None,
host: Optional[str] = None,
port: Optional[int] = None,
loop: Optional[asyncio.AbstractEventLoop] = None):
loop: Optional[asyncio.AbstractEventLoop] = None,
update_interval: int = 60):
if connection is None:
assert host is not None
assert port is not None
Expand All @@ -31,6 +33,8 @@ def __init__(self,
self._closed = False
self._backoff = Backoff()
self._connect_lock = asyncio.Lock()
self._last_recv: Optional[datetime.datetime] = None
self._update_interval = update_interval

async def arm_away(self, code: Optional[str] = None) -> None:
command = 'A{}E'.format(code if code else '')
Expand Down Expand Up @@ -65,13 +69,20 @@ async def update(self) -> None:

async def _connect(self) -> None:
async with self._connect_lock:
if self._should_reconnect():
_LOGGER.debug('Closing stale connection and reconnecting')
await self._connection.close()

while not self._connection.connected:
_LOGGER.debug('Attempting to connect')
try:
await self._connection.connect()
except (ConnectionRefusedError, OSError) as e:
_LOGGER.warning('Failed to connect: %s', e)
await sleep(self._backoff.duration())

self._last_recv = datetime.datetime.now()

self._backoff.reset()

async def send_command(self, command: str) -> None:
Expand All @@ -85,16 +96,23 @@ async def send_command(self, command: str) -> None:
await self._connect()
return await self._connection.write(packet.encode().encode('ascii'))

async def keepalive(self) -> None:
async def _recv_loop(self) -> None:
while not self._closed:
await self._connect()

while True:
data = await self._connection.read()
if data is None:
_LOGGER.debug("Received None data from connection.read()")
break

decoded_data = data.decode('utf-8').strip()
self._last_recv = datetime.datetime.now()
try:
decoded_data = data.decode('utf-8').strip()
except UnicodeDecodeError:
_LOGGER.warning("Failed to decode data", exc_info=True)
continue

_LOGGER.debug("Decoding data: '%s'", decoded_data)
if len(decoded_data) > 0:
pkt = Packet.decode(decoded_data)
Expand All @@ -104,9 +122,28 @@ async def keepalive(self) -> None:

self.alarm.handle_event(event)

def close(self) -> None:
def _should_reconnect(self) -> bool:
now = datetime.datetime.now()
return self._last_recv is not None and self._last_recv < now - datetime.timedelta(
seconds=self._update_interval + 30)

async def _update_loop(self) -> None:
"""Schedule a state update to keep the connection alive"""
await asyncio.sleep(self._update_interval)
while not self._closed:
_LOGGER.debug("Forcing a keepalive state update")
await self.update()
await asyncio.sleep(self._update_interval)

async def keepalive(self) -> None:
await asyncio.gather(
self._recv_loop(),
self._update_loop(),
)

async def close(self) -> None:
self._closed = True
self._connection.close()
await self._connection.close()

def on_state_change(self, f: Callable[[ArmingState], None]
) -> Callable[[ArmingState], None]:
Expand Down
15 changes: 8 additions & 7 deletions nessclient/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import Optional

LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger(__name__)


class Connection(ABC):
Expand All @@ -18,7 +18,7 @@ async def write(self, data: bytes) -> None:
raise NotImplementedError()

@abstractmethod
def close(self) -> None:
async def close(self) -> None:
raise NotImplementedError()

@abstractmethod
Expand Down Expand Up @@ -52,7 +52,7 @@ async def connect(self) -> bool:
self._reader, self._writer = await asyncio.open_connection(
host=self._host,
port=self._port,
loop=self._loop
loop=self._loop,
)
return True

Expand All @@ -61,16 +61,16 @@ async def read(self) -> Optional[bytes]:

try:
data = await self._reader.readuntil(b'\n')
except asyncio.IncompleteReadError as e:
LOGGER.warning(
except (asyncio.IncompleteReadError, TimeoutError) as e:
_LOGGER.warning(
"Got exception: %s. Most likely the other side has "
"disconnected!", e)
self._writer = None
self._reader = None
return None

if data is None:
LOGGER.warning("Empty response received")
_LOGGER.warning("Empty response received")
self._writer = None
self._reader = None
return None
Expand All @@ -83,8 +83,9 @@ async def write(self, data: bytes) -> None:
self._writer.write(data)
await self._writer.drain()

def close(self) -> None:
async def close(self) -> None:
if self.connected and self._writer is not None:
self._writer.close()
await self._writer.wait_closed() # type: ignore
self._writer = None
self._reader = None

0 comments on commit c06d369

Please sign in to comment.