diff --git a/test/model/test_controller.py b/test/model/test_controller.py index 6894dfbb5..b2dd49aeb 100644 --- a/test/model/test_controller.py +++ b/test/model/test_controller.py @@ -902,7 +902,7 @@ async def test_get_association_groups(controller, uuid4, mock_command): }, ) - association_address = association_pkg.Association(node_id=52) + association_address = association_pkg.AssociationAddress(node_id=52) result = await controller.async_get_association_groups(association_address) assert result[1].max_nodes == 10 @@ -950,7 +950,7 @@ async def test_get_associations(controller, uuid4, mock_command): }, ) - association_address = association_pkg.Association(node_id=52) + association_address = association_pkg.AssociationAddress(node_id=52) result = await controller.async_get_associations(association_address) assert result[1][0].node_id == 10 @@ -984,9 +984,9 @@ async def test_is_association_allowed(controller, uuid4, mock_command): {"allowed": True}, ) - association_address = association_pkg.Association(node_id=52) + association_address = association_pkg.AssociationAddress(node_id=52) group = 0 - association = association_pkg.Association(node_id=5, endpoint=0) + association = association_pkg.AssociationAddress(node_id=5, endpoint=0) assert await controller.async_is_association_allowed( association_address, group, association @@ -1010,11 +1010,11 @@ async def test_add_associations(controller, uuid4, mock_command): {}, ) - association_address = association_pkg.Association(node_id=52) + association_address = association_pkg.AssociationAddress(node_id=52) group = 0 associations = [ - association_pkg.Association(node_id=5, endpoint=0), - association_pkg.Association(node_id=10), + association_pkg.AssociationAddress(node_id=5, endpoint=0), + association_pkg.AssociationAddress(node_id=10), ] await controller.async_add_associations(association_address, group, associations) @@ -1031,11 +1031,11 @@ async def test_add_associations(controller, uuid4, mock_command): ], } - association_address = association_pkg.Association(node_id=52, endpoint=111) + association_address = association_pkg.AssociationAddress(node_id=52, endpoint=111) group = 1 associations = [ - association_pkg.Association(node_id=11), - association_pkg.Association(node_id=6, endpoint=1), + association_pkg.AssociationAddress(node_id=11), + association_pkg.AssociationAddress(node_id=6, endpoint=1), ] await controller.async_add_associations( @@ -1064,11 +1064,11 @@ async def test_remove_associations(controller, uuid4, mock_command): {}, ) - association_address = association_pkg.Association(node_id=52) + association_address = association_pkg.AssociationAddress(node_id=52) group = 0 associations = [ - association_pkg.Association(node_id=5, endpoint=0), - association_pkg.Association(node_id=10), + association_pkg.AssociationAddress(node_id=5, endpoint=0), + association_pkg.AssociationAddress(node_id=10), ] await controller.async_remove_associations(association_address, group, associations) @@ -1085,11 +1085,11 @@ async def test_remove_associations(controller, uuid4, mock_command): ], } - association_address = association_pkg.Association(node_id=53, endpoint=112) + association_address = association_pkg.AssociationAddress(node_id=53, endpoint=112) group = 1 associations = [ - association_pkg.Association(node_id=11), - association_pkg.Association(node_id=6, endpoint=1), + association_pkg.AssociationAddress(node_id=11), + association_pkg.AssociationAddress(node_id=6, endpoint=1), ] await controller.async_remove_associations( diff --git a/zwave_js_server/model/association.py b/zwave_js_server/model/association.py index c676b91f4..712a75cd1 100644 --- a/zwave_js_server/model/association.py +++ b/zwave_js_server/model/association.py @@ -16,7 +16,7 @@ class AssociationGroup: @dataclass -class Association: +class AssociationAddress: """Represent a association dict type.""" node_id: int diff --git a/zwave_js_server/model/controller/__init__.py b/zwave_js_server/model/controller/__init__.py index eb2fa979d..4fa9092f7 100644 --- a/zwave_js_server/model/controller/__init__.py +++ b/zwave_js_server/model/controller/__init__.py @@ -12,7 +12,7 @@ ) from ...event import Event, EventBase from ...util.helpers import convert_base64_to_bytes, convert_bytes_to_base64 -from ..association import Association, AssociationGroup +from ..association import AssociationAddress, AssociationGroup from ..node import Node from .data_model import ControllerDataType from .event_model import CONTROLLER_EVENT_MODEL_MAP @@ -420,7 +420,7 @@ async def async_is_failed_node(self, node_id: int) -> bool: return cast(bool, data["failed"]) async def async_get_association_groups( - self, source: Association + self, source: AssociationAddress ) -> Dict[int, AssociationGroup]: """Send getAssociationGroups command to Controller.""" source_data = {"nodeId": source.node_id} @@ -445,8 +445,8 @@ async def async_get_association_groups( return groups async def async_get_associations( - self, source: Association - ) -> Dict[int, List[Association]]: + self, source: AssociationAddress + ) -> Dict[int, List[AssociationAddress]]: """Send getAssociations command to Controller.""" source_data = {"nodeId": source.node_id} if source.endpoint is not None: @@ -457,18 +457,19 @@ async def async_get_associations( **source_data, } ) - associations_map = {} - for key, associations in data["associations"].items(): - associations_map[int(key)] = [ - Association( - node_id=association["nodeId"], endpoint=association.get("endpoint") + associations = {} + for key, association_addresses in data["associations"].items(): + associations[int(key)] = [ + AssociationAddress( + node_id=association_address["nodeId"], + endpoint=association_address.get("endpoint"), ) - for association in associations + for association_address in association_addresses ] - return associations_map + return associations async def async_is_association_allowed( - self, source: Association, group: int, association: Association + self, source: AssociationAddress, group: int, association: AssociationAddress ) -> bool: """Send isAssociationAllowed command to Controller.""" source_data = {"nodeId": source.node_id} @@ -490,9 +491,9 @@ async def async_is_association_allowed( async def async_add_associations( self, - source: Association, + source: AssociationAddress, group: int, - associations: List[Association], + associations: List[AssociationAddress], wait_for_result: bool = False, ) -> None: """Send addAssociations command to Controller.""" @@ -520,9 +521,9 @@ async def async_add_associations( async def async_remove_associations( self, - source: Association, + source: AssociationAddress, group: int, - associations: List[Association], + associations: List[AssociationAddress], wait_for_result: bool = False, ) -> None: """Send removeAssociations command to Controller."""