From 37a648e86c0ac2cbca96d99b5269c00e70d2bd41 Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Mon, 11 Jul 2016 15:31:29 +0100 Subject: [PATCH] Make all controllers take process --- malcolm/controllers/clientcontroller.py | 8 ++--- .../controllers/scanpointtickercontroller.py | 3 -- malcolm/core/controller.py | 6 +++- .../test_controllers/test_clientcontroller.py | 2 +- .../test_countercontroller.py | 12 +++---- .../test_controllers/test_hellocontroller.py | 4 +-- .../test_scanpointtickercontroller.py | 6 ++-- tests/test_core/test_controller.py | 8 ++--- tests/test_core/test_system_core.py | 8 ++--- tests/test_wscomms/test_system_wscomms.py | 35 +++++++++++-------- 10 files changed, 48 insertions(+), 44 deletions(-) diff --git a/malcolm/controllers/clientcontroller.py b/malcolm/controllers/clientcontroller.py index 54befef4d..0730324f0 100644 --- a/malcolm/controllers/clientcontroller.py +++ b/malcolm/controllers/clientcontroller.py @@ -17,10 +17,9 @@ def __init__(self, process, block): process (Process): The process this should run under block (Block): The local block we should be controlling """ - super(ClientController, self).__init__(block=block) - self.process = process + super(ClientController, self).__init__(block=block, process=process) request = Request.Subscribe( - None, self, [process.name, "remoteBlocks", "value"]) + None, self, [self.process.name, "remoteBlocks", "value"]) request.set_id(self.REMOTE_BLOCKS_ID) self.process.q.put(request) @@ -75,8 +74,7 @@ def call_server_method(self, method_name, parameters, returns): request = Request.Post(None, q, [self.block.name, method_name], parameters) self.client_comms.q.put(request) - with self.block.lock_released(): - response = q.get() + response = q.get() assert response.type_ == response.RETURN, \ "Expected Return, got %s" % response.type_ returns.update(response.value) diff --git a/malcolm/controllers/scanpointtickercontroller.py b/malcolm/controllers/scanpointtickercontroller.py index 38f666aed..2c7da18f4 100644 --- a/malcolm/controllers/scanpointtickercontroller.py +++ b/malcolm/controllers/scanpointtickercontroller.py @@ -14,9 +14,6 @@ @RunnableDeviceStateMachine.insert class ScanPointTickerController(Controller): - def __init__(self, block): - super(ScanPointTickerController, self).__init__(block) - def create_attributes(self): self.value = Attribute("value", NumberMeta("meta", "Value", numpy.float64)) diff --git a/malcolm/core/controller.py b/malcolm/core/controller.py index b77b8f670..6b9738438 100644 --- a/malcolm/core/controller.py +++ b/malcolm/core/controller.py @@ -13,15 +13,17 @@ class Controller(Loggable): """Implement the logic that takes a Block through its statemachine""" - def __init__(self, block): + def __init__(self, process, block): """ Args: + process (Process): The process this should run under block (Block): Block instance to add Methods and Attributes to """ logger_name = "%s.controller" % block.name super(Controller, self).__init__(logger_name) self.writeable_methods = OrderedDict() + self.process = process self.block = block for attribute in self.create_attributes(): @@ -29,6 +31,8 @@ def __init__(self, block): for method in self.create_methods(): block.add_method(method) + self.process.add_block(block) + def create_methods(self): """Abstract method that should provide Method instances for Block diff --git a/tests/test_controllers/test_clientcontroller.py b/tests/test_controllers/test_clientcontroller.py index 25d656e06..d5c5972f0 100644 --- a/tests/test_controllers/test_clientcontroller.py +++ b/tests/test_controllers/test_clientcontroller.py @@ -24,7 +24,7 @@ class TestClientController(unittest.TestCase): def setUp(self): # Serialized version of the block we want source = Block("blockname") - HelloController(source) + HelloController(MagicMock(), source) self.serialized = source.to_dict() # Setup client controller prerequisites self.b = Block("blockname") diff --git a/tests/test_controllers/test_countercontroller.py b/tests/test_controllers/test_countercontroller.py index 8a0be4f3f..3fd38905d 100644 --- a/tests/test_controllers/test_countercontroller.py +++ b/tests/test_controllers/test_countercontroller.py @@ -12,7 +12,7 @@ class TestCounterController(unittest.TestCase): def test_init(self): block = Mock() - c = CounterController(block) + c = CounterController(Mock(), block) self.assertIs(block, c.block) self.assertEquals(2, len(block.add_method.call_args_list)) method_1 = block.add_method.call_args_list[0][0][0] @@ -23,7 +23,7 @@ def test_init(self): self.assertEquals(c.reset, method_2.func) def test_increment_increments(self): - c = CounterController(Mock()) + c = CounterController(Mock(), Mock()) self.assertEquals(0, c.counter.value) c.increment() self.assertEquals(1, c.counter.value) @@ -31,26 +31,26 @@ def test_increment_increments(self): self.assertEquals(2, c.counter.value) def test_increment_calls_on_changed(self): - c = CounterController(Mock()) + c = CounterController(Mock(), Mock()) c.counter.on_changed = Mock(side_effect=c.counter.on_changed) c.increment() c.counter.on_changed.assert_called_once_with([["value"], 1]) def test_reset_sets_zero(self): - c = CounterController(Mock()) + c = CounterController(Mock(), Mock()) c.counter.value = 1234 c.reset() self.assertEquals(0, c.counter.value) def test_reset_calls_on_changed(self): - c = CounterController(Mock()) + c = CounterController(Mock(), Mock()) c.counter.value = 1234 c.counter.on_changed = Mock(side_effect=c.counter.on_changed) c.reset() c.counter.on_changed.assert_called_once_with([["value"], 0]) def test_put_changes_value(self): - c = CounterController(Mock()) + c = CounterController(Mock(), Mock()) c.counter.parent = c.block c.counter.put(32) self.assertEqual(c.counter.value, 32) diff --git a/tests/test_controllers/test_hellocontroller.py b/tests/test_controllers/test_hellocontroller.py index cd1d97d25..cf2089201 100644 --- a/tests/test_controllers/test_hellocontroller.py +++ b/tests/test_controllers/test_hellocontroller.py @@ -4,7 +4,7 @@ import setup_malcolm_paths import unittest -from mock import Mock +from mock import Mock, MagicMock from malcolm.controllers.hellocontroller import HelloController @@ -13,7 +13,7 @@ class TestHelloController(unittest.TestCase): def setUp(self): self.block = Mock() - self.c = HelloController(self.block) + self.c = HelloController(MagicMock(), self.block) def test_init(self): self.assertIs(self.block, self.c.block) diff --git a/tests/test_controllers/test_scanpointtickercontroller.py b/tests/test_controllers/test_scanpointtickercontroller.py index 281befa47..36ab5abb5 100644 --- a/tests/test_controllers/test_scanpointtickercontroller.py +++ b/tests/test_controllers/test_scanpointtickercontroller.py @@ -20,7 +20,7 @@ class TestScanPointTickerController(unittest.TestCase): def test_init(self, pgmd_mock, nmd_mock, smd_mock): attr_id = "epics:nt/NTAttribute:1.0" block = MagicMock() - sptc = ScanPointTickerController(block) + sptc = ScanPointTickerController(MagicMock(), block) self.assertEqual(block, sptc.block) self.assertEqual(RunnableDeviceStateMachine, type(sptc.stateMachine)) self.assertEqual("RunnableDeviceStateMachine", sptc.stateMachine.name) @@ -42,7 +42,7 @@ def test_configure(self): an = MagicMock() e = MagicMock() block = MagicMock() - sptc = ScanPointTickerController(block) + sptc = ScanPointTickerController(MagicMock(), block) sptc.configure(g, an, e) @@ -60,7 +60,7 @@ def test_run(self, sleep_mock): e = MagicMock() e.__float__ = MagicMock(return_value=0.1) block = MagicMock() - sptc = ScanPointTickerController(block) + sptc = ScanPointTickerController(MagicMock(), block) sptc.value.set_value = MagicMock(side_effect=sptc.value.set_value) sptc.configure(g, an, e) diff --git a/tests/test_core/test_controller.py b/tests/test_core/test_controller.py index c94890f73..e0fc467b5 100644 --- a/tests/test_core/test_controller.py +++ b/tests/test_core/test_controller.py @@ -30,13 +30,11 @@ def setUp(self): self.m1 = MagicMock() self.m2 = MagicMock() b._methods.__getitem__.side_effect = [self.m1, self.m2] - self.c = DummyController(b) + self.c = DummyController(MagicMock(), b) def test_init(self): - b = MagicMock() - self.c = DummyController(b) - self.assertEqual(self.c.block, b) - b.add_method.assert_has_calls( + self.c.process.add_block.assert_called_once_with(self.c.block) + self.c.block.add_method.assert_has_calls( [call(self.c.say_goodbye.Method), call(self.c.say_hello.Method)]) self.assertEqual(self.c.state.name, "State") diff --git a/tests/test_core/test_system_core.py b/tests/test_core/test_system_core.py index ae2f28e09..bbb630a82 100644 --- a/tests/test_core/test_system_core.py +++ b/tests/test_core/test_system_core.py @@ -4,6 +4,7 @@ import setup_malcolm_paths import unittest +from mock import MagicMock # logging # import logging @@ -24,7 +25,7 @@ class TestHelloControllerSystem(unittest.TestCase): def test_hello_controller_good_input(self): block = Block("hello") - HelloController(block) + HelloController(MagicMock(), block) result = block.say_hello(name="me") self.assertEquals(result.greeting, "Hello me") @@ -32,8 +33,7 @@ def test_hello_controller_with_process(self): sync_factory = SyncFactory("sched") process = Process("proc", sync_factory) block = Block("hello") - HelloController(block) - process.add_block(block) + HelloController(process, block) process.start() q = sync_factory.create_queue() req = Request.Post(response_queue=q, context="ClientConnection", @@ -54,7 +54,7 @@ def test_counter_controller_subscribe(self): sync_factory = SyncFactory("sched") process = Process("proc", sync_factory) block = Block("counting") - CounterController(block) + CounterController(process, block) process.add_block(block) process.start() q = sync_factory.create_queue() diff --git a/tests/test_wscomms/test_system_wscomms.py b/tests/test_wscomms/test_system_wscomms.py index df4631182..49fcc86e7 100644 --- a/tests/test_wscomms/test_system_wscomms.py +++ b/tests/test_wscomms/test_system_wscomms.py @@ -26,21 +26,17 @@ from malcolm.wscomms.wsclientcomms import WSClientComms -class TestSystemWSComms(unittest.TestCase): +class TestSystemWSCommsServerOnly(unittest.TestCase): def setUp(self): - sync_factory = SyncFactory("sync") - self.process = Process("proc", sync_factory) + self.sf = SyncFactory("sync") + self.process = Process("proc", self.sf) block = Block("hello") - self.process.add_block(block) - HelloController(block) + HelloController(self.process, block) self.sc = WSServerComms("sc", self.process, 8888) self.process.start() self.sc.start() def tearDown(self): - if hasattr(self, "cc"): - self.cc.stop() - self.cc.wait() self.sc.stop() self.sc.wait() self.process.stop() @@ -71,16 +67,27 @@ def send_message(self): def test_server_and_simple_client(self): self.send_message() - def test_server_with_malcolm_client(self): - self.cc = WSClientComms("cc", self.process, "ws://localhost:8888/ws") +class TestSystemWSCommsServerAndClient(TestSystemWSCommsServerOnly): + def setUp(self): + super(TestSystemWSCommsServerAndClient, self).setUp() + self.process2 = Process("proc2", self.sf) + self.block2 = Block("hello") + ClientController(self.process2, self.block2) + self.cc = WSClientComms("cc", self.process2, "ws://localhost:8888/ws") + self.process2.start() self.cc.start() - # Don't add to process as we already have a block of that name - block2 = Block("hello") - ClientController(self.process, block2) + + def tearDown(self): + super(TestSystemWSCommsServerAndClient, self).tearDown() + self.cc.stop() + self.cc.wait() + self.process2.stop() + + def test_server_with_malcolm_client(self): # Normally we would wait for it to be connected here, but it isn't # attached to a process so just sleep for a bit time.sleep(0.1) - ret = block2.say_hello("me2") + ret = self.block2.say_hello("me2") self.assertEqual(ret, dict(greeting="Hello me2"))