Skip to content

Commit

Permalink
Add fast fetch_many to Task
Browse files Browse the repository at this point in the history
  • Loading branch information
Chronial committed Nov 6, 2018
1 parent 04b0947 commit 77f1a88
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
2 changes: 1 addition & 1 deletion redis_tasks/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_task_ids(self, offset=0, length=-1):
reversed(connection.lrange(self.key, start, end))]

def get_tasks(self, offset=0, length=-1):
return [Task.fetch(x) for x in self.get_task_ids(offset, length)]
return Task.fetch_many(self.get_task_ids(offset, length))

@atomic_pipeline
def enqueue_call(self, *args, pipeline, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion redis_tasks/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def get_task_ids(self, offset=0, length=-1):
return decode_list(connection.zrange(self.key, offset, end))

def get_tasks(self, offset=0, length=-1):
return [redis_tasks.task.Task.fetch(x) for x in self.get_task_ids(offset, length)]
return redis_tasks.task.Task.fetch_many(
self.get_task_ids(offset, length))

def empty(self): # TODO: test
def transaction(pipeline):
Expand Down
25 changes: 19 additions & 6 deletions redis_tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ def __repr__(self):


class Task:
def __init__(self, func=None, args=None, kwargs=None, *, fetch_id=None):
def __init__(self, func=None, args=None, kwargs=None, *,
fetch_id=None, fetch_data=None):
if fetch_id:
self.id = fetch_id
self.refresh()
self.refresh(data=fetch_data)
return

self.id = str(uuid.uuid4())
Expand Down Expand Up @@ -255,11 +256,23 @@ def delete_many(cls, task_ids, *, pipeline):
if task_ids:
pipeline.delete(*(cls.key_for(task_id) for task_id in task_ids))

def refresh(self):
key = self.key
obj = {k.decode(): v for k, v in connection.hgetall(key).items()}
@classmethod
def fetch_many(cls, task_ids):
with connection.pipeline(transaction=False) as pipeline:
for task_id in task_ids:
pipeline.hgetall(cls.key_for(task_id))
results = pipeline.execute()
tasks = []
for task_id, data in zip(task_ids, results):
tasks.append(cls(fetch_id=task_id, fetch_data=data))
return tasks

def refresh(self, data=None):
if not data:
data = connection.hgetall(self.key)
obj = {k.decode(): v for k, v in data.items()}
if len(obj) == 0:
raise TaskDoesNotExist('No such task: {0}'.format(key))
raise TaskDoesNotExist('No such task: {0}'.format(self.key))

self.func_name = obj['func_name'].decode()
self.args = deserialize(obj['args'])
Expand Down
14 changes: 14 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,17 @@ def test_init_save_fetch_delete(connection, assert_atomic, stub):
assert not connection.exists(t.key)
with pytest.raises(TaskDoesNotExist):
Task.fetch(t.id)


def test_fetch_many(stub):
tasks_data = [{"args": []}, {"args": ["foo"]}, {"args": ["bar"]},
{"args": ["foo", "bar"]}, {"args": []}]
tasks = [Task(stub, **d) for d in tasks_data]
for t in tasks:
t._save()

fetched = Task.fetch_many([t.id for t in tasks[1:4]])
assert len(fetched) == 3
for i, task in enumerate(fetched):
assert task.id == tasks[i+1].id
assert task.args == tasks_data[i+1]["args"]

0 comments on commit 77f1a88

Please sign in to comment.