Skip to content

Commit

Permalink
thread executor
Browse files Browse the repository at this point in the history
  • Loading branch information
oklahomer committed Jul 12, 2015
1 parent f77f384 commit c590e03
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 65 deletions.
49 changes: 6 additions & 43 deletions sarah/bot_base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# -*- coding: utf-8 -*-
import abc
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, Future
from functools import wraps
import imp
import importlib
import logging
from queue import Queue, Empty
import threading
from apscheduler.schedulers.background import BackgroundScheduler
import sys
from typing import Callable, List, Tuple, Union, Optional
from sarah.thread import ThreadExecuter


class Command(object):
Expand Down Expand Up @@ -45,13 +45,13 @@ def __init__(self,
self.plugins = plugins
self.worker = ThreadPoolExecutor(max_workers=max_workers) \
if max_workers else None
self.message_worker = ThreadExecuter()

# Reset to ease tests in one file
self.__commands[self.__class__.__name__] = OrderedDict()
self.__schedules[self.__class__.__name__] = OrderedDict()

self.stop_event = threading.Event()
self.sending_queue = None
self.scheduler = BackgroundScheduler()
self.load_plugins(self.plugins)
self.add_schedule_jobs(self.schedules)
Expand All @@ -66,6 +66,7 @@ def add_schedule_job(self, command: Command) -> None:

def stop(self) -> None:
self.stop_event.set()
self.message_worker.shutdown(wait=False)
if self.worker:
self.worker.shutdown(wait=False)

Expand All @@ -83,46 +84,8 @@ def wrapper(self, *args, **kwargs):

return wrapper

def supervise_enqueued_message(self) -> None:
""" Supervise the message queue, and send queued messages to chat room
Send messages queued via concurrent_sending_message(). One message is
sent at a time, so it is suitable to ensure thread-safety and avoid
sending multiple message in concurrent jobs.
"""
self.sending_queue = Queue()

def _supervise(stop_event: threading.Event) -> None:
while not stop_event.is_set():
try:
func = self.sending_queue.get()
except Empty:
# Queue is empty.
# sending_queue.empty() doesn't guarantee its emptiness,
# so just call get() and see if exception is raised.
continue
except Exception as e:
logging.error('Error on getting task from queue. %s', e)
continue

try:
func()
except Exception as e:
logging.error('Error on sending response. %s', e)
return

t = threading.Thread(target=_supervise, args=(self.stop_event,))
t.setDaemon(True)
t.start()

def enqueue_sending_message(self, function, *args, **kwargs):
if self.sending_queue is None:
msg = ("To utilize this method, supervise_enqueued_message() must "
"be called in run()")
logging.error(msg)
raise SarahException(msg)
else:
self.sending_queue.put_nowait(lambda: function(*args, **kwargs))
def enqueue_sending_message(self, function, *args, **kwargs) -> Future:
return self.message_worker.submit(function, *args, **kwargs)

def load_plugins(self, plugins: Union[List, Tuple]) -> None:
for module_config in plugins:
Expand Down
10 changes: 5 additions & 5 deletions sarah/hipchat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from concurrent.futures import Future
import logging
import re
from sleekxmpp import ClientXMPP, Message
Expand Down Expand Up @@ -54,7 +55,6 @@ def job_function():
def run(self) -> None:
if not self.client.connect():
raise SarahHipChatException('Couldn\'t connect to server.')
self.supervise_enqueued_message()
self.scheduler.start()
self.client.process(block=True)

Expand Down Expand Up @@ -116,7 +116,7 @@ def join_rooms(self, event: Dict) -> None:
wait=True)

@concurrent
def message(self, msg: Message) -> None:
def message(self, msg: Message) -> Optional[Future]:
if msg['delay']['stamp']:
# Avoid answering to all past messages when joining the room.
# xep_0203 plugin required.
Expand Down Expand Up @@ -148,14 +148,14 @@ def message(self, msg: Message) -> None:
'text': text,
'from': msg['from']})
except Exception as e:
self.enqueue_sending_message(lambda: msg.reply(
'Something went wrong with "%s"' % msg['body']).send())
logging.error('Error occurred. '
'command: %s. input: %s. error: %s.' % (
command.name, msg['body'], e
))
return self.enqueue_sending_message(lambda: msg.reply(
'Something went wrong with "%s"' % msg['body']).send())
else:
self.enqueue_sending_message(lambda: msg.reply(ret).send())
return self.enqueue_sending_message(lambda: msg.reply(ret).send())

def stop(self) -> None:
super().stop()
Expand Down
11 changes: 5 additions & 6 deletions sarah/slack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# https://api.slack.com/rtm
from concurrent.futures import Future

import logging
from typing import Optional, Union, List, Tuple, Dict
Expand All @@ -26,8 +27,6 @@ def setup_client(self, token: str) -> None:
self.client = SlackClient(token=token)

def run(self) -> None:
self.supervise_enqueued_message()

