Skip to content

Commit

Permalink
Remove connection pools and long-running tasks
Browse files Browse the repository at this point in the history
This will result in slightly higher connection times, but make it much
easier to not leak connections when using threads. We can revisit pools
at a later date if the need becomes apparent.
  • Loading branch information
andrewgodwin committed Feb 21, 2018
1 parent 4390d28 commit 6613baf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 94 deletions.
171 changes: 79 additions & 92 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
self._setup_encryption(symmetric_encryption_keys)
# Buffered messages by process-local channel name
self.receive_buffer = {}
# Coroutines currently receiving the process-local channel.
self.receive_tasks = {}
# Coroutine currently receiving the process-local channel.
self.receive_lock = asyncio.Lock()

def decode_hosts(self, hosts):
"""
Expand Down Expand Up @@ -123,11 +123,9 @@ async def send(self, channel, message):
# channels, random for general channels
if "!" in channel:
index = self.consistent_hash(channel)
pool = await self.connection(index)
else:
index = next(self._send_index_generator)
pool = await self.connection(index)
with (await pool) as connection:
async with self.connection(index) as connection:
# Check the length of the list before send
# This can allow the list to leak slightly over capacity, but that's fine.
if await connection.llen(channel_key) >= self.get_capacity(channel):
Expand All @@ -148,30 +146,47 @@ async def receive(self, channel):
if "!" in channel:
real_channel = self.non_local_name(channel)
assert real_channel.endswith(self.client_prefix + "!"), "Wrong client prefix"
# Make sure a receive task is running
task = self.receive_tasks.get(real_channel, None)
if task is not None and task.done():
task = None
if task is None:
self.receive_tasks[real_channel] = asyncio.ensure_future(
self.receive_loop(real_channel),
)
# Wait on the receive buffer's contents
return await self.receive_buffer_lpop(channel)
# Launch our own receive loop task
loop = asyncio.get_event_loop()
task = loop.create_task(self.receive_loop(channel))
try:
# Wait on the receive buffer's contents
# TODO: Two coroutines rather than a poll
while True:
if self.receive_buffer.get(channel, None):
message = self.receive_buffer[channel][0]
if len(self.receive_buffer[channel]) == 1:
del self.receive_buffer[channel]
else:
self.receive_buffer[channel] = self.receive_buffer[channel][1:]
return message
else:
# See if we need to propagate a dead receiver exception
if task.done():
task.result()
# Sleep poll
await asyncio.sleep(self.local_poll_interval)
finally:
# Shut down the task
if not task.done():
task.cancel()
else:
# Do a plain direct receive
return (await self.receive_single(channel))[1]

async def receive_loop(self, channel):
async def receive_loop(self, specific_channel):
"""
Continuous-receiving loop that fetches results into the receive buffer.
Continuous-receiving loop that makes sure something is fetching results
for the channel passed in.
"""
assert "!" in channel, "receive_loop called on non-process-local channel"
assert "!" in specific_channel, "receive_loop called on non-process-local channel"
general_channel = self.non_local_name(specific_channel)
while True:
# Catch RuntimeErrors from the loop stopping while we release
# a connection. Wish there was a cleaner solution here.
real_channel, message = await self.receive_single(channel)
self.receive_buffer.setdefault(real_channel, []).append(message)
async with self.receive_lock:
real_channel, message = await self.receive_single(general_channel)
self.receive_buffer.setdefault(real_channel, []).append(message)
if real_channel == specific_channel:
return

async def receive_single(self, channel):
"""
Expand All @@ -186,8 +201,7 @@ async def receive_single(self, channel):
else:
index = next(self._receive_index_generator)
# Get that connection and receive off of it
pool = await self.connection(index)
with (await pool) as connection:
async with self.connection(index) as connection:
channel_key = self.prefix + channel
content = None
while content is None:
Expand All @@ -201,28 +215,6 @@ async def receive_single(self, channel):
del message["__asgi_channel__"]
return channel, message

async def receive_buffer_lpop(self, channel):
"""
Atomic, async method that returns the left-hand item in a receive buffer.
"""
# TODO: Use locks or something, not a poll
while True:
if self.receive_buffer.get(channel, None):
message = self.receive_buffer[channel][0]
if len(self.receive_buffer[channel]) == 1:
del self.receive_buffer[channel]
else:
self.receive_buffer[channel] = self.receive_buffer[channel][1:]
return message
else:
# See if we need to propagate a dead receiver exception
real_channel = self.non_local_name(channel)
task = self.receive_tasks.get(real_channel, None)
if task is not None and task.done():
task.result()
# Sleep poll
await asyncio.sleep(self.local_poll_interval)

