diff --git a/pybossa/redis_lock.py b/pybossa/redis_lock.py index 4f2469d299..bb59649cb9 100644 --- a/pybossa/redis_lock.py +++ b/pybossa/redis_lock.py @@ -132,7 +132,7 @@ def __init__(self, cache, duration): self._redis = cache self._duration = duration - def acquire_lock(self, resource_id, client_id, limit, pipeline=None): + def acquire_lock(self, resource_id, client_id, limit): """ Acquire a lock on a resource. :param resource_id: resource on which lock is needed @@ -145,13 +145,31 @@ def acquire_lock(self, resource_id, client_id, limit, pipeline=None): self._release_expired_locks(resource_id, timestamp) if self._redis.hexists(resource_id, client_id): return True - num_acquired = self._redis.hlen(resource_id) - if num_acquired < limit: - cache = pipeline or self._redis - cache.hset(resource_id, client_id, expiration) - cache.expire(resource_id, timedelta(seconds=self._duration)) + + pipeline = self._redis.pipeline() + if limit == float('inf'): + pipeline.hset(resource_id, client_id, expiration) + pipeline.expire(resource_id, timedelta(seconds=self._duration)) + pipeline.execute() return True - return False + + # Get a mutex lock for updating redis hash with default TTL 3s + lock_name = f"{resource_id}_update_mutex" + result = False + + try: + with self._redis.lock(lock_name, timeout=3, blocking_timeout=1) as mutex: + if not mutex.locked(): + return False + + num_acquired = self._redis.hlen(resource_id) + if num_acquired < limit: + pipeline.hset(resource_id, client_id, expiration) + pipeline.expire(resource_id, timedelta(seconds=self._duration)) + pipeline.execute() + result = True + finally: + return result def has_lock(self, resource_id, client_id): """ @@ -177,6 +195,7 @@ def release_lock(self, resource_id, client_id, pipeline=None): the database. :param resource_id: resource on which lock is being held :param client_id: id of client holding the lock + :param pipeline: object that can queue multiple commands for later execution """ cache = pipeline or self._redis cache.hset(resource_id, client_id, time() + EXPIRE_LOCK_DELAY) diff --git a/pybossa/sched.py b/pybossa/sched.py index 4cc5d68815..ab3a888ed6 100644 --- a/pybossa/sched.py +++ b/pybossa/sched.py @@ -356,7 +356,7 @@ def template_get_locked_task(project_id, user_id=None, user_ip=None, for task_id, taskcount, n_answers, calibration, _, _, timeout in rows: timeout = timeout or TIMEOUT remaining = float('inf') if calibration else n_answers - taskcount - if acquire_lock(task_id, user_id, remaining, timeout): + if acquire_locks(task_id, user_id, remaining, timeout): # reserve tasks acquire_reserve_task_lock(project_id, task_id, user_id, timeout) return _lock_task_for_user(task_id, project_id, user_id, timeout, calibration) @@ -509,7 +509,6 @@ def locked_task_sql(project_id, user_id=None, limit=1, rand_within_priority=Fals LIMIT :limit; '''.format(' '.join(filters), task_category_filters, ','.join(order_by)) - print(sql) return text(sql) @@ -577,16 +576,12 @@ def has_lock(task_id, user_id, timeout): return lock_manager.has_lock(task_users_key, user_id) -def acquire_lock(task_id, user_id, limit, timeout, pipeline=None, execute=True): - redis_conn = sentinel.master - pipeline = pipeline or redis_conn.pipeline(transaction=True) - lock_manager = LockManager(redis_conn, timeout) +def acquire_locks(task_id, user_id, limit, timeout): + lock_manager = LockManager(sentinel.master, timeout) task_users_key = get_task_users_key(task_id) user_tasks_key = get_user_tasks_key(user_id) - if lock_manager.acquire_lock(task_users_key, user_id, limit, pipeline=pipeline): - lock_manager.acquire_lock(user_tasks_key, task_id, float('inf'), pipeline=pipeline) - if execute: - return all(not isinstance(r, Exception) for r in pipeline.execute()) + if lock_manager.acquire_lock(task_users_key, user_id, limit): + lock_manager.acquire_lock(user_tasks_key, task_id, float('inf')) return True return False @@ -686,7 +681,7 @@ def lock_task_for_user(task_id, project_id, user_id): for task_id, taskcount, n_answers, calibration, timeout in rows: timeout = timeout or TIMEOUT remaining = float('inf') if calibration else n_answers - taskcount - if acquire_lock(task_id, user_id, remaining, timeout): + if acquire_locks(task_id, user_id, remaining, timeout): # reserve tasks acquire_reserve_task_lock(project_id, task_id, user_id, timeout) return _lock_task_for_user(task_id, project_id, user_id, timeout, calibration) diff --git a/pybossa/view/projects.py b/pybossa/view/projects.py index 7f17a9fa1d..44849f31f4 100644 --- a/pybossa/view/projects.py +++ b/pybossa/view/projects.py @@ -3225,8 +3225,11 @@ def add_coowner(short_name, user_name=None): if user.id in project.owners_ids: flash(gettext('User is already an owner'), 'warning') else: + old_list = project.owners_ids.copy() project.owners_ids.append(user.id) project_repo.update(project) + auditlogger.log_event(project, current_user, 'update', + 'project.coowners', old_list, project.owners_ids) flash(gettext('User was added to list of owners'), 'success') return redirect_content_type(url_for(".coowners", short_name=short_name)) return abort(404) @@ -3248,8 +3251,11 @@ def del_coowner(short_name, user_name=None): elif user.id not in project.owners_ids: flash(gettext('User is not a project owner'), 'error') else: + old_list = project.owners_ids.copy() project.owners_ids.remove(user.id) project_repo.update(project) + auditlogger.log_event(project, current_user, 'update', + 'project.coowners', old_list, project.owners_ids) flash(gettext('User was deleted from the list of owners'), 'success') return redirect_content_type(url_for('.coowners', short_name=short_name)) diff --git a/settings_test.py.tmpl b/settings_test.py.tmpl index 24f3a066ac..263eb7aa9c 100644 --- a/settings_test.py.tmpl +++ b/settings_test.py.tmpl @@ -84,7 +84,7 @@ LDAP_PYBOSSA_FIELDS = {'fullname': 'givenName', WEEKLY_ADMIN_REPORTS_EMAIL = ['admin@admin.com'] FLASK_PROFILER = { - "enabled": True, + "enabled": False, # disable so that sqlite works in multithreading "storage": { "engine": "sqlite" }, diff --git a/test/test_api/test_project_api.py b/test/test_api/test_project_api.py index c24e217d5a..a5afb17b77 100644 --- a/test/test_api/test_project_api.py +++ b/test/test_api/test_project_api.py @@ -17,6 +17,7 @@ # along with PYBOSSA. If not, see . import copy import json +import threading from unittest.mock import patch, call, MagicMock from nose.tools import assert_equal @@ -26,7 +27,7 @@ from pybossa.repositories import ResultRepository from pybossa.repositories import TaskRepository from pybossa.sched import Schedulers -from test import db, with_context +from test import db, with_context, flask_app from test.factories import (ProjectFactory, TaskFactory, TaskRunFactory, AnonymousTaskRunFactory, UserFactory, CategoryFactory, AuditlogFactory) @@ -1285,7 +1286,6 @@ def test_newtask_allow_anonymous_contributors(self, passwd_needed): err_msg = "There should be a question" assert task['info'].get('question') == 'answer', err_msg - @with_context def test_newtask(self): """Test API project new_task method and authentication""" @@ -1335,6 +1335,42 @@ def test_newtask(self): task = json.loads(res.data) assert task['id'] == tasks[0].id + @with_context + @patch('pybossa.api.pwd_manager.ProjectPasswdManager.password_needed') + def test_newtask_without_race_condition(self, password_needed): + """Test API project new_task method without race condition + It simulates 10 users grabbing 1 to 10 tasks simultaneously + """ + password_needed.return_value = False + concurrent_user = 10 + + n_answers_list = list(range(1, concurrent_user + 1)) + for n_answers in n_answers_list: + project = ProjectFactory.create() + project.info['sched'] = Schedulers.locked + project_repo.save(project) + users = UserFactory.create_batch(concurrent_user) + responses = [] + + def api_call(user): + with patch.dict(flask_app.config, {'RATE_LIMIT_BY_USER_ID': True}): + url = f'/api/project/{project.id}/newtask?api_key={user.api_key}' + # self.set_proj_passwd_cookie(project, user) + res = self.app.get(url) + if res.status_code == 200 and res.data != b'{}': + responses.append(json.loads(res.data)) + + task = TaskFactory.create(n_answers=n_answers, project=project) + + threads = [] + for u in users: + thread = threading.Thread(target=api_call, args=(u,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + assert_equal(len(responses), task.n_answers) @with_context @patch('pybossa.repositories.project_repository.uploader') diff --git a/test/test_locked_sched.py b/test/test_locked_sched.py index ea8456b6db..b0790a1bf4 100644 --- a/test/test_locked_sched.py +++ b/test/test_locked_sched.py @@ -15,6 +15,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with PYBOSSA. If not, see . +import threading + +from nose.tools import assert_equal from test.helper import sched from pybossa.core import project_repo, task_repo @@ -22,7 +25,7 @@ from pybossa.sched import ( Schedulers, get_task_users_key, - acquire_lock, + acquire_locks, has_lock, get_task_id_and_duration_for_project_user, get_task_id_project_id_key, @@ -243,14 +246,41 @@ def test_user_logout_unlocks_locked_tasks(self, release_lock): assert get_task_users_key(task2.id) in key_args @with_context - def test_acquire_lock_no_pipeline(self): + def test_acquire_locks_no_pipeline(self): task_id = 1 user_id = 1 limit = 1 timeout = 100 - acquire_lock(task_id, user_id, limit, timeout) + acquire_locks(task_id, user_id, limit, timeout) assert has_lock(task_id, user_id, limit) + @with_context + def test_acquire_locks_concurrently(self): + """Test acquire locks using 10 concurrent users to grab limit number(loop from 1 to 10) of resources""" + con_current_user = 1 + task_id = 1 + user_ids = list(range(con_current_user)) + limits = list(range(1, con_current_user + 1)) + timeout = 100 + + for limit in limits: + results = [False] * con_current_user + + def call_acquire_locks(u_id): + result = acquire_locks(task_id, u_id, limit, timeout) + results[u_id] = result + + threads = [] + for user_id in user_ids: + thread = threading.Thread(target=call_acquire_locks, + args=(user_id,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + assert_equal(sum(results), limit) + @with_context def test_get_task_id_and_duration_for_project_user_missing(self): user = UserFactory.create() @@ -259,7 +289,7 @@ def test_get_task_id_and_duration_for_project_user_missing(self): task = TaskFactory.create_batch(1, project=project, n_answers=1)[0] limit = 1 timeout = 100 - acquire_lock(task.id, user.id, limit, timeout) + acquire_locks(task.id, user.id, limit, timeout) task_id, _ = get_task_id_and_duration_for_project_user(project.id, user.id) # Redis client returns bytes string in Python3 @@ -507,7 +537,7 @@ def test_lock_expiration(self): res = self.app.get('api/project/{}/newtask?api_key={}' .format(project.id, owner.api_key)) # fake expired user lock - acquire_lock(task1.id, 1000, 2, -10) + acquire_locks(task1.id, 1000, 2, -10) res = self.app.get('api/project/{}/newtask?api_key={}' .format(project.id, owner.api_key))