diff --git a/socketio/asyncio_redis_manager.py b/socketio/asyncio_redis_manager.py index e568ce9c..7ecc34c7 100644 --- a/socketio/asyncio_redis_manager.py +++ b/socketio/asyncio_redis_manager.py @@ -13,19 +13,14 @@ def _parse_redis_url(url): p = urlparse(url) if p.scheme != 'redis': raise ValueError('Invalid redis url') - if ':' in p.netloc: - host, port = p.netloc.split(':') - port = int(port) - else: - host = p.netloc or 'localhost' - port = 6379 + host = p.hostname or 'localhost' + port = p.port or 6379 + password = p.password if p.path: db = int(p.path[1:]) else: db = 0 - if not host: - raise ValueError('Invalid redis hostname') - return host, port, db + return host, port, password, db class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover @@ -53,15 +48,14 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover name = 'aioredis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False, password=None): + write_only=False): if aioredis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install aioredis" in your ' 'virtualenv).') - self.host, self.port, self.db = _parse_redis_url(url) + self.host, self.port, self.password, self.db = _parse_redis_url(url) self.pub = None self.sub = None - self.password = password super().__init__(channel=channel, write_only=write_only) async def _publish(self, data): diff --git a/tests/test_asyncio_redis_manager.py b/tests/test_asyncio_redis_manager.py index efa9721d..6beb1e99 100644 --- a/tests/test_asyncio_redis_manager.py +++ b/tests/test_asyncio_redis_manager.py @@ -9,26 +9,37 @@ class TestAsyncRedisManager(unittest.TestCase): def test_default_url(self): self.assertEqual(asyncio_redis_manager._parse_redis_url('redis://'), - ('localhost', 6379, 0)) + ('localhost', 6379, None, 0)) def test_only_host_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://redis.host'), - ('redis.host', 6379, 0)) + ('redis.host', 6379, None, 0)) def test_no_db_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://redis.host:123/1'), - ('redis.host', 123, 1)) + ('redis.host', 123, None, 1)) def test_no_port_url(self): self.assertEqual( asyncio_redis_manager._parse_redis_url('redis://redis.host/1'), - ('redis.host', 6379, 1)) + ('redis.host', 6379, None, 1)) + + def test_password(self): + self.assertEqual( + asyncio_redis_manager._parse_redis_url('redis://:pw@redis.host/1'), + ('redis.host', 6379, 'pw', 1)) def test_no_host_url(self): - self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url, - 'redis://:123/1') + self.assertEqual( + asyncio_redis_manager._parse_redis_url('redis://:123/1'), + ('localhost', 123, None, 1)) + + def test_no_host_password_url(self): + self.assertEqual( + asyncio_redis_manager._parse_redis_url('redis://:pw@:123/1'), + ('localhost', 123, 'pw', 1)) def test_bad_port_url(self): self.assertRaises(ValueError, asyncio_redis_manager._parse_redis_url,