Skip to content

Commit

Permalink
Improve the framer: let PDUs encode and decode themselves
Browse files Browse the repository at this point in the history
  • Loading branch information
dewet22 committed May 31, 2022
1 parent 00ed7e4 commit 75e7ab7
Show file tree
Hide file tree
Showing 8 changed files with 391 additions and 374 deletions.
226 changes: 95 additions & 131 deletions givenergy_modbus/framer.py

Large diffs are not rendered by default.

134 changes: 73 additions & 61 deletions givenergy_modbus/pdu/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import logging
import struct
from abc import ABC
from typing import Optional
from typing import Optional, Type

from givenergy_modbus.codec import PayloadDecoder, PayloadEncoder
from givenergy_modbus.exceptions import InvalidFrame
from givenergy_modbus.exceptions import InvalidFrame, InvalidPduState

_logger = logging.getLogger(__name__)


class BasePDU(ABC):
"""Base of the PDU Message handler class tree.
"""Base of the PDU Message network_timeout_handler class tree.
The Protocol Data Unit (PDU) defines the basic unit of message exchange for Modbus. It is routed to devices with
specific addresses, and targets specific operations through function codes. This tree defines the hierarchy of
Expand All @@ -26,18 +26,16 @@ class BasePDU(ABC):
"""

_builder: PayloadEncoder
function_code: int
data_adapter_serial_number: str = 'AB1234G567' # for client requests this seems ignored
main_function_code: int
raw_frame: bytes

def _set_attribute_if_present(self, attr: str, **kwargs):
if attr in kwargs:
setattr(self, attr, kwargs[attr])

def __init__(self, **kwargs):
self._set_attribute_if_present('data_adapter_serial_number', **kwargs)
self._set_attribute_if_present('padding', **kwargs)
self._set_attribute_if_present('slave_address', **kwargs)
self._set_attribute_if_present('check', **kwargs)

def encode(self) -> bytes:
"""Encode PDU message from instance attributes."""
Expand All @@ -46,63 +44,59 @@ def encode(self) -> bytes:
self._builder.add_serial_number(self.data_adapter_serial_number)
self._encode_function_data()
# self._update_check_code()
return self._builder.to_string()
inner_frame = self._builder.to_string()
mbap_header = struct.pack('>HHHBB', 0x5959, 0x1, len(inner_frame) + 2, 0x1, self.function_code)
self.raw_frame = mbap_header + inner_frame
return self.raw_frame

@classmethod
def decode_bytes(cls, data: bytes) -> 'BasePDU':
"""Decode raw byte frame to populated PDU instance."""
_logger.debug(f'{cls.__name__}.decode_bytes(0x{data.hex()})')
attrs = {}
decoder = PayloadDecoder(data)
attrs['_tid'] = decoder.decode_16bit_uint()
attrs['_pid'] = decoder.decode_16bit_uint()
attrs['_len'] = decoder.decode_16bit_uint()
remaining_bytes = decoder.remaining_bytes
attrs['_uid'] = decoder.decode_8bit_uint()
attrs['main_function_code'] = decoder.decode_8bit_uint()
if attrs['_tid'] != 0x5959:
raise InvalidFrame(f'Transaction ID != 0x5959, attrs: {attrs}', data)
if attrs['_pid'] != 0x01:
raise InvalidFrame(f'Protocol ID != 0x0001, attrs: {attrs}', data)
if attrs['_len'] != remaining_bytes:
raise InvalidFrame(
f'Header length {attrs["_len"]} != remaining bytes {remaining_bytes}, attrs: {attrs}', data

t_id = decoder.decode_16bit_uint()
if t_id != 0x5959:
raise InvalidFrame(f'Transaction ID 0x{t_id:04x} != 0x5959', data)

p_id = decoder.decode_16bit_uint()
if p_id != 0x0001:
raise InvalidFrame(f'Protocol ID 0x{p_id:04x} != 0x0001', data)

header_len = decoder.decode_16bit_uint()
remaining_frame_len = decoder.remaining_bytes # includes 2 bytes for uid and function code
if header_len != remaining_frame_len:
raise InvalidFrame(f'Header length {header_len} != remaining frame length {remaining_frame_len}', data)

u_id = decoder.decode_8bit_uint()
if u_id != 0x01:
raise InvalidFrame(f'Unit ID 0x{u_id:02x} != 0x01', data)

function_code = decoder.decode_8bit_uint()
decoder_class = cls.lookup_main_function_decoder(function_code)

try:
pdu = decoder_class.decode_main_function(decoder)
pdu.raw_frame = data
pdu.ensure_valid_state()
except InvalidPduState:
raise
except Exception as e:
raise InvalidFrame(str(e), data)

if not decoder.decoding_complete:
_logger.error(
f'Decoder did not fully consume frame for {pdu}: decoded {decoder.decoded_bytes}b but '
f'packet header specified length={decoder.payload_size}. '
f'Remaining payload: [{decoder.remaining_payload.hex()}]'
)
if attrs['_uid'] != 0x01:
raise InvalidFrame(f'Unit ID != 0x01, attrs: {attrs}', data)

candidate_decoder_classes = cls.__subclasses__()
_logger.debug(
f'Candidate decoders for function code {attrs["main_function_code"]}: '
f'{", ".join([c.__name__ for c in candidate_decoder_classes])}'
)

for c in candidate_decoder_classes:
cls_main_function_code = getattr(c, 'main_function_code', None)
if cls_main_function_code == attrs['main_function_code']:
_logger.debug(f'Passing off to {c.__name__}.decode_main_function(0x{decoder.remaining_payload.hex()})')
try:
pdu = c._decode_main_function(decoder, **attrs)
except struct.error as e:
raise InvalidFrame(str(e), data)
if not decoder.decoding_complete:
_logger.error(
f'Decoder did not fully consume frame for {pdu}: decoded {decoder.decoded_bytes}b but '
f'packet header specified length={decoder.payload_size}. '
f'Remaining payload: [{decoder.remaining_payload.hex()}]'
)
pdu.ensure_valid_state()
if not decoder.remaining_bytes == 0:
_logger.warning(
f'Decoder buffer not exhausted, {decoder.remaining_bytes} bytes remain: '
f'0x{decoder.remaining_payload.hex()}'
)
return pdu
_logger.debug(f'{c.__name__} disregarded, it handles function code {cls_main_function_code}')
raise InvalidFrame(f'Found no decoder for function code {attrs["main_function_code"]}', data)
return pdu

@classmethod
def lookup_main_function_decoder(cls, function_code: int) -> Type['BasePDU']:
raise NotImplementedError()

@classmethod
def _decode_main_function(cls, decoder: PayloadDecoder, **attrs) -> 'BasePDU':
def decode_main_function(cls, decoder: PayloadDecoder, **attrs) -> 'BasePDU':
raise NotImplementedError()

def _encode_function_data(self) -> None:
Expand All @@ -126,15 +120,15 @@ def has_same_shape(self, o: object):
"""
if isinstance(o, BasePDU):
return self.shape_hash() == o.shape_hash()
return NotImplemented
raise NotImplementedError()

