Skip to content

Commit

Permalink
Add ServerComms abstract base class
Browse files Browse the repository at this point in the history
  • Loading branch information
c-mita committed Jun 3, 2016
1 parent 91d591f commit 1913212
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
68 changes: 68 additions & 0 deletions malcolm/core/servercomms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from malcolm.core.loggable import Loggable

# Sentinel object to stop the send loop
SERVER_STOP = object()


class ServerComms(Loggable):
"""Abstract class for dispatching requests to a process and responses to a
client"""

def __init__(self, name, process):
super(ServerComms, self).__init__(logger_name=name)
self.process = process
self.q = self.process.create_queue()
self._send_spawned = None

def send_loop(self):
"""Service self.q, sending responses to client"""
while True:
response = self.q.get()
if response is SERVER_STOP:
break
try:
self.send_to_client(response)
except:
self.log_exception(
"Exception sending response %s", response.to_dict())

def send_to_client(response):
"""Abstract method to dispatch response to a client
Args:
response (Response): The message to pass to the client
"""
raise NotImplementedError(
"Abstract method that must be implemented by deriving class")

def send_to_process(self, request):
"""Send request to process"""
self.process.q.put(request)

def start(self):
"""Start communications"""
self._send_spawned = self.process.spawn(self.send_loop)
self.start_recv_loop()

def start_recv_loop(self):
"""Abstract method to start a recieve loop to dispatch requests to
Process"""
raise NotImplementedError(
"Abstract method that must be implemented by deriving class")

def stop(self, timeout=None):
"""Request all communications be stopped and wait for finish
Args:
timeout (float): Time in seconds to wait for comms to stop.
None means wait forever.
"""
self.q.put(SERVER_STOP)
self._send_spawned.wait(timeout=timeout)
self.stop_recv_loop()

def stop_recv_loop(self):
"""Abstract method to stop the receive loop created by
start_recv_loop"""
raise NotImplementedError(
"Abstract method that must be implemented by deriving class")
87 changes: 87 additions & 0 deletions tests/test_core/test_servercomms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from pkg_resources import require
require("mock")
from mock import Mock

from malcolm.core.servercomms import ServerComms, SERVER_STOP
from malcolm.core.syncfactory import SyncFactory

class TestServerComms(unittest.TestCase):

def setUp(self):
self.process = Mock()

def test_init(self):
server = ServerComms("server", self.process)
self.process.create_queue.assert_called_once_with()
self.assertEqual(
server.q, self.process.create_queue.return_value)

def test_not_implemented_error(self):
server = ServerComms("server", self.process)
self.assertRaises(NotImplementedError, server.send_to_client)
self.assertRaises(NotImplementedError, server.start_recv_loop)
self.assertRaises(NotImplementedError, server.stop_recv_loop)

def test_loop_starts(self):
self.process.spawn = lambda x: x()
server = ServerComms("server", self.process)
server.send_loop = Mock()
server.start_recv_loop = Mock()
server.start()
server.send_loop.assert_called_once_with()
server.start_recv_loop.assert_called_once_with()

def test_loop_stops(self):
self.process.spawn = lambda x: x()
self.process.create_queue = Mock(
return_value=Mock(get=Mock(return_value=SERVER_STOP)))
server = ServerComms("server", self.process)
server.start_recv_loop = Mock()
server.stop_recv_loop = Mock()
server.send_loop = Mock(side_effect = server.send_loop)
server.start()
server.send_loop.assert_called_once_with()

def test_start_stop(self):
self.process.sync_factory = SyncFactory("s")
self.process.spawn = self.process.sync_factory.spawn
self.process.create_queue = self.process.sync_factory.create_queue
server = ServerComms("server", self.process)
server.send_loop = Mock(side_effect = server.send_loop)
server.start_recv_loop = Mock()
server.stop_recv_loop = Mock()
server.start()
self.assertFalse(server._send_spawned.ready())
server.start_recv_loop.assert_called_once_with()
server.stop(0.1)
self.assertTrue(server._send_spawned.ready())
server.send_loop.assert_called_once_with()
server.stop_recv_loop.assert_called_once_with()

def test_send_to_client(self):
request = Mock()
dummy_queue = Mock()
dummy_queue.get = Mock(side_effect = [request, SERVER_STOP])
self.process.create_queue = Mock(return_value = dummy_queue)
self.process.spawn = Mock(side_effect = lambda x: x())
server = ServerComms("server", self.process)
server.send_to_client = Mock(
side_effect = server.send_to_client)
server.start_recv_loop = Mock()
server.start()
server.send_to_client.assert_called_once_with(request)

def test_send_to_process(self):
self.process.q = Mock()
server = ServerComms("server", self.process)
request = Mock()
server.send_to_process(request)
self.process.q.put.assert_called_once_with(request)

if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 1913212

Please sign in to comment.