/
base.py
99 lines (75 loc) · 2.78 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import datetime
import json
import logging
import sys
import time
import unittest
from huey import RedisHuey
from huey.consumer import Consumer
from huey.registry import registry
def b(s):
if sys.version_info[0] == 3:
return s.encode('utf-8')
return s
test_huey = RedisHuey('testing', blocking=False, read_timeout=0.1)
# Logger used by the consumer.
logger = logging.getLogger('huey.consumer')
# Create a log handler that will track messages generated by the consumer.
class CaptureLogs(logging.Handler):
def __init__(self, *args, **kwargs):
self.messages = []
logging.Handler.__init__(self, *args, **kwargs)
def emit(self, record):
self.messages.append(record.getMessage())
def __enter__(self):
logger.addHandler(self)
logger.setLevel(logging.INFO)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logger.removeHandler(self)
class BaseTestCase(unittest.TestCase):
pass
class HueyTestCase(BaseTestCase):
def setUp(self):
self.huey = test_huey
self.consumer = self.get_consumer(workers=2, scheduler_interval=10)
pubsub = self.huey.events.listener()
self.events = pubsub.listen()
next(self.events) # Consume the "subscribe" event.
self._periodic_tasks = registry._periodic_tasks
registry._periodic_tasks = self.get_periodic_tasks()
self._sleep = time.sleep
time.sleep = lambda x: None
def tearDown(self):
if self.consumer is not None:
self.consumer.stop()
self.events.close()
self.huey.flush()
registry._periodic_tasks = self._periodic_tasks
time.sleep = self._sleep
def get_consumer(self, **kwargs):
return Consumer(self.huey, **kwargs)
def get_periodic_tasks(self):
return []
def assertTaskEvents(self, *states):
for (status, task) in states:
raw_event = next(self.events)
event_data = json.loads(raw_event['data'].decode('utf-8'))
self.assertEqual(event_data['status'], status)
self.assertEqual(event_data['id'], task.task_id)
def assertLogs(self, capture, expected):
self.assertEqual(len(capture.messages), len(expected))
for (log, msg) in zip(capture.messages, expected):
self.assertTrue(log.startswith(msg))
def worker(self, task, ts=None):
worker = self.consumer._create_worker()
ts = ts or datetime.datetime.utcnow()
worker.handle_task(task, ts)
return worker
def scheduler(self, ts=None, periodic=False):
scheduler = self.consumer._create_scheduler()
ts = ts or datetime.datetime.utcnow()
if periodic:
scheduler._counter = scheduler._q
scheduler.loop(ts)
return scheduler