async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by something in our
Expand Down Expand Up @@ -250,23 +242,12 @@ async def flush(self):
"""
# Go through each connection and remove all with prefix
for i in range(self.ring_size):
connection = await self.connection(i)
await connection.eval(
delete_prefix,
keys=[],
args=[self.prefix + "*"]
)

async def close(self):
# Stop all reader tasks
for task in self.receive_tasks.values():
task.cancel()
asyncio.wait(self.receive_tasks.values())
self.receive_tasks = {}
# Close up all pools
for pool in self.pools.values():
pool.close()
await pool.wait_closed()
async with self.connection(i) as connection:
await connection.eval(
delete_prefix,
keys=[],
args=[self.prefix + "*"]
)

### Groups extension ###

Expand All @@ -279,8 +260,7 @@ async def group_add(self, group, channel):
assert self.valid_channel_name(channel), "Channel name not valid"
# Get a connection to the right shard
group_key = self._group_key(group)
pool = await self.connection(self.consistent_hash(group))
with (await pool) as connection:
async with self.connection(self.consistent_hash(group)) as connection:
# Add to group sorted set with creation time as timestamp
await connection.zadd(
group_key,
Expand All @@ -299,11 +279,11 @@ async def group_discard(self, group, channel):
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
key = self._group_key(group)
pool = await self.connection(self.consistent_hash(group))
await pool.zrem(
key,
channel,
)
async with self.connection(self.consistent_hash(group)) as connection:
await connection.zrem(
key,
channel,
)

async def group_send(self, group, message):
"""
Expand All @@ -312,8 +292,7 @@ async def group_send(self, group, message):
assert self.valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
pool = await self.connection(self.consistent_hash(group))
with (await pool) as connection:
async with self.connection(self.consistent_hash(group)) as connection:
# Discard old channels based on group_expiry
await connection.zremrangebyscore(key, min=0, max=int(time.time()) - self.group_expiry)
# Return current lot
Expand Down Expand Up @@ -366,26 +345,6 @@ def consistent_hash(self, value):
ring_divisor = 4096 / float(self.ring_size)
return int(bigval / ring_divisor)

async def connection(self, index):
"""
Returns the correct connection for the index given.
Lazily instantiates pools.
"""
# Catch bad indexes
if not 0 <= index < self.ring_size:
raise ValueError("There are only %s hosts - you asked for %s!" % (self.ring_size, index))
# Check to see if the stored pools are for the right event loop
# TODO: Maybe cache from multiple event loops to avoid AsyncToSync wiping
# out the main thread's cache (but we'd need to cap the number of entries
# with an LRU strategy or something)
if self.pools_loop != asyncio.get_event_loop():
self.pools = {}
self.pools_loop = asyncio.get_event_loop()
# Make the new pool if it does not exist
if index not in self.pools:
self.pools[index] = await aioredis.create_redis_pool(**self.hosts[index])
return self.pools[index]

def make_fernet(self, key):
"""
Given a single encryption key, returns a Fernet instance using it.
Expand All @@ -398,3 +357,31 @@ def make_fernet(self, key):

def __str__(self):
return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)

### Connection handling ###

def connection(self, index):
"""
Returns the correct connection for the index given.
Lazily instantiates pools.
"""
# Catch bad indexes
if not 0 <= index < self.ring_size:
raise ValueError("There are only %s hosts - you asked for %s!" % (self.ring_size, index))
# Make a context manager
return self.ConnectionContextManager(self.hosts[index])

class ConnectionContextManager:
"""
Async context manager for connections
"""

def __init__(self, kwargs):
self.kwargs = kwargs

async def __aenter__(self):
self.conn = await aioredis.create_redis(**self.kwargs)
return self.conn

async def __aexit__(self, exc_type, exc, tb):
self.conn.close()
2 changes: 0 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async def channel_layer():
channel_layer = RedisChannelLayer(hosts=TEST_HOSTS, capacity=3)
await yield_(channel_layer)
await channel_layer.flush()
await channel_layer.close()


@pytest.mark.asyncio
Expand Down Expand Up @@ -60,7 +59,6 @@ async def test_send_specific_capacity(channel_layer):
with pytest.raises(ChannelFull):
await custom_channel_layer.send("one", {"type": "test.message"})
await custom_channel_layer.flush()
await custom_channel_layer.close()


@pytest.mark.asyncio
Expand Down

0 comments on commit 6613baf

Please sign in to comment.