Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow custom serializer and deserializers for task #92

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions tasktiger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ def __init__(self, connection=None, config=None, setup_structlog=False):

# If non-empty, a worker excludes the given queues from processing.
'EXCLUDE_QUEUES': [],

# Serializer / Deserilaizer to use for serializing/deserializing tasks

'SERIALIZER': json.dumps,

'DESERIALIZER': json.loads

}
if config:
self.config.update(config)
Expand Down Expand Up @@ -193,6 +200,9 @@ def __init__(self, connection=None, config=None, setup_structlog=False):
# List of task functions that are executed periodically.
self.periodic_task_funcs = {}

self._serialize = self.config['SERIALIZER']
self._deserialize = self.config['DESERIALIZER']

def _get_current_task(self):
if g['current_tasks'] is None:
raise RuntimeError('Must be accessed from within a task')
Expand Down
13 changes: 6 additions & 7 deletions tasktiger/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import json
import redis
import time

Expand Down Expand Up @@ -280,7 +279,7 @@ def delay(self, when=None):

# When using ALWAYS_EAGER, make sure we have serialized the task to
# ensure there are no serialization errors.
serialized_task = json.dumps(self._data)
serialized_task = self.tiger._serialize(self._data)

if tiger.config['ALWAYS_EAGER'] and state == QUEUED:
return self.execute()
Expand Down Expand Up @@ -341,8 +340,8 @@ def from_id(self, tiger, queue, state, task_id, load_executions=0):
serialized_executions = []
# XXX: No timestamp for now
if serialized_data:
data = json.loads(serialized_data)
executions = [json.loads(e) for e in serialized_executions if e]
data = tiger._deserialize(serialized_data)
executions = [tiger._deserialize(e) for e in serialized_executions if e]
return Task(tiger, queue=queue, _data=data, _state=state,
_executions=executions)
else:
Expand Down Expand Up @@ -380,8 +379,8 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000,
results = pipeline.execute()

for serialized_data, serialized_executions, ts in zip(results[0], results[1:], tss):
data = json.loads(serialized_data)
executions = [json.loads(e) for e in serialized_executions if e]
data = tiger._deserialize(serialized_data)
executions = [tiger._deserialize(e) for e in serialized_executions if e]

task = Task(tiger, queue=queue, _data=data, _state=state,
_ts=ts, _executions=executions)
Expand All @@ -390,7 +389,7 @@ def tasks_from_queue(self, tiger, queue, state, skip=0, limit=1000,
else:
data = tiger.connection.mget([tiger._key('task', item[0]) for item in items])
for serialized_data, ts in zip(data, tss):
data = json.loads(serialized_data)
data = tiger._deserialize(serialized_data)
task = Task(tiger, queue=queue, _data=data, _state=state,
_ts=ts)
tasks.append(task)
Expand Down
29 changes: 29 additions & 0 deletions tasktiger/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import json
import datetime
import decimal

from .task import Task
from .worker import Worker

Expand Down Expand Up @@ -30,3 +34,28 @@ def run_worker(self, tiger, raise_on_errors=True, **kwargs):
has_errors = True
if has_errors and raise_on_errors:
raise Exception('One or more tasks have failed.')


class CustomJSONEncoder(json.JSONEncoder):
"""
A JSON encoder that allows for more common Python data types.

In addition to the defaults handled by ``json``, this also supports:

* ``datetime.datetime``
* ``datetime.date``
* ``datetime.time``
* ``decimal.Decimal``

"""
def default(self, data):
if isinstance(data, (datetime.datetime, datetime.date, datetime.time)):
return data.isoformat()
elif isinstance(data, decimal.Decimal):
return str(data)
else:
return super(CustomJSONEncoder, self).default(data)


def custom_serializer(obj):
return json.dumps(obj, cls=CustomJSONEncoder)
7 changes: 3 additions & 4 deletions tasktiger/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import OrderedDict
import errno
import fcntl
import json
import os
import random
import select
Expand Down Expand Up @@ -327,7 +326,7 @@ def _execute_forked(self, tasks, log):
''.join(traceback.format_exception(*exc_info))
execution['success'] = success
execution['host'] = socket.gethostname()
serialized_execution = json.dumps(execution)
serialized_execution = self.tiger._serialize(execution)
for task in tasks:
self.connection.rpush(self._key('task', task.id, 'executions'),
serialized_execution)
Expand Down Expand Up @@ -544,7 +543,7 @@ def _process_queue_tasks(self, queue, queue_lock, task_ids, now, log):
tasks = []
for task_id, serialized_task in zip(task_ids, serialized_tasks):
if serialized_task:
task_data = json.loads(serialized_task)
task_data = self.tiger._deserialize(serialized_task)
else:
# In the rare case where we don't find the task which is
# queued (see ReliabilityTestCase.test_task_disappears),
Expand Down Expand Up @@ -739,7 +738,7 @@ def _mark_done():
self._key('task', task.id, 'executions'), -1)

