Skip to content

Commit

Permalink
Refactored Redis connection utilities to share between layers. (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
sevdog committed Mar 28, 2023
1 parent 0c89a97 commit 62e8fe2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 65 deletions.
45 changes: 3 additions & 42 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer

from .utils import _consistent_hash, _wrap_close
from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
self.prefix = prefix
assert isinstance(self.prefix, str), "Prefix must be unicode"
# Configure the host objects
self.hosts = self.decode_hosts(hosts)
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
# Cached redis connection pools and the event loop they are from
self._layers = {}
Expand Down Expand Up @@ -146,46 +146,7 @@ def __init__(
self.receive_clean_locks = ChannelLock()

def create_pool(self, index):
host = self.hosts[index]

if "address" in host:
return aioredis.ConnectionPool.from_url(host["address"])
elif "master_name" in host:
sentinels = host.pop("sentinels")
master_name = host.pop("master_name")
sentinel_kwargs = host.pop("sentinel_kwargs", None)
return aioredis.sentinel.SentinelConnectionPool(
master_name,
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
**host,
)
else:
return aioredis.ConnectionPool(**host)

def decode_hosts(self, hosts):
"""
Takes the value of the "hosts" argument passed to the class and returns
a list of kwargs to use for the Redis connection constructor.
"""
# If no hosts were provided, return a default value
if not hosts:
return [{"address": "redis://localhost:6379"}]
# If they provided just a string, scold them.
if isinstance(hosts, (str, bytes)):
raise ValueError(
"You must pass a list of Redis hosts, even if there is only one."
)

# Decode each hosts entry into a kwargs dict
result = []
for entry in hosts:
if isinstance(entry, dict):
result.append(entry)
elif isinstance(entry, tuple):
result.append({"host": entry[0], "port": entry[1]})
else:
result.append({"address": entry})
return result
return create_pool(self.hosts[index])

def _setup_encryption(self, symmetric_encryption_keys):
# See if we can do encryption if they asked
Expand Down
29 changes: 6 additions & 23 deletions channels_redis/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import msgpack
from redis import asyncio as aioredis

from .utils import _consistent_hash, _wrap_close
from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,12 +81,6 @@ def __init__(
channel_layer=None,
**kwargs,
):
if hosts is None:
hosts = ["redis://localhost:6379"]
assert (
isinstance(hosts, list) and len(hosts) > 0
), "`hosts` must be a list with at least one Redis server"

self.prefix = prefix

self.on_disconnect = on_disconnect
Expand All @@ -102,7 +96,9 @@ def __init__(
self.groups = {}

# For each host, we create a `RedisSingleShardConnection` to manage the connection to that host.
self._shards = [RedisSingleShardConnection(host, self) for host in hosts]
self._shards = [
RedisSingleShardConnection(host, self) for host in decode_hosts(hosts)
]

def _get_shard(self, channel_or_group_name):
"""
Expand Down Expand Up @@ -247,9 +243,7 @@ async def flush(self):

class RedisSingleShardConnection:
def __init__(self, host, channel_layer):
self.host = host.copy() if type(host) is dict else {"address": host}
self.master_name = self.host.pop("master_name", None)
self.sentinel_kwargs = self.host.pop("sentinel_kwargs", None)
self.host = host
self.channel_layer = channel_layer
self._subscribed_to = set()
self._lock = asyncio.Lock()
Expand Down Expand Up @@ -331,18 +325,7 @@ def _receive_message(self, message):

def _ensure_redis(self):
if self._redis is None:
if self.master_name is None:
pool = aioredis.ConnectionPool.from_url(self.host["address"])
else:
# aioredis default timeout is way too low
pool = aioredis.sentinel.SentinelConnectionPool(
self.master_name,
aioredis.sentinel.Sentinel(
self.host["sentinels"],
socket_timeout=2,
sentinel_kwargs=self.sentinel_kwargs,
),
)
pool = create_pool(self.host)
self._redis = aioredis.Redis(connection_pool=pool)
self._pubsub = self._redis.pubsub()

Expand Down
52 changes: 52 additions & 0 deletions channels_redis/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import binascii
import types

from redis import asyncio as aioredis


def _consistent_hash(value, ring_size):
"""
Expand Down Expand Up @@ -31,3 +33,53 @@ def _wrapper(self, *args, **kwargs):
return self.close(*args, **kwargs)

loop.close = types.MethodType(_wrapper, loop)


def decode_hosts(hosts):
"""
Takes the value of the "hosts" argument and returns
a list of kwargs to use for the Redis connection constructor.
"""
# If no hosts were provided, return a default value
if not hosts:
return [{"address": "redis://localhost:6379"}]
# If they provided just a string, scold them.
if isinstance(hosts, (str, bytes)):
raise ValueError(
"You must pass a list of Redis hosts, even if there is only one."
)

# Decode each hosts entry into a kwargs dict
result = []
for entry in hosts:
if isinstance(entry, dict):
result.append(entry)
elif isinstance(entry, (tuple, list)):
result.append({"host": entry[0], "port": entry[1]})
else:
result.append({"address": entry})
return result


def create_pool(host):
"""
Takes the value of the "host" argument and returns a suited connection pool to
the corresponding redis instance.
"""
# avoid side-effects from modifying host
host = host.copy()
if "address" in host:
address = host.pop("address")
return aioredis.ConnectionPool.from_url(address, **host)

master_name = host.pop("master_name", None)
if master_name is not None:
sentinels = host.pop("sentinels")
sentinel_kwargs = host.pop("sentinel_kwargs", None)
return aioredis.sentinel.SentinelConnectionPool(
master_name,
aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
**host
)

return aioredis.ConnectionPool(**host)

0 comments on commit 62e8fe2

Please sign in to comment.