diff --git a/canopen/pdo/__init__.py b/canopen/pdo/__init__.py index 533309f8..b749dfee 100644 --- a/canopen/pdo/__init__.py +++ b/canopen/pdo/__init__.py @@ -1,4 +1,6 @@ +import itertools import logging +from collections.abc import Iterator from canopen import node from canopen.pdo.base import PdoBase, PdoMap, PdoMaps, PdoVariable @@ -24,23 +26,36 @@ class PDO(PdoBase): :param tpdo: TPDO object holding the Transmit PDO mappings """ - def __init__(self, node, rpdo, tpdo): + def __init__(self, node, rpdo: PdoBase, tpdo: PdoBase): super(PDO, self).__init__(node) self.rx = rpdo.map self.tx = tpdo.map - self.map = {} - # the object 0x1A00 equals to key '1' so we remove 1 from the key + self.map = PdoMaps(0, 0, self) + # Combine RX and TX entries, but only via mapping parameter index. Relative index + # numbers would be ambiguous. + # The object 0x1A00 equals to key '1' so we remove 1 from the key for key, value in self.rx.items(): - self.map[0x1A00 + (key - 1)] = value + self.map.maps[self.rx.map_offset + (key - 1)] = value + self.map.maps[self.rx.com_offset + (key - 1)] = value for key, value in self.tx.items(): - self.map[0x1600 + (key - 1)] = value + self.map.maps[self.tx.map_offset + (key - 1)] = value + self.map.maps[self.tx.com_offset + (key - 1)] = value + + def __iter__(self) -> Iterator[int]: + return itertools.chain( + (self.rx.map_offset + i - 1 for i in self.rx), + (self.tx.map_offset + i - 1 for i in self.tx), + ) + + def __len__(self) -> int: + return len(self.rx) + len(self.tx) class RPDO(PdoBase): """Receive PDO to transfer data from somewhere to the represented node. - Properties 0x1400 to 0x1403 | Mapping 0x1600 to 0x1603. + Properties 0x1400 to 0x15FF | Mapping 0x1600 to 0x17FF. :param object node: Parent node for this object. """ @@ -65,7 +80,7 @@ def stop(self): class TPDO(PdoBase): """Transmit PDO to broadcast data from the represented node to the network. - Properties 0x1800 to 0x1803 | Mapping 0x1A00 to 0x1A03. + Properties 0x1800 to 0x19FF | Mapping 0x1A00 to 0x1BFF. :param object node: Parent node for this object. """ diff --git a/canopen/pdo/base.py b/canopen/pdo/base.py index 216fc550..55781fd7 100644 --- a/canopen/pdo/base.py +++ b/canopen/pdo/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import binascii +import contextlib import logging import math import threading @@ -33,7 +34,7 @@ class PdoBase(Mapping): def __init__(self, node: Union[LocalNode, RemoteNode]): self.network: canopen.network.Network = canopen.network._UNINITIALIZED_NETWORK - self.map: Optional[PdoMaps] = None + self.map: PdoMaps # must initialize in derived classes self.node: Union[LocalNode, RemoteNode] = node def __iter__(self): @@ -45,8 +46,7 @@ def __getitem__(self, key: Union[int, str]): raise KeyError("PDO index zero requested for 1-based sequence") if ( 0 < key <= 512 # By PDO Index - or 0x1600 <= key <= 0x17FF # By RPDO ID (512) - or 0x1A00 <= key <= 0x1BFF # By TPDO ID (512) + or 0x1600 <= key <= 0x1BFF # By RPDO / TPDO mapping or communication record ): return self.map[key] for pdo_map in self.map.values(): @@ -144,10 +144,10 @@ def stop(self): pdo_map.stop() -class PdoMaps(Mapping): +class PdoMaps(Mapping[int, 'PdoMap']): """A collection of transmit or receive maps.""" - def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None): + def __init__(self, com_offset: int, map_offset: int, pdo_node: PdoBase, cob_base=None): """ :param com_offset: :param map_offset: @@ -155,6 +155,11 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None): :param cob_base: """ self.maps: dict[int, PdoMap] = {} + self.com_offset = com_offset + self.map_offset = map_offset + if not com_offset and not map_offset: + # Skip generating entries without parameter index offsets + return for map_no in range(512): if com_offset + map_no in pdo_node.node.object_dictionary: new_map = PdoMap( @@ -167,7 +172,14 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None): self.maps[map_no + 1] = new_map def __getitem__(self, key: int) -> PdoMap: - return self.maps[key] + try: + return self.maps[key] + except KeyError: + with contextlib.suppress(KeyError): + return self.maps[key + 1 - self.map_offset] + with contextlib.suppress(KeyError): + return self.maps[key + 1 - self.com_offset] + raise def __iter__(self) -> Iterator[int]: return iter(self.maps) diff --git a/test/test_pdo.py b/test/test_pdo.py index 9eb6fb2f..50d218de 100644 --- a/test/test_pdo.py +++ b/test/test_pdo.py @@ -50,9 +50,12 @@ def test_pdo_getitem(self): self.assertEqual(node.tpdo[1]['BOOLEAN value 2'].raw, True) # Test different types of access - by_mapping_record = node.pdo[0x1600] + by_mapping_record = node.pdo[0x1A00] self.assertIsInstance(by_mapping_record, canopen.pdo.PdoMap) self.assertEqual(by_mapping_record['INTEGER16 value'].raw, -3) + self.assertIs(node.tpdo[0x1A00], by_mapping_record) + self.assertIs(node.tpdo[0x1800], by_mapping_record) + self.assertIs(node.pdo[0x1800], by_mapping_record) by_object_name = node.pdo['INTEGER16 value'] self.assertIsInstance(by_object_name, canopen.pdo.PdoVariable) self.assertIs(by_object_name.od, node.object_dictionary['INTEGER16 value']) @@ -68,7 +71,7 @@ def test_pdo_getitem(self): self.assertEqual(by_object_index.raw, 0xf) self.assertIs(node.pdo['0x2002'], by_object_index) self.assertIs(node.tpdo[0x2002], by_object_index) - self.assertIs(node.pdo[0x1600][0x2002], by_object_index) + self.assertIs(node.pdo[0x1A00][0x2002], by_object_index) self.assertRaises(KeyError, lambda: node.pdo[0]) self.assertRaises(KeyError, lambda: node.tpdo[0]) @@ -76,6 +79,20 @@ def test_pdo_getitem(self): self.assertRaises(KeyError, lambda: node.pdo[0x1BFF]) self.assertRaises(KeyError, lambda: node.tpdo[0x1BFF]) + def test_pdo_iterate(self): + node = self.node + pdo_iter = iter(node.pdo.items()) + prev = 0 # To check strictly increasing record index number + for rpdo, (index, pdo) in zip(node.rpdo.values(), pdo_iter): + self.assertIs(rpdo, pdo) + self.assertGreater(index, prev) + prev = index + # Continue consuming from pdo_iter + for tpdo, (index, pdo) in zip(node.tpdo.values(), pdo_iter): + self.assertIs(tpdo, pdo) + self.assertGreater(index, prev) + prev = index + def test_pdo_maps_iterate(self): node = self.node self.assertEqual(len(node.pdo), sum(1 for _ in node.pdo))