if execution:
execution = json.loads(execution)
execution = self.tiger._deserialize(execution)

if execution and execution.get('retry'):
if 'retry_method' in execution:
Expand Down
36 changes: 35 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import time
from multiprocessing import Pool, Process

from decimal import Decimal

from tasktiger import (JobTimeoutException, StopRetry, Task, TaskNotFound,
Worker, exponential, fixed, linear)
from tasktiger._internal import serialize_func_name
from tasktiger.test_helpers import custom_serializer

from .config import DELAY
from .tasks import (batch_task, decorated_task, decorated_task_simple_func,
Expand All @@ -23,8 +26,10 @@


class BaseTestCase:
CONFIG = {}

def setup_method(self, method):
self.tiger = get_tiger()
self.tiger = get_tiger(**self.CONFIG)
self.conn = self.tiger.connection
self.conn.flushdb()

Expand Down Expand Up @@ -1012,3 +1017,32 @@ def test_single_worker_queue(self):
self._ensure_queues()

worker.join()


class TestCustomSerializer(BaseTestCase):

CONFIG = {
'SERIALIZER': custom_serializer
}

def test_task(self):
tmpfile = tempfile.NamedTemporaryFile()
task_args = (tmpfile.name, 'test', 5)
task_kwargs = dict(a=datetime.datetime.now(),
b=Decimal("5.05"))

self.tiger.delay(file_args_task, args=task_args, kwargs=task_kwargs)
queues = self._ensure_queues(queued={'default': 1})
task = queues['queued']['default'][0]
assert task['func'] == 'tests.tasks:file_args_task'

Worker(self.tiger).run(once=True)
self._ensure_queues(queued={'default': 0})
json_data = tmpfile.read().decode('utf8')
assert json.loads(json_data) == {
'args': ['test', 5],
'kwargs': {
'a': task_kwargs['a'].isoformat(),
'b': str(task_kwargs['b'])
}
}
9 changes: 6 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __exit__(self, *args):
setattr(self.orig_obj, self.func_name, self.orig_func)


def get_tiger():
def get_tiger(**kwargs):
"""
Sets up logging and returns a new tasktiger instance.
"""
Expand All @@ -38,7 +38,7 @@ def get_tiger():
)
logging.basicConfig(format='%(message)s')
conn = redis.Redis(db=TEST_DB, decode_responses=True)
tiger = TaskTiger(connection=conn, config={
config = {
# We need this 0 here so we don't pick up scheduled tasks when
# doing a single worker run.
'SELECT_TIMEOUT': 0,
Expand All @@ -56,7 +56,10 @@ def get_tiger():
},

'SINGLE_WORKER_QUEUES': ['swq'],
})
}

config.update(kwargs)
tiger = TaskTiger(connection=conn, config=config)
tiger.log.setLevel(logging.CRITICAL)
return tiger

Expand Down