diff --git a/test/model/test_controller.py b/test/model/test_controller.py index 054b2badf..6894dfbb5 100644 --- a/test/model/test_controller.py +++ b/test/model/test_controller.py @@ -902,8 +902,8 @@ async def test_get_association_groups(controller, uuid4, mock_command): }, ) - node_id = 52 - result = await controller.async_get_association_groups(node_id) + association_address = association_pkg.Association(node_id=52) + result = await controller.async_get_association_groups(association_address) assert result[1].max_nodes == 10 assert result[1].is_lifeline is True @@ -924,7 +924,7 @@ async def test_get_association_groups(controller, uuid4, mock_command): assert ack_commands[0] == { "command": "controller.get_association_groups", "messageId": uuid4, - "nodeId": node_id, + "nodeId": association_address.node_id, } @@ -935,31 +935,44 @@ async def test_get_associations(controller, uuid4, mock_command): {"command": "controller.get_associations"}, { "associations": { - 1: { - "nodeId": 10, - }, - 2: { - "nodeId": 30, - "endpoint": 0, - }, + "1": [ + {"nodeId": 10}, + ], + "2": [ + {"nodeId": 11}, + {"nodeId": 20}, + ], + "3": [ + {"nodeId": 30, "endpoint": 0}, + {"nodeId": 40, "endpoint": 1}, + ], } }, ) - node_id = 52 - result = await controller.async_get_associations(node_id) + association_address = association_pkg.Association(node_id=52) + result = await controller.async_get_associations(association_address) + + assert result[1][0].node_id == 10 + assert result[1][0].endpoint is None + + assert result[2][0].node_id == 11 + assert result[2][0].endpoint is None - assert result[1].node_id == 10 - assert result[1].endpoint is None + assert result[2][1].node_id == 20 + assert result[2][1].endpoint is None - assert result[2].node_id == 30 - assert result[2].endpoint == 0 + assert result[3][0].node_id == 30 + assert result[3][0].endpoint == 0 + + assert result[3][1].node_id == 40 + assert result[3][1].endpoint == 1 assert len(ack_commands) == 1 assert ack_commands[0] == { "command": "controller.get_associations", "messageId": uuid4, - "nodeId": node_id, + "nodeId": association_address.node_id, } @@ -971,17 +984,19 @@ async def test_is_association_allowed(controller, uuid4, mock_command): {"allowed": True}, ) - node_id = 52 + association_address = association_pkg.Association(node_id=52) group = 0 association = association_pkg.Association(node_id=5, endpoint=0) - assert await controller.async_is_association_allowed(node_id, group, association) + assert await controller.async_is_association_allowed( + association_address, group, association + ) assert len(ack_commands) == 1 assert ack_commands[0] == { "command": "controller.is_association_allowed", "messageId": uuid4, - "nodeId": node_id, + "nodeId": association_address.node_id, "group": group, "association": {"nodeId": 5, "endpoint": 0}, } @@ -995,24 +1010,48 @@ async def test_add_associations(controller, uuid4, mock_command): {}, ) - node_id = 52 + association_address = association_pkg.Association(node_id=52) group = 0 associations = [ association_pkg.Association(node_id=5, endpoint=0), association_pkg.Association(node_id=10), ] - await controller.async_add_associations(node_id, group, associations) + await controller.async_add_associations(association_address, group, associations) assert len(ack_commands) == 1 assert ack_commands[0] == { "command": "controller.add_associations", "messageId": uuid4, - "nodeId": node_id, + "nodeId": association_address.node_id, + "group": group, + "associations": [ + {"nodeId": associations[0].node_id, "endpoint": associations[0].endpoint}, + {"nodeId": associations[1].node_id}, + ], + } + + association_address = association_pkg.Association(node_id=52, endpoint=111) + group = 1 + associations = [ + association_pkg.Association(node_id=11), + association_pkg.Association(node_id=6, endpoint=1), + ] + + await controller.async_add_associations( + association_address, group, associations, True + ) + + assert len(ack_commands) == 2 + assert ack_commands[1] == { + "command": "controller.add_associations", + "messageId": uuid4, + "nodeId": association_address.node_id, + "endpoint": association_address.endpoint, "group": group, "associations": [ - {"nodeId": 5, "endpoint": 0}, - {"nodeId": 10, "endpoint": None}, + {"nodeId": associations[0].node_id}, + {"nodeId": associations[1].node_id, "endpoint": associations[1].endpoint}, ], } @@ -1025,24 +1064,48 @@ async def test_remove_associations(controller, uuid4, mock_command): {}, ) - node_id = 52 + association_address = association_pkg.Association(node_id=52) group = 0 associations = [ association_pkg.Association(node_id=5, endpoint=0), association_pkg.Association(node_id=10), ] - await controller.async_remove_associations(node_id, group, associations) + await controller.async_remove_associations(association_address, group, associations) assert len(ack_commands) == 1 assert ack_commands[0] == { "command": "controller.remove_associations", "messageId": uuid4, - "nodeId": node_id, + "nodeId": association_address.node_id, + "group": group, + "associations": [ + {"nodeId": associations[0].node_id, "endpoint": associations[0].endpoint}, + {"nodeId": associations[1].node_id}, + ], + } + + association_address = association_pkg.Association(node_id=53, endpoint=112) + group = 1 + associations = [ + association_pkg.Association(node_id=11), + association_pkg.Association(node_id=6, endpoint=1), + ] + + await controller.async_remove_associations( + association_address, group, associations, True + ) + + assert len(ack_commands) == 2 + assert ack_commands[1] == { + "command": "controller.remove_associations", + "messageId": uuid4, + "nodeId": association_address.node_id, + "endpoint": association_address.endpoint, "group": group, "associations": [ - {"nodeId": 5, "endpoint": 0}, - {"nodeId": 10, "endpoint": None}, + {"nodeId": associations[0].node_id}, + {"nodeId": associations[1].node_id, "endpoint": associations[1].endpoint}, ], } @@ -1065,6 +1128,16 @@ async def test_remove_node_from_all_associations(controller, uuid4, mock_command "nodeId": node_id, } + node_id = 53 + await controller.async_remove_node_from_all_associations(node_id, True) + + assert len(ack_commands) == 2 + assert ack_commands[1] == { + "command": "controller.remove_node_from_all_associations", + "messageId": uuid4, + "nodeId": node_id, + } + async def test_get_node_neighbors(controller, uuid4, mock_command): """Test get node neighbors.""" diff --git a/zwave_js_server/model/controller/__init__.py b/zwave_js_server/model/controller/__init__.py index 645ae3650..eb2fa979d 100644 --- a/zwave_js_server/model/controller/__init__.py +++ b/zwave_js_server/model/controller/__init__.py @@ -420,18 +420,21 @@ async def async_is_failed_node(self, node_id: int) -> bool: return cast(bool, data["failed"]) async def async_get_association_groups( - self, node_id: int + self, source: Association ) -> Dict[int, AssociationGroup]: """Send getAssociationGroups command to Controller.""" + source_data = {"nodeId": source.node_id} + if source.endpoint is not None: + source_data["endpoint"] = source.endpoint data = await self.client.async_send_command( { "command": "controller.get_association_groups", - "nodeId": node_id, + **source_data, } ) groups = {} for key, group in data["groups"].items(): - groups[key] = AssociationGroup( + groups[int(key)] = AssociationGroup( max_nodes=group["maxNodes"], is_lifeline=group["isLifeline"], multi_channel=group["multiChannel"], @@ -441,84 +444,124 @@ async def async_get_association_groups( ) return groups - async def async_get_associations(self, node_id: int) -> Dict[int, Association]: + async def async_get_associations( + self, source: Association + ) -> Dict[int, List[Association]]: """Send getAssociations command to Controller.""" + source_data = {"nodeId": source.node_id} + if source.endpoint is not None: + source_data["endpoint"] = source.endpoint data = await self.client.async_send_command( { "command": "controller.get_associations", - "nodeId": node_id, + **source_data, } ) - associations = {} - for key, association in data["associations"].items(): - associations[key] = Association( - node_id=association["nodeId"], endpoint=association.get("endpoint") - ) - return associations + associations_map = {} + for key, associations in data["associations"].items(): + associations_map[int(key)] = [ + Association( + node_id=association["nodeId"], endpoint=association.get("endpoint") + ) + for association in associations + ] + return associations_map async def async_is_association_allowed( - self, node_id: int, group: int, association: Association + self, source: Association, group: int, association: Association ) -> bool: """Send isAssociationAllowed command to Controller.""" + source_data = {"nodeId": source.node_id} + if source.endpoint is not None: + source_data["endpoint"] = source.endpoint + + association_data = {"nodeId": association.node_id} + if association.endpoint is not None: + association_data["endpoint"] = association.endpoint data = await self.client.async_send_command( { "command": "controller.is_association_allowed", - "nodeId": node_id, + **source_data, "group": group, - "association": { - "nodeId": association.node_id, - "endpoint": association.endpoint, - }, + "association": association_data, } ) return cast(bool, data["allowed"]) async def async_add_associations( - self, node_id: int, group: int, associations: List[Association] + self, + source: Association, + group: int, + associations: List[Association], + wait_for_result: bool = False, ) -> None: """Send addAssociations command to Controller.""" - await self.client.async_send_command( - { - "command": "controller.add_associations", - "nodeId": node_id, - "group": group, - "associations": [ - { - "nodeId": association.node_id, - "endpoint": association.endpoint, - } - for association in associations - ], - } - ) + source_data = {"nodeId": source.node_id} + if source.endpoint is not None: + source_data["endpoint"] = source.endpoint + + associations_data = [] + for association in associations: + association_data = {"nodeId": association.node_id} + if association.endpoint is not None: + association_data["endpoint"] = association.endpoint + associations_data.append(association_data) + + cmd = { + "command": "controller.add_associations", + **source_data, + "group": group, + "associations": associations_data, + } + if wait_for_result: + await self.client.async_send_command(cmd) + else: + await self.client.async_send_command_no_wait(cmd) async def async_remove_associations( - self, node_id: int, group: int, associations: List[Association] + self, + source: Association, + group: int, + associations: List[Association], + wait_for_result: bool = False, ) -> None: """Send removeAssociations command to Controller.""" - await self.client.async_send_command( - { - "command": "controller.remove_associations", - "nodeId": node_id, - "group": group, - "associations": [ - { - "nodeId": association.node_id, - "endpoint": association.endpoint, - } - for association in associations - ], - } - ) + source_data = {"nodeId": source.node_id} + if source.endpoint is not None: + source_data["endpoint"] = source.endpoint + + associations_data = [] + for association in associations: + association_data = {"nodeId": association.node_id} + if association.endpoint is not None: + association_data["endpoint"] = association.endpoint + associations_data.append(association_data) + + cmd = { + "command": "controller.remove_associations", + **source_data, + "group": group, + "associations": associations_data, + } + if wait_for_result: + await self.client.async_send_command(cmd) + else: + await self.client.async_send_command_no_wait(cmd) - async def async_remove_node_from_all_associations(self, node_id: int) -> None: + async def async_remove_node_from_all_associations( + self, + node_id: int, + wait_for_result: bool = False, + ) -> None: """Send removeNodeFromAllAssociations command to Controller.""" - await self.client.async_send_command( - { - "command": "controller.remove_node_from_all_associations", - "nodeId": node_id, - } - ) + cmd = { + "command": "controller.remove_node_from_all_associations", + "nodeId": node_id, + } + if wait_for_result: + await self.client.async_send_command(cmd) + else: + await self.client.async_send_command_no_wait(cmd) async def async_get_node_neighbors(self, node_id: int) -> List[int]: """Send getNodeNeighbors command to Controller to get node's neighbors."""