Skip to content

Commit

Permalink
Allow set parameter for acquire for context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
tavor118sn authored and ionelmc committed May 1, 2024
1 parent 1893d3a commit e5f596f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
8 changes: 8 additions & 0 deletions docs/usage.rst
Expand Up @@ -50,6 +50,14 @@ The above example could be rewritten using context manager::
print("Got the lock. Doing some work ...")
time.sleep(5)

You can pass `blocking=False` parameter to the contex manager (default value
is True, will raise a NotAcquired exception if lock won't be acquired)::

conn = StrictRedis()
with redis_lock.Lock(conn, "name-of-the-lock", blocking=False):
print("Got the lock. Doing some work ...")
time.sleep(5)

In cases, where lock not necessarily in acquired state, and
user need to ensure, that it has a matching ``id``, example::

Expand Down
14 changes: 11 additions & 3 deletions src/redis_lock/__init__.py
Expand Up @@ -99,11 +99,12 @@ class Lock(object):
extend_script = None
reset_script = None
reset_all_script = None
blocking = None

_lock_renewal_interval: float
_lock_renewal_thread: Union[threading.Thread, None]

def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000):
def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000, blocking=True):
"""
:param redis_client:
An instance of :class:`~StrictRedis`.
Expand Down Expand Up @@ -131,6 +132,9 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False,
If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
:param signal_expire:
Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``.
:param blocking:
Boolean value specifying whether lock should be blocking or not.
Used in `__enter__` method.
"""
if strict and not isinstance(redis_client, StrictRedis):
raise ValueError("redis_client must be instance of StrictRedis. Use strict=False if you know what you're doing.")
Expand Down Expand Up @@ -164,6 +168,8 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False,
self._lock_renewal_interval = float(expire) * 2 / 3 if auto_renewal else None
self._lock_renewal_thread = None

self.blocking = blocking

self.register_scripts(redis_client)

@classmethod
Expand Down Expand Up @@ -318,9 +324,11 @@ def _stop_lock_renewer(self):
logger_for_refresh_exit.debug("Renewal thread for Lock(%r) exited.", self._name)

def __enter__(self):
acquired = self.acquire(blocking=True)
acquired = self.acquire(blocking=self.blocking)
if not acquired:
raise AssertionError(f"Lock({self._name}) wasn't acquired, but blocking=True was used!")
if self.blocking:
raise AssertionError(f"Lock({self._name}) wasn't acquired, but blocking=True was used!")
raise NotAcquired(f"Lock({self._name}) is not acquired or it already expired.")
return self

def __exit__(self, exc_type=None, exc_value=None, traceback=None):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_redis_lock.py
Expand Up @@ -341,6 +341,16 @@ def test_double_acquire(conn):
pytest.raises(AlreadyAcquired, lock.acquire)


def test_enter_already_acquired_with_not_blocking(conn):
lock = Lock(conn, "foobar")
acquired = lock.acquire()
assert acquired

with pytest.raises(NotAcquired):
with Lock(conn, "foobar", blocking=False):
pass


def test_plain(conn):
with Lock(conn, "foobar"):
time.sleep(0.01)
Expand Down

0 comments on commit e5f596f

Please sign in to comment.