diff --git a/channels_redis/core.py b/channels_redis/core.py index 1111fc2..a164059 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -15,7 +15,13 @@ from channels.exceptions import ChannelFull from channels.layers import BaseChannelLayer -from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts +from .utils import ( + _close_redis, + _consistent_hash, + _wrap_close, + create_pool, + decode_hosts, +) logger = logging.getLogger(__name__) @@ -86,7 +92,7 @@ async def flush(self): async with self._lock: for index in list(self._connections): connection = self._connections.pop(index) - await connection.close(close_connection_pool=True) + await _close_redis(connection) class RedisChannelLayer(BaseChannelLayer): diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 78db68e..21771ab 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -6,7 +6,13 @@ import msgpack from redis import asyncio as aioredis -from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts +from .utils import ( + _close_redis, + _consistent_hash, + _wrap_close, + create_pool, + decode_hosts, +) logger = logging.getLogger(__name__) @@ -285,7 +291,7 @@ async def flush(self): # The pool was created just for this client, so make sure it is closed, # otherwise it will schedule the connection to be closed inside the # __del__ method, which doesn't have a loop running anymore. - await self._redis.close(close_connection_pool=True) + await _close_redis(self._redis) self._redis = None self._pubsub = None self._subscribed_to = set() diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 98e06ca..6f15050 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -35,6 +35,16 @@ def _wrapper(self, *args, **kwargs): loop.close = types.MethodType(_wrapper, loop) +async def _close_redis(connection): + """ + Handle compatibility with redis-py 4.x and 5.x close methods + """ + try: + await connection.aclose(close_connection_pool=True) + except AttributeError: + await connection.close(close_connection_pool=True) + + def decode_hosts(hosts): """ Takes the value of the "hosts" argument and returns diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 78ad080..3c00dd6 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -8,6 +8,7 @@ from asgiref.sync import async_to_sync from channels_redis.pubsub import RedisPubSubChannelLayer +from channels_redis.utils import _close_redis TEST_HOSTS = ["redis://localhost:6379"] @@ -239,10 +240,10 @@ async def test_auto_reconnect(channel_layer): channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3") await channel_layer.group_add("test-group", channel_name1) await channel_layer.group_add("test-group", channel_name2) - await channel_layer._shards[0]._redis.close(close_connection_pool=True) + await _close_redis(channel_layer._shards[0]._redis) await channel_layer.group_add("test-group", channel_name3) await channel_layer.group_discard("test-group", channel_name2) - await channel_layer._shards[0]._redis.close(close_connection_pool=True) + await _close_redis(channel_layer._shards[0]._redis) await asyncio.sleep(1) await channel_layer.group_send("test-group", {"type": "message.1"}) # Make sure we get the message on the two channels that were in diff --git a/tests/test_pubsub_sentinel.py b/tests/test_pubsub_sentinel.py index 049e39b..41b7de2 100644 --- a/tests/test_pubsub_sentinel.py +++ b/tests/test_pubsub_sentinel.py @@ -6,6 +6,7 @@ from asgiref.sync import async_to_sync from channels_redis.pubsub import RedisPubSubChannelLayer +from channels_redis.utils import _close_redis SENTINEL_MASTER = "sentinel" SENTINEL_KWARGS = {"password": "channels_redis"} @@ -188,10 +189,10 @@ async def test_auto_reconnect(channel_layer): channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3") await channel_layer.group_add("test-group", channel_name1) await channel_layer.group_add("test-group", channel_name2) - await channel_layer._shards[0]._redis.close(close_connection_pool=True) + await _close_redis(channel_layer._shards[0]._redis) await channel_layer.group_add("test-group", channel_name3) await channel_layer.group_discard("test-group", channel_name2) - await channel_layer._shards[0]._redis.close(close_connection_pool=True) + await _close_redis(channel_layer._shards[0]._redis) await asyncio.sleep(1) await channel_layer.group_send("test-group", {"type": "message.1"}) # Make sure we get the message on the two channels that were in