def shape_hash(self) -> int:
"""Calculates the "shape hash" for a given message."""
return hash(self._shape_hash_keys())

def _shape_hash_keys(self) -> tuple:
"""Defines which keys to compare to see if two messages have the same shape."""
return (type(self), self.main_function_code) + self._extra_shape_hash_keys()
return (type(self), self.function_code) + self._extra_shape_hash_keys()

def _extra_shape_hash_keys(self) -> tuple:
"""Allows extra message-specific keys to be mixed in."""
Expand All @@ -144,6 +138,17 @@ def _extra_shape_hash_keys(self) -> tuple:
class ClientIncomingMessage(BasePDU, ABC):
"""Root of the hierarchy for PDUs clients are expected to receive and handle."""

@classmethod
def lookup_main_function_decoder(cls, function_code: int) -> Type['ClientIncomingMessage']:
from givenergy_modbus.pdu import HeartbeatRequest, TransparentResponse

if function_code == 1:
return HeartbeatRequest
elif function_code == 2:
return TransparentResponse
else:
raise NotImplementedError(f'ClientIncomingMessage main function #{function_code} decoder')

def expected_response(self) -> Optional['ClientOutgoingMessage']:
"""Create a template of a correctly shaped Response expected for this Request."""
raise NotImplementedError()
Expand All @@ -152,9 +157,16 @@ def expected_response(self) -> Optional['ClientOutgoingMessage']:
class ClientOutgoingMessage(BasePDU, ABC):
"""Root of the hierarchy for PDUs clients are expected to send to servers."""

def expected_response(self) -> Optional['ClientIncomingMessage']:
"""Create a template of a correctly shaped Response expected for this Request."""
raise NotImplementedError()
@classmethod
def lookup_main_function_decoder(cls, function_code: int) -> Type['ClientOutgoingMessage']:
from givenergy_modbus.pdu import HeartbeatResponse, TransparentRequest

if function_code == 1:
return HeartbeatResponse
elif function_code == 2:
return TransparentRequest
else:
raise NotImplementedError(f'ClientOutgoingMessage main function #{function_code} decoder')


ServerIncomingMessage = ClientOutgoingMessage
Expand Down
4 changes: 2 additions & 2 deletions givenergy_modbus/pdu/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class HeartbeatMessage(BasePDU, ABC):
"""Root of the hierarchy for 1/Heartbeat function PDUs."""

main_function_code = 1
function_code = 1
data_adapter_type: int

def __init__(self, **kwargs):
Expand All @@ -33,7 +33,7 @@ def _decode_function_data(self, decoder):
self.data_adapter_type = decoder.decode_8bit_uint()

