Skip to content

Commit

Permalink
Enable efficient chord when using dynamicdb as backend store (#8783)
Browse files Browse the repository at this point in the history
* test

* add unit test

* test

* revert bad test chamnge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dingxiong and pre-commit-ci[bot] committed Jan 29, 2024
1 parent 2b3fde4 commit 86895a9
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 5 deletions.
2 changes: 1 addition & 1 deletion celery/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def on_chord_part_return(self, request, state, result, **kwargs):
)
finally:
deps.delete()
self.client.delete(key)
self.delete(key)
else:
self.expire(key, self.expires)

Expand Down
54 changes: 54 additions & 0 deletions celery/backends/dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""AWS DynamoDB result store backend."""
from collections import namedtuple
from time import sleep, time
from typing import Any, Dict

from kombu.utils.url import _parse_url as parse_url

Expand Down Expand Up @@ -54,11 +55,15 @@ class DynamoDBBackend(KeyValueStoreBackend):
supports_autoexpire = True

_key_field = DynamoDBAttribute(name='id', data_type='S')
# Each record has either a value field or count field
_value_field = DynamoDBAttribute(name='result', data_type='B')
_count_filed = DynamoDBAttribute(name="chord_count", data_type='N')
_timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N')
_ttl_field = DynamoDBAttribute(name='ttl', data_type='N')
_available_fields = None

implements_incr = True

def __init__(self, url=None, table_name=None, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -459,6 +464,40 @@ def _prepare_put_request(self, key, value):
})
return put_request

def _prepare_init_count_request(self, key: str) -> Dict[str, Any]:
"""Construct the counter initialization request parameters"""
timestamp = time()
return {
'TableName': self.table_name,
'Item': {
self._key_field.name: {
self._key_field.data_type: key
},
self._count_filed.name: {
self._count_filed.data_type: "0"
},
self._timestamp_field.name: {
self._timestamp_field.data_type: str(timestamp)
}
}
}

def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]:
"""Construct the counter increment request parameters"""
return {
'TableName': self.table_name,
'Key': {
self._key_field.name: {
self._key_field.data_type: key
}
},
'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num",
"ExpressionAttributeValues": {
":num": {"N": "1"},
},
"ReturnValues" : "UPDATED_NEW",
}

def _item_to_dict(self, raw_response):
"""Convert get_item() response to field-value pairs."""
if 'Item' not in raw_response:
Expand Down Expand Up @@ -491,3 +530,18 @@ def delete(self, key):
key = str(key)
request_parameters = self._prepare_get_request(key)
self.client.delete_item(**request_parameters)

def incr(self, key: bytes) -> int:
"""Atomically increase the chord_count and return the new count"""
key = str(key)
request_parameters = self._prepare_inc_count_request(key)
item_response = self.client.update_item(**request_parameters)
new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type]
return int(new_count)

def _apply_chord_incr(self, header_result_args, body, **kwargs):
chord_key = self.get_key_for_chord(header_result_args[0])
init_count_request = self._prepare_init_count_request(str(chord_key))
self.client.put_item(**init_count_request)
return super()._apply_chord_incr(
header_result_args, body, **kwargs)
4 changes: 2 additions & 2 deletions docs/userguide/canvas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1000,11 +1000,11 @@ Example implementation:
raise self.retry(countdown=interval, max_retries=max_retries)
This is used by all result backends except Redis and Memcached: they
This is used by all result backends except Redis, Memcached and DynamoDB: they
increment a counter after each task in the header, then applies the callback
when the counter exceeds the number of tasks in the set.

The Redis and Memcached approach is a much better solution, but not easily
The Redis, Memcached and DynamoDB approach is a much better solution, but not easily
implemented in other backends (suggestions welcome!).

.. note::
Expand Down
95 changes: 93 additions & 2 deletions t/unit/backends/test_dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from decimal import Decimal
from unittest.mock import MagicMock, Mock, patch, sentinel
from unittest.mock import ANY, MagicMock, Mock, call, patch, sentinel

import pytest

from celery import states
from celery import states, uuid
from celery.backends import dynamodb as module
from celery.backends.dynamodb import DynamoDBBackend
from celery.exceptions import ImproperlyConfigured
Expand Down Expand Up @@ -426,6 +426,34 @@ def test_prepare_put_request_with_ttl(self):
result = self.backend._prepare_put_request('abcdef', 'val')
assert result == expected

def test_prepare_init_count_request(self):
expected = {
'TableName': 'celery',
'Item': {
'id': {'S': 'abcdef'},
'chord_count': {'N': '0'},
'timestamp': {
'N': str(Decimal(self._static_timestamp))
},
}
}
with patch('celery.backends.dynamodb.time', self._mock_time):
result = self.backend._prepare_init_count_request('abcdef')
assert result == expected

def test_prepare_inc_count_request(self):
expected = {
'TableName': 'celery',
'Key': {
'id': {'S': 'abcdef'},
},
'UpdateExpression': 'set chord_count = chord_count + :num',
'ExpressionAttributeValues': {":num": {"N": "1"}},
'ReturnValues': 'UPDATED_NEW',
}
result = self.backend._prepare_inc_count_request('abcdef')
assert result == expected

def test_item_to_dict(self):
boto_response = {
'Item': {
Expand Down Expand Up @@ -517,6 +545,39 @@ def test_delete(self):
TableName='celery'
)

def test_inc(self):
mocked_incr_response = {
'Attributes': {
'chord_count': {
'N': '1'
}
},
'ResponseMetadata': {
'RequestId': '16d31c72-51f6-4538-9415-499f1135dc59',
'HTTPStatusCode': 200,
'HTTPHeaders': {
'date': 'Wed, 10 Jan 2024 17:53:41 GMT',
'x-amzn-requestid': '16d31c72-51f6-4538-9415-499f1135dc59',
'content-type': 'application/x-amz-json-1.0',
'x-amz-crc32': '3438282865',
'content-length': '40',
'server': 'Jetty(11.0.17)'
},
'RetryAttempts': 0
}
}
self.backend._client = MagicMock()
self.backend._client.update_item = MagicMock(return_value=mocked_incr_response)

assert self.backend.incr('1f3fab') == 1
self.backend.client.update_item.assert_called_once_with(
Key={'id': {'S': '1f3fab'}},
TableName='celery',
UpdateExpression='set chord_count = chord_count + :num',
ExpressionAttributeValues={":num": {"N": "1"}},
ReturnValues='UPDATED_NEW',
)

def test_backend_by_url(self, url='dynamodb://'):
from celery.app import backends
from celery.backends.dynamodb import DynamoDBBackend
Expand All @@ -537,3 +598,33 @@ def test_backend_params_by_url(self):
assert self.backend.write_capacity_units == 20
assert self.backend.time_to_live_seconds == 600
assert self.backend.endpoint_url is None

def test_apply_chord(self, unlock="celery.chord_unlock"):
self.app.tasks[unlock] = Mock()
chord_uuid = uuid()
header_result_args = (
chord_uuid,
[self.app.AsyncResult(x) for x in range(3)],
)
self.backend._client = MagicMock()
self.backend.apply_chord(header_result_args, None)
assert self.backend._client.put_item.call_args_list == [
call(
TableName="celery",
Item={
"id": {"S": f"b'chord-unlock-{chord_uuid}'"},
"chord_count": {"N": "0"},
"timestamp": {"N": ANY},
},
),
call(
TableName="celery",
Item={
"id": {"S": f"b'celery-taskset-meta-{chord_uuid}'"},
"result": {
"B": ANY,
},
"timestamp": {"N": ANY},
},
),
]

0 comments on commit 86895a9

Please sign in to comment.