Skip to content

Commit

Permalink
Merge pull request #37 from henriksod/feature/ProcessSubscriptionCons…
Browse files Browse the repository at this point in the history
…umer

Implemented ProcessSubscriptionConsumer
  • Loading branch information
dkumor committed Jul 24, 2022
2 parents ba45b7a + 540f16e commit 156fba2
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 3 deletions.
2 changes: 1 addition & 1 deletion rtcbot/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
NoClosedSubscription,
)
from .thread import ThreadedSubscriptionConsumer, ThreadedSubscriptionProducer
from .multiprocess import ProcessSubscriptionProducer
from .multiprocess import ProcessSubscriptionProducer, ProcessSubscriptionConsumer
161 changes: 159 additions & 2 deletions rtcbot/base/multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import multiprocessing
import threading
import queue
import concurrent.futures

from .base import BaseSubscriptionProducer, SubscriptionClosed
from rtcbot.base import BaseSubscriptionProducer, BaseSubscriptionConsumer, SubscriptionClosed


class internalSubscriptionMessage:
Expand All @@ -30,7 +31,7 @@ def __init__(
"ProcessSubscriptionProducer"
)
else:
self.__splog = logger.getChild("ProcessSubscriptionConsumer")
self.__splog = logger.getChild("ProcessSubscriptionProducer")

self.__closeEvent = multiprocessing.Event()

Expand Down Expand Up @@ -150,3 +151,159 @@ def close(self):
self.__splog.debug("Process did not terminate in time. Killing it.")
self._producerProcess.terminate()
self._producerProcess.join()


class ProcessSubscriptionConsumer(BaseSubscriptionConsumer):
def __init__(
self,
directPutSubscriptionType=asyncio.Queue,
logger=None,
loop=None,
daemonProcess=True,
joinTimeout=1,
):
self._joinTimeout = joinTimeout
if logger is None:
self.__splog = logging.getLogger(self.__class__.__name__).getChild(
"ProcessSubscriptionConsumer"
)
else:
self.__splog = logger.getChild("ProcessSubscriptionConsumer")

self.__closeEvent = multiprocessing.Event()

self._taskLock = multiprocessing.Lock()
self._getEvent = multiprocessing.Event()

super().__init__(directPutSubscriptionType, logger=logger)

self._loop = loop
if self._loop is None:
self._loop = asyncio.get_event_loop()

self._consumerQueue = multiprocessing.Queue()

self.__queueReaderThread = threading.Thread(target=self.__queueReader)
self.__queueReaderThread.daemon = True
self.__queueReaderThread.start()

self._consumerProcess = multiprocessing.Process(target=self.__consumerSetup)
self._consumerProcess.daemon = daemonProcess
self._consumerProcess.start()

@property
def _shouldClose(self):
# We need to check the event
return self.__closeEvent.is_set()

@_shouldClose.setter
def _shouldClose(self, value):
self.__splog.debug("Setting _shouldClose to %s", value)
if value:
self.__closeEvent.set()
else:
self.__closeEvent.clear()

def _setReady(self, value):
self._loop.call_soon_threadsafe(super()._setReady, value)

def _setError(self, err):
self._loop.call_soon_threadsafe(super()._setError, err)

def _close(self):
self._loop.call_soon_threadsafe(super()._close)

def __queueReader(self):
while not self._shouldClose:
if self._getEvent.is_set():
timedout = False
while not self._shouldClose:
with self._taskLock:
# Only create a new task if it was finished, and did not time out
if not timedout:
self._getTask = asyncio.run_coroutine_threadsafe(
self._subscription.get(), self._loop
)
timedout = False
try:
self._consumerQueue.put(self._getTask.result(self._joinTimeout))
self._getEvent.clear()
break
except (asyncio.CancelledError, concurrent.futures.CancelledError):
self.__splog.debug("Subscription cancelled - checking for new tasks")
except (asyncio.TimeoutError, concurrent.futures.TimeoutError):
self.__splog.debug(f"No incoming data for {self._joinTimeout} seconds...")
timedout = True
except SubscriptionClosed:
self.__splog.debug(
"Incoming stream closed... Checking for new subscription"
)

def _get(self):
"""
This is not a coroutine - it is to be called in the worker thread.
If the worker thread is to be shut down, raises a SubscriptionClosed exception.
"""
self._getEvent.set()
while not self._shouldClose:
try:
return self._consumerQueue.get(timeout=self._joinTimeout)
except queue.Empty:
pass # No need to notify each time we check whether we chould close

self.__splog.debug(
"close() was called on the aio thread. raising SubscriptionClosed."
)
raise SubscriptionClosed("ProcessSubscriptionConsumer has been closed")

def putSubscription(self, subscription):
with self._taskLock:
super().putSubscription(subscription)

def _consumer(self):
"""
This is the function run in another thread. You override the function with your own logic.
The base implementation is used for testing
"""

# We are ready!
self._setReady(True)
# Have to think how to make this work
# in testing

def __consumerSetup(self):
# This function sets up the consumer. In particular, it receives KeyboardInterrupts

def handleInterrupt(sig, frame):
self.__splog.debug("Received KeyboardInterrupt - not notifying process")

old_handler = signal.signal(signal.SIGINT, handleInterrupt)
try:
self._consumer()
except:
self.__splog.exception("The remote process had an exception!")
# self._setReady(False)
self._shouldClose = True

signal.signal(signal.SIGINT, old_handler)

self.__splog.debug("Exiting remote process")

def close(self):
"""
Shuts down data gathering, and closes all subscriptions. Note that it is not recommended
to call this in an async function, since it waits until the background thread joins.
The object is meant to be used as a singleton, which is initialized at the start of your code,
and is closed when shutting down.
"""
with self._taskLock:
super().close()
self._consumerProcess.join(self._joinTimeout)
self.__queueReaderThread.join()
if self._consumerProcess.is_alive():
self.__splog.debug("Process did not terminate in time. Killing it.")
self._consumerProcess.terminate()
self._consumerProcess.join()

0 comments on commit 156fba2

Please sign in to comment.