@classmethod
def _decode_main_function(cls, decoder: PayloadDecoder, **attrs) -> 'HeartbeatMessage':
def decode_main_function(cls, decoder: PayloadDecoder, **attrs) -> 'HeartbeatMessage':
attrs['data_adapter_serial_number'] = decoder.decode_serial_number()
attrs['data_adapter_type'] = decoder.decode_8bit_uint()
return cls(**attrs)
Expand Down
4 changes: 2 additions & 2 deletions givenergy_modbus/pdu/null.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class NullResponse(TransparentResponse):
data payload seems to be invariably just a series of nulls.
"""

inner_function_code = 0
transparent_function_code = 0

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -26,7 +26,7 @@ def _encode_function_data(self) -> None:
self._update_check_code()

@classmethod
def _decode_inner_function(cls, decoder: PayloadDecoder, **attrs) -> 'NullResponse':
def decode_transparent_function(cls, decoder: PayloadDecoder, **attrs) -> 'NullResponse':
if decoder.remaining_bytes != 126:
_logger.warning(
f'remaining bytes: {decoder.remaining_bytes}b 0x{decoder.remaining_payload.hex()} attrs: {attrs}'
Expand Down
22 changes: 12 additions & 10 deletions givenergy_modbus/pdu/read_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, **kwargs):
self.register_count = kwargs.get('register_count', 0)

@classmethod
def _decode_inner_function(cls, decoder: PayloadDecoder, **attrs) -> 'ReadRegistersMessage':
def decode_transparent_function(cls, decoder: PayloadDecoder, **attrs) -> 'ReadRegistersMessage':
attrs['base_register'] = decoder.decode_16bit_uint()
attrs['register_count'] = decoder.decode_16bit_uint()
if issubclass(cls, ReadRegistersResponse) and not attrs.get('error', False):
Expand All @@ -34,12 +34,14 @@ def _extra_shape_hash_keys(self) -> tuple:

def _ensure_registers_spec_correct(self):
if self.base_register is None:
raise ValueError('Base register must be set', self)
raise InvalidPduState('Base register must be set', self)
if self.base_register < 0 or 0xFFFF < self.base_register:
raise ValueError('Base register must be an unsigned 16-bit int', self)
raise InvalidPduState('Base register must be an unsigned 16-bit int', self)

if self.register_count is None:
raise ValueError('Register count must be set', self)
raise InvalidPduState('Register count must be set', self)
if self.register_count == 0 and not self.error:
_logger.warning(f'Register count of 0 does not make sense: {self}')


class ReadRegistersRequest(ReadRegistersMessage, TransparentRequest, ABC):
Expand All @@ -53,7 +55,7 @@ def _encode_function_data(self):

def _update_check_code(self):
crc_builder = PayloadEncoder()
crc_builder.add_8bit_uint(self.inner_function_code)
crc_builder.add_8bit_uint(self.transparent_function_code)
crc_builder.add_16bit_uint(self.base_register)
crc_builder.add_16bit_uint(self.register_count)
self.check = crc_builder.calculate_crc()
Expand All @@ -66,7 +68,7 @@ def ensure_valid_state(self):
if self.register_count != 1 and self.base_register % 60 != 0:
_logger.warning(f'Base register {self.base_register} not aligned on 60-byte boundary')
if self.register_count <= 0 or 60 < self.register_count:
raise ValueError('Register count must be in (0,60]', self)
raise InvalidPduState('Register count must be in (0,60]', self)


class ReadRegistersResponse(ReadRegistersMessage, TransparentResponse, ABC):
Expand All @@ -88,8 +90,8 @@ def ensure_valid_state(self) -> None:
self._ensure_registers_spec_correct()

if not self.error:
if self.register_count != 1 and self.base_register % 60 != 0:
_logger.warning(f'Base register {self.base_register} not aligned on 60-byte boundary')
# if self.register_count != 1 and self.base_register % 60 != 0:
# _logger.warning(f'Base register {self.base_register} not aligned on 60-byte boundary')
if self.register_count != len(self.register_values):
raise InvalidPduState(
f'register_count={self.register_count} but len(register_values)={len(self.register_values)}.',
Expand Down Expand Up @@ -147,7 +149,7 @@ def is_suspicious(self) -> bool:
class ReadHoldingRegisters(ReadRegistersMessage, ABC):
"""Request & Response PDUs for function #3/Read Holding Registers."""

inner_function_code = 3
transparent_function_code = 3


class ReadHoldingRegistersRequest(ReadHoldingRegisters, ReadRegistersRequest):
Expand All @@ -169,7 +171,7 @@ def expected_response(self):
class ReadInputRegisters(ReadRegistersMessage, ABC):
"""Request & Response PDUs for function #4/Read Input Registers."""

inner_function_code = 4
transparent_function_code = 4


class ReadInputRegistersRequest(ReadInputRegisters, ReadRegistersRequest):
Expand Down

0 comments on commit 75e7ab7

Please sign in to comment.