Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent event loop polling from stopping on active redis connections #1734

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions kombu/transport/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,9 @@ def __init__(self, *args, **kwargs):
# All channels share the same poller.
self.cycle = MultiChannelPoller()

# tracks event loop (on_tick) entries by connection
self.on_poll_start_by_connection = {}
pawl marked this conversation as resolved.
Show resolved Hide resolved

def driver_version(self):
return redis.__version__

Expand All @@ -1328,6 +1331,9 @@ def _on_disconnect(connection):
if cycle.fds:
# stop polling in the event loop
try:
on_poll_start = self.on_poll_start_by_connection.pop(
connection
)
loop.on_tick.remove(on_poll_start)
except KeyError:
pass
Expand All @@ -1337,6 +1343,7 @@ def on_poll_start():
cycle_poll_start()
[add_reader(fd, on_readable, fd) for fd in cycle.fds]
loop.on_tick.add(on_poll_start)
self.on_poll_start_by_connection[connection] = on_poll_start
loop.call_repeatedly(10, cycle.maybe_restore_messages)
health_check_interval = connection.client.transport_options.get(
'health_check_interval',
Expand Down
32 changes: 32 additions & 0 deletions t/integration/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import redis

import kombu
from kombu.asynchronous import Hub
from kombu.transport.redis import Transport

from .common import (BaseExchangeTypes, BaseMessage, BasePriority,
Expand Down Expand Up @@ -222,3 +223,34 @@ def connect_timeout(self):
# note the host/port here is irrelevant because
# connect will raise a socket.timeout
kombu.Connection('redis://localhost:12345').connect()


@pytest.mark.env('redis')
def test_RedisEventLoopCleanup(connection):
"""Disconnection removes the respective function in event loop."""
hub = Hub()

# register 1st connection with event loop
conn_1 = connection
chan_1 = conn_1.channel()
chan_1.basic_consume('test', False, None, 1)
conn_1.register_with_event_loop(hub)
assert len(hub.on_tick) == 1
first_on_tick = list(hub.on_tick)[0]

# register 2nd connection with event loop
conn_2 = get_connection(
hostname=os.environ.get('REDIS_HOST', 'localhost'),
port=os.environ.get('REDIS_6379_TCP', '6379'),
vhost=None,
)
conn_2.register_with_event_loop(hub)
assert len(hub.on_tick) == 2

hub.run_once()

# disconnect channel connection, expected event loop entry removed
chan_1.close()

assert len(hub.on_tick) == 1
assert list(hub.on_tick)[0] is not first_on_tick
29 changes: 28 additions & 1 deletion t/unit/transport/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,13 +1097,40 @@ def test_register_with_event_loop__on_disconnect__loop_cleanup(self, fds):
loop.on_tick = set()
redis.Transport.register_with_event_loop(transport, conn, loop)
assert len(loop.on_tick) == 1
transport.cycle._on_connection_disconnect(self.connection)
transport.cycle._on_connection_disconnect(conn)
if fds:
assert len(loop.on_tick) == 0
else:
# on_tick shouldn't be cleared when polling hasn't started
assert len(loop.on_tick) == 1

def test_register_with_event_loop__on_disconnect__per_connection(self):
"""Disconnection removes the respective function in event loop."""
connection = Connection(transport=Transport)
transport = connection.transport
transport.cycle = Mock(name='cycle')
transport.cycle.fds = {12: 'LISTEN', 13: 'BRPOP'}

# create a first connection and register with event loop
conn = Mock(name='conn')
conn.client = Mock(name='client', transport_options={})
loop = Mock(name='loop')
loop.on_tick = set()
redis.Transport.register_with_event_loop(transport, conn, loop)
assert len(loop.on_tick) == 1
first_on_tick = list(loop.on_tick)[0]

# create a second connection and register with event loop
conn_2 = Mock(name='conn_2')
conn_2.client = Mock(name='client', transport_options={})
redis.Transport.register_with_event_loop(transport, conn_2, loop)
assert len(loop.on_tick) == 2

# disconnect the first connection, expected event loop entry removed
transport.cycle._on_connection_disconnect(conn)
assert len(loop.on_tick) == 1
assert list(loop.on_tick)[0] is not first_on_tick

def test_configurable_health_check(self):
transport = self.connection.transport
transport.cycle = Mock(name='cycle')
Expand Down
Loading