diff --git a/pybossa/redis_lock.py b/pybossa/redis_lock.py
index 4f2469d299..bb59649cb9 100644
--- a/pybossa/redis_lock.py
+++ b/pybossa/redis_lock.py
@@ -132,7 +132,7 @@ def __init__(self, cache, duration):
self._redis = cache
self._duration = duration
- def acquire_lock(self, resource_id, client_id, limit, pipeline=None):
+ def acquire_lock(self, resource_id, client_id, limit):
"""
Acquire a lock on a resource.
:param resource_id: resource on which lock is needed
@@ -145,13 +145,31 @@ def acquire_lock(self, resource_id, client_id, limit, pipeline=None):
self._release_expired_locks(resource_id, timestamp)
if self._redis.hexists(resource_id, client_id):
return True
- num_acquired = self._redis.hlen(resource_id)
- if num_acquired < limit:
- cache = pipeline or self._redis
- cache.hset(resource_id, client_id, expiration)
- cache.expire(resource_id, timedelta(seconds=self._duration))
+
+ pipeline = self._redis.pipeline()
+ if limit == float('inf'):
+ pipeline.hset(resource_id, client_id, expiration)
+ pipeline.expire(resource_id, timedelta(seconds=self._duration))
+ pipeline.execute()
return True
- return False
+
+ # Get a mutex lock for updating redis hash with default TTL 3s
+ lock_name = f"{resource_id}_update_mutex"
+ result = False
+
+ try:
+ with self._redis.lock(lock_name, timeout=3, blocking_timeout=1) as mutex:
+ if not mutex.locked():
+ return False
+
+ num_acquired = self._redis.hlen(resource_id)
+ if num_acquired < limit:
+ pipeline.hset(resource_id, client_id, expiration)
+ pipeline.expire(resource_id, timedelta(seconds=self._duration))
+ pipeline.execute()
+ result = True
+ finally:
+ return result
def has_lock(self, resource_id, client_id):
"""
@@ -177,6 +195,7 @@ def release_lock(self, resource_id, client_id, pipeline=None):
the database.
:param resource_id: resource on which lock is being held
:param client_id: id of client holding the lock
+ :param pipeline: object that can queue multiple commands for later execution
"""
cache = pipeline or self._redis
cache.hset(resource_id, client_id, time() + EXPIRE_LOCK_DELAY)
diff --git a/pybossa/sched.py b/pybossa/sched.py
index 4cc5d68815..ab3a888ed6 100644
--- a/pybossa/sched.py
+++ b/pybossa/sched.py
@@ -356,7 +356,7 @@ def template_get_locked_task(project_id, user_id=None, user_ip=None,
for task_id, taskcount, n_answers, calibration, _, _, timeout in rows:
timeout = timeout or TIMEOUT
remaining = float('inf') if calibration else n_answers - taskcount
- if acquire_lock(task_id, user_id, remaining, timeout):
+ if acquire_locks(task_id, user_id, remaining, timeout):
# reserve tasks
acquire_reserve_task_lock(project_id, task_id, user_id, timeout)
return _lock_task_for_user(task_id, project_id, user_id, timeout, calibration)
@@ -509,7 +509,6 @@ def locked_task_sql(project_id, user_id=None, limit=1, rand_within_priority=Fals
LIMIT :limit;
'''.format(' '.join(filters), task_category_filters,
','.join(order_by))
- print(sql)
return text(sql)
@@ -577,16 +576,12 @@ def has_lock(task_id, user_id, timeout):
return lock_manager.has_lock(task_users_key, user_id)
-def acquire_lock(task_id, user_id, limit, timeout, pipeline=None, execute=True):
- redis_conn = sentinel.master
- pipeline = pipeline or redis_conn.pipeline(transaction=True)
- lock_manager = LockManager(redis_conn, timeout)
+def acquire_locks(task_id, user_id, limit, timeout):
+ lock_manager = LockManager(sentinel.master, timeout)
task_users_key = get_task_users_key(task_id)
user_tasks_key = get_user_tasks_key(user_id)
- if lock_manager.acquire_lock(task_users_key, user_id, limit, pipeline=pipeline):
- lock_manager.acquire_lock(user_tasks_key, task_id, float('inf'), pipeline=pipeline)
- if execute:
- return all(not isinstance(r, Exception) for r in pipeline.execute())
+ if lock_manager.acquire_lock(task_users_key, user_id, limit):
+ lock_manager.acquire_lock(user_tasks_key, task_id, float('inf'))
return True
return False
@@ -686,7 +681,7 @@ def lock_task_for_user(task_id, project_id, user_id):
for task_id, taskcount, n_answers, calibration, timeout in rows:
timeout = timeout or TIMEOUT
remaining = float('inf') if calibration else n_answers - taskcount
- if acquire_lock(task_id, user_id, remaining, timeout):
+ if acquire_locks(task_id, user_id, remaining, timeout):
# reserve tasks
acquire_reserve_task_lock(project_id, task_id, user_id, timeout)
return _lock_task_for_user(task_id, project_id, user_id, timeout, calibration)
diff --git a/pybossa/view/projects.py b/pybossa/view/projects.py
index 7f17a9fa1d..44849f31f4 100644
--- a/pybossa/view/projects.py
+++ b/pybossa/view/projects.py
@@ -3225,8 +3225,11 @@ def add_coowner(short_name, user_name=None):
if user.id in project.owners_ids:
flash(gettext('User is already an owner'), 'warning')
else:
+ old_list = project.owners_ids.copy()
project.owners_ids.append(user.id)
project_repo.update(project)
+ auditlogger.log_event(project, current_user, 'update',
+ 'project.coowners', old_list, project.owners_ids)
flash(gettext('User was added to list of owners'), 'success')
return redirect_content_type(url_for(".coowners", short_name=short_name))
return abort(404)
@@ -3248,8 +3251,11 @@ def del_coowner(short_name, user_name=None):
elif user.id not in project.owners_ids:
flash(gettext('User is not a project owner'), 'error')
else:
+ old_list = project.owners_ids.copy()
project.owners_ids.remove(user.id)
project_repo.update(project)
+ auditlogger.log_event(project, current_user, 'update',
+ 'project.coowners', old_list, project.owners_ids)
flash(gettext('User was deleted from the list of owners'),
'success')
return redirect_content_type(url_for('.coowners', short_name=short_name))
diff --git a/settings_test.py.tmpl b/settings_test.py.tmpl
index 24f3a066ac..263eb7aa9c 100644
--- a/settings_test.py.tmpl
+++ b/settings_test.py.tmpl
@@ -84,7 +84,7 @@ LDAP_PYBOSSA_FIELDS = {'fullname': 'givenName',
WEEKLY_ADMIN_REPORTS_EMAIL = ['admin@admin.com']
FLASK_PROFILER = {
- "enabled": True,
+ "enabled": False, # disable so that sqlite works in multithreading
"storage": {
"engine": "sqlite"
},
diff --git a/test/test_api/test_project_api.py b/test/test_api/test_project_api.py
index c24e217d5a..a5afb17b77 100644
--- a/test/test_api/test_project_api.py
+++ b/test/test_api/test_project_api.py
@@ -17,6 +17,7 @@
# along with PYBOSSA. If not, see .
import copy
import json
+import threading
from unittest.mock import patch, call, MagicMock
from nose.tools import assert_equal
@@ -26,7 +27,7 @@
from pybossa.repositories import ResultRepository
from pybossa.repositories import TaskRepository
from pybossa.sched import Schedulers
-from test import db, with_context
+from test import db, with_context, flask_app
from test.factories import (ProjectFactory, TaskFactory, TaskRunFactory,
AnonymousTaskRunFactory, UserFactory,
CategoryFactory, AuditlogFactory)
@@ -1285,7 +1286,6 @@ def test_newtask_allow_anonymous_contributors(self, passwd_needed):
err_msg = "There should be a question"
assert task['info'].get('question') == 'answer', err_msg
-
@with_context
def test_newtask(self):
"""Test API project new_task method and authentication"""
@@ -1335,6 +1335,42 @@ def test_newtask(self):
task = json.loads(res.data)
assert task['id'] == tasks[0].id
+ @with_context
+ @patch('pybossa.api.pwd_manager.ProjectPasswdManager.password_needed')
+ def test_newtask_without_race_condition(self, password_needed):
+ """Test API project new_task method without race condition
+ It simulates 10 users grabbing 1 to 10 tasks simultaneously
+ """
+ password_needed.return_value = False
+ concurrent_user = 10
+
+ n_answers_list = list(range(1, concurrent_user + 1))
+ for n_answers in n_answers_list:
+ project = ProjectFactory.create()
+ project.info['sched'] = Schedulers.locked
+ project_repo.save(project)
+ users = UserFactory.create_batch(concurrent_user)
+ responses = []
+
+ def api_call(user):
+ with patch.dict(flask_app.config, {'RATE_LIMIT_BY_USER_ID': True}):
+ url = f'/api/project/{project.id}/newtask?api_key={user.api_key}'
+ # self.set_proj_passwd_cookie(project, user)
+ res = self.app.get(url)
+ if res.status_code == 200 and res.data != b'{}':
+ responses.append(json.loads(res.data))
+
+ task = TaskFactory.create(n_answers=n_answers, project=project)
+
+ threads = []
+ for u in users:
+ thread = threading.Thread(target=api_call, args=(u,))
+ threads.append(thread)
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+ assert_equal(len(responses), task.n_answers)
@with_context
@patch('pybossa.repositories.project_repository.uploader')
diff --git a/test/test_locked_sched.py b/test/test_locked_sched.py
index ea8456b6db..b0790a1bf4 100644
--- a/test/test_locked_sched.py
+++ b/test/test_locked_sched.py
@@ -15,6 +15,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with PYBOSSA. If not, see .
+import threading
+
+from nose.tools import assert_equal
from test.helper import sched
from pybossa.core import project_repo, task_repo
@@ -22,7 +25,7 @@
from pybossa.sched import (
Schedulers,
get_task_users_key,
- acquire_lock,
+ acquire_locks,
has_lock,
get_task_id_and_duration_for_project_user,
get_task_id_project_id_key,
@@ -243,14 +246,41 @@ def test_user_logout_unlocks_locked_tasks(self, release_lock):
assert get_task_users_key(task2.id) in key_args
@with_context
- def test_acquire_lock_no_pipeline(self):
+ def test_acquire_locks_no_pipeline(self):
task_id = 1
user_id = 1
limit = 1
timeout = 100
- acquire_lock(task_id, user_id, limit, timeout)
+ acquire_locks(task_id, user_id, limit, timeout)
assert has_lock(task_id, user_id, limit)
+ @with_context
+ def test_acquire_locks_concurrently(self):
+ """Test acquire locks using 10 concurrent users to grab limit number(loop from 1 to 10) of resources"""
+ con_current_user = 1
+ task_id = 1
+ user_ids = list(range(con_current_user))
+ limits = list(range(1, con_current_user + 1))
+ timeout = 100
+
+ for limit in limits:
+ results = [False] * con_current_user
+
+ def call_acquire_locks(u_id):
+ result = acquire_locks(task_id, u_id, limit, timeout)
+ results[u_id] = result
+
+ threads = []
+ for user_id in user_ids:
+ thread = threading.Thread(target=call_acquire_locks,
+ args=(user_id,))
+ threads.append(thread)
+ thread.start()
+
+ for thread in threads:
+ thread.join()
+ assert_equal(sum(results), limit)
+
@with_context
def test_get_task_id_and_duration_for_project_user_missing(self):
user = UserFactory.create()
@@ -259,7 +289,7 @@ def test_get_task_id_and_duration_for_project_user_missing(self):
task = TaskFactory.create_batch(1, project=project, n_answers=1)[0]
limit = 1
timeout = 100
- acquire_lock(task.id, user.id, limit, timeout)
+ acquire_locks(task.id, user.id, limit, timeout)
task_id, _ = get_task_id_and_duration_for_project_user(project.id, user.id)
# Redis client returns bytes string in Python3
@@ -507,7 +537,7 @@ def test_lock_expiration(self):
res = self.app.get('api/project/{}/newtask?api_key={}'
.format(project.id, owner.api_key))
# fake expired user lock
- acquire_lock(task1.id, 1000, 2, -10)
+ acquire_locks(task1.id, 1000, 2, -10)
res = self.app.get('api/project/{}/newtask?api_key={}'
.format(project.id, owner.api_key))