Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from collections import OrderedDict | ||
|
||
from malcolm.core.loggable import Loggable | ||
|
||
# Sentinel object to stop the send loop | ||
CLIENT_STOP = object() | ||
|
||
|
||
class ClientComms(Loggable): | ||
"""Abstract class for dispatching requests to a server and resonses to | ||
a method""" | ||
|
||
def __init__(self, name, process): | ||
super(ClientComms, self).__init__(logger_name=name) | ||
self.process = process | ||
self.q = self.process.create_queue() | ||
self._send_spawned = None | ||
self._current_id = 1 | ||
self.requests = OrderedDict() | ||
|
||
def send_loop(self): | ||
"""Service self.q, sending requests to server""" | ||
while True: | ||
request = self.q.get() | ||
if request is CLIENT_STOP: | ||
break | ||
try: | ||
request.id_ = self._current_id | ||
self._current_id += 1 | ||
|
||
# TODO: Move request store into new method? | ||
self.requests[request.id_] = request | ||
self.send_to_server(request) | ||
except: | ||
self.log_exception( | ||
"Exception sending request %s", request.to_dict()) | ||
|
||
def send_to_server(self, request): | ||
"""Abstract method to dispatch request to a server | ||
Args: | ||
request (Request): The message to pass to the server | ||
""" | ||
raise NotImplementedError( | ||
"Abstract method that must be implemented by deriving class") | ||
|
||
def start(self): | ||
"""Start communications""" | ||
self._send_spawned = self.process.spawn(self.send_loop) | ||
self.start_recv_loop() | ||
|
||
def stop(self, timeout=None): | ||
"""Request all communications be stopped and wait for finish | ||
Args: | ||
timeout (float): Time in seconds to wait for comms to stop. | ||
None means wait forever. | ||
""" | ||
self.q.put(CLIENT_STOP) | ||
self._send_spawned.wait(timeout=timeout) | ||
self.stop_recv_loop() | ||
|
||
def start_recv_loop(self): | ||
"""Abstract method to start a receive loop to dispatch responses to a | ||
a Method""" | ||
raise NotImplementedError( | ||
"Abstract method that must be implemented by deriving class") | ||
|
||
def stop_recv_loop(self): | ||
"""Abstract method to stop the receive loop created by | ||
start_recv_loop""" | ||
raise NotImplementedError( | ||
"Abstract method that must be implemented by deriving class") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import unittest | ||
import sys | ||
import os | ||
from collections import OrderedDict | ||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) | ||
|
||
from pkg_resources import require | ||
require("mock") | ||
from mock import Mock | ||
|
||
from malcolm.core.clientcomms import ClientComms, CLIENT_STOP | ||
from malcolm.core.syncfactory import SyncFactory | ||
|
||
class TestClientComms(unittest.TestCase): | ||
def test_init(self): | ||
process = Mock() | ||
client = ClientComms("c", process) | ||
process.create_queue.assert_called_once_with() | ||
self.assertEqual(client.q, process.create_queue.return_value) | ||
|
||
def test_not_implemented_error(self): | ||
client = ClientComms("c", Mock()) | ||
self.assertRaises(NotImplementedError, client.send_to_server, Mock()) | ||
self.assertRaises(NotImplementedError, client.start_recv_loop) | ||
self.assertRaises(NotImplementedError, client.stop_recv_loop) | ||
|
||
def test_send_logs_error(self): | ||
client = ClientComms("c", Mock()) | ||
client.send_to_server = Mock(side_effect=Exception) | ||
request = Mock() | ||
request.to_dict = Mock(return_value = "<to_dict>") | ||
client.q.get = Mock(side_effect = [request, CLIENT_STOP]) | ||
client.log_exception = Mock() | ||
client.send_loop() | ||
client.log_exception.assert_called_once_with( | ||
"Exception sending request %s", "<to_dict>") | ||
|
||
def test_requests_are_stored(self): | ||
client = ClientComms("c", Mock()) | ||
client._current_id = 1234 | ||
request = Mock() | ||
client.send_to_server = Mock() | ||
client.q.get = Mock(side_effect = [request, CLIENT_STOP]) | ||
client.send_loop() | ||
expected = OrderedDict({1234 : request}) | ||
self.assertEquals(expected, client.requests) | ||
|
||
def test_loop_starts(self): | ||
process = Mock(spawn = lambda x: x()) | ||
client = ClientComms("c", process) | ||
client.send_loop = Mock() | ||
client.start_recv_loop = Mock() | ||
client.log_exception = Mock() | ||
client.start() | ||
client.send_loop.assert_called_once_with() | ||
client.start_recv_loop.assert_called_once_with() | ||
client.log_exception.assert_not_called() | ||
|
||
def test_sends_to_server(self): | ||
client = ClientComms("c", Mock()) | ||
client.send_to_server = Mock() | ||
request = Mock() | ||
client.q.get = Mock(side_effect = [request, CLIENT_STOP]) | ||
client.log_exception = Mock() | ||
client.send_loop() | ||
client.send_to_server.assert_called_once_with(request) | ||
client.log_exception.assert_not_called() | ||
|
||
def test_start_stop(self): | ||
sync_factory = SyncFactory("s") | ||
process = Mock() | ||
process.spawn = sync_factory.spawn | ||
process.create_queue = sync_factory.create_queue | ||
client = ClientComms("c", process) | ||
client.send_loop = Mock(side_effect = client.send_loop) | ||
client.start_recv_loop = Mock() | ||
client.stop_recv_loop = Mock() | ||
client.log_exception = Mock() | ||
client.start() | ||
self.assertFalse(client._send_spawned.ready()) | ||
client.start_recv_loop.assert_called_once_with() | ||
client.stop(0.1) | ||
self.assertTrue(client._send_spawned.ready()) | ||
client.send_loop.assert_called_once_with() | ||
client.stop_recv_loop.assert_called_once_with() | ||
client.log_exception.assert_not_called() | ||
|
||
def test_request_id_provided(self): | ||
client = ClientComms("c", Mock()) | ||
client._current_id = 1234 | ||
client.send_to_server = Mock() | ||
request_1 = Mock(id_ = None) | ||
request_2 = Mock(id_ = None) | ||
client.q.get = Mock(side_effect = [request_1, request_2, CLIENT_STOP]) | ||
client.send_loop() | ||
self.assertEqual(1234, request_1.id_) | ||
self.assertEqual(1235, request_2.id_) | ||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=2) |