diff --git a/api/pubsub.py b/api/pubsub.py index ed9901b9..9f5bf89d 100644 --- a/api/pubsub.py +++ b/api/pubsub.py @@ -5,6 +5,7 @@ """Pub/Sub implementation""" +import logging import asyncio import json @@ -14,6 +15,8 @@ from .models import Subscription, SubscriptionStats from .config import PubSubSettings +logger = logging.getLogger(__name__) + class PubSub: """Pub/Sub implementation class @@ -39,7 +42,9 @@ def __init__(self, host=None, db_number=None): host = self._settings.redis_host if db_number is None: db_number = self._settings.redis_db_number - self._redis = aioredis.from_url(f'redis://{host}/{db_number}') + self._redis = aioredis.from_url( + f'redis://{host}/{db_number}', health_check_interval=30 + ) # self._subscriptions is a dict that matches a subscription id # (key) with a Subscription object ('sub') and a redis # PubSub object ('redis_sub'). For instance: @@ -135,9 +140,24 @@ async def listen(self, sub_id, user=None): f"not owned by {user}") while True: self._subscriptions[sub_id]['last_poll'] = datetime.utcnow() - msg = await sub['redis_sub'].get_message( - ignore_subscribe_messages=True, timeout=1.0 - ) + msg = None + try: + msg = await sub['redis_sub'].get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) + except aioredis.ConnectionError: + async with self._lock: + channel = self._subscriptions[sub_id]['sub'].channel + new_redis_sub = self._redis.pubsub() + await new_redis_sub.subscribe(channel) + self._subscriptions[sub_id]['redis_sub'] = new_redis_sub + sub['redis_sub'] = new_redis_sub + continue + except aioredis.RedisError as exc: + # log the error and continue + logger.error("Redis error occurred: %s", exc) + return None # Handle any exceptions gracefully + if msg is None: continue msg_data = json.loads(msg['data'])