Skip to content

Commit

Permalink
RDISCROWD-5170 fix race condition (#753)
Browse files Browse the repository at this point in the history
* RDISCROWD-5170 fix Race condition in obtaining task

* remove debug info
  • Loading branch information
XiChenn committed Aug 11, 2022
1 parent 39345f9 commit 908d03a
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 26 deletions.
33 changes: 26 additions & 7 deletions pybossa/redis_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, cache, duration):
self._redis = cache
self._duration = duration

def acquire_lock(self, resource_id, client_id, limit, pipeline=None):
def acquire_lock(self, resource_id, client_id, limit):
"""
Acquire a lock on a resource.
:param resource_id: resource on which lock is needed
Expand All @@ -145,13 +145,31 @@ def acquire_lock(self, resource_id, client_id, limit, pipeline=None):
self._release_expired_locks(resource_id, timestamp)
if self._redis.hexists(resource_id, client_id):
return True
num_acquired = self._redis.hlen(resource_id)
if num_acquired < limit:
cache = pipeline or self._redis
cache.hset(resource_id, client_id, expiration)
cache.expire(resource_id, timedelta(seconds=self._duration))

pipeline = self._redis.pipeline()
if limit == float('inf'):
pipeline.hset(resource_id, client_id, expiration)
pipeline.expire(resource_id, timedelta(seconds=self._duration))
pipeline.execute()
return True
return False

# Get a mutex lock for updating redis hash with default TTL 3s
lock_name = f"{resource_id}_update_mutex"
result = False

try:
with self._redis.lock(lock_name, timeout=3, blocking_timeout=1) as mutex:
if not mutex.locked():
return False

num_acquired = self._redis.hlen(resource_id)
if num_acquired < limit:
pipeline.hset(resource_id, client_id, expiration)
pipeline.expire(resource_id, timedelta(seconds=self._duration))
pipeline.execute()
result = True
finally:
return result

def has_lock(self, resource_id, client_id):
"""
Expand All @@ -177,6 +195,7 @@ def release_lock(self, resource_id, client_id, pipeline=None):
the database.
:param resource_id: resource on which lock is being held
:param client_id: id of client holding the lock
:param pipeline: object that can queue multiple commands for later execution
"""
cache = pipeline or self._redis
cache.hset(resource_id, client_id, time() + EXPIRE_LOCK_DELAY)
Expand Down
17 changes: 6 additions & 11 deletions pybossa/sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def template_get_locked_task(project_id, user_id=None, user_ip=None,
for task_id, taskcount, n_answers, calibration, _, _, timeout in rows:
timeout = timeout or TIMEOUT
remaining = float('inf') if calibration else n_answers - taskcount
if acquire_lock(task_id, user_id, remaining, timeout):
if acquire_locks(task_id, user_id, remaining, timeout):
# reserve tasks
acquire_reserve_task_lock(project_id, task_id, user_id, timeout)
return _lock_task_for_user(task_id, project_id, user_id, timeout, calibration)
Expand Down Expand Up @@ -509,7 +509,6 @@ def locked_task_sql(project_id, user_id=None, limit=1, rand_within_priority=Fals
LIMIT :limit;
'''.format(' '.join(filters), task_category_filters,
','.join(order_by))
print(sql)
return text(sql)


Expand Down Expand Up @@ -577,16 +576,12 @@ def has_lock(task_id, user_id, timeout):
return lock_manager.has_lock(task_users_key, user_id)


def acquire_lock(task_id, user_id, limit, timeout, pipeline=None, execute=True):
redis_conn = sentinel.master
pipeline = pipeline or redis_conn.pipeline(transaction=True)
lock_manager = LockManager(redis_conn, timeout)
def acquire_locks(task_id, user_id, limit, timeout):
lock_manager = LockManager(sentinel.master, timeout)
task_users_key = get_task_users_key(task_id)
user_tasks_key = get_user_tasks_key(user_id)
if lock_manager.acquire_lock(task_users_key, user_id, limit, pipeline=pipeline):
lock_manager.acquire_lock(user_tasks_key, task_id, float('inf'), pipeline=pipeline)
if execute:
return all(not isinstance(r, Exception) for r in pipeline.execute())
if lock_manager.acquire_lock(task_users_key, user_id, limit):
lock_manager.acquire_lock(user_tasks_key, task_id, float('inf'))
return True
return False

Expand Down Expand Up @@ -686,7 +681,7 @@ def lock_task_for_user(task_id, project_id, user_id):
for task_id, taskcount, n_answers, calibration, timeout in rows:
timeout = timeout or TIMEOUT
remaining = float('inf') if calibration else n_answers - taskcount
if acquire_lock(task_id, user_id, remaining, timeout):
if acquire_locks(task_id, user_id, remaining, timeout):
# reserve tasks
acquire_reserve_task_lock(project_id, task_id, user_id, timeout)
return _lock_task_for_user(task_id, project_id, user_id, timeout, calibration)
Expand Down
2 changes: 1 addition & 1 deletion settings_test.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ LDAP_PYBOSSA_FIELDS = {'fullname': 'givenName',
WEEKLY_ADMIN_REPORTS_EMAIL = ['admin@admin.com']

FLASK_PROFILER = {
"enabled": True,
"enabled": False, # disable so that sqlite works in multithreading
"storage": {
"engine": "sqlite"
},
Expand Down
40 changes: 38 additions & 2 deletions test/test_api/test_project_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# along with PYBOSSA. If not, see <http://www.gnu.org/licenses/>.
import copy
import json
import threading
from unittest.mock import patch, call, MagicMock

from nose.tools import assert_equal
Expand All @@ -26,7 +27,7 @@
from pybossa.repositories import ResultRepository
from pybossa.repositories import TaskRepository
from pybossa.sched import Schedulers
from test import db, with_context
from test import db, with_context, flask_app
from test.factories import (ProjectFactory, TaskFactory, TaskRunFactory,
AnonymousTaskRunFactory, UserFactory,
CategoryFactory, AuditlogFactory)
Expand Down Expand Up @@ -1285,7 +1286,6 @@ def test_newtask_allow_anonymous_contributors(self, passwd_needed):
err_msg = "There should be a question"
assert task['info'].get('question') == 'answer', err_msg


@with_context
def test_newtask(self):
"""Test API project new_task method and authentication"""
Expand Down Expand Up @@ -1335,6 +1335,42 @@ def test_newtask(self):
task = json.loads(res.data)
assert task['id'] == tasks[0].id

@with_context
@patch('pybossa.api.pwd_manager.ProjectPasswdManager.password_needed')
def test_newtask_without_race_condition(self, password_needed):
"""Test API project new_task method without race condition
It simulates 10 users grabbing 1 to 10 tasks simultaneously
"""
password_needed.return_value = False
concurrent_user = 10

n_answers_list = list(range(1, concurrent_user + 1))
for n_answers in n_answers_list:
project = ProjectFactory.create()
project.info['sched'] = Schedulers.locked
project_repo.save(project)
users = UserFactory.create_batch(concurrent_user)
responses = []

def api_call(user):
with patch.dict(flask_app.config, {'RATE_LIMIT_BY_USER_ID': True}):
url = f'/api/project/{project.id}/newtask?api_key={user.api_key}'
# self.set_proj_passwd_cookie(project, user)
res = self.app.get(url)
if res.status_code == 200 and res.data != b'{}':
responses.append(json.loads(res.data))

task = TaskFactory.create(n_answers=n_answers, project=project)

threads = []
for u in users:
thread = threading.Thread(target=api_call, args=(u,))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()
assert_equal(len(responses), task.n_answers)

@with_context
@patch('pybossa.repositories.project_repository.uploader')
Expand Down
40 changes: 35 additions & 5 deletions test/test_locked_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with PYBOSSA. If not, see <http://www.gnu.org/licenses/>.
import threading

from nose.tools import assert_equal

from test.helper import sched
from pybossa.core import project_repo, task_repo
from test.factories import TaskFactory, ProjectFactory, UserFactory
from pybossa.sched import (
Schedulers,
get_task_users_key,
acquire_lock,
acquire_locks,
has_lock,
get_task_id_and_duration_for_project_user,
get_task_id_project_id_key,
Expand Down Expand Up @@ -243,14 +246,41 @@ def test_user_logout_unlocks_locked_tasks(self, release_lock):
assert get_task_users_key(task2.id) in key_args

@with_context
def test_acquire_lock_no_pipeline(self):
def test_acquire_locks_no_pipeline(self):
task_id = 1
user_id = 1
limit = 1
timeout = 100
acquire_lock(task_id, user_id, limit, timeout)
acquire_locks(task_id, user_id, limit, timeout)
assert has_lock(task_id, user_id, limit)

@with_context
def test_acquire_locks_concurrently(self):
"""Test acquire locks using 10 concurrent users to grab limit number(loop from 1 to 10) of resources"""
con_current_user = 1
task_id = 1
user_ids = list(range(con_current_user))
limits = list(range(1, con_current_user + 1))
timeout = 100

for limit in limits:
results = [False] * con_current_user

def call_acquire_locks(u_id):
result = acquire_locks(task_id, u_id, limit, timeout)
results[u_id] = result

threads = []
for user_id in user_ids:
thread = threading.Thread(target=call_acquire_locks,
args=(user_id,))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()
assert_equal(sum(results), limit)

@with_context
def test_get_task_id_and_duration_for_project_user_missing(self):
user = UserFactory.create()
Expand All @@ -259,7 +289,7 @@ def test_get_task_id_and_duration_for_project_user_missing(self):
task = TaskFactory.create_batch(1, project=project, n_answers=1)[0]
limit = 1
timeout = 100
acquire_lock(task.id, user.id, limit, timeout)
acquire_locks(task.id, user.id, limit, timeout)
task_id, _ = get_task_id_and_duration_for_project_user(project.id, user.id)

# Redis client returns bytes string in Python3
Expand Down Expand Up @@ -507,7 +537,7 @@ def test_lock_expiration(self):
res = self.app.get('api/project/{}/newtask?api_key={}'
.format(project.id, owner.api_key))
# fake expired user lock
acquire_lock(task1.id, 1000, 2, -10)
acquire_locks(task1.id, 1000, 2, -10)

res = self.app.get('api/project/{}/newtask?api_key={}'
.format(project.id, owner.api_key))
Expand Down

0 comments on commit 908d03a

Please sign in to comment.