diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 989fdc6f..70ec1dce 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -202,6 +202,7 @@ def __init__( kwds.pop("server", None) kwds.pop("connected", None) kwds.pop("version", None) + kwds.pop("lua_modules", None) super().__init__(**kwds) @classmethod diff --git a/test/conftest.py b/test/conftest.py index 1ecd3738..f8a030bf 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -85,3 +85,40 @@ def factory(db=2): return cls('localhost', port=6380, db=db, decode_responses=decode_responses) return factory + + +@pytest_asyncio.fixture( + name='async_redis', + params=[ + pytest.param('fake', marks=pytest.mark.fake), + pytest.param('real', marks=pytest.mark.real) + ] +) +async def _req_aioredis2(request) -> redis.asyncio.Redis: + server_version = request.getfixturevalue('real_redis_version') + if request.param != 'fake' and not server_version: + pytest.skip('Redis is not running') + server_version = _create_version(server_version) or (6,) + min_server_marker = _marker_version_value(request, 'min_server') + max_server_marker = _marker_version_value(request, 'max_server') + if server_version < min_server_marker: + pytest.skip(f'Redis server {min_server_marker} or more required but {server_version} found') + if server_version > max_server_marker: + pytest.skip(f'Redis server {max_server_marker} or less required but {server_version} found') + lua_modules_marker = request.node.get_closest_marker('load_lua_modules') + lua_modules = set(lua_modules_marker.args) if lua_modules_marker else None + + if request.param == 'fake': + fake_server = request.getfixturevalue('fake_server') + ret = fakeredis.FakeAsyncRedis(server=fake_server, lua_modules=lua_modules) + else: + ret = redis.asyncio.Redis(host='localhost', port=6380, db=2) + fake_server = None + if not fake_server or fake_server.connected: + await ret.flushall() + + yield ret + + if not fake_server or fake_server.connected: + await ret.flushall() + await ret.connection_pool.disconnect() diff --git a/test/test_asyncredis.py b/test/test_asyncredis.py index 9e8b48c8..e922faca 100644 --- a/test/test_asyncredis.py +++ b/test/test_asyncredis.py @@ -1,11 +1,7 @@ import asyncio -import json import re import sys -from fakeredis._server import _create_version -from test.conftest import _marker_version_value - if sys.version_info >= (3, 11): from asyncio import timeout as async_timeout else: @@ -21,7 +17,7 @@ pytestmark = [ ] fake_only = pytest.mark.parametrize( - 'req_aioredis2', + 'async_redis', [pytest.param('fake', marks=pytest.mark.fake)], indirect=True ) @@ -30,58 +26,21 @@ ]) -@pytest_asyncio.fixture( - name='req_aioredis2', - params=[ - pytest.param('fake', marks=pytest.mark.fake), - pytest.param('real', marks=pytest.mark.real) - ] -) -async def _req_aioredis2(request) -> redis.asyncio.Redis: - server_version = request.getfixturevalue('real_redis_version') - if request.param != 'fake' and not server_version: - pytest.skip('Redis is not running') - server_version = _create_version(server_version) or (6,) - min_server_marker = _marker_version_value(request, 'min_server') - max_server_marker = _marker_version_value(request, 'max_server') - if server_version < min_server_marker: - pytest.skip(f'Redis server {min_server_marker} or more required but {server_version} found') - if server_version > max_server_marker: - pytest.skip(f'Redis server {max_server_marker} or less required but {server_version} found') - lua_modules_marker = request.node.get_closest_marker('load_lua_modules') - lua_modules = set(lua_modules_marker.args) if lua_modules_marker else None - - if request.param == 'fake': - fake_server = request.getfixturevalue('fake_server') - ret = aioredis.FakeRedis(server=fake_server, lua_modules=lua_modules) - else: - ret = redis.asyncio.Redis(host='localhost', port=6380, db=2) - fake_server = None - if not fake_server or fake_server.connected: - await ret.flushall() - - yield ret - - if not fake_server or fake_server.connected: - await ret.flushall() - await ret.connection_pool.disconnect() - - @pytest_asyncio.fixture -async def conn(req_aioredis2: redis.asyncio.Redis): +async def conn(async_redis: redis.asyncio.Redis): """A single connection, rather than a pool.""" - async with req_aioredis2.client() as conn: + async with async_redis.client() as conn: yield conn -async def test_ping(req_aioredis2: redis.asyncio.Redis): - pong = await req_aioredis2.ping() +async def test_ping(async_redis: redis.asyncio.Redis): + pong = await async_redis.ping() assert pong is True -async def test_types(req_aioredis2: redis.asyncio.Redis): - await req_aioredis2.hset('hash', mapping={'key1': 'value1', 'key2': 'value2', 'key3': 123}) - result = await req_aioredis2.hgetall('hash') +async def test_types(async_redis: redis.asyncio.Redis): + await async_redis.hset('hash', mapping={'key1': 'value1', 'key2': 'value2', 'key3': 123}) + result = await async_redis.hgetall('hash') assert result == { b'key1': b'value1', b'key2': b'value2', @@ -89,29 +48,29 @@ async def test_types(req_aioredis2: redis.asyncio.Redis): } -async def test_transaction(req_aioredis2: redis.asyncio.Redis): - async with req_aioredis2.pipeline(transaction=True) as tr: +async def test_transaction(async_redis: redis.asyncio.Redis): + async with async_redis.pipeline(transaction=True) as tr: tr.set('key1', 'value1') tr.set('key2', 'value2') ok1, ok2 = await tr.execute() assert ok1 assert ok2 - result = await req_aioredis2.get('key1') + result = await async_redis.get('key1') assert result == b'value1' -async def test_transaction_fail(req_aioredis2: redis.asyncio.Redis): - await req_aioredis2.set('foo', '1') - async with req_aioredis2.pipeline(transaction=True) as tr: +async def test_transaction_fail(async_redis: redis.asyncio.Redis): + await async_redis.set('foo', '1') + async with async_redis.pipeline(transaction=True) as tr: await tr.watch('foo') - await req_aioredis2.set('foo', '2') # Different connection + await async_redis.set('foo', '2') # Different connection tr.multi() tr.get('foo') with pytest.raises(redis.asyncio.WatchError): await tr.execute() -async def test_pubsub(req_aioredis2, event_loop): +async def test_pubsub(async_redis, event_loop): queue = asyncio.Queue() async def reader(ps): @@ -122,11 +81,11 @@ async def reader(ps): break queue.put_nowait(message) - async with async_timeout(5), req_aioredis2.pubsub() as ps: + async with async_timeout(5), async_redis.pubsub() as ps: await ps.subscribe('channel') task = event_loop.create_task(reader(ps)) - await req_aioredis2.publish('channel', 'message1') - await req_aioredis2.publish('channel', 'message2') + await async_redis.publish('channel', 'message1') + await async_redis.publish('channel', 'message2') result1 = await queue.get() result2 = await queue.get() assert result1 == { @@ -141,13 +100,13 @@ async def reader(ps): 'type': 'message', 'data': b'message2' } - await req_aioredis2.publish('channel', 'stop') + await async_redis.publish('channel', 'stop') await task @pytest.mark.slow -async def test_pubsub_timeout(req_aioredis2: redis.asyncio.Redis): - async with req_aioredis2.pubsub() as ps: +async def test_pubsub_timeout(async_redis: redis.asyncio.Redis): + async with async_redis.pubsub() as ps: await ps.subscribe('channel') await ps.get_message(timeout=0.5) # Subscription message message = await ps.get_message(timeout=0.5) @@ -155,8 +114,8 @@ async def test_pubsub_timeout(req_aioredis2: redis.asyncio.Redis): @pytest.mark.slow -async def test_pubsub_disconnect(req_aioredis2: redis.asyncio.Redis): - async with req_aioredis2.pubsub() as ps: +async def test_pubsub_disconnect(async_redis: redis.asyncio.Redis): + async with async_redis.pubsub() as ps: await ps.subscribe('channel') await ps.connection.disconnect() message = await ps.get_message(timeout=0.5) # Subscription message @@ -165,9 +124,9 @@ async def test_pubsub_disconnect(req_aioredis2: redis.asyncio.Redis): assert message is None -async def test_blocking_ready(req_aioredis2, conn): +async def test_blocking_ready(async_redis, conn): """Blocking command which does not need to block.""" - await req_aioredis2.rpush('list', 'x') + await async_redis.rpush('list', 'x') result = await conn.blpop('list', timeout=1) assert result == (b'list', b'x') @@ -180,12 +139,12 @@ async def test_blocking_timeout(conn): @pytest.mark.slow -async def test_blocking_unblock(req_aioredis2, conn, event_loop): +async def test_blocking_unblock(async_redis, conn, event_loop): """Blocking command that gets unblocked after some time.""" async def unblock(): await asyncio.sleep(0.1) - await req_aioredis2.rpush('list', 'y') + await async_redis.rpush('list', 'y') task = event_loop.create_task(unblock()) result = await conn.blpop('list', timeout=1) @@ -193,99 +152,99 @@ async def unblock(): await task -async def test_wrongtype_error(req_aioredis2: redis.asyncio.Redis): - await req_aioredis2.set('foo', 'bar') +async def test_wrongtype_error(async_redis: redis.asyncio.Redis): + await async_redis.set('foo', 'bar') with pytest.raises(redis.asyncio.ResponseError, match='^WRONGTYPE'): - await req_aioredis2.rpush('foo', 'baz') + await async_redis.rpush('foo', 'baz') -async def test_syntax_error(req_aioredis2: redis.asyncio.Redis): +async def test_syntax_error(async_redis: redis.asyncio.Redis): with pytest.raises(redis.asyncio.ResponseError, match="^wrong number of arguments for 'get' command$"): - await req_aioredis2.execute_command('get') + await async_redis.execute_command('get') @testtools.run_test_if_lupa class TestScripts: - async def test_no_script_error(self, req_aioredis2: redis.asyncio.Redis): + async def test_no_script_error(self, async_redis: redis.asyncio.Redis): with pytest.raises(redis.exceptions.NoScriptError): - await req_aioredis2.evalsha('0123456789abcdef0123456789abcdef', 0) + await async_redis.evalsha('0123456789abcdef0123456789abcdef', 0) @pytest.mark.max_server('6.2.7') - async def test_failed_script_error6(self, req_aioredis2): - await req_aioredis2.set('foo', 'bar') + async def test_failed_script_error6(self, async_redis): + await async_redis.set('foo', 'bar') with pytest.raises(redis.asyncio.ResponseError, match='^Error running script'): - await req_aioredis2.eval('return redis.call("ZCOUNT", KEYS[1])', 1, 'foo') + await async_redis.eval('return redis.call("ZCOUNT", KEYS[1])', 1, 'foo') @pytest.mark.min_server('7') - async def test_failed_script_error7(self, req_aioredis2): - await req_aioredis2.set('foo', 'bar') + async def test_failed_script_error7(self, async_redis): + await async_redis.set('foo', 'bar') with pytest.raises(redis.asyncio.ResponseError): - await req_aioredis2.eval('return redis.call("ZCOUNT", KEYS[1])', 1, 'foo') + await async_redis.eval('return redis.call("ZCOUNT", KEYS[1])', 1, 'foo') @fake_only @testtools.run_test_if_redispy_ver('lt', '5.1.0b1') -async def test_repr_redis_until_51(req_aioredis2: redis.asyncio.Redis): +async def test_repr_redis_until_51(async_redis: redis.asyncio.Redis): assert re.fullmatch( r'ConnectionPool,db=0>>', - repr(req_aioredis2.connection_pool) + repr(async_redis.connection_pool) ) @testtools.run_test_if_redispy_ver('gte', '5.1') -async def test_repr_redis_51(req_aioredis2: redis.asyncio.Redis): +async def test_repr_redis_51(async_redis: redis.asyncio.Redis): assert re.fullmatch( r',db=0)>)>', - repr(req_aioredis2.connection_pool) + repr(async_redis.connection_pool) ) @fake_only @pytest.mark.disconnected -async def test_not_connected(req_aioredis2: redis.asyncio.Redis): +async def test_not_connected(async_redis: redis.asyncio.Redis): with pytest.raises(redis.asyncio.ConnectionError): - await req_aioredis2.ping() + await async_redis.ping() @fake_only -async def test_disconnect_server(req_aioredis2, fake_server): - await req_aioredis2.ping() +async def test_disconnect_server(async_redis, fake_server): + await async_redis.ping() fake_server.connected = False with pytest.raises(redis.asyncio.ConnectionError): - await req_aioredis2.ping() + await async_redis.ping() fake_server.connected = True -async def test_type(req_aioredis2: redis.asyncio.Redis): - await req_aioredis2.set('string_key', "value") - await req_aioredis2.lpush("list_key", "value") - await req_aioredis2.sadd("set_key", "value") - await req_aioredis2.zadd("zset_key", {"value": 1}) - await req_aioredis2.hset('hset_key', 'key', 'value') +async def test_type(async_redis: redis.asyncio.Redis): + await async_redis.set('string_key', "value") + await async_redis.lpush("list_key", "value") + await async_redis.sadd("set_key", "value") + await async_redis.zadd("zset_key", {"value": 1}) + await async_redis.hset('hset_key', 'key', 'value') - assert b'string' == await req_aioredis2.type('string_key') # noqa: E721 - assert b'list' == await req_aioredis2.type('list_key') # noqa: E721 - assert b'set' == await req_aioredis2.type('set_key') # noqa: E721 - assert b'zset' == await req_aioredis2.type('zset_key') # noqa: E721 - assert b'hash' == await req_aioredis2.type('hset_key') # noqa: E721 - assert b'none' == await req_aioredis2.type('none_key') # noqa: E721 + assert b'string' == await async_redis.type('string_key') # noqa: E721 + assert b'list' == await async_redis.type('list_key') # noqa: E721 + assert b'set' == await async_redis.type('set_key') # noqa: E721 + assert b'zset' == await async_redis.type('zset_key') # noqa: E721 + assert b'hash' == await async_redis.type('hset_key') # noqa: E721 + assert b'none' == await async_redis.type('none_key') # noqa: E721 -async def test_xdel(req_aioredis2: redis.asyncio.Redis): +async def test_xdel(async_redis: redis.asyncio.Redis): stream = "stream" # deleting from an empty stream doesn't do anything - assert await req_aioredis2.xdel(stream, 1) == 0 + assert await async_redis.xdel(stream, 1) == 0 - m1 = await req_aioredis2.xadd(stream, {"foo": "bar"}) - m2 = await req_aioredis2.xadd(stream, {"foo": "bar"}) - m3 = await req_aioredis2.xadd(stream, {"foo": "bar"}) + m1 = await async_redis.xadd(stream, {"foo": "bar"}) + m2 = await async_redis.xadd(stream, {"foo": "bar"}) + m3 = await async_redis.xadd(stream, {"foo": "bar"}) # xdel returns the number of deleted elements - assert await req_aioredis2.xdel(stream, m1) == 1 - assert await req_aioredis2.xdel(stream, m2, m3) == 2 + assert await async_redis.xdel(stream, m1) == 1 + assert await async_redis.xdel(stream, m2, m3) == 2 @pytest.mark.fake @@ -315,9 +274,9 @@ async def test_from_url_with_version(): @fake_only -async def test_from_url_with_server(req_aioredis2, fake_server): +async def test_from_url_with_server(async_redis, fake_server): r2 = aioredis.FakeRedis.from_url('redis://localhost', server=fake_server) - await req_aioredis2.set('foo', 'bar') + await async_redis.set('foo', 'bar') assert await r2.get('foo') == b'bar' await r2.connection_pool.disconnect() @@ -393,33 +352,3 @@ async def test_init_args(): assert await r3.get('bar') == b'baz' assert await r4.get('bar') == b'baz' assert await r1.get('bar') is None - - -@pytest.mark.load_lua_modules('cjson') -async def test_asgi_ratelimit_script(req_aioredis2: redis.Redis): - script = """ -local ruleset = cjson.decode(ARGV[1]) - --- Set limits -for i, key in pairs(KEYS) do - redis.call('SET', key, ruleset[key][1], 'EX', ruleset[key][2], 'NX') -end - --- Check limits -for i = 1, #KEYS do - local value = redis.call('GET', KEYS[i]) - if value and tonumber(value) < 1 then - return ruleset[KEYS[i]][2] - end -end - --- Decrease limits -for i, key in pairs(KEYS) do - redis.call('DECR', key) -end -return 0 -""" - - script = req_aioredis2.register_script(script) - ruleset = {"path:get:user:name": (1, 1)} - await script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)]) diff --git a/test/test_lua_modules.py b/test/test_lua_modules.py new file mode 100644 index 00000000..9a1776ee --- /dev/null +++ b/test/test_lua_modules.py @@ -0,0 +1,70 @@ +import json + +import pytest +import redis + +pytestmark = [ +] +pytestmark.extend([ + pytest.mark.asyncio, +]) + + +@pytest.mark.load_lua_modules('cjson') +async def test_async_asgi_ratelimit_script(async_redis: redis.Redis): + script = """ +local ruleset = cjson.decode(ARGV[1]) + +-- Set limits +for i, key in pairs(KEYS) do + redis.call('SET', key, ruleset[key][1], 'EX', ruleset[key][2], 'NX') +end + +-- Check limits +for i = 1, #KEYS do + local value = redis.call('GET', KEYS[i]) + if value and tonumber(value) < 1 then + return ruleset[KEYS[i]][2] + end +end + +-- Decrease limits +for i, key in pairs(KEYS) do + redis.call('DECR', key) +end +return 0 +""" + + script = async_redis.register_script(script) + ruleset = {"path:get:user:name": (1, 1)} + await script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)]) + + +@pytest.mark.load_lua_modules('cjson') +def test_asgi_ratelimit_script(r: redis.Redis): + script = """ +local ruleset = cjson.decode(ARGV[1]) + +-- Set limits +for i, key in pairs(KEYS) do + redis.call('SET', key, ruleset[key][1], 'EX', ruleset[key][2], 'NX') +end + +-- Check limits +for i = 1, #KEYS do + local value = redis.call('GET', KEYS[i]) + if value and tonumber(value) < 1 then + return ruleset[KEYS[i]][2] + end +end + +-- Decrease limits +for i, key in pairs(KEYS) do + redis.call('DECR', key) +end +return 0 +""" + + script = r.register_script(script) + ruleset = {"path:get:user:name": (1, 1)} + script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)]) diff --git a/test/test_mixins/test_scripting.py b/test/test_mixins/test_scripting.py index 19d31ef7..93ad8a60 100644 --- a/test/test_mixins/test_scripting.py +++ b/test/test_mixins/test_scripting.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import logging import pytest @@ -604,36 +603,6 @@ def test_hscan_cursors_are_bytes(r: redis.Redis): assert isinstance(result, bytes) -@pytest.mark.load_lua_modules('cjson') -def test_asgi_ratelimit_script(r: redis.Redis): - script = """ -local ruleset = cjson.decode(ARGV[1]) - --- Set limits -for i, key in pairs(KEYS) do - redis.call('SET', key, ruleset[key][1], 'EX', ruleset[key][2], 'NX') -end - --- Check limits -for i = 1, #KEYS do - local value = redis.call('GET', KEYS[i]) - if value and tonumber(value) < 1 then - return ruleset[KEYS[i]][2] - end -end - --- Decrease limits -for i, key in pairs(KEYS) do - redis.call('DECR', key) -end -return 0 -""" - - script = r.register_script(script) - ruleset = {"path:get:user:name": (1, 1)} - script(keys=list(ruleset.keys()), args=[json.dumps(ruleset)]) - - @pytest.mark.xfail # TODO def test_deleting_while_scan(r: redis.Redis): for i in range(100):