response = self.client.get('rtm.start')
self.ws = WebSocketApp(response['url'],
on_message=self.message,
Expand Down Expand Up @@ -94,7 +93,7 @@ def message(self, ws: WebSocketApp, event: str) -> None:
def handle_hello(self, content: Dict) -> None:
logging.info('Successfully connected to the server.')

def handle_message(self, content: Dict) -> None:
def handle_message(self, content: Dict) -> Optional[Future]:
required_props = ('type', 'channel', 'user', 'text', 'ts')
missing_props = [p for p in required_props if p not in content]

Expand All @@ -107,9 +106,9 @@ def handle_message(self, content: Dict) -> None:
# TODO Check command and return results
# Just returning the exact same text for now.
# self.send_message(content['channel'], content['text'])
self.enqueue_sending_message(self.send_message,
content['channel'],
content['text'])
return self.enqueue_sending_message(self.send_message,
content['channel'],
content['text'])

def on_error(self, ws: WebSocketApp, error) -> None:
logging.error(error)
Expand Down
95 changes: 95 additions & 0 deletions sarah/thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
from concurrent.futures.thread import _WorkItem as WorkItem
from concurrent.futures import Executor, Future
import logging
from queue import Queue
import threading
import weakref
import atexit

# Provide the same interface as ThreadPoolExecutor, but create only on thread.
# Worker is created as daemon thread. This is done to allow the interpreter to
# exit when there is still idle thread in ThreadExecutor (i.e. shutdown() was
# not called). However, allowing worker to die with the interpreter has two
# undesirable properties:
# - The worker would still be running during interpretor shutdown,
# meaning that they would fail in unpredictable ways.
# - The worker could be killed while evaluating a work item, which could
# be bad if the callable being evaluated has external side-effects e.g.
# writing to a file.
#
# To work around this problem, an exit handler is installed which tells the
# worker to exit when its work queue is empty and then waits until the thread
# finish.

_shutdown = False


def _python_exit():
global _shutdown
_shutdown = True


atexit.register(_python_exit)


def _worker(executor_reference, work_queue):
try:
while True:
work_item = work_queue.get(block=True)
if work_item is not None:
work_item.run()
continue
executor = executor_reference()
# Exit if:
# - The interpreter is shutting down OR
# - The executor that owns the worker has been collected OR
# - The executor that owns the worker has been shutdown.
if _shutdown or executor is None or executor._shutdown:
# Notice other workers
work_queue.put(None)
return
del executor
except BaseException:
logging.critical('Exception in worker', exc_info=True)


class ThreadExecuter(Executor):
def __init__(self):
""" Initialize a new ThreadExecutor instance. """
self._work_queue = Queue()
self._shutdown = False
self._shutdown_lock = threading.Lock()

def weakref_cb(_, q=self._work_queue):
q.put(None)

t = threading.Thread(target=_worker,
args=(weakref.ref(self, weakref_cb),
self._work_queue))
t.daemon = True
t.start()
self._thread = t

def submit(self, fn, *args, **kwargs):
with self._shutdown_lock:
if self._shutdown:
raise RuntimeError(
'cannot schedule new futures after shutdown')

f = Future()
w = WorkItem(f, fn, args, kwargs)

self._work_queue.put(w)
return f

submit.__doc__ = Executor.submit.__doc__

def shutdown(self, wait=True):
with self._shutdown_lock:
self._shutdown = True
self._work_queue.put(None)
if wait:
self._thread.join()

shutdown.__doc__ = Executor.shutdown.__doc__
27 changes: 16 additions & 11 deletions tests/test_hipchat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
import concurrent
from concurrent.futures import ALL_COMPLETED
from time import sleep
import pytest
import logging
Expand Down Expand Up @@ -153,21 +155,29 @@ def hipchat(self, request):
('sarah.plugins.echo',)),
max_workers=4)
h.client.connect = lambda: True
h.client.process = lambda: True
h.client.process = lambda *args, **kwargs: True
request.addfinalizer(h.stop)

t = Thread(target=h.run)
t.start()

return h

def wait_future_finish(self, future):
sleep(.5) # Why whould I need this line?? Check later.

ret = concurrent.futures.wait([future], 5, return_when=ALL_COMPLETED)
if len(ret.not_done) > 0:
logging.error("Jobs are not finished.")
assert future in ret.done

def test_skip_message(self, hipchat):
msg = Message(hipchat.client, stype='normal')
msg['body'] = 'test body'

msg.reply = MagicMock()

hipchat.message(msg)
self.wait_future_finish(hipchat.message(msg))
assert msg.reply.call_count == 0

def test_echo_message(self, hipchat):
Expand All @@ -176,9 +186,7 @@ def test_echo_message(self, hipchat):

msg.reply = MagicMock()

hipchat.message(msg)

sleep(.1)
self.wait_future_finish(hipchat.message(msg))
assert msg.reply.call_count == 1
assert msg.reply.call_args == call('spam')

Expand All @@ -190,19 +198,16 @@ def test_count_message(self, hipchat):

msg.reply = MagicMock()

hipchat.message(msg)
sleep(.1)
self.wait_future_finish(hipchat.message(msg))
assert msg.reply.call_count == 1
assert msg.reply.call_args == call('1')

hipchat.message(msg)
sleep(.1)
self.wait_future_finish(hipchat.message(msg))
assert msg.reply.call_count == 2
assert msg.reply.call_args == call('2')

msg['body'] = '.count egg'
hipchat.message(msg)
sleep(.1)
self.wait_future_finish(hipchat.message(msg))
assert msg.reply.call_count == 3
assert msg.reply.call_args == call('1')

Expand Down

0 comments on commit c590e03

Please sign in to comment.