diff --git a/fakeredis.py b/fakeredis.py index 921e3ff..5134b3b 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -8,15 +8,15 @@ from datetime import datetime, timedelta import operator import sys -import threading import time import types import re import functools from itertools import count, islice +from uuid import uuid4 import redis -from redis.exceptions import ResponseError +from redis.exceptions import ResponseError, LockError import redis.client try: @@ -315,8 +315,8 @@ class _Lock(object): def __init__(self, redis, name, timeout): self.redis = redis self.name = name - self.lock = threading.Lock() - redis.set(name, self, ex=timeout) + self.timeout = timeout + self.id = str(uuid4()) def __enter__(self): self.acquire() @@ -326,11 +326,17 @@ def __exit__(self, exc_type, exc_value, traceback): self.release() def acquire(self, blocking=True, blocking_timeout=None): - return self.lock.acquire(blocking) + acquired = bool(self.redis.set(self.name, self.id, nx=True, ex=self.timeout)) + if not acquired and blocking: + raise ValueError('fakeredis can\'t do blocking locks') + + return acquired def release(self): - self.lock.release() - self.redis.delete(self.name) + if _decode(self.redis.get(self.name)) == self.id: + self.redis.delete(self.name) + else: + raise LockError('Cannot release an unlocked lock') def _check_conn(func): diff --git a/test_fakeredis.py b/test_fakeredis.py index 715559c..2ccdf13 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -3916,6 +3916,38 @@ def test_lock(self): self.assertTrue(self.redis.exists('bar')) self.assertFalse(self.redis.exists('bar')) + def test_acquiring_lock_twice(self): + lock = self.redis.lock('foo') + self.assertTrue(lock.acquire(blocking=False)) + self.assertFalse(lock.acquire(blocking=False)) + + def test_acquiring_lock_different_lock(self): + lock1 = self.redis.lock('foo') + lock2 = self.redis.lock('foo') + self.assertTrue(lock1.acquire(blocking=False)) + self.assertFalse(lock2.acquire(blocking=False)) + + def test_acquiring_lock_different_lock_release(self): + lock1 = self.redis.lock('foo') + lock2 = self.redis.lock('foo') + self.assertTrue(lock1.acquire(blocking=False)) + self.assertFalse(lock2.acquire(blocking=False)) + + # Test only releasing lock1 actually releases the lock + with self.assertRaises(redis.exceptions.LockError): + lock2.release() + self.assertFalse(lock2.acquire(blocking=False)) + lock1.release() + + # Locking with lock2 now has the lock + self.assertTrue(lock2.acquire(blocking=False)) + self.assertFalse(lock1.acquire(blocking=False)) + + def test_nested_lock(self): + with self.redis.lock('bar'): + acquired = self.redis.lock('bar').acquire(blocking=False) + self.assertFalse(acquired) + class DecodeMixin(object): decode_responses = True