From f8e5161569802db64d1d7d1d4864f8a77c4fe2dd Mon Sep 17 00:00:00 2001 From: Kiall Mac Innes Date: Tue, 6 Jan 2015 14:13:49 +0000 Subject: [PATCH] Retry transactions on database deadlocks Retry transactions upon database deadlocks, helping to ensure all requests are processed even while there is contention on the domain.serial column. Additionally, due to the copy.deepcopy() introduced in the retry decorator, code which previously relied in the input values to central being mutated in place (only the tests, as everything else was over RPC so couldn't have been mutated in place) can no longer rely on this behavior. Change-Id: Id470608d7cc6c34c133803ba34b9bf242dc5e6ae Closes-Bug: 1408336 --- designate/central/service.py | 68 +++++++++++++- designate/tests/test_central/test_service.py | 96 +++++++++++++++++--- 2 files changed, 149 insertions(+), 15 deletions(-) diff --git a/designate/central/service.py b/designate/central/service.py index d45e2f38a..9faf16edb 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -16,17 +16,20 @@ # under the License. import re import collections +import copy import functools import threading import itertools import string import random +import time from oslo.config import cfg from oslo import messaging from oslo_log import log as logging from oslo_utils import excutils from oslo_concurrency import lockutils +from oslo_db import exception as db_exception from designate.i18n import _LI from designate.i18n import _LC @@ -46,21 +49,78 @@ LOG = logging.getLogger(__name__) DOMAIN_LOCKS = threading.local() NOTIFICATION_BUFFER = threading.local() +RETRY_STATE = threading.local() +def _retry_on_deadlock(exc): + """Filter to trigger retry a when a Deadlock is received.""" + # TODO(kiall): This is a total leak of the SQLA Driver, we'll need a better + # way to handle this. + if isinstance(exc, db_exception.DBDeadlock): + LOG.warn(_LW("Deadlock detected. Retrying...")) + return True + return False + + +def retry(cb=None, retries=50, delay=150): + """A retry decorator that ignores attempts at creating nested retries""" + def outer(f): + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + if not hasattr(RETRY_STATE, 'held'): + # Create the state vars if necessary + RETRY_STATE.held = False + RETRY_STATE.retries = 0 + + if not RETRY_STATE.held: + # We're the outermost retry decorator + RETRY_STATE.held = True + + try: + while True: + try: + result = f(self, *copy.deepcopy(args), + **copy.deepcopy(kwargs)) + break + except Exception as exc: + RETRY_STATE.retries += 1 + if RETRY_STATE.retries >= retries: + # Exceeded retry attempts, raise. + raise + elif cb is not None and cb(exc) is False: + # We're not setup to retry on this exception. + raise + else: + # Retry, with a delay. + time.sleep(delay / float(1000)) + + finally: + RETRY_STATE.held = False + RETRY_STATE.retries = 0 + + else: + # We're an inner retry decorator, just pass on through. + result = f(self, *copy.deepcopy(args), **copy.deepcopy(kwargs)) + + return result + return wrapper + return outer + + +# TODO(kiall): Get this a better home :) def transaction(f): - # TODO(kiall): Get this a better home :) + @retry(cb=_retry_on_deadlock) @functools.wraps(f) def wrapper(self, *args, **kwargs): self.storage.begin() try: result = f(self, *args, **kwargs) + self.storage.commit() + return result except Exception: with excutils.save_and_reraise_exception(): self.storage.rollback() - else: - self.storage.commit() - return result + return wrapper diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index 52d10b2c7..0804aee2d 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -17,10 +17,11 @@ import copy import random -from mock import patch -from oslo_log import log as logging import testtools from testtools.matchers import GreaterThan +from mock import patch +from oslo_log import log as logging +from oslo_db import exception as db_exception from designate import exceptions from designate import objects @@ -865,6 +866,39 @@ def test_update_domain_name_fail(self): with testtools.ExpectedException(exceptions.BadRequest): self.central_service.update_domain(self.admin_context, domain) + def test_update_domain_deadlock_retry(self): + # Create a domain + domain = self.create_domain(name='example.org.') + original_serial = domain.serial + + # Update the Object + domain.email = 'info@example.net' + + # Due to Python's scoping of i - we need to make it a mutable type + # for the counter to work.. In Py3, we can use the nonlocal keyword. + i = [False] + + def fail_once_then_pass(): + if i[0] is True: + return self.central_service.storage.session.commit() + else: + i[0] = True + raise db_exception.DBDeadlock() + + with patch.object(self.central_service.storage, 'commit', + side_effect=fail_once_then_pass): + # Perform the update + domain = self.central_service.update_domain( + self.admin_context, domain) + + # Ensure i[0] is True, indicating the side_effect code above was + # triggered + self.assertTrue(i[0]) + + # Ensure the domain was updated correctly + self.assertTrue(domain.serial > original_serial) + self.assertEqual('info@example.net', domain.email) + def test_delete_domain(self): # Create a domain domain = self.create_domain() @@ -1242,6 +1276,40 @@ def test_update_recordset(self): self.assertEqual(recordset.ttl, 1800) self.assertThat(new_serial, GreaterThan(original_serial)) + def test_update_recordset_deadlock_retry(self): + # Create a domain + domain = self.create_domain() + + # Create a recordset + recordset = self.create_recordset(domain) + + # Update the recordset + recordset.ttl = 1800 + + # Due to Python's scoping of i - we need to make it a mutable type + # for the counter to work.. In Py3, we can use the nonlocal keyword. + i = [False] + + def fail_once_then_pass(): + if i[0] is True: + return self.central_service.storage.session.commit() + else: + i[0] = True + raise db_exception.DBDeadlock() + + with patch.object(self.central_service.storage, 'commit', + side_effect=fail_once_then_pass): + # Perform the update + recordset = self.central_service.update_recordset( + self.admin_context, recordset) + + # Ensure i[0] is True, indicating the side_effect code above was + # triggered + self.assertTrue(i[0]) + + # Ensure the recordset was updated correctly + self.assertEqual(1800, recordset.ttl) + def test_update_recordset_with_record_create(self): # Create a domain domain = self.create_domain() @@ -2400,16 +2468,20 @@ def test_create_pool(self): # Compare the actual values of attributes and nameservers for k in range(0, len(values['attributes'])): - self.assertEqual( - pool['attributes'][k].to_primitive()['designate_object.data'], - values['attributes'][k].to_primitive()['designate_object.data'] + self.assertDictContainsSubset( + values['attributes'][k].to_primitive() + ['designate_object.data'], + pool['attributes'][k].to_primitive() + ['designate_object.data'] ) for k in range(0, len(values['nameservers'])): - self.assertEqual( - pool['nameservers'][k].to_primitive()['designate_object.data'], + self.assertDictContainsSubset( values['nameservers'][k].to_primitive() - ['designate_object.data']) + ['designate_object.data'], + pool['nameservers'][k].to_primitive() + ['designate_object.data'] + ) def test_get_pool(self): # Create a server pool @@ -2509,7 +2581,7 @@ def test_update_pool(self): for r in nameserver_values]) # Update pool - self.central_service.update_pool(self.admin_context, pool) + pool = self.central_service.update_pool(self.admin_context, pool) # GET the pool pool = self.central_service.get_pool(self.admin_context, pool.id) @@ -2523,14 +2595,16 @@ def test_update_pool(self): pool['attributes'][0].to_primitive()['designate_object.data'] expected_attributes = \ pool_attributes[0].to_primitive()['designate_object.data'] - self.assertEqual(actual_attributes, expected_attributes) + self.assertDictContainsSubset( + expected_attributes, actual_attributes) for k in range(0, len(pool_nameservers)): actual_nameservers = \ pool['nameservers'][k].to_primitive()['designate_object.data'] expected_nameservers = \ pool_nameservers[k].to_primitive()['designate_object.data'] - self.assertEqual(actual_nameservers, expected_nameservers) + self.assertDictContainsSubset( + expected_nameservers, actual_nameservers) def test_delete_pool(self): # Create a server pool