From e838492dda478fbfd1f76e491b3a4aa23882d083 Mon Sep 17 00:00:00 2001 From: Franco Fichtner Date: Thu, 7 Jul 2022 13:02:41 +0200 Subject: [PATCH] Revert "VPN: IPsec: Status Overview - cleanup, remove vici library in favour of port package" This reverts commit becf4e934282a4e10a1d667efb35b34d2a59b850. --- Makefile | 1 - plist | 7 + src/opnsense/scripts/ipsec/vici/__init__.py | 1 + src/opnsense/scripts/ipsec/vici/compat.py | 14 + src/opnsense/scripts/ipsec/vici/exception.py | 13 + src/opnsense/scripts/ipsec/vici/protocol.py | 206 ++++++++++ src/opnsense/scripts/ipsec/vici/session.py | 388 ++++++++++++++++++ .../scripts/ipsec/vici/test/__init__.py | 0 .../scripts/ipsec/vici/test/test_protocol.py | 144 +++++++ 9 files changed, 773 insertions(+), 1 deletion(-) create mode 100755 src/opnsense/scripts/ipsec/vici/__init__.py create mode 100755 src/opnsense/scripts/ipsec/vici/compat.py create mode 100755 src/opnsense/scripts/ipsec/vici/exception.py create mode 100755 src/opnsense/scripts/ipsec/vici/protocol.py create mode 100755 src/opnsense/scripts/ipsec/vici/session.py create mode 100755 src/opnsense/scripts/ipsec/vici/test/__init__.py create mode 100755 src/opnsense/scripts/ipsec/vici/test/test_protocol.py diff --git a/Makefile b/Makefile index 481279eec96..6bcc04b7dd7 100644 --- a/Makefile +++ b/Makefile @@ -184,7 +184,6 @@ CORE_DEPENDS?= ca_root_nss \ py${CORE_PYTHON}-requests \ py${CORE_PYTHON}-sqlite3 \ py${CORE_PYTHON}-ujson \ - py${CORE_PYTHON}-vici \ radvd \ rrdtool \ samplicator \ diff --git a/plist b/plist index 62848c43ef4..86f318a0829 100644 --- a/plist +++ b/plist @@ -806,6 +806,13 @@ /usr/local/opnsense/scripts/ipsec/list_spd.py /usr/local/opnsense/scripts/ipsec/list_status.py /usr/local/opnsense/scripts/ipsec/spddelete.py +/usr/local/opnsense/scripts/ipsec/vici/__init__.py +/usr/local/opnsense/scripts/ipsec/vici/compat.py +/usr/local/opnsense/scripts/ipsec/vici/exception.py +/usr/local/opnsense/scripts/ipsec/vici/protocol.py +/usr/local/opnsense/scripts/ipsec/vici/session.py +/usr/local/opnsense/scripts/ipsec/vici/test/__init__.py +/usr/local/opnsense/scripts/ipsec/vici/test/test_protocol.py /usr/local/opnsense/scripts/netflow/dump_log.py /usr/local/opnsense/scripts/netflow/export_details.py /usr/local/opnsense/scripts/netflow/flowctl_stats.py diff --git a/src/opnsense/scripts/ipsec/vici/__init__.py b/src/opnsense/scripts/ipsec/vici/__init__.py new file mode 100755 index 00000000000..d314325b6cf --- /dev/null +++ b/src/opnsense/scripts/ipsec/vici/__init__.py @@ -0,0 +1 @@ +from .session import Session diff --git a/src/opnsense/scripts/ipsec/vici/compat.py b/src/opnsense/scripts/ipsec/vici/compat.py new file mode 100755 index 00000000000..01af987d837 --- /dev/null +++ b/src/opnsense/scripts/ipsec/vici/compat.py @@ -0,0 +1,14 @@ +# Help functions for compatibility between python version 2 and 3 + + +# From https://legacy.python.org/dev/peps/pep-0469 +try: + dict.iteritems +except AttributeError: + # python 3 + def iteritems(d): + return iter(d.items()) +else: + # python 2 + def iteritems(d): + return d.iteritems() diff --git a/src/opnsense/scripts/ipsec/vici/exception.py b/src/opnsense/scripts/ipsec/vici/exception.py new file mode 100755 index 00000000000..757ac51a9d3 --- /dev/null +++ b/src/opnsense/scripts/ipsec/vici/exception.py @@ -0,0 +1,13 @@ +"""Exception types that may be thrown by this library.""" + +class DeserializationException(Exception): + """Encountered an unexpected byte sequence or missing element type.""" + +class SessionException(Exception): + """Session request exception.""" + +class CommandException(Exception): + """Command result exception.""" + +class EventUnknownException(Exception): + """Event unknown exception.""" diff --git a/src/opnsense/scripts/ipsec/vici/protocol.py b/src/opnsense/scripts/ipsec/vici/protocol.py new file mode 100755 index 00000000000..37022946384 --- /dev/null +++ b/src/opnsense/scripts/ipsec/vici/protocol.py @@ -0,0 +1,206 @@ +import io +import socket +import struct + +from collections import namedtuple +from collections import OrderedDict + +from .compat import iteritems +from .exception import DeserializationException + + +class Transport(object): + HEADER_LENGTH = 4 + MAX_SEGMENT = 512 * 1024 + + def __init__(self, sock): + self.socket = sock + + def send(self, packet): + self.socket.sendall(struct.pack("!I", len(packet)) + packet) + + def receive(self): + raw_length = self._recvall(self.HEADER_LENGTH) + length, = struct.unpack("!I", raw_length) + payload = self._recvall(length) + return payload + + def close(self): + self.socket.shutdown(socket.SHUT_RDWR) + self.socket.close() + + def _recvall(self, count): + """Ensure to read count bytes from the socket""" + data = b"" + while len(data) < count: + buf = self.socket.recv(count - len(data)) + if not buf: + raise socket.error('Connection closed') + data += buf + return data + + +class Packet(object): + CMD_REQUEST = 0 # Named request message + CMD_RESPONSE = 1 # Unnamed response message for a request + CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown + EVENT_REGISTER = 3 # Named event registration request + EVENT_UNREGISTER = 4 # Named event de-registration request + EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration + EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed + EVENT = 7 # Named event message + + ParsedPacket = namedtuple( + "ParsedPacket", + ["response_type", "payload"] + ) + + ParsedEventPacket = namedtuple( + "ParsedEventPacket", + ["response_type", "event_type", "payload"] + ) + + @classmethod + def _named_request(cls, request_type, request, message=None): + request = request.encode("UTF-8") + payload = struct.pack("!BB", request_type, len(request)) + request + if message is not None: + return payload + message + else: + return payload + + @classmethod + def request(cls, command, message=None): + return cls._named_request(cls.CMD_REQUEST, command, message) + + @classmethod + def register_event(cls, event_type): + return cls._named_request(cls.EVENT_REGISTER, event_type) + + @classmethod + def unregister_event(cls, event_type): + return cls._named_request(cls.EVENT_UNREGISTER, event_type) + + @classmethod + def parse(cls, packet): + stream = FiniteStream(packet) + response_type, = struct.unpack("!B", stream.read(1)) + + if response_type == cls.EVENT: + length, = struct.unpack("!B", stream.read(1)) + event_type = stream.read(length) + return cls.ParsedEventPacket(response_type, event_type, stream) + else: + return cls.ParsedPacket(response_type, stream) + + +class Message(object): + SECTION_START = 1 # Begin a new section having a name + SECTION_END = 2 # End a previously started section + KEY_VALUE = 3 # Define a value for a named key in the section + LIST_START = 4 # Begin a named list for list items + LIST_ITEM = 5 # Define an unnamed item value in the current list + LIST_END = 6 # End a previously started list + + @classmethod + def serialize(cls, message): + def encode_named_type(marker, name): + name = name.encode("UTF-8") + return struct.pack("!BB", marker, len(name)) + name + + def encode_blob(value): + if not isinstance(value, bytes): + value = str(value).encode("UTF-8") + return struct.pack("!H", len(value)) + value + + def serialize_list(lst): + segment = bytes() + for item in lst: + segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item) + return segment + + def serialize_dict(d): + segment = bytes() + for key, value in iteritems(d): + if isinstance(value, dict): + segment += ( + encode_named_type(cls.SECTION_START, key) + + serialize_dict(value) + + struct.pack("!B", cls.SECTION_END) + ) + elif isinstance(value, list): + segment += ( + encode_named_type(cls.LIST_START, key) + + serialize_list(value) + + struct.pack("!B", cls.LIST_END) + ) + else: + segment += ( + encode_named_type(cls.KEY_VALUE, key) + + encode_blob(value) + ) + return segment + + return serialize_dict(message) + + @classmethod + def deserialize(cls, stream): + def decode_named_type(stream): + length, = struct.unpack("!B", stream.read(1)) + return stream.read(length).decode("UTF-8") + + def decode_blob(stream): + length, = struct.unpack("!H", stream.read(2)) + return stream.read(length) + + def decode_list_item(stream): + marker, = struct.unpack("!B", stream.read(1)) + while marker == cls.LIST_ITEM: + yield decode_blob(stream) + marker, = struct.unpack("!B", stream.read(1)) + + if marker != cls.LIST_END: + raise DeserializationException( + "Expected end of list at {pos}".format(pos=stream.tell()) + ) + + section = OrderedDict() + section_stack = [] + while stream.has_more(): + element_type, = struct.unpack("!B", stream.read(1)) + if element_type == cls.SECTION_START: + section_name = decode_named_type(stream) + new_section = OrderedDict() + section[section_name] = new_section + section_stack.append(section) + section = new_section + + elif element_type == cls.LIST_START: + list_name = decode_named_type(stream) + section[list_name] = [item for item in decode_list_item(stream)] + + elif element_type == cls.KEY_VALUE: + key = decode_named_type(stream) + section[key] = decode_blob(stream) + + elif element_type == cls.SECTION_END: + if len(section_stack): + section = section_stack.pop() + else: + raise DeserializationException( + "Unexpected end of section at {pos}".format( + pos=stream.tell() + ) + ) + + if len(section_stack): + raise DeserializationException("Expected end of section") + return section + + +class FiniteStream(io.BytesIO): + def __len__(self): + return len(self.getvalue()) + + def has_more(self): + return self.tell() < len(self) diff --git a/src/opnsense/scripts/ipsec/vici/session.py b/src/opnsense/scripts/ipsec/vici/session.py new file mode 100755 index 00000000000..1383fa77836 --- /dev/null +++ b/src/opnsense/scripts/ipsec/vici/session.py @@ -0,0 +1,388 @@ +import collections +import socket + +from .exception import SessionException, CommandException, EventUnknownException +from .protocol import Transport, Packet, Message + + +class Session(object): + def __init__(self, sock=None): + if sock is None: + sock = socket.socket(socket.AF_UNIX) + sock.connect("/var/run/charon.vici") + self.handler = SessionHandler(Transport(sock)) + + def version(self): + """Retrieve daemon and system specific version information. + + :return: daemon and system specific version information + :rtype: dict + """ + return self.handler.request("version") + + def stats(self): + """Retrieve IKE daemon statistics and load information. + + :return: IKE daemon statistics and load information + :rtype: dict + """ + return self.handler.request("stats") + + def reload_settings(self): + """Reload strongswan.conf settings and any plugins supporting reload. + """ + self.handler.request("reload-settings") + + def initiate(self, sa): + """Initiate an SA. + + :param sa: the SA to initiate + :type sa: dict + :return: generator for logs emitted as dict + :rtype: generator + """ + return self.handler.streamed_request("initiate", "control-log", sa) + + def terminate(self, sa): + """Terminate an SA. + + :param sa: the SA to terminate + :type sa: dict + :return: generator for logs emitted as dict + :rtype: generator + """ + return self.handler.streamed_request("terminate", "control-log", sa) + + def redirect(self, sa): + """Redirect an IKE_SA. + + :param sa: the SA to redirect + :type sa: dict + """ + self.handler.request("redirect", sa) + + def install(self, policy): + """Install a trap, drop or bypass policy defined by a CHILD_SA config. + + :param policy: policy to install + :type policy: dict + """ + self.handler.request("install", policy) + + def uninstall(self, policy): + """Uninstall a trap, drop or bypass policy defined by a CHILD_SA config. + + :param policy: policy to uninstall + :type policy: dict + """ + self.handler.request("uninstall", policy) + + def list_sas(self, filters=None): + """Retrieve active IKE_SAs and associated CHILD_SAs. + + :param filters: retrieve only matching IKE_SAs (optional) + :type filters: dict + :return: generator for active IKE_SAs and associated CHILD_SAs as dict + :rtype: generator + """ + return self.handler.streamed_request("list-sas", "list-sa", filters) + + def list_policies(self, filters=None): + """Retrieve installed trap, drop and bypass policies. + + :param filters: retrieve only matching policies (optional) + :type filters: dict + :return: generator for installed trap, drop and bypass policies as dict + :rtype: generator + """ + return self.handler.streamed_request("list-policies", "list-policy", + filters) + + def list_conns(self, filters=None): + """Retrieve loaded connections. + + :param filters: retrieve only matching configuration names (optional) + :type filters: dict + :return: generator for loaded connections as dict + :rtype: generator + """ + return self.handler.streamed_request("list-conns", "list-conn", + filters) + + def get_conns(self): + """Retrieve connection names loaded exclusively over vici. + + :return: connection names + :rtype: dict + """ + return self.handler.request("get-conns") + + def list_certs(self, filters=None): + """Retrieve loaded certificates. + + :param filters: retrieve only matching certificates (optional) + :type filters: dict + :return: generator for loaded certificates as dict + :rtype: generator + """ + return self.handler.streamed_request("list-certs", "list-cert", filters) + + def load_conn(self, connection): + """Load a connection definition into the daemon. + + :param connection: connection definition + :type connection: dict + """ + self.handler.request("load-conn", connection) + + def unload_conn(self, name): + """Unload a connection definition. + + :param name: connection definition name + :type name: dict + """ + self.handler.request("unload-conn", name) + + def load_cert(self, certificate): + """Load a certificate into the daemon. + + :param certificate: PEM or DER encoded certificate + :type certificate: dict + """ + self.handler.request("load-cert", certificate) + + def load_key(self, private_key): + """Load a private key into the daemon. + + :param private_key: PEM or DER encoded key + """ + self.handler.request("load-key", private_key) + + def load_shared(self, secret): + """Load a shared IKE PSK, EAP or XAuth secret into the daemon. + + :param secret: shared IKE PSK, EAP or XAuth secret + :type secret: dict + """ + self.handler.request("load-shared", secret) + + def flush_certs(self, filter=None): + """Flush the volatile certificate cache. + + Flush the certificate stored temporarily in the cache. The filter + allows to flush only a certain type of certificates, e.g. CRLs. + + :param filter: flush only certificates of a given type (optional) + :type filter: dict + """ + self.handler.request("flush-certs", filter) + + def clear_creds(self): + """Clear credentials loaded over vici. + + Clear all loaded certificate, private key and shared key credentials. + This affects only credentials loaded over vici, but additionally + flushes the credential cache. + """ + self.handler.request("clear-creds") + + def load_pool(self, pool): + """Load a virtual IP pool. + + Load an in-memory virtual IP and configuration attribute pool. + Existing pools with the same name get updated, if possible. + + :param pool: virtual IP and configuration attribute pool + :type pool: dict + """ + return self.handler.request("load-pool", pool) + + def unload_pool(self, pool_name): + """Unload a virtual IP pool. + + Unload a previously loaded virtual IP and configuration attribute pool. + Unloading fails for pools with leases currently online. + + :param pool_name: pool by name + :type pool_name: dict + """ + self.handler.request("unload-pool", pool_name) + + def get_pools(self, options): + """Retrieve loaded pools. + + :param options: filter by name and/or retrieve leases (optional) + :type options: dict + :return: loaded pools + :rtype: dict + """ + return self.handler.request("get-pools", options) + + def listen(self, event_types): + """Register and listen for the given events. + + :param event_types: event types to register + :type event_types: list + :return: generator for streamed event responses as (event_type, dict) + :rtype: generator + """ + return self.handler.listen(event_types) + + +class SessionHandler(object): + """Handles client command execution requests over vici.""" + + def __init__(self, transport): + self.transport = transport + + def _communicate(self, packet): + """Send packet over transport and parse response. + + :param packet: packet to send + :type packet: :py:class:`vici.protocol.Packet` + :return: parsed packet in a tuple with message type and payload + :rtype: :py:class:`collections.namedtuple` + """ + self.transport.send(packet) + return Packet.parse(self.transport.receive()) + + def _register_unregister(self, event_type, register): + """Register or unregister for the given event. + + :param event_type: event to register + :type event_type: str + :param register: whether to register or unregister + :type register: bool + """ + if register: + packet = Packet.register_event(event_type) + else: + packet = Packet.unregister_event(event_type) + response = self._communicate(packet) + if response.response_type == Packet.EVENT_UNKNOWN: + raise EventUnknownException( + "Unknown event type '{event}'".format(event=event_type) + ) + elif response.response_type != Packet.EVENT_CONFIRM: + raise SessionException( + "Unexpected response type {type}, " + "expected '{confirm}' (EVENT_CONFIRM)".format( + type=response.response_type, + confirm=Packet.EVENT_CONFIRM, + ) + ) + + def request(self, command, message=None): + """Send request with an optional message. + + :param command: command to send + :type command: str + :param message: message (optional) + :type message: str + :return: command result + :rtype: dict + """ + if message is not None: + message = Message.serialize(message) + packet = Packet.request(command, message) + response = self._communicate(packet) + + if response.response_type != Packet.CMD_RESPONSE: + raise SessionException( + "Unexpected response type {type}, " + "expected '{response}' (CMD_RESPONSE)".format( + type=response.response_type, + response=Packet.CMD_RESPONSE + ) + ) + + command_response = Message.deserialize(response.payload) + if "success" in command_response: + if command_response["success"] != b"yes": + raise CommandException( + "Command failed: {errmsg}".format( + errmsg=command_response["errmsg"] + ) + ) + + return command_response + + def streamed_request(self, command, event_stream_type, message=None): + """Send command request and collect and return all emitted events. + + :param command: command to send + :type command: str + :param event_stream_type: event type emitted on command execution + :type event_stream_type: str + :param message: message (optional) + :type message: str + :return: generator for streamed event responses as dict + :rtype: generator + """ + if message is not None: + message = Message.serialize(message) + + self._register_unregister(event_stream_type, True); + + try: + packet = Packet.request(command, message) + self.transport.send(packet) + exited = False + while True: + response = Packet.parse(self.transport.receive()) + if response.response_type == Packet.EVENT: + if not exited: + try: + yield Message.deserialize(response.payload) + except GeneratorExit: + exited = True + pass + else: + break + + if response.response_type == Packet.CMD_RESPONSE: + command_response = Message.deserialize(response.payload) + else: + raise SessionException( + "Unexpected response type {type}, " + "expected '{response}' (CMD_RESPONSE)".format( + type=response.response_type, + response=Packet.CMD_RESPONSE + ) + ) + + finally: + self._register_unregister(event_stream_type, False); + + # evaluate command result, if any + if "success" in command_response: + if command_response["success"] != b"yes": + raise CommandException( + "Command failed: {errmsg}".format( + errmsg=command_response["errmsg"] + ) + ) + + def listen(self, event_types): + """Register and listen for the given events. + + :param event_types: event types to register + :type event_types: list + :return: generator for streamed event responses as (event_type, dict) + :rtype: generator + """ + for event_type in event_types: + self._register_unregister(event_type, True) + + try: + while True: + response = Packet.parse(self.transport.receive()) + if response.response_type == Packet.EVENT: + try: + yield response.event_type, Message.deserialize(response.payload) + except GeneratorExit: + break + + finally: + for event_type in event_types: + self._register_unregister(event_type, False) diff --git a/src/opnsense/scripts/ipsec/vici/test/__init__.py b/src/opnsense/scripts/ipsec/vici/test/__init__.py new file mode 100755 index 00000000000..e69de29bb2d diff --git a/src/opnsense/scripts/ipsec/vici/test/test_protocol.py b/src/opnsense/scripts/ipsec/vici/test/test_protocol.py new file mode 100755 index 00000000000..a1f202d79cb --- /dev/null +++ b/src/opnsense/scripts/ipsec/vici/test/test_protocol.py @@ -0,0 +1,144 @@ +import pytest + +from ..protocol import Packet, Message, FiniteStream +from ..exception import DeserializationException + + +class TestPacket(object): + # test data definitions for outgoing packet types + cmd_request = b"\x00\x0c" b"command_type" + cmd_request_msg = b"\x00\x07" b"command" b"payload" + event_register = b"\x03\x0a" b"event_type" + event_unregister = b"\x04\x0a" b"event_type" + + # test data definitions for incoming packet types + cmd_response = b"\x01" b"reply" + cmd_unknown = b"\x02" + event_confirm = b"\x05" + event_unknown = b"\x06" + event = b"\x07\x03" b"log" b"message" + + def test_request(self): + assert Packet.request("command_type") == self.cmd_request + assert Packet.request("command", b"payload") == self.cmd_request_msg + + def test_register_event(self): + assert Packet.register_event("event_type") == self.event_register + + def test_unregister_event(self): + assert Packet.unregister_event("event_type") == self.event_unregister + + def test_parse(self): + parsed_cmd_response = Packet.parse(self.cmd_response) + assert parsed_cmd_response.response_type == Packet.CMD_RESPONSE + assert parsed_cmd_response.payload.getvalue() == self.cmd_response + + parsed_cmd_unknown = Packet.parse(self.cmd_unknown) + assert parsed_cmd_unknown.response_type == Packet.CMD_UNKNOWN + assert parsed_cmd_unknown.payload.getvalue() == self.cmd_unknown + + parsed_event_confirm = Packet.parse(self.event_confirm) + assert parsed_event_confirm.response_type == Packet.EVENT_CONFIRM + assert parsed_event_confirm.payload.getvalue() == self.event_confirm + + parsed_event_unknown = Packet.parse(self.event_unknown) + assert parsed_event_unknown.response_type == Packet.EVENT_UNKNOWN + assert parsed_event_unknown.payload.getvalue() == self.event_unknown + + parsed_event = Packet.parse(self.event) + assert parsed_event.response_type == Packet.EVENT + assert parsed_event.payload.getvalue() == self.event + + +class TestMessage(object): + """Message (de)serialization test.""" + + # data definitions for test of de(serialization) + # serialized messages holding a section + ser_sec_unclosed = b"\x01\x08unclosed" + ser_sec_single = b"\x01\x07section\x02" + ser_sec_nested = b"\x01\x05outer\x01\x0asubsection\x02\x02" + + # serialized messages holding a list + ser_list_invalid = b"\x04\x07invalid\x05\x00\x02e1\x02\x03sec\x06" + ser_list_0_item = b"\x04\x05empty\x06" + ser_list_1_item = b"\x04\x01l\x05\x00\x02e1\x06" + ser_list_2_item = b"\x04\x01l\x05\x00\x02e1\x05\x00\x02e2\x06" + + # serialized messages with key value pairs + ser_kv_pair = b"\x03\x03key\x00\x05value" + ser_kv_zero = b"\x03\x0azerolength\x00\x00" + + # deserialized messages holding a section + des_sec_single = { "section": {} } + des_sec_nested = { "outer": { "subsection": {} } } + + # deserialized messages holding a list + des_list_0_item = { "empty": [] } + des_list_1_item = { "l": [ b"e1" ] } + des_list_2_item = { "l": [ b"e1", b"e2" ] } + + # deserialized messages with key value pairs + des_kv_pair = { "key": b"value" } + des_kv_zero = { "zerolength": b"" } + + def test_section_serialization(self): + assert Message.serialize(self.des_sec_single) == self.ser_sec_single + assert Message.serialize(self.des_sec_nested) == self.ser_sec_nested + + def test_list_serialization(self): + assert Message.serialize(self.des_list_0_item) == self.ser_list_0_item + assert Message.serialize(self.des_list_1_item) == self.ser_list_1_item + assert Message.serialize(self.des_list_2_item) == self.ser_list_2_item + + def test_key_serialization(self): + assert Message.serialize(self.des_kv_pair) == self.ser_kv_pair + assert Message.serialize(self.des_kv_zero) == self.ser_kv_zero + + def test_section_deserialization(self): + single = Message.deserialize(FiniteStream(self.ser_sec_single)) + nested = Message.deserialize(FiniteStream(self.ser_sec_nested)) + + assert single == self.des_sec_single + assert nested == self.des_sec_nested + + with pytest.raises(DeserializationException): + Message.deserialize(FiniteStream(self.ser_sec_unclosed)) + + def test_list_deserialization(self): + l0 = Message.deserialize(FiniteStream(self.ser_list_0_item)) + l1 = Message.deserialize(FiniteStream(self.ser_list_1_item)) + l2 = Message.deserialize(FiniteStream(self.ser_list_2_item)) + + assert l0 == self.des_list_0_item + assert l1 == self.des_list_1_item + assert l2 == self.des_list_2_item + + with pytest.raises(DeserializationException): + Message.deserialize(FiniteStream(self.ser_list_invalid)) + + def test_key_deserialization(self): + pair = Message.deserialize(FiniteStream(self.ser_kv_pair)) + zerolength = Message.deserialize(FiniteStream(self.ser_kv_zero)) + + assert pair == self.des_kv_pair + assert zerolength == self.des_kv_zero + + def test_roundtrip(self): + message = { + "key1": "value1", + "section1": { + "sub-section": { + "key2": b"value2", + }, + "list1": [ "item1", "item2" ], + }, + } + serialized_message = FiniteStream(Message.serialize(message)) + deserialized_message = Message.deserialize(serialized_message) + + # ensure that list items and key values remain as undecoded bytes + deserialized_section = deserialized_message["section1"] + assert deserialized_message["key1"] == b"value1" + assert deserialized_section["sub-section"]["key2"] == b"value2" + assert deserialized_section["list1"] == [ b"item1", b"item2" ]