diff --git a/data/azure-pipelines.yml b/data/azure-pipelines.yml index 22f85c5392..3eecef090d 100644 --- a/data/azure-pipelines.yml +++ b/data/azure-pipelines.yml @@ -56,9 +56,13 @@ jobs: touch tests/show_gui cp tests/.coveragerc . pytest --junitxml=junit/test-results.xml --cov=src --cov-config=.coveragerc tests + displayName: 'Run pytest with coverage' + condition: eq(variables['python.version'], '3.7') + + - script: | coverage xml coverage html - displayName: 'Run pytest with coverage' + displayName: 'Generate coverage report' condition: eq(variables['python.version'], '3.7') - script: pytest --junitxml=junit/test-results.xml tests diff --git a/data/requirements.txt b/data/requirements.txt index 644ca14247..b366703d3e 100644 --- a/data/requirements.txt +++ b/data/requirements.txt @@ -1,8 +1,8 @@ -numpy; sys_platform != 'win32' -numpy!=1.16.0; sys_platform == 'win32' +numpy>=1.9; sys_platform != 'win32' +numpy>=1.9,!=1.16.0; sys_platform == 'win32' pyqt5; sys_platform != 'win32' and sys_platform != 'linux' pyqt5!=5.11.1,!=5.11.2,!=5.11.3; sys_platform == 'win32' -pyqt5!=5.12,!=5.12.1; sys_platform == 'linux' +pyqt5!=5.12,!=5.12.1,!=5.12.2; sys_platform == 'linux' psutil pyzmq cython diff --git a/src/urh/awre/AutoAssigner.py b/src/urh/awre/AutoAssigner.py new file mode 100644 index 0000000000..6b37d9a1a3 --- /dev/null +++ b/src/urh/awre/AutoAssigner.py @@ -0,0 +1,65 @@ +import numpy as np + +from urh.cythonext import util +from urh.signalprocessing.Message import Message + + +def auto_assign_participants(messages, participants): + """ + + :type messages: list of Message + :type participants: list of Participant + :return: + """ + if len(participants) == 0: + return + + if len(participants) == 1: + for message in messages: # type: Message + message.participant = participants[0] + return + + # Try to assign participants based on SRC_ADDRESS label and participant address + for msg in filter(lambda m: m.participant is None, messages): + src_address = msg.get_src_address_from_data() + if src_address: + try: + msg.participant = next(p for p in participants if p.address_hex == src_address) + except StopIteration: + pass + + # Assign remaining participants based on RSSI of messages + rssis = np.array([msg.rssi for msg in messages], dtype=np.float32) + min_rssi, max_rssi = util.minmax(rssis) + center_spacing = (max_rssi - min_rssi) / (len(participants) - 1) + centers = [min_rssi + i * center_spacing for i in range(0, len(participants))] + rssi_assigned_centers = [] + + for rssi in rssis: + center_index = np.argmin(np.abs(rssi - centers)) + rssi_assigned_centers.append(int(center_index)) + + participants.sort(key=lambda participant: participant.relative_rssi) + for message, center_index in zip(messages, rssi_assigned_centers): + if message.participant is None: + message.participant = participants[center_index] + + +def auto_assign_participant_addresses(messages, participants): + """ + + :type messages: list of Message + :type participants: list of Participant + :return: + """ + participants_without_address = [p for p in participants if not p.address_hex] + + if len(participants_without_address) == 0: + return + + for msg in messages: + if msg.participant in participants_without_address: + src_address = msg.get_src_address_from_data() + if src_address: + participants_without_address.remove(msg.participant) + msg.participant.address_hex = src_address diff --git a/src/urh/awre/CommonRange.py b/src/urh/awre/CommonRange.py index bffe36911e..a7f1c50aba 100644 --- a/src/urh/awre/CommonRange.py +++ b/src/urh/awre/CommonRange.py @@ -1,65 +1,296 @@ -class CommonRange(object): +import copy +import itertools + +import numpy as np + +from urh.util import util +from urh.util.GenericCRC import GenericCRC - __slots__ = ["start", "end", "__bits", "__hex", "messages"] - def __init__(self, start: int, end: int, bits: str): +class CommonRange(object): + def __init__(self, start, length, value: np.ndarray = None, score=0, field_type="Generic", message_indices=None, + range_type="bit", byte_order="big"): """ - :param start: Start of the common range - :param end: End of the common range - :param bits: Value of the common range + :param start: + :param length: + :param value: Value for this common range as string """ self.start = start - self.end = end - self.__bits = bits - self.__hex = ("{0:0"+str(len(self.__bits)//4)+"x}").format(int(self.__bits, 2)) - self.messages = set() - """:type: set of int """ + self.length = length - @property - def bits(self) -> str: - return self.__bits + self.__byte_order = byte_order + self.sync_end = 0 + + if isinstance(value, str): + value = np.array(list(map(lambda x: int(x, 16), value)), dtype=np.uint8) + + self.values = [value] if value is not None else [] + self.score = score + self.field_type = field_type # can also be length, address etc. + + self.range_type = range_type.lower() # one of bit/hex/byte + + self.message_indices = set() if message_indices is None else set(message_indices) + """ + Set of message indices, this range applies to + """ @property - def hex_value(self) -> str: - return self.__hex + def end(self): + return self.start + self.length - 1 @property - def byte_len(self) -> int: - return (self.end - self.start) // 8 + def bit_start(self): + return self.__convert_number(self.start) + self.sync_end - def __len__(self): - return self.end - self.start + @property + def bit_end(self): + return self.__convert_number(self.start) + self.__convert_number(self.length) - 1 + self.sync_end - def __hash__(self): - return hash(self.start) + hash(self.end) + hash(self.bits) + @property + def length_in_bits(self): + return self.bit_end - self.bit_start - 1 - def pos_of_hex(self, hex_str) -> tuple: - try: - start = 4 * self.hex_value.index(hex_str) - return start, start + 4 * len(hex_str) - except ValueError: + @property + def value(self): + if len(self.values) == 0: return None + elif len(self.values) == 1: + return self.values[0] + else: + raise ValueError("This range has multiple values!") - @staticmethod - def from_hex(hex_str): - return CommonRange(start=0, end=0, bits="{0:b}".format(int(hex_str, 16))) + @value.setter + def value(self, val): + if len(self.values) == 0: + self.values = [val] + elif len(self.values) == 1: + self.values[0] = val + else: + raise ValueError("This range has multiple values!") - def __eq__(self, other): - if isinstance(other, CommonRange): - return self.start == other.start and self.end == other.end and self.bits == other.bits + @property + def byte_order(self): + if self.byte_order_is_unknown: + return "big" + return self.__byte_order + + @byte_order.setter + def byte_order(self, val: str): + self.__byte_order = val + + @property + def byte_order_is_unknown(self) -> bool: + return self.__byte_order is None + + def matches(self, start: int, value: np.ndarray): + return self.start == start and \ + self.length == len(value) and \ + self.value.tobytes() == value.tobytes() + + def __convert_number(self, n): + if self.range_type == "bit": + return n + elif self.range_type == "hex": + return n * 4 + elif self.range_type == "byte": + return n * 8 else: + raise ValueError("Unknown range type {}".format(self.range_type)) + + def __repr__(self): + result = "{} {}-{} ({} {})".format(self.field_type, self.bit_start, + self.bit_end, self.length, self.range_type) + + result += " Values: " + " ".join(map(util.convert_numbers_to_hex_string, self.values)) + if self.score is not None: + result += " Score: " + str(self.score) + result += " Message indices: {" + ",".join(map(str, sorted(self.message_indices))) + "}" + return result + + def __eq__(self, other): + if not isinstance(other, CommonRange): return False + return self.bit_start == other.bit_start and \ + self.bit_end == other.bit_end and \ + self.field_type == other.field_type + + def __hash__(self): + return hash((self.start, self.length, self.field_type)) + def __lt__(self, other): - if isinstance(other, CommonRange): - if self.start != other.start: - return self.start < other.start - else: - return self.end <= other.end + return self.bit_start < other.bit_start + + def overlaps_with(self, other) -> bool: + if not isinstance(other, CommonRange): + raise ValueError("Need another bit range to compare") + return any(i in range(self.bit_start, self.bit_end) + for i in range(other.bit_start, other.bit_end)) + + def ensure_not_overlaps(self, start: int, end: int): + """ + + :param start: + :param end: + :rtype: list of CommonRange + """ + if end < self.start or start > self.end: + # Other range is right or left of our range -> no overlapping + return [copy.deepcopy(self)] + + if start <= self.start < end < self.end: + # overlaps on the left + result = copy.deepcopy(self) + result.length -= end - result.start + result.start = end + result.value = result.value[result.start-self.start:(result.start-self.start)+result.length] + return [result] + + if self.start < start <= self.end <= end: + # overlaps on the right + result = copy.deepcopy(self) + result.length -= self.end + 1 - start + result.value = result.value[:result.length] + return [result] + + if self.start < start and self.end > end: + # overlaps in the middle + left = copy.deepcopy(self) + right = copy.deepcopy(self) + + left.length -= (left.end + 1 - start) + left.value = self.value[:left.length] + + right.start = end + 1 + right.length = self.end - end + right.value = self.value[right.start-self.start:(right.start-self.start)+right.length] + return [left, right] + + return [] + + +class ChecksumRange(CommonRange): + def __init__(self, start, length, crc: GenericCRC, data_range_start, data_range_end, value: np.ndarray = None, + score=0, field_type="Generic", message_indices=None, range_type="bit"): + super().__init__(start, length, value, score, field_type, message_indices, range_type) + self.data_range_start = data_range_start + self.data_range_end = data_range_end + self.crc = crc + + @property + def data_range_bit_start(self): + return self.data_range_start + self.sync_end + + @property + def data_range_bit_end(self): + return self.data_range_end + self.sync_end + + def __eq__(self, other): + return super().__eq__(other) \ + and self.data_range_start == other.data_range_start \ + and self.data_range_end == other.data_range_end \ + and self.crc == other.crc + + def __hash__(self): + return hash((self.start, self.length, self.data_range_start, self.data_range_end, self.crc)) + + def __repr__(self): + return super().__repr__() + " \t" + \ + "{}".format(self.crc.caption) + \ + " Datarange: {}-{} ".format(self.data_range_start, self.data_range_end) + + +class EmptyCommonRange(CommonRange): + """ + Empty Common Bit Range, to indicate, that no common Bit Range was found + """ + + def __init__(self, field_type="Generic"): + super().__init__(0, 0, "") + self.field_type = field_type + + def __eq__(self, other): + return isinstance(other, EmptyCommonRange) \ + and other.field_type == self.field_type + + def __repr__(self): + return "No " + self.field_type + + def __hash__(self): + return hash(super) + + +class CommonRangeContainer(object): + """ + This is the raw equivalent of a Message Type: + A container of common ranges + """ + + def __init__(self, ranges: list, message_indices: set = None): + + assert isinstance(ranges, list) + + self.__ranges = ranges # type: list[CommonRange] + self.__ranges.sort() + + if message_indices is None: + self.update_message_indices() + else: + self.message_indices = message_indices + + @property + def ranges_overlap(self) -> bool: + return self.has_overlapping_ranges(self.__ranges) + + def update_message_indices(self): + if len(self) == 0: + self.message_indices = set() else: - return -1 + self.message_indices = set(self[0].message_indices) + for i in range(1, len(self)): + self.message_indices.intersection_update(self[i].message_indices) + + def add_range(self, rng: CommonRange): + self.__ranges.append(rng) + self.__ranges.sort() + + def add_ranges(self, ranges: list): + self.__ranges.extend(ranges) + self.__ranges.sort() + + def has_same_ranges(self, ranges: list) -> bool: + return self.__ranges == ranges + + def has_same_ranges_as_container(self, container): + if not isinstance(container, CommonRangeContainer): + return False + + return self.__ranges == container.__ranges + + @staticmethod + def has_overlapping_ranges(ranges: list) -> bool: + for rng1, rng2 in itertools.combinations(ranges, 2): + if rng1.overlaps_with(rng2): + return True + return False + + def __len__(self): + return len(self.__ranges) + + def __iter__(self): + return self.__ranges.__iter__() + + def __getitem__(self, item): + return self.__ranges[item] def __repr__(self): - return "{}-{} {} ({})".format(self.start, self.end, self.hex_value, self.messages) - + from pprint import pformat + return pformat(self.__ranges) + + def __eq__(self, other): + if not isinstance(other, CommonRangeContainer): + return False + + return self.__ranges == other.__ranges and self.message_indices == other.message_indices diff --git a/src/urh/awre/FormatFinder.py b/src/urh/awre/FormatFinder.py index c78aa5ede8..119ab8352d 100644 --- a/src/urh/awre/FormatFinder.py +++ b/src/urh/awre/FormatFinder.py @@ -1,112 +1,427 @@ +import copy +import math +from collections import defaultdict + import numpy as np -import time +from urh.awre import AutoAssigner +from urh.awre.CommonRange import CommonRange, EmptyCommonRange, CommonRangeContainer, ChecksumRange +from urh.awre.Preprocessor import Preprocessor +from urh.awre.engines.AddressEngine import AddressEngine +from urh.awre.engines.ChecksumEngine import ChecksumEngine +from urh.awre.engines.LengthEngine import LengthEngine +from urh.awre.engines.SequenceNumberEngine import SequenceNumberEngine +from urh.cythonext import awre_util +from urh.signalprocessing.ChecksumLabel import ChecksumLabel from urh.signalprocessing.FieldType import FieldType -from urh.util.Logger import logger +from urh.signalprocessing.Message import Message +from urh.signalprocessing.MessageType import MessageType +from urh.signalprocessing.ProtocoLabel import ProtocolLabel +from urh.util.WSPChecksum import WSPChecksum -from urh.awre.components.Address import Address -from urh.awre.components.Component import Component -from urh.awre.components.Flags import Flags -from urh.awre.components.Length import Length -from urh.awre.components.Preamble import Preamble -from urh.awre.components.SequenceNumber import SequenceNumber -from urh.awre.components.Type import Type -from urh.cythonext import util class FormatFinder(object): - MIN_MESSAGES_PER_CLUSTER = 2 # If there is only one message per cluster it is not very significant + MIN_MESSAGES_PER_CLUSTER = 2 - def __init__(self, protocol, participants=None, field_types=None): + def __init__(self, messages, participants=None, shortest_field_length=None): """ - :type protocol: urh.signalprocessing.ProtocolAnalyzer.ProtocolAnalyzer + :type messages: list of Message :param participants: """ if participants is not None: - protocol.auto_assign_participants(participants) + AutoAssigner.auto_assign_participants(messages, participants) + + existing_message_types_by_msg = {i: msg.message_type for i, msg in enumerate(messages)} + self.existing_message_types = defaultdict(list) + for i, message_type in existing_message_types_by_msg.items(): + self.existing_message_types[message_type].append(i) + + preprocessor = Preprocessor(self.get_bitvectors_from_messages(messages), existing_message_types_by_msg) + self.preamble_starts, self.preamble_lengths, sync_len = preprocessor.preprocess() + self.sync_ends = self.preamble_starts + self.preamble_lengths + sync_len - self.protocol = protocol - self.bitvectors = [np.array(msg.decoded_bits, dtype=np.int8) for msg in self.protocol.messages] - self.len_cluster = self.cluster_lengths() - self.xor_matrix = self.build_xor_matrix() + n = shortest_field_length + if n is None: + # 0 = no sync found + n = 8 if sync_len >= 8 else 4 if sync_len >= 4 else 1 if sync_len >= 1 else 0 + for i, value in enumerate(self.sync_ends): + # In doubt it is better to under estimate the sync end + if n > 0: + self.sync_ends[i] = n * max(int(math.floor((value - self.preamble_starts[i]) / n)), 1) + \ + self.preamble_starts[i] + else: + self.sync_ends[i] = self.preamble_starts[i] - mt = self.protocol.message_types + if self.sync_ends[i] - self.preamble_starts[i] < self.preamble_lengths[i]: + self.preamble_lengths[i] = self.sync_ends[i] - self.preamble_starts[i] - field_types = FieldType.load_from_xml() if field_types is None else field_types + self.bitvectors = self.get_bitvectors_from_messages(messages, self.sync_ends) + self.hexvectors = self.get_hexvectors(self.bitvectors) + self.current_iteration = 0 - self.preamble_component = Preamble(fieldtypes=field_types, priority=0, messagetypes=mt) - self.length_component = Length(fieldtypes=field_types, length_cluster=self.len_cluster, priority=1, - predecessors=[self.preamble_component], messagetypes=mt) - self.address_component = Address(fieldtypes=field_types, xor_matrix=self.xor_matrix, priority=2, - predecessors=[self.preamble_component], messagetypes=mt) - self.sequence_number_component = SequenceNumber(fieldtypes=field_types, priority=3, - predecessors=[self.preamble_component]) - self.type_component = Type(priority=4, predecessors=[self.preamble_component]) - self.flags_component = Flags(priority=5, predecessors=[self.preamble_component]) + participants = list(sorted(set(msg.participant for msg in messages if msg.participant is not None))) + self.participant_indices = [participants.index(msg.participant) if msg.participant is not None else -1 + for msg in messages] + self.known_participant_addresses = { + participants.index(p): np.array([int(h, 16) for h in p.address_hex], dtype=np.uint8) + for p in participants if p and p.address_hex + } + + @property + def message_types(self): + """ - def build_component_order(self): + :rtype: list of MessageType """ - Build the order of component based on their priority and predecessors + return sorted(self.existing_message_types.keys(), key=lambda x: x.name) - :rtype: list of Component + def perform_iteration_for_message_type(self, message_type: MessageType): """ - present_components = [item for item in self.__dict__.values() if isinstance(item, Component) and item.enabled] - result = [None] * len(present_components) - used_prios = set() - for component in present_components: - index = component.priority % len(present_components) - if index in used_prios: - raise ValueError("Duplicate priority: {}".format(component.priority)) - used_prios.add(index) + Perform a field inference iteration for messages of the given message type + This routine will return newly found fields as a set of Common Ranges + + :param message_type: + :rtype: set of CommonRange + """ + indices = self.existing_message_types[message_type] + engines = [] + + # We can take an arbitrary sync end to correct the already labeled fields for this message type, + # because if the existing labels would have different sync positions, + # they would not belong to the same message type in the first place + sync_end = self.sync_ends[indices[0]] if indices else 0 + already_labeled = [(lbl.start - sync_end, lbl.end - sync_end) for lbl in message_type if lbl.start >= sync_end] + + if not message_type.get_first_label_with_type(FieldType.Function.LENGTH): + engines.append(LengthEngine([self.bitvectors[i] for i in indices], already_labeled=already_labeled)) - result[index] = component + if not message_type.get_first_label_with_type(FieldType.Function.SRC_ADDRESS): + engines.append(AddressEngine([self.hexvectors[i] for i in indices], + [self.participant_indices[i] for i in indices], + self.known_participant_addresses, + already_labeled=already_labeled)) + elif not message_type.get_first_label_with_type(FieldType.Function.DST_ADDRESS): + engines.append(AddressEngine([self.hexvectors[i] for i in indices], + [self.participant_indices[i] for i in indices], + self.known_participant_addresses, + already_labeled=already_labeled, + src_field_present=True)) - # Check if predecessors are valid - for i, component in enumerate(result): - if any(i < result.index(pre) for pre in component.predecessors): - raise ValueError("Component {} comes before at least one of its predecessors".format(component)) + if not message_type.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER): + engines.append(SequenceNumberEngine([self.bitvectors[i] for i in indices], already_labeled=already_labeled)) + if not message_type.get_first_label_with_type(FieldType.Function.CHECKSUM): + # If checksum was not found in first iteration, it will also not be found in next one + if self.current_iteration == 0: + engines.append(ChecksumEngine([self.bitvectors[i] for i in indices], already_labeled=already_labeled)) + result = set() + for engine in engines: + high_scored_ranges = engine.find() # type: list[CommonRange] + high_scored_ranges = self.retransform_message_indices(high_scored_ranges, indices, self.sync_ends) + merged_ranges = self.merge_common_ranges(high_scored_ranges) + result.update(merged_ranges) return result - def perform_iteration(self): - for component in self.build_component_order(): - # OPEN: Create new message types e.g. for addresses - component.find_field(self.protocol.messages) + def perform_iteration(self) -> bool: + new_field_found = False + + for message_type in self.existing_message_types.copy(): + new_fields_for_message_type = self.perform_iteration_for_message_type(message_type) + new_fields_for_message_type.update( + self.get_preamble_and_sync(self.preamble_starts, self.preamble_lengths, self.sync_ends, + message_type_indices=self.existing_message_types[message_type]) + ) + + self.remove_overlapping_fields(new_fields_for_message_type, message_type) + containers = self.create_common_range_containers(new_fields_for_message_type) + + # Store addresses of participants if we found a SRC address field + participants_with_unknown_address = set(self.participant_indices) - set(self.known_participant_addresses) + participants_with_unknown_address.discard(-1) + + if participants_with_unknown_address: + for container in containers: + src_range = next((rng for rng in container if rng.field_type == "source address"), None) + if src_range is None: + continue + for msg_index in src_range.message_indices: + if len(participants_with_unknown_address) == 0: + break + p = self.participant_indices[msg_index] + if p not in self.known_participant_addresses: + hex_vector = self.hexvectors[msg_index] + self.known_participant_addresses[p] = hex_vector[src_range.start:src_range.end + 1] + participants_with_unknown_address.discard(p) + + new_field_found |= len(containers) > 0 + + if len(containers) == 1: + for rng in containers[0]: + self.add_range_to_message_type(rng, message_type) + elif len(containers) > 1: + del self.existing_message_types[message_type] + + for i, container in enumerate(containers): + new_message_type = copy.deepcopy(message_type) # type: MessageType + + if i > 0: + new_message_type.name = "Inferred #{}".format(i) + new_message_type.give_new_id() + + for rng in container: + self.add_range_to_message_type(rng, new_message_type) - def cluster_lengths(self): + self.existing_message_types[new_message_type].extend(sorted(container.message_indices)) + + return new_field_found + + def run(self, max_iterations=10): + self.current_iteration = 0 + while self.perform_iteration() and self.current_iteration < max_iterations: + self.current_iteration += 1 + + @staticmethod + def remove_overlapping_fields(common_ranges, message_type: MessageType): + """ + Remove all fields from a set of CommonRanges which overlap with fields of the existing message type + + :type common_ranges: set of CommonRange + :param message_type: + :return: """ - This method clusters some bitvectors based on their length. An example output is + if len(message_type) == 0: + return - 2: [0.5, 1] - 4: [1, 0.75, 1, 1] + for rng in common_ranges.copy(): + for lbl in message_type: # type: ProtocolLabel + if any(i in range(rng.bit_start, rng.bit_end) for i in range(lbl.start, lbl.end)): + common_ranges.discard(rng) + break + + @staticmethod + def merge_common_ranges(common_ranges): + """ + Merge common ranges if possible + + :type common_ranges: list of CommonRange + :rtype: list of CommonRange + """ + merged_ranges = [] + for common_range in common_ranges: + assert isinstance(common_range, CommonRange) + try: + same_range = next(rng for rng in merged_ranges + if rng.bit_start == common_range.bit_start + and rng.bit_end == common_range.bit_end + and rng.field_type == common_range.field_type) + same_range.values.extend(common_range.values) + same_range.message_indices.update(common_range.message_indices) + except StopIteration: + merged_ranges.append(common_range) - Meaning there were two message lengths: 2 and 4 bit. - (0.5, 1) means, the first bit was equal in 50% of cases (meaning maximum difference) and bit 2 was equal in all messages + return merged_ranges - A simple XOR would not work as it would be error prone. + @staticmethod + def add_range_to_message_type(common_range: CommonRange, message_type: MessageType): + field_type = FieldType.from_caption(common_range.field_type) + label = message_type.add_protocol_label(name=common_range.field_type, + start=common_range.bit_start, end=common_range.bit_end, + auto_created=True, + type=field_type + ) + label.display_endianness = common_range.byte_order + + if field_type.function == FieldType.Function.CHECKSUM: + assert isinstance(label, ChecksumLabel) + assert isinstance(common_range, ChecksumRange) + label.data_ranges = [(common_range.data_range_bit_start, common_range.data_range_bit_end)] + + if isinstance(common_range.crc, WSPChecksum): + label.category = ChecksumLabel.Category.wsp + else: + label.checksum = copy.copy(common_range.crc) + + @staticmethod + def get_hexvectors(bitvectors: list): + result = awre_util.get_hexvectors(bitvectors) + return result + + @staticmethod + def get_bitvectors_from_messages(messages: list, sync_ends: np.ndarray = None): + if sync_ends is None: + sync_ends = defaultdict(lambda: None) + + return [np.array(msg.decoded_bits[sync_ends[i]:], dtype=np.uint8, order="C") for i, msg in enumerate(messages)] + + @staticmethod + def create_common_range_containers(label_set: set, num_messages: int = None): + """ + Create message types from set of labels. + Handle overlapping conflicts and create multiple message types if needed - :rtype: dict[int, tuple[np.ndarray, int]] + :param label_set: + :param num_messages: + :return: + :rtype: list of CommonRangeContainer """ + if num_messages is None: + message_indices = sorted(set(i for rng in label_set for i in rng.message_indices)) + else: + message_indices = range(num_messages) - number_ones = dict() # dict of tuple. 0 = number ones vector, 1 = number of blocks for this vector - for vector in self.bitvectors: - vec_len = 4 * (len(vector) // 4) - if vec_len == 0: + result = [] + for i in message_indices: + labels = sorted(set(rng for rng in label_set if i in rng.message_indices + and not isinstance(rng, EmptyCommonRange))) + + container = next((container for container in result if container.has_same_ranges(labels)), None) + if container is None: + result.append(CommonRangeContainer(labels, message_indices={i})) + else: + container.message_indices.add(i) + + result = FormatFinder.handle_overlapping_conflict(result) + + return result + + @staticmethod + def handle_overlapping_conflict(containers): + """ + Handle overlapping conflicts for a list of CommonRangeContainers + + :type containers: list of CommonRangeContainer + :return: + """ + result = [] + for container in containers: + if container.ranges_overlap: + conflicted_handled = FormatFinder.__handle_container_overlapping_conflict(container) + else: + conflicted_handled = container + + try: + same_rng_container = next(c for c in result if c.has_same_ranges_as_container(conflicted_handled)) + same_rng_container.message_indices.update(conflicted_handled.message_indices) + except StopIteration: + result.append(conflicted_handled) + + return result + + @staticmethod + def __handle_container_overlapping_conflict(container: CommonRangeContainer): + """ + Handle overlapping conflict for a CommRangeContainer. + We can assert that all labels in the container share the same message indices + because we partitioned them in a step before. + If two or more labels overlap we have three ways to resolve the conflict: + + 1. Choose the range with the highest score + 2. If multiple ranges overlap choose the ranges that maximize the overall (cumulated) score + 3. If the overlapping is very small i.e. only 1 or 2 bits we can adjust the start/end of the conflicting ranges + + The ranges inside the container _must_ be sorted i.e. the range with lowest start must be at front + + :param container: + :return: + """ + partitions = [] # type: list[list[CommonRange]] + # partition the container into overlapping partitions + # results in something like [[A], [B,C], [D], [E,F,G]]] where B and C and E, F, G are overlapping + for cur_rng in container: + if len(partitions) == 0: + partitions.append([cur_rng]) continue - if vec_len not in number_ones: - number_ones[vec_len] = [np.zeros(vec_len, dtype=int), 0] + last_rng = partitions[-1][-1] # type: CommonRange + if cur_rng.overlaps_with(last_rng): + partitions[-1].append(cur_rng) + else: + partitions.append([cur_rng]) - number_ones[vec_len][0] += vector[0:vec_len] - number_ones[vec_len][1] += 1 + # Todo: Adjust start/end of conflicting ranges if overlapping is very small (i.e. 1 or 2 bits) - # Calculate the relative numbers and normalize the equalness so e.g. 0.3 becomes 0.7 - return {vl: (np.vectorize(lambda x: x if x >= 0.5 else 1 - x)(number_ones[vl][0] / number_ones[vl][1])) - for vl in number_ones if number_ones[vl][1] >= self.MIN_MESSAGES_PER_CLUSTER} + result = [] + # Go through these partitions and handle overlapping conflicts + for partition in partitions: + possible_solutions = [] + for i, rng in enumerate(partition): + # Append every range to this solution that does not overlap with current rng + solution = [rng] + [r for r in partition[i + 1:] if not rng.overlaps_with(r)] + possible_solutions.append(solution) - def build_xor_matrix(self): - t = time.time() - xor_matrix = util.build_xor_matrix(self.bitvectors) - logger.debug("XOR matrix: {}s".format(time.time()-t)) - return xor_matrix + # Take solution that maximizes score. In case of tie, choose solution with shorter total length. + # if there is still a tie prefer solution that contains a length field as is is very likely to be correct + # if nothing else helps break tie by names of field types to prevent randomness + best_solution = max(possible_solutions, + key=lambda sol: (sum(r.score for r in sol), + -sum(r.length_in_bits for r in sol), + "length" in {r.field_type for r in sol}, + "".join(r.field_type[0] for r in sol))) + result.extend(best_solution) + + return CommonRangeContainer(result, message_indices=container.message_indices) + + @staticmethod + def retransform_message_indices(common_ranges, message_type_indices: list, sync_ends) -> list: + """ + Retransform the found message indices of an engine to the original index space + based on the message indices of the message type. + + Furthermore, set the sync_end of the common ranges so bit_start and bit_end + match the position in the original space + + :type common_ranges: list of CommonRange + :param message_type_indices: Messages belonging to the message type the engine ran for + :type sync_ends: np.ndarray + :return: + """ + result = [] + for common_range in common_ranges: + # Retransform message indices into original space + message_indices = np.fromiter((message_type_indices[i] for i in common_range.message_indices), + dtype=int, count=len(common_range.message_indices)) + + # If we have different sync_ends we need to create a new common range for each different sync_length + matching_sync_ends = sync_ends[message_indices] + for sync_end in np.unique(matching_sync_ends): + rng = copy.deepcopy(common_range) + rng.sync_end = sync_end + rng.message_indices = set(message_indices[np.nonzero(matching_sync_ends == sync_end)]) + result.append(rng) + + return result + + @staticmethod + def get_preamble_and_sync(preamble_starts, preamble_lengths, sync_ends, message_type_indices): + """ + Get preamble and sync common ranges based on the data + + :type preamble_starts: np.ndarray + :type preamble_lengths: np.ndarray + :type sync_ends: np.ndarray + :type message_type_indices: list + :rtype: set of CommonRange + """ + assert len(preamble_starts) == len(preamble_lengths) == len(sync_ends) + + result = set() # type: set[CommonRange] + for i in message_type_indices: + preamble = CommonRange(preamble_starts[i], preamble_lengths[i], field_type="preamble", message_indices={i}) + existing_preamble = next((rng for rng in result if preamble == rng), None) + if existing_preamble is not None: + existing_preamble.message_indices.add(i) + elif preamble_lengths[i] > 0: + result.add(preamble) + + preamble_end = preamble_starts[i] + preamble_lengths[i] + sync_end = sync_ends[i] + sync = CommonRange(preamble_end, sync_end - preamble_end, field_type="synchronization", message_indices={i}) + existing_sync = next((rng for rng in result if sync == rng), None) + if existing_sync is not None: + existing_sync.message_indices.add(i) + elif sync_end - preamble_end > 0: + result.add(sync) + + return result diff --git a/src/urh/awre/Histogram.py b/src/urh/awre/Histogram.py new file mode 100644 index 0000000000..b863de6cea --- /dev/null +++ b/src/urh/awre/Histogram.py @@ -0,0 +1,116 @@ +from collections import defaultdict + +import numpy as np + +from urh.awre.CommonRange import CommonRange +from urh.cythonext import awre_util + + +class Histogram(object): + """ + Create a histogram based on the equalness of vectors + """ + + def __init__(self, vectors, indices=None, normalize=True, debug=False): + """ + + :type vectors: list of np.ndarray + :param indices: Indices of vectors for which the Histogram shall be created. + This is useful for clustering. + If None Histogram will be created over all bitvectors + :type: list of int + :param normalize: + """ + self.__vectors = vectors # type: list[np.ndarray] + self.__active_indices = list(range(len(vectors))) if indices is None else indices + + self.normalize = normalize + self.data = self.__create_histogram() + + def __create_histogram(self): + return awre_util.create_difference_histogram(self.__vectors, self.__active_indices) + + def __repr__(self): + return str(self.data.tolist()) + + def find_common_ranges(self, alpha=0.95, range_type="bit"): + """ + Find all common ranges where at least alpha percent of numbers are equal + + :param range_type: on of bit/hex/byte + :param alpha: + :return: + """ + data_indices = np.argwhere(self.data >= alpha).flatten() + + if len(data_indices) < 2: + return [] + + result = [] + start, length = None, 0 + for i in range(1, len(data_indices)): + if start is None: + start = data_indices[i - 1] + length = 1 + + if data_indices[i] - data_indices[i - 1] == 1: + length += 1 + else: + if length >= 2: + value = self.__get_value_for_common_range(start, length) + result.append(CommonRange(start, length, value, message_indices=set(self.__active_indices), + range_type=range_type)) + + start, length = None, 0 + + if i == len(data_indices) - 1 and length >= 2: + value = self.__get_value_for_common_range(start, length) + result.append(CommonRange(start, length, value, message_indices=set(self.__active_indices), + range_type=range_type)) + + return result + + def __get_value_for_common_range(self, start: int, length: int): + """ + Get the value for a range of common numbers. This is the value that appears most. + + :param start: Start of the common bit range + :param length: Length of the common bit range + :return: + """ + values = defaultdict(list) + for i in self.__active_indices: + vector = self.__vectors[i] + values[vector[start:start + length].tostring()].append(i) + value = max(values, key=lambda x: len(x)) + indices = values[value] + return self.__vectors[indices[0]][start:start + length] + + def __vector_to_string(self, data_vector) -> str: + lut = {i: "{0:x}".format(i) for i in range(16)} + return "".join(lut[x] if x in lut else " {} ".format(x) for x in data_vector) + + def plot(self): + import matplotlib.pyplot as plt + self.subplot_on(plt) + plt.show() + + def subplot_on(self, plt): + plt.grid() + plt.plot(self.data) + plt.xticks(np.arange(4, len(self.data), 4)) + plt.xlabel("Bit position") + if self.normalize: + plt.ylabel("Number common bits (normalized)") + else: + plt.ylabel("Number common bits") + plt.ylim(ymin=0) + + +if __name__ == "__main__": + bv1 = np.array([1, 0, 1, 0, 1, 1, 1, 1], dtype=np.int8) + bv2 = np.array([1, 0, 1, 0, 1, 0, 0, 0], dtype=np.int8) + bv3 = np.array([1, 0, 1, 0, 1, 1, 1, 1], dtype=np.int8) + bv4 = np.array([1, 0, 1, 0, 0, 0, 0, 0], dtype=np.int8) + h = Histogram([bv1, bv2, bv3, bv4]) + h.plot() diff --git a/src/urh/awre/MessageTypeBuilder.py b/src/urh/awre/MessageTypeBuilder.py new file mode 100644 index 0000000000..a35a02a577 --- /dev/null +++ b/src/urh/awre/MessageTypeBuilder.py @@ -0,0 +1,55 @@ +from urh.signalprocessing.ChecksumLabel import ChecksumLabel + +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.MessageType import MessageType +from urh.signalprocessing.ProtocoLabel import ProtocolLabel + + +class MessageTypeBuilder(object): + def __init__(self, name: str): + self.name = name + self.message_type = MessageType(name) + + def add_label(self, label_type: FieldType.Function, length: int, name: str=None): + try: + start = self.message_type[-1].end + color_index = self.message_type[-1].color_index + 1 + except IndexError: + start, color_index = 0, 0 + + if name is None: + name = label_type.value + + lbl = ProtocolLabel(name, start, start+length-1, color_index, field_type=FieldType(label_type.name, label_type)) + self.message_type.append(lbl) + + def add_checksum_label(self, length, checksum, data_start=None, data_end=None, name: str=None): + label_type = FieldType.Function.CHECKSUM + try: + start = self.message_type[-1].end + color_index = self.message_type[-1].color_index + 1 + except IndexError: + start, color_index = 0, 0 + + if name is None: + name = label_type.value + + if data_start is None: + # End of sync or preamble + sync_label = self.message_type.get_first_label_with_type(FieldType.Function.SYNC) + if sync_label: + data_start = sync_label.end + else: + preamble_label = self.message_type.get_first_label_with_type(FieldType.Function.PREAMBLE) + if preamble_label: + data_start = preamble_label.end + else: + data_start = 0 + + if data_end is None: + data_end = start + + lbl = ChecksumLabel(name, start, start+length-1, color_index, field_type=FieldType(label_type.name, label_type)) + lbl.data_ranges = [(data_start, data_end)] + lbl.checksum = checksum + self.message_type.append(lbl) diff --git a/src/urh/awre/Preprocessor.py b/src/urh/awre/Preprocessor.py new file mode 100644 index 0000000000..8503b19ee2 --- /dev/null +++ b/src/urh/awre/Preprocessor.py @@ -0,0 +1,271 @@ +import itertools +import math +import os +import time +from collections import defaultdict + +import numpy as np + +from urh.cythonext import awre_util +from urh.signalprocessing.FieldType import FieldType + + +class Preprocessor(object): + """ + This class preprocesses the messages in the following ways + 1) Identify preamble / length of preamble + 2) Identify sync word(s) + 3) Align all given messages on the identified preamble information + """ + + _DEBUG_ = False + + def __init__(self, bitvectors: list, existing_message_types: dict = None): + self.bitvectors = bitvectors # type: list[np.ndarray] + self.existing_message_types = existing_message_types if existing_message_types is not None else dict() + + def preprocess(self) -> (np.ndarray, int): + raw_preamble_positions = self.get_raw_preamble_positions() + existing_sync_words = self.__get_existing_sync_words() + if len(existing_sync_words) == 0: + sync_words = self.find_possible_syncs(raw_preamble_positions) + else: + # NOTE: This does not cover the case if protocol has multiple sync words and not all of them were labeled + sync_words = existing_sync_words + + preamble_starts = raw_preamble_positions[:, 0] + preamble_lengths = self.get_preamble_lengths_from_sync_words(sync_words, preamble_starts=preamble_starts) + sync_len = len(sync_words[0]) if len(sync_words) > 0 else 0 + return preamble_starts, preamble_lengths, sync_len + + def get_preamble_lengths_from_sync_words(self, sync_words: list, preamble_starts: np.ndarray): + """ + Get the preamble lengths based on the found sync words for all messages. + If there should be more than one sync word in a message, use the first one. + + :param sync_words: + :param preamble_starts: + :return: + """ + # If there should be varying sync word lengths we need to return an array of sync lengths per message + assert all(len(sync_word) == len(sync_words[0]) for sync_word in sync_words) + + byte_sync_words = [bytes(map(int, sync_word)) for sync_word in sync_words] + + result = np.zeros(len(self.bitvectors), dtype=np.uint32) + + for i, bitvector in enumerate(self.bitvectors): + preamble_lengths = [] + bits = bitvector.tobytes() + + for sync_word in byte_sync_words: + sync_start = bits.find(sync_word) + if sync_start != -1: + if sync_start - preamble_starts[i] >= 2: + preamble_lengths.append(sync_start - preamble_starts[i]) + + # Consider case where sync word starts with preamble pattern + sync_start = bits.find(sync_word, sync_start + 1, sync_start + 2 * len(sync_word)) + + if sync_start != -1: + if sync_start - preamble_starts[i] >= 2: + preamble_lengths.append(sync_start - preamble_starts[i]) + + preamble_lengths.sort() + + if len(preamble_lengths) == 0: + result[i] = 0 + elif len(preamble_lengths) == 1: + result[i] = preamble_lengths[0] + else: + # consider all indices not more than one byte before first one + preamble_lengths = list(filter(lambda x: x < preamble_lengths[0] + 7, preamble_lengths)) + + # take the smallest preamble_length, but prefer a greater one if it is divisible by 8 (or 4) + preamble_length = next((pl for pl in preamble_lengths if pl % 8 == 0), None) + if preamble_length is None: + preamble_length = next((pl for pl in preamble_lengths if pl % 4 == 0), None) + if preamble_length is None: + preamble_length = preamble_lengths[0] + result[i] = preamble_length + + return result + + def find_possible_syncs(self, raw_preamble_positions=None): + difference_matrix = self.get_difference_matrix() + if raw_preamble_positions is None: + raw_preamble_positions = self.get_raw_preamble_positions() + return self.determine_sync_candidates(raw_preamble_positions, difference_matrix, n_gram_length=4) + + @staticmethod + def merge_possible_sync_words(possible_sync_words: dict, n_gram_length: int): + """ + Merge possible sync words by looking for common prefixes + + :param possible_sync_words: dict of possible sync words and their frequencies + :return: + """ + result = defaultdict(int) + if len(possible_sync_words) < 2: + return possible_sync_words.copy() + + for sync1, sync2 in itertools.combinations(possible_sync_words, 2): + common_prefix = os.path.commonprefix([sync1, sync2]) + if len(common_prefix) > n_gram_length: + result[common_prefix] += possible_sync_words[sync1] + possible_sync_words[sync2] + else: + result[sync1] += possible_sync_words[sync1] + result[sync2] += possible_sync_words[sync2] + return result + + def determine_sync_candidates(self, + raw_preamble_positions: np.ndarray, + difference_matrix: np.ndarray, + n_gram_length=4) -> list: + + possible_sync_words = awre_util.find_possible_sync_words(difference_matrix, raw_preamble_positions, + self.bitvectors, n_gram_length) + + self.__debug("Possible sync words", possible_sync_words) + if len(possible_sync_words) == 0: + return [] + + possible_sync_words = self.merge_possible_sync_words(possible_sync_words, n_gram_length) + self.__debug("Merged sync words", possible_sync_words) + + scores = self.__score_sync_lengths(possible_sync_words) + + sorted_scores = sorted(scores, reverse=True, key=scores.get) + estimated_sync_length = sorted_scores[0] + if estimated_sync_length % 8 != 0: + for other in filter(lambda x: 0 < estimated_sync_length-x < 7, sorted_scores): + if other % 8 == 0: + estimated_sync_length = other + break + + # Now we look at all possible sync words with this length + sync_words = {word: frequency for word, frequency in possible_sync_words.items() + if len(word) == estimated_sync_length} + self.__debug("Sync words", sync_words) + + additional_syncs = self.__find_additional_sync_words(estimated_sync_length, sync_words, possible_sync_words) + + if additional_syncs: + self.__debug("Found addtional sync words", additional_syncs) + sync_words.update(additional_syncs) + + result = [] + for sync_word in sorted(sync_words, key=sync_words.get, reverse=True): + # Convert bytes back to string + result.append("".join(str(c) for c in sync_word)) + + return result + + def __find_additional_sync_words(self, sync_length: int, present_sync_words, possible_sync_words) -> dict: + """ + Look for additional sync words, in case we had varying preamble lengths and multiple sync words + (see test_with_three_syncs_different_preamble_lengths for an example) + + :param sync_length: + :type present_sync_words: dict + :type possible_sync_words: dict + :return: + """ + np_syn = [np.fromiter(map(int, sync_word), dtype=np.uint8, count=len(sync_word)) + for sync_word in present_sync_words] + + messages_without_sync = [i for i, bv in enumerate(self.bitvectors) + if not any(awre_util.find_occurrences(bv, s, return_after_first=True) for s in np_syn)] + + result = dict() + if len(messages_without_sync) == 0: + return result + + # Is there another sync word that applies to all messages without sync? + additional_candidates = {word: score for word, score in possible_sync_words.items() + if len(word) > sync_length and not any(s in word for s in present_sync_words)} + + for sync in sorted(additional_candidates, key=additional_candidates.get, reverse=True): + if len(messages_without_sync) == 0: + break + + score = additional_candidates[sync] + s = sync[:sync_length] + np_s = np.fromiter(s, dtype=np.uint8, count=len(s)) + matching = [i for i in messages_without_sync + if awre_util.find_occurrences(self.bitvectors[i], np_s, return_after_first=True)] + if matching: + result[s] = score + for m in matching: + messages_without_sync.remove(m) + + return result + + def get_raw_preamble_positions(self) -> np.ndarray: + """ + Return a 2D numpy array where first column is the start of preamble + second and third columns are lower and upper bound for preamble length by message, respectively + """ + result = np.zeros((len(self.bitvectors), 3), dtype=np.uint32) + + for i, bitvector in enumerate(self.bitvectors): + if i in self.existing_message_types: + preamble_label = self.existing_message_types[i].get_first_label_with_type(FieldType.Function.PREAMBLE) + else: + preamble_label = None + + if preamble_label is None: + start, lower, upper = awre_util.get_raw_preamble_position(bitvector) + else: + # If this message is already labeled with a preamble we just use it's values + start, lower, upper = preamble_label.start, preamble_label.end, preamble_label.end + + result[i, 0] = start + result[i, 1] = lower - start + result[i, 2] = upper - start + + return result + + def get_difference_matrix(self) -> np.ndarray: + """ + Return a matrix of the first difference index between all messages + :return: + """ + return awre_util.get_difference_matrix(self.bitvectors) + + def __score_sync_lengths(self, possible_sync_words: dict): + sync_lengths = defaultdict(int) + for sync_word, score in possible_sync_words.items(): + sync_lengths[len(sync_word)] += score + + self.__debug("Sync lengths", sync_lengths) + + return sync_lengths + + def __get_existing_sync_words(self) -> list: + result = [] + for i, bitvector in enumerate(self.bitvectors): + if i in self.existing_message_types: + sync_label = self.existing_message_types[i].get_first_label_with_type(FieldType.Function.SYNC) + else: + sync_label = None + + if sync_label is not None: + result.append("".join(map(str, bitvector[sync_label.start:sync_label.end]))) + return result + + def __debug(self, *args): + if self._DEBUG_: + print("[PREPROCESSOR]", *args) + + @staticmethod + def get_next_multiple_of_n(number: int, n: int): + return n * int(math.ceil(number / n)) + + @staticmethod + def lower_multiple_of_n(number: int, n: int): + return n * int(math.floor(number / n)) + + @staticmethod + def get_next_lower_multiple_of_two(number: int): + return number if number % 2 == 0 else number - 1 diff --git a/src/urh/awre/ProtocolGenerator.py b/src/urh/awre/ProtocolGenerator.py new file mode 100644 index 0000000000..b17003daa2 --- /dev/null +++ b/src/urh/awre/ProtocolGenerator.py @@ -0,0 +1,260 @@ +import math +import struct +from array import array +from collections import defaultdict + +from urh.util import util + +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.signalprocessing.ChecksumLabel import ChecksumLabel +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Message import Message +from urh.signalprocessing.MessageType import MessageType +from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocoLabel import ProtocolLabel +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer + + +class ProtocolGenerator(object): + DEFAULT_PREAMBLE = "10101010" + DEFAULT_SYNC = "1001" + BROADCAST_ADDRESS = "0xffff" + + def __init__(self, message_types: list, participants: list = None, preambles_by_mt=None, + syncs_by_mt=None, little_endian=False, length_in_bytes=True, sequence_numbers=None, + sequence_number_increment=1, message_type_codes=None): + """ + + :param message_types: + :param participants: + :param preambles_by_mt: + :param syncs_by_mt: + :param byte_order: + :param length_in_bytes: If false length will be given in bit + """ + self.participants = participants if participants is not None else [] + + self.protocol = ProtocolAnalyzer(None) + self.protocol.message_types = message_types + + self.length_in_bytes = length_in_bytes + self.little_endian = little_endian + + preambles_by_mt = dict() if preambles_by_mt is None else preambles_by_mt + + self.preambles_by_message_type = defaultdict(lambda: self.DEFAULT_PREAMBLE) + for mt, preamble in preambles_by_mt.items(): + self.preambles_by_message_type[mt] = self.to_bits(preamble) + + syncs_by_mt = dict() if syncs_by_mt is None else syncs_by_mt + + self.syncs_by_message_type = defaultdict(lambda: self.DEFAULT_SYNC) + for mt, sync in syncs_by_mt.items(): + self.syncs_by_message_type[mt] = self.to_bits(sync) + + sequence_numbers = dict() if sequence_numbers is None else sequence_numbers + self.sequence_numbers = defaultdict(lambda: 0) + self.sequence_number_increment = sequence_number_increment + + for mt, seq in sequence_numbers.items(): + self.sequence_numbers[mt] = seq + + if message_type_codes is None: + message_type_codes = dict() + for i, mt in enumerate(self.message_types): + message_type_codes[mt] = i + self.message_type_codes = message_type_codes + + + @property + def messages(self): + return self.protocol.messages + + @property + def message_types(self): + return self.protocol.message_types + + def __get_address_for_participant(self, participant: Participant): + if participant is None: + return self.to_bits(self.BROADCAST_ADDRESS) + + address = "0x" + participant.address_hex if not participant.address_hex.startswith( + "0x") else participant.address_hex + return self.to_bits(address) + + @staticmethod + def to_bits(bit_or_hex_str: str): + if bit_or_hex_str.startswith("0x"): + lut = {"{0:x}".format(i): "{0:04b}".format(i) for i in range(16)} + return "".join(lut[c] for c in bit_or_hex_str[2:]) + else: + return bit_or_hex_str + + def decimal_to_bits(self, number: int, num_bits: int) -> str: + len_formats = {8: "B", 16: "H", 32: "I", 64: "Q"} + if num_bits not in len_formats: + raise ValueError("Invalid length for length field: {} bits".format(num_bits)) + + struct_format = "<" if self.little_endian else ">" + struct_format += len_formats[num_bits] + + byte_length = struct.pack(struct_format, number) + return "".join("{0:08b}".format(byte) for byte in byte_length) + + def generate_message(self, message_type=None, data="0x00", source: Participant = None, + destination: Participant = None): + for participant in (source, destination): + if isinstance(participant, Participant) and participant not in self.participants: + self.participants.append(participant) + + if isinstance(message_type, MessageType): + message_type_index = self.protocol.message_types.index(message_type) + elif isinstance(message_type, int): + message_type_index = message_type + else: + message_type_index = 0 + + data = self.to_bits(data) + + mt = self.protocol.message_types[message_type_index] # type: MessageType + mt.sort() + + bits = [] + + start = 0 + + data_label_present = mt.get_first_label_with_type(FieldType.Function.DATA) is not None + + if data_label_present: + message_length = mt[-1].end - 1 + else: + message_length = mt[-1].end - 1 + len(data) + + checksum_labels = [] + + for lbl in mt: # type: ProtocolLabel + bits.append("0" * (lbl.start - start)) + len_field = lbl.end - lbl.start # in bits + + if isinstance(lbl, ChecksumLabel): + checksum_labels.append(lbl) + continue # processed last + + if lbl.field_type.function == FieldType.Function.PREAMBLE: + preamble = self.preambles_by_message_type[mt] + assert len(preamble) == len_field + bits.append(preamble) + message_length -= len(preamble) + elif lbl.field_type.function == FieldType.Function.SYNC: + sync = self.syncs_by_message_type[mt] + assert len(sync) == len_field + bits.append(sync) + message_length -= len(sync) + elif lbl.field_type.function == FieldType.Function.LENGTH: + value = int(math.ceil(message_length / 8)) + + if not self.length_in_bytes: + value *= 8 + + bits.append(self.decimal_to_bits(value, len_field)) + elif lbl.field_type.function == FieldType.Function.TYPE: + bits.append(self.decimal_to_bits(self.message_type_codes[mt] % (2 ** len_field), len_field)) + elif lbl.field_type.function == FieldType.Function.SEQUENCE_NUMBER: + bits.append(self.decimal_to_bits(self.sequence_numbers[mt] % (2 ** len_field), len_field)) + elif lbl.field_type.function == FieldType.Function.DST_ADDRESS: + dst_bits = self.__get_address_for_participant(destination) + + if len(dst_bits) != len_field: + raise ValueError( + "Length of dst ({0} bits) != length dst field ({1} bits)".format(len(dst_bits), len_field)) + + bits.append(dst_bits) + elif lbl.field_type.function == FieldType.Function.SRC_ADDRESS: + src_bits = self.__get_address_for_participant(source) + + if len(src_bits) != len_field: + raise ValueError( + "Length of src ({0} bits) != length src field ({1} bits)".format(len(src_bits), len_field)) + + bits.append(src_bits) + elif lbl.field_type.function == FieldType.Function.DATA: + if len(data) != len_field: + raise ValueError( + "Length of data ({} bits) != length data field ({} bits)".format(len(data), len_field)) + bits.append(data) + + start = lbl.end + + if not data_label_present: + bits.append(data) + + msg = Message.from_plain_bits_str("".join(bits)) + msg.message_type = mt + msg.participant = source + self.sequence_numbers[mt] += self.sequence_number_increment + + for checksum_label in checksum_labels: + msg[checksum_label.start:checksum_label.end] = checksum_label.calculate_checksum_for_message(msg, False) + + self.protocol.messages.append(msg) + + def to_file(self, filename: str): + self.protocol.to_xml_file(filename, [], self.participants, write_bits=True) + + def export_to_latex(self, filename: str, number: int): + def export_message_type_to_latex(message_type, f): + f.write(" \\begin{itemize}\n") + for lbl in message_type: # type: ProtocolLabel + if lbl.field_type.function == FieldType.Function.SYNC: + sync = array("B", map(int, self.syncs_by_message_type[message_type])) + f.write(" \\item {}: \\texttt{{0x{}}}\n".format(lbl.name, util.bit2hex(sync))) + elif lbl.field_type.function == FieldType.Function.PREAMBLE: + preamble = array("B", map(int, self.preambles_by_message_type[message_type])) + f.write(" \\item {}: \\texttt{{0x{}}}\n".format(lbl.name, util.bit2hex(preamble))) + elif lbl.field_type.function == FieldType.Function.CHECKSUM: + f.write(" \\item {}: {}\n".format(lbl.name, lbl.checksum.caption)) + elif lbl.field_type.function in (FieldType.Function.LENGTH, FieldType.Function.SEQUENCE_NUMBER) and lbl.length > 8: + f.write(" \\item {}: {} bit (\\textbf{{{} endian}})\n".format(lbl.name, lbl.length, "little" if self.little_endian else "big")) + elif lbl.field_type.function == FieldType.Function.DATA: + f.write(" \\item payload: {} byte\n".format(lbl.length // 8)) + else: + f.write(" \\item {}: {} bit\n".format(lbl.name, lbl.length)) + f.write(" \\end{itemize}\n") + + with open(filename, "a") as f: + f.write("\\subsection{{Protocol {}}}\n".format(number)) + + if len(self.participants) > 1: + f.write("There were {} participants involved in communication: ".format(len(self.participants))) + f.write(", ".join("{} (\\texttt{{0x{}}})".format(p.name, p.address_hex) for p in self.participants[:-1])) + f.write(" and {} (\\texttt{{0x{}}})".format(self.participants[-1].name, self.participants[-1].address_hex)) + f.write(".\n") + + if len(self.message_types) == 1: + f.write("The protocol has one message type with the following fields:\n") + export_message_type_to_latex(self.message_types[0], f) + else: + f.write("The protocol has {} message types with the following fields:\n".format(len(self.message_types))) + f.write("\\begin{itemize}\n") + for mt in self.message_types: + f.write(" \\item \\textbf{{{}}}\n".format(mt.name)) + export_message_type_to_latex(mt, f) + f.write("\\end{itemize}\n") + + f.write("\n") + + +if __name__ == '__main__': + mb = MessageTypeBuilder("test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 4) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + pg = ProtocolGenerator([mb.message_type], [], little_endian=False) + pg.generate_message(data="1" * 8) + pg.generate_message(data="1" * 16) + pg.generate_message(data="0xab", source=Participant("Alice", "A", "1234"), + destination=Participant("Bob", "B", "4567")) + pg.to_file("/tmp/test.proto") diff --git a/src/urh/awre/components/Address.py b/src/urh/awre/components/Address.py deleted file mode 100644 index 8df6d820fd..0000000000 --- a/src/urh/awre/components/Address.py +++ /dev/null @@ -1,315 +0,0 @@ -from collections import defaultdict - -import numpy as np -from urh import constants -from urh.awre.CommonRange import CommonRange -from urh.cythonext import util -from urh.awre.components.Component import Component -from urh.signalprocessing.MessageType import MessageType - - -class Address(Component): - MIN_ADDRESS_LENGTH = 8 # Address should be at least one byte - - def __init__(self, fieldtypes, xor_matrix, priority=2, predecessors=None, enabled=True, backend=None, messagetypes=None): - super().__init__(priority, predecessors, enabled, backend, messagetypes) - self.xor_matrix = xor_matrix - - self.dst_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.DST_ADDRESS), None) - self.src_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SRC_ADDRESS), None) - - self.dst_field_name = self.dst_field_type.caption if self.dst_field_type else "DST address" - self.src_field_name = self.src_field_type.caption if self.src_field_type else "SRC address" - - def _py_find_field(self, messages, verbose=False): - """ - - :type messages: list of urh.signalprocessing.Message.Message - :return: - """ - msg_indices_per_participant = defaultdict(list) - """:type : dict[urh.signalprocessing.Participant.Participant, list[int]] """ - - for i, msg in enumerate(messages): - msg_indices_per_participant[msg.participant].append(i) - - - # Cluster participants - equal_ranges_per_participant = defaultdict(list) - """:type : dict[urh.signalprocessing.Participant.Participant, list[CommonRange]] """ - - alignment = 8 - - # Step 1: Find equal ranges for participants by evaluating the XOR matrix participant wise - for participant, participant_msg_indices in msg_indices_per_participant.items(): - for i, msg_index in enumerate(participant_msg_indices): - msg = messages[msg_index] - bitvector_str = msg.decoded_bits_str - - for other_index in participant_msg_indices[i+1:]: - other_msg = messages[other_index] - xor_vec = self.xor_matrix[msg_index, other_index][self.xor_matrix[msg_index, other_index] != -1] # -1 = End of Vector - - # addresses are searched across message types, as we assume them to be in almost every message - # therefore we need to consider message types of both messages we compare and ignore already labeled areas - unlabeled_ranges = msg.message_type.unlabeled_ranges_with_other_mt(other_msg.message_type) - for rng_start, rng_end in unlabeled_ranges: - start = 0 - # The last 1 marks end of sequence, and prevents swallowing long zero sequences at the end - cmp_vector = np.append(xor_vec[rng_start:rng_end], 1) - for end in np.where(cmp_vector == 1)[0]: - if end - start >= self.MIN_ADDRESS_LENGTH: - equal_range_start = alignment * ((rng_start + start) // alignment) - equal_range_end = alignment * ((rng_start + end) // alignment) - bits = bitvector_str[equal_range_start:equal_range_end] - - # Did we already found this range? - cr = next((cr for cr in equal_ranges_per_participant[participant] if - cr.start == equal_range_start and cr.end == equal_range_end - and cr.bits == bits), None) - - # If not: Create it - if cr is None: - cr = CommonRange(equal_range_start, equal_range_end, bits) - equal_ranges_per_participant[participant].append(cr) - - cr.messages.add(msg_index) - cr.messages.add(other_index) - - start = end + alignment - - if verbose: - print(constants.color.BOLD + "Result after Step 1" +constants.color.END) - self.__print_ranges(equal_ranges_per_participant) - - # Step 2: Now we want to find our address candidates. - # We do this by weighting them in order of LCS they share with each other - scored_candidates = self.find_candidates([cr for crl in equal_ranges_per_participant.values() for cr in crl]) - """:type : dict[str, int] """ - - try: - highscored = next(self.choose_candidate_pair(scored_candidates)) - assert len(highscored[0]) == len(highscored[1]) - except (StopIteration, AssertionError): - return - - if verbose: - print(scored_candidates) - print(sorted(scored_candidates, key=scored_candidates.get, reverse=True)) - - # Now get the common_ranges we need - scored_candidates_per_participant = defaultdict(list) - """:type : dict[urh.signalprocessing.Participant.Participant, list[CommonRange]] """ - - for participant, ranges in equal_ranges_per_participant.items(): - for equal_range in ranges: - for h in highscored: - rng = equal_range.pos_of_hex(h) - if rng is not None: - start, end = rng - bits = equal_range.bits[start:end] - rel_start = equal_range.start + start - rel_end = rel_start + (end - start) - cr = next((cr for cr in scored_candidates_per_participant[participant] if cr.start == rel_start - and cr.end == rel_end and - cr.bits == bits), None) - if cr is None: - cr = CommonRange(rel_start, rel_end, bits) - scored_candidates_per_participant[participant].append(cr) - - cr.messages.update(equal_range.messages) - - # Now we have the highscored ranges per participant - # If there is a crossmatch of the ranges we are good and found the addresses! - # We have something like: - # - # Participant: Alice (A): Participant: Bob (B): - # ======================= ===================== - # - # Range Value Messages Range Value Messages - # ----- ----- -------- ----- ----- -------- - # 72-96 1b6033 {1, 5, 9, 13, 17, 20} 72-96 78e289 {11, 3, 15, 7} - # 88-112 1b6033 {2, 6, 10, 14, 18} 88-112 78e289 {4, 8, 12, 16, 19} - # 112-136 78e289 {2, 6, 10, 14, 18} 112-136 1b6033 {0, 4, 8, 12, 16, 19} - # - - # If the value doubles for the same participant in other range, then we need to create a new message type - # We consider the default case (=default message type) to have addresses followed by each other - # Furthermore, we assume if there is only one address per message type, it is the destination address - clusters = {"default": defaultdict(set), "ack": defaultdict(set)} - """:type: dict[str, dict[tuple[int.int],set[int]]]""" - - all_candidates = [cr for crl in scored_candidates_per_participant.values() for cr in crl] - # Check for crossmatch and cluster in together and splitted addresses - # Perform a merge by only saving the ranges and applying messages - for candidate in sorted(all_candidates): - if any(c.start == candidate.start and c.end == candidate.end and c.bits != candidate.bits for c in all_candidates): - # Crossmatch! This is a address - if any(c.start == candidate.end or c.end == candidate.start for c in all_candidates): - clusters["default"][(candidate.start, candidate.end)].update(candidate.messages) - else: - clusters["ack"][(candidate.start, candidate.end)].update(candidate.messages) - - msg_clusters = {cname: set(i for s in ranges.values() for i in s) for cname, ranges in clusters.items()} - - # If there are no addresses in default message type prevent evaluating everything as ACK - if not msg_clusters["default"]: - msg_clusters["ack"] = set() - scored_candidates_per_participant.clear() - - self.assign_messagetypes(messages, msg_clusters) - - # Now try to find the addresses of the participants to separate SRC and DST address later - self.assign_participant_addresses(messages, list(scored_candidates_per_participant.keys()), highscored) - - for participant, ranges in scored_candidates_per_participant.items(): - for rng in ranges: - for msg_index in rng.messages: - msg = messages[msg_index] - - if msg.message_type.name == "ack": - field_type = self.dst_field_type - name = self.dst_field_name - elif msg.participant: - if rng.hex_value == msg.participant.address_hex: - name = self.src_field_name - field_type = self.src_field_type - else: - name = self.dst_field_name - field_type = self.dst_field_type - else: - name = "Address" - field_type = None - - if not any(lbl.name == name and lbl.auto_created for lbl in msg.message_type): - msg.message_type.add_protocol_label(rng.start, rng.end - 1, name=name, - auto_created=True, type=field_type) - - - @staticmethod - def find_candidates(candidates): - """ - Find candidate addresses using LCS algorithm - perform a scoring based on how often a candidate appears in a longer candidate - - Input is something like - ------------------------ - ['1b6033', '1b6033fd57', '701b603378e289', '20701b603378e289000c62', - '1b603300', '78e289757e', '7078e2891b6033000000', '207078e2891b6033000000'] - - Output like - ----------- - {'1b6033': 18, '1b6033fd57': 1, '701b603378e289': 2, '207078e2891b6033000000': 1, - '57': 1, '7078e2891b6033000000': 2, '78e289757e': 1, '20701b603378e289000c62': 1, - '78e289': 4, '1b603300': 3} - - :type candidates: list of CommonRange - :return: - """ - - result = defaultdict(int) - for i, c_i in enumerate(candidates): - for j in range(i, len(candidates)): - lcs = util.longest_common_substring(c_i.hex_value, candidates[j].hex_value) - if lcs: - result[lcs] += 1 - - return result - - @staticmethod - def choose_candidate_pair(candidates): - """ - Choose a pair of address candidates ensuring they have the same length and starting with the highest scored ones - - :type candidates: dict[str, int] - :param candidates: Count how often the longest common substrings appeared in the messages - :return: - """ - highscored = sorted(candidates, key=candidates.get, reverse=True) - for i, h_i in enumerate(highscored): - for h_j in highscored[i+1:]: - if len(h_i) == len(h_j): - yield (h_i, h_j) - - @staticmethod - def assign_participant_addresses(messages, participants, hex_addresses): - """ - - :type participants: list[urh.signalprocessing.Participant.Participant] - :type hex_addresses: tuple[str] - :type messages: list[urh.signalprocessing.Message.Message] - :return: - """ - try: - participants.remove(None) - except ValueError: - pass - - if len(participants) != len(hex_addresses): - return - - if len(participants) == 0: - return # No chance - - - score = {p: {addr: 0 for addr in hex_addresses} for p in participants} - - for i in range(1, len(messages)): - msg = messages[i] - prev_msg = messages[i-1] - - if msg.message_type.name == "ack": - addr = next(addr for addr in hex_addresses if addr in msg.decoded_hex_str) - if addr in prev_msg.decoded_hex_str: - score[prev_msg.participant][addr] += 1 - - for p in participants: - p.address_hex = max(score[p], key=score[p].get) - - def __print_clustered(self, clustered_addresses): - for bl in sorted(clustered_addresses): - print(constants.color.BOLD + "Byte length " + str(bl) + constants.color.END) - for (start, end), bits in sorted(clustered_addresses[bl].items()): - print(start, end, bits) - - def __print_ranges(self, equal_ranges_per_participant): - for parti in sorted(equal_ranges_per_participant): - if parti is None: - continue - - print("\n" + constants.color.UNDERLINE + str(parti.name) + " (" + parti.shortname+ ")" + constants.color.END) - address1 = "000110110110000000110011" - address2 = "011110001110001010001001" - - assert len(address1) % 8 == 0 - assert len(address2) % 8 == 0 - - print("address1", constants.color.BLUE, address1 + " (" +hex(int("".join(map(str, address1)), 2)) +")", constants.color.END) - print("address2", constants.color.GREEN, address2 + " (" + hex(int("".join(map(str, address2)), 2)) + ")", - constants.color.END) - - print() - - for common_range in sorted(equal_ranges_per_participant[parti]): - assert isinstance(common_range, CommonRange) - bits_str = common_range.bits - format_start = "" - if address1 in bits_str and address2 not in bits_str: - format_start = constants.color.BLUE - if address2 in bits_str and address1 not in bits_str: - format_start = constants.color.GREEN - if address1 in bits_str and address2 in bits_str: - format_start = constants.color.RED + constants.color.BOLD - - # For Bob the adress 1b60330 is found to be 0x8db0198000 which is correct, - # as it starts with a leading 1 in all messages. - # This is the last Bit of e0003 (Broadcast) or 78e289 (Other address) - # Code to verify: hex(int("1000"+bin(int("1b6033",16))[2:]+"000",2)) - # Therefore we need to check for partial bits inside the address candidates to be sure we find the correct ones - occurences = len(common_range.messages) - print(common_range.start, common_range.end, - "({})\t".format(occurences), - format_start + common_range.hex_value + "\033[0m", common_range.byte_len, - bits_str, "(" + ",".join(map(str, common_range.messages)) + ")") - - diff --git a/src/urh/awre/components/Component.py b/src/urh/awre/components/Component.py deleted file mode 100644 index 156613525d..0000000000 --- a/src/urh/awre/components/Component.py +++ /dev/null @@ -1,112 +0,0 @@ -from abc import ABCMeta - -from urh.signalprocessing.Message import Message -from urh.signalprocessing.MessageType import MessageType -from urh.signalprocessing.ProtocoLabel import ProtocolLabel -from enum import Enum - -from urh.util.Logger import logger - - -class Component(metaclass=ABCMeta): - """ - A component is the basic building block of our AWRE algorithm. - A component can be a Preamble or Sync or Length Field finding routine. - Components can have a priority which determines the order in which they are processed by the algorithm. - Additionally, components can have a set of predecessors to define hard dependencies. - """ - - - EQUAL_BIT_TRESHOLD = 0.9 - - class Backend(Enum): - python = 1 - cython = 2 - plainc = 3 - - def __init__(self, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None): - """ - - :param priority: Priority for this Component. 0 is highest priority - :type priority: int - :param predecessors: List of preceding components, that need to be run before this one - :type predecessors: list of Component or None - :param messagetypes: Message types of the examined protocol - :type messagetypes: list[MessageType] - """ - self.enabled = enabled - self.backend = backend if backend is not None else self.Backend.python - self.priority = abs(priority) - self.predecessors = predecessors if isinstance(predecessors, list) else [] - """:type: list of Component """ - - self.messagetypes = messagetypes - - def find_field(self, messages): - """ - Wrapper method selecting the backend to assign the protocol field. - Various strategies are possible e.g.: - 1) Heuristics e.g. for Preamble - 2) Scoring based e.g. for Length - 3) Fulltext search for addresses based on participant subgroups - - :param messages: messages a field shall be searched for - :type messages: list of Message - """ - try: - if self.backend == self.Backend.python: - self._py_find_field(messages) - elif self.backend == self.Backend.cython: - self._cy_find_field(messages) - elif self.backend == self.Backend.plainc: - self._c_find_field(messages) - else: - raise ValueError("Unsupported backend {}".format(self.backend)) - except NotImplementedError: - logger.info("Skipped {} because not implemented yet".format(self.__class__.__name__)) - - def _py_find_field(self, messages): - raise NotImplementedError() - - def _cy_find_field(self, messages): - raise NotImplementedError() - - def _c_find_field(self, messages): - raise NotImplementedError() - - - def assign_messagetypes(self, messages, clusters): - """ - Assign message types based on the clusters. Following rules: - 1) Messages from different clusters will get different message types - 2) Messages from same clusters will get same message type - 3) The new message type will copy over the existing labels - 4) No new message type will be set for messages, that already have a custom message type assigned - - For messages with clustername "default" no new message type will be created - - :param messages: Messages, that messagetype needs to be clustered - :param clusters: clusters for the messages - :type messages: list[Message] - :type clusters: dict[str, set[int]] - :return: - """ - for clustername, clustercontent in clusters.items(): - if clustername == "default": - # Do not force the default message type - continue - - for msg_i in clustercontent: - msg = messages[msg_i] - if msg.message_type == self.messagetypes[0]: - # Message has default message type - # Copy the existing labels and create a new message type - # if it was not already done - try: - msg_type = next(mtype for mtype in self.messagetypes if mtype.name == clustername) - except StopIteration: - msg_type = MessageType(name=clustername, iterable=msg.message_type) - msg_type.assigned_by_logic_analyzer = True - self.messagetypes.append(msg_type) - msg.message_type = msg_type - diff --git a/src/urh/awre/components/Flags.py b/src/urh/awre/components/Flags.py deleted file mode 100644 index f9ec755f1f..0000000000 --- a/src/urh/awre/components/Flags.py +++ /dev/null @@ -1,8 +0,0 @@ -from urh.awre.components.Component import Component - -class Flags(Component): - def __init__(self, priority=2, predecessors=None, enabled=True, backend=None): - super().__init__(priority, predecessors, enabled, backend) - - def _py_find_field(self, messages): - raise NotImplementedError("Todo") \ No newline at end of file diff --git a/src/urh/awre/components/Length.py b/src/urh/awre/components/Length.py deleted file mode 100644 index a357c6eb3a..0000000000 --- a/src/urh/awre/components/Length.py +++ /dev/null @@ -1,139 +0,0 @@ -import math -from collections import defaultdict - -import numpy as np - -from urh.awre.components.Component import Component -from urh.signalprocessing.FieldType import FieldType -from urh.signalprocessing.Interval import Interval -from urh.signalprocessing.MessageType import MessageType -from urh.signalprocessing.ProtocoLabel import ProtocolLabel - - -class Length(Component): - """ - The length is defined as byte length and found by finding equal ranges in the length clustered blocks. - A length field should be a common equal range in all clusters. - """ - - def __init__(self, fieldtypes, length_cluster, priority=2, predecessors=None, - enabled=True, backend=None, messagetypes=None): - super().__init__(priority, predecessors, enabled, backend, messagetypes) - - self.length_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.LENGTH), None) - self.length_field_name = self.length_field_type.caption if self.length_field_type else "Length" - - self.length_cluster = length_cluster - """ - An example length cluster is - - 2: [0.5, 1] - 4: [1, 0.75, 1, 1] - - Meaning there were two message lengths: 2 and 4 bit. - (0.5, 1) means, the first bit was equal in 50% of cases (meaning maximum difference) and bit 2 was equal in all messages - - A simple XOR would not work as it would be very error prone. - """ - - def _py_find_field(self, messages): - """ - - :type messages: list of urh.signalprocessing.Message.Message - :return: - """ - messages_by_type = defaultdict(list) - """:type : dict[urh.signalprocessing.MessageType.MessageType, list[urh.signalprocessing.Message.Message]] """ - - for msg in messages: - messages_by_type[msg.message_type].append(msg) - - # First we get the common ranges per message length - common_ranges_by_length = defaultdict(lambda: defaultdict(list)) - """:type: dict[urh.signalprocessing.MessageType.MessageType, dict[int, List[(int,int)]]]""" - - for message_type in messages_by_type.keys(): - unlabeled_ranges = message_type.unlabeled_ranges - for vec_len in set(4 * (len(msg.decoded_bits) // 4) for msg in messages_by_type[message_type]): - try: - cluster = self.length_cluster[vec_len] - except KeyError: - continue # Skip message lengths that appear only once - - for rng_start, rng_end in unlabeled_ranges: - start = 0 - for end in np.where(cluster[rng_start:rng_end] < self.EQUAL_BIT_TRESHOLD)[0]: - if start < end - 1: - common_ranges_by_length[message_type][vec_len].append( - (rng_start + start, rng_start + end - 1)) - start = end + 1 - - # Now we merge the ranges together to get our candidate ranges - common_intervals_by_type = {message_type: [] for message_type in common_ranges_by_length.keys()} - """:type: dict[urh.signalprocessing.MessageType.MessageType, list[Interval]]""" - - for message_type in common_intervals_by_type.keys(): - msg_lens = sorted(common_ranges_by_length[message_type].keys()) - for interval in common_ranges_by_length[message_type][msg_lens[0]]: - candidate = Interval(interval[0], interval[1]) - for other_len in msg_lens[1:]: - matches = [] - for other_interval in common_ranges_by_length[message_type][other_len]: - oi = Interval(other_interval[0], other_interval[1]) - if oi.overlaps_with(candidate): - candidate = candidate.find_common_interval(oi) - matches.append(candidate) - - if not matches: - candidate = None - break - else: - candidate = Interval.find_greatest(matches) - - if candidate: - common_intervals_by_type[message_type].append(candidate) - - # Now we have the common intervals and need to check which one is the length - for message_type, intervals in common_intervals_by_type.items(): - assert isinstance(message_type, MessageType) - # Exclude Synchronization (or preamble if not present) from length calculation - sync_lbl = self.find_lbl_function_in(FieldType.Function.SYNC, message_type) - if sync_lbl: - sync_len = self.__nbits2bytes(sync_lbl.end) - else: - preamble_lbl = self.find_lbl_function_in(FieldType.Function.PREAMBLE, message_type) - sync_len = self.__nbits2bytes(preamble_lbl.end) if preamble_lbl is not None else 0 - - scores = defaultdict(int) - weights = {-4: 1, -3: 2, -2: 3, -1: 4, 0: 5} - - for common_interval in intervals: - for msg in messages_by_type[message_type]: - bits = msg.decoded_bits - byte_len = self.__nbits2bytes(len(bits)) - sync_len - start, end = common_interval.start, common_interval.end - for byte_start in range(start, end, 8): - byte_end = byte_start + 8 if byte_start + 8 <= end else end - try: - byte = int("".join(["1" if bit else "0" for bit in bits[byte_start:byte_end]]), 2) - diff = byte - byte_len - if diff in weights: - scores[(byte_start, byte_end)] += weights[diff] - except ValueError: - pass # Byte_end or byte_start was out of bits --> too close on the end - - try: - start, end = max(scores, key=scores.__getitem__) - if not any((lbl.field_type.function == FieldType.Function.LENGTH or lbl.name == "Length") and lbl.auto_created - for lbl in message_type): - message_type.add_protocol_label(start=start, end=end - 1, name=self.length_field_name, - auto_created=True, type=self.length_field_type) - except ValueError: - continue - - def __nbits2bytes(self, nbits): - return int(math.ceil(nbits / 8)) - - @staticmethod - def find_lbl_function_in(function: FieldType.Function, message_type: MessageType) -> ProtocolLabel: - return next((lbl for lbl in message_type if lbl.field_type and lbl.field_type.function == function), None) diff --git a/src/urh/awre/components/Preamble.py b/src/urh/awre/components/Preamble.py deleted file mode 100644 index fb129ba7b6..0000000000 --- a/src/urh/awre/components/Preamble.py +++ /dev/null @@ -1,121 +0,0 @@ -from collections import defaultdict -from urh.awre.components.Component import Component -from urh.signalprocessing.FieldType import FieldType -from urh.signalprocessing.Message import Message - - -class Preamble(Component): - """ - Assign Preamble and SoF. - - """ - def __init__(self, fieldtypes, priority=0, predecessors=None, enabled=True, backend=None, messagetypes=None): - """ - - :type fieldtypes: list of FieldType - :param priority: - :param predecessors: - :param enabled: - :param backend: - :param messagetypes: - """ - super().__init__(priority, predecessors, enabled, backend, messagetypes) - - self.preamble_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.PREAMBLE), None) - self.sync_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SYNC), None) - - self.preamble_name = self.preamble_field_type.caption if self.preamble_field_type else "Preamble" - self.sync_name = self.sync_field_type.caption if self.sync_field_type else "Synchronization" - - def _py_find_field(self, messages): - """ - - :type messages: list of Message - :return: - """ - preamble_ranges = defaultdict(list) - """:type: dict[MessageType, list] """ - - for msg in messages: - rng = self.__find_preamble_range(msg) - if rng: - preamble_ranges[msg.message_type].append(rng) - - preamble_ends = defaultdict(int) - for message_type, ranges in preamble_ranges.items(): - start, end = max(ranges, key=ranges.count) - message_type.add_protocol_label(start=start, end=end, name=self.preamble_name, - auto_created=True, type=self.preamble_field_type) - - preamble_ends[message_type] = end + 1 - - for message_type in preamble_ranges.keys(): - messages = [msg for msg in messages if msg.message_type == message_type] - first_field = next((field for field in message_type if field.start > preamble_ends[message_type]), None) - search_end = first_field.start if first_field is not None else None - sync_range = self.__find_sync_range(messages, preamble_ends[message_type], search_end) - - if sync_range: - message_type.add_protocol_label(start=sync_range[0], end=sync_range[1]-1, name=self.sync_name, - auto_created=True, type=self.sync_field_type) - - - def __find_preamble_range(self, message: Message): - search_start = 0 - - if len(message.message_type) == 0: - search_end = len(message.decoded_bits) - else: - search_end = message.message_type[0].start - - bits = message.decoded_bits - - # Skip sequences of equal bits - try: - first_difference = next((i for i in range(search_start, search_end-1) if bits[i] != bits[i+1]), None) - except IndexError: - # see: https://github.com/jopohl/urh/issues/290 - first_difference = None - - if first_difference is None: - return None - - try: - preamble_end = next((i-1 for i in range(first_difference, search_end, 4) - if bits[i] == bits[i+1] or bits[i] != bits[i+2] or bits[i] == bits[i+3]), search_end) - except IndexError: - return None - - if preamble_end - first_difference > 4: - return first_difference, preamble_end - else: - return None - - - def __find_sync_range(self, messages, preamble_end: int, search_end: int): - """ - Finding the synchronization works by finding the first difference between two messages. - This is performed for all messages and the most frequent first difference is chosen - - :type messages: list of Message - :param preamble_end: End of preamble = start of search - :param search_end: End of search = start of first other label - """ - - possible_sync_pos = defaultdict(int) - - - for i, msg in enumerate(messages): - bits_i = msg.decoded_bits[preamble_end:search_end] - for j in range(i, len(messages)): - bits_j = messages[j].decoded_bits[preamble_end:search_end] - first_diff = next((k for k, (bit_i, bit_j) in enumerate(zip(bits_i, bits_j)) if bit_i != bit_j), None) - if first_diff is not None: - first_diff = preamble_end + 4 * (first_diff // 4) - if (first_diff - preamble_end) >= 4: - possible_sync_pos[(preamble_end, first_diff)] += 1 - try: - sync_interval = max(possible_sync_pos, key=possible_sync_pos.__getitem__) - return sync_interval - except ValueError: - return None diff --git a/src/urh/awre/components/SequenceNumber.py b/src/urh/awre/components/SequenceNumber.py deleted file mode 100644 index 30d06ad16b..0000000000 --- a/src/urh/awre/components/SequenceNumber.py +++ /dev/null @@ -1,21 +0,0 @@ -from urh.awre.components.Component import Component - -class SequenceNumber(Component): - def __init__(self, fieldtypes, priority=2, predecessors=None, enabled=True, backend=None): - """ - - :type fieldtypes: list of FieldType - :param priority: - :param predecessors: - :param enabled: - :param backend: - :param messagetypes: - """ - super().__init__(priority, predecessors, enabled, backend) - - self.seqnr_field_type = next((ft for ft in fieldtypes if ft.function == ft.Function.SEQUENCE_NUMBER), None) - self.seqnr_field_name = self.seqnr_field_type.caption if self.seqnr_field_type else "Sequence Number" - - - def _py_find_field(self, messages): - raise NotImplementedError("Todo") \ No newline at end of file diff --git a/src/urh/awre/components/Type.py b/src/urh/awre/components/Type.py deleted file mode 100644 index fc40d595b8..0000000000 --- a/src/urh/awre/components/Type.py +++ /dev/null @@ -1,8 +0,0 @@ -from urh.awre.components.Component import Component - -class Type(Component): - def __init__(self, priority=2, predecessors=None, enabled=True, backend=None): - super().__init__(priority, predecessors, enabled, backend) - - def _py_find_field(self, messages): - raise NotImplementedError("Todo") \ No newline at end of file diff --git a/src/urh/awre/engines/AddressEngine.py b/src/urh/awre/engines/AddressEngine.py new file mode 100644 index 0000000000..7aa50f1248 --- /dev/null +++ b/src/urh/awre/engines/AddressEngine.py @@ -0,0 +1,399 @@ +import itertools +import math +from array import array +from collections import defaultdict, Counter + +import numpy as np + +from urh.awre.CommonRange import CommonRange +from urh.awre.engines.Engine import Engine +from urh.cythonext import awre_util +from urh.util.Logger import logger + + +class AddressEngine(Engine): + def __init__(self, msg_vectors, participant_indices, known_participant_addresses: dict = None, + already_labeled: list = None, src_field_present=False): + """ + + :param msg_vectors: Message data behind synchronization + :type msg_vectors: list of np.ndarray + :param participant_indices: list of participant indices + where ith position holds participants index for ith messages + :type participant_indices: list of int + """ + assert len(msg_vectors) == len(participant_indices) + + self.minimum_score = 0.1 + + self.msg_vectors = msg_vectors + self.participant_indices = participant_indices + self.already_labeled = [] + + self.src_field_present = src_field_present + + if already_labeled is not None: + for start, end in already_labeled: + # convert it to hex + self.already_labeled.append((int(math.ceil(start / 4)), int(math.ceil(end / 4)))) + + self.message_indices_by_participant = defaultdict(list) + for i, participant_index in enumerate(self.participant_indices): + self.message_indices_by_participant[participant_index].append(i) + + if known_participant_addresses is None: + self.known_addresses_by_participant = dict() # type: dict[int, np.ndarray] + else: + self.known_addresses_by_participant = known_participant_addresses # type: dict[int, np.ndarray] + + @staticmethod + def cross_swap_check(rng1: CommonRange, rng2: CommonRange): + return (rng1.start == rng2.start + rng1.length or rng1.start == rng2.start - rng1.length) \ + and rng1.value.tobytes() == rng2.value.tobytes() + + @staticmethod + def ack_check(rng1: CommonRange, rng2: CommonRange): + return rng1.start == rng2.start and rng1.length == rng2.length and rng1.value.tobytes() != rng2.value.tobytes() + + def find(self): + addresses_by_participant = {p: [addr.tostring()] for p, addr in self.known_addresses_by_participant.items()} + addresses_by_participant.update(self.find_addresses()) + self._debug("Addresses by participant", addresses_by_participant) + + # Find the address candidates by participant in messages + ranges_by_participant = defaultdict(list) # type: dict[int, list[CommonRange]] + + addresses = [np.array(np.frombuffer(a, dtype=np.uint8)) + for address_list in addresses_by_participant.values() + for a in address_list] + + already_labeled_cols = array("L", [e for rng in self.already_labeled for e in range(*rng)]) + + # Find occurrences of address candidates in messages and create common ranges over matching positions + for i, msg_vector in enumerate(self.msg_vectors): + participant = self.participant_indices[i] + for address in addresses: + for index in awre_util.find_occurrences(msg_vector, address, already_labeled_cols): + common_ranges = ranges_by_participant[participant] + rng = next((cr for cr in common_ranges if cr.matches(index, address)), None) # type: CommonRange + if rng is not None: + rng.message_indices.add(i) + else: + common_ranges.append(CommonRange(index, len(address), address, + message_indices={i}, + range_type="hex")) + + num_messages_by_participant = defaultdict(int) + for participant in self.participant_indices: + num_messages_by_participant[participant] += 1 + + # Look for cross swapped values between participant clusters + for p1, p2 in itertools.combinations(ranges_by_participant, 2): + ranges1_set, ranges2_set = set(ranges_by_participant[p1]), set(ranges_by_participant[p2]) + + for rng1, rng2 in itertools.product(ranges_by_participant[p1], ranges_by_participant[p2]): + if rng1 in ranges2_set and rng2 in ranges1_set: + if self.cross_swap_check(rng1, rng2): + rng1.score += len(rng2.message_indices) / num_messages_by_participant[p2] + rng2.score += len(rng1.message_indices) / num_messages_by_participant[p1] + elif self.ack_check(rng1, rng2): + # Add previous score in divisor to add bonus to ranges that apply to all messages + rng1.score += len(rng2.message_indices) / (num_messages_by_participant[p2] + rng1.score) + rng2.score += len(rng1.message_indices) / (num_messages_by_participant[p1] + rng2.score) + + if len(ranges_by_participant) == 1 and not self.src_field_present: + for p, ranges in ranges_by_participant.items(): + for rng in sorted(ranges): + try: + if np.array_equal(rng.value, self.known_addresses_by_participant[p]): + # Only one participant in this iteration and address already known -> Highscore + rng.score = 1 + break # Take only the first (leftmost) range + except KeyError: + pass + + high_scored_ranges_by_participant = defaultdict(list) + + address_length = self.__estimate_address_length(ranges_by_participant) + + # Get highscored ranges by participant + for participant, common_ranges in ranges_by_participant.items(): + # Sort by negative score so ranges with highest score appear first + # Secondary sort by tuple to ensure order when ranges have same score + sorted_ranges = sorted(filter(lambda cr: cr.score > self.minimum_score, common_ranges), + key=lambda cr: (-cr.score, cr)) + if len(sorted_ranges) == 0: + addresses_by_participant[participant] = dict() + continue + + addresses_by_participant[participant] = {a for a in addresses_by_participant.get(participant, []) + if len(a) == address_length} + + for rng in filter(lambda r: r.length == address_length, sorted_ranges): + rng.score = min(rng.score, 1.0) + high_scored_ranges_by_participant[participant].append(rng) + + # Now we find the most probable address for all participants + self.__assign_participant_addresses(addresses_by_participant, high_scored_ranges_by_participant) + + # Eliminate participants for which we could not assign an address + for participant, address in addresses_by_participant.copy().items(): + if address is None: + del addresses_by_participant[participant] + + # Now we can separate SRC and DST + for participant, ranges in high_scored_ranges_by_participant.items(): + try: + address = addresses_by_participant[participant] + except KeyError: + high_scored_ranges_by_participant[participant] = [] + continue + + result = [] + + for rng in sorted(ranges, key=lambda r: r.score, reverse=True): + rng.field_type = "source address" if rng.value.tostring() == address else "destination address" + if len(result) == 0: + result.append(rng) + else: + subset = next((r for r in result if rng.message_indices.issubset(r.message_indices)), None) + if subset is not None: + if rng.field_type == subset.field_type: + # Avoid adding same address type twice + continue + + if rng.length != subset.length or (rng.start != subset.end + 1 and rng.end + 1 != subset.start): + # Ensure addresses are next to each other + continue + + result.append(rng) + + high_scored_ranges_by_participant[participant] = result + + self.__find_broadcast_fields(high_scored_ranges_by_participant, addresses_by_participant) + + result = [rng for ranges in high_scored_ranges_by_participant.values() for rng in ranges] + # If we did not find a SRC address, lower the score a bit, + # so DST fields do not win later e.g. again length fields in case of tie + if not any(rng.field_type == "source address" for rng in result): + for rng in result: + rng.score *= 0.95 + + return result + + def __estimate_address_length(self, ranges_by_participant: dict): + """ + Estimate the address length which is assumed to be the same for all participants + + :param ranges_by_participant: + :return: + """ + address_lengths = [] + for participant, common_ranges in ranges_by_participant.items(): + sorted_ranges = sorted(filter(lambda cr: cr.score > self.minimum_score, common_ranges), + key=lambda cr: (-cr.score, cr)) + + max_scored = [r for r in sorted_ranges if r.score == sorted_ranges[0].score] + + # Prevent overestimation of address length by looking for substrings + for rng in max_scored[:]: + same_message_rng = [r for r in sorted_ranges + if r not in max_scored and r.score > 0 and r.message_indices == rng.message_indices] + + if len(same_message_rng) > 1 and all( + r.value.tobytes() in rng.value.tobytes() for r in same_message_rng): + # remove the longer range and add the smaller ones + max_scored.remove(rng) + max_scored.extend(same_message_rng) + + possible_address_lengths = [r.length for r in max_scored] + + # Count possible address lengths. + frequencies = Counter(possible_address_lengths) + # Take the most common one. On tie, take the shorter one + try: + addr_len = max(frequencies, key=lambda x: (frequencies[x], -x)) + address_lengths.append(addr_len) + except ValueError: # max() arg is an empty sequence + pass + + # Take most common address length of participants, to ensure they all have same address length + counted = Counter(address_lengths) + try: + address_length = max(counted, key=lambda x: (counted[x], -x)) + return address_length + except ValueError: # max() arg is an empty sequence + return 0 + + def __assign_participant_addresses(self, addresses_by_participant, high_scored_ranges_by_participant): + scored_participants_addresses = dict() + for participant in addresses_by_participant: + scored_participants_addresses[participant] = defaultdict(int) + + for participant, addresses in addresses_by_participant.items(): + if participant in self.known_addresses_by_participant: + address = self.known_addresses_by_participant[participant].tostring() + scored_participants_addresses[participant][address] = 9999999999 + continue + + for i in self.message_indices_by_participant[participant]: + matching = [rng for rng in high_scored_ranges_by_participant[participant] + if i in rng.message_indices and rng.value.tostring() in addresses] + + if len(matching) == 1: + address = matching[0].value.tostring() + # only one address, so probably a destination and not a source + scored_participants_addresses[participant][address] *= 0.9 + + # Since this is probably an ACK, the address is probably SRC of participant of previous message + if i > 0 and self.participant_indices[i - 1] != participant: + prev_participant = self.participant_indices[i - 1] + prev_matching = [rng for rng in high_scored_ranges_by_participant[prev_participant] + if i - 1 in rng.message_indices and rng.value.tostring() in addresses] + if len(prev_matching) > 1: + for prev_rng in filter(lambda r: r.value.tostring() == address, prev_matching): + scored_participants_addresses[prev_participant][address] += prev_rng.score + + elif len(matching) > 1: + # more than one address, so there must be a source address included + for rng in matching: + scored_participants_addresses[participant][rng.value.tostring()] += rng.score + + minimum_score = 0.5 + taken_addresses = set() + self._debug("Scored addresses", scored_participants_addresses) + + # If all participants have exactly one possible address and they all differ, we can assign them right away + if all(len(addresses) == 1 for addresses in scored_participants_addresses.values()): + all_addresses = [list(addresses)[0] for addresses in scored_participants_addresses.values()] + if len(all_addresses) == len(set(all_addresses)): # ensure all addresses are different + for p, addresses in scored_participants_addresses.items(): + addresses_by_participant[p] = list(addresses)[0] + return + + for participant, addresses in sorted(scored_participants_addresses.items()): + try: + # sort filtered results to prevent randomness for equal scores + found_address = max(sorted( + filter(lambda a: a not in taken_addresses and addresses[a] >= minimum_score, addresses), + reverse=True + ), key=addresses.get) + except ValueError: + # Could not assign address for this participant + addresses_by_participant[participant] = None + continue + + addresses_by_participant[participant] = found_address + taken_addresses.add(found_address) + + def __find_broadcast_fields(self, high_scored_ranges_by_participant, addresses_by_participant: dict): + """ + Last we check for messages that were sent to broadcast + 1. we search for messages that have a SRC address but no DST address + 2. we look at other messages that have this SRC field and find the corresponding DST position + 3. we evaluate the value of message without DST from 1 and compare these values with each other. + if they match, we found the broadcast address + :param high_scored_ranges_by_participant: + :return: + """ + if -1 in addresses_by_participant: + # broadcast address is already known + return + + broadcast_bag = defaultdict(list) # type: dict[CommonRange, list[int]] + for common_ranges in high_scored_ranges_by_participant.values(): + src_address_fields = sorted(filter(lambda r: r.field_type == "source address", common_ranges)) + dst_address_fields = sorted(filter(lambda r: r.field_type == "destination address", common_ranges)) + msg_with_dst = {i for dst_address_field in dst_address_fields for i in dst_address_field.message_indices} + + for src_address_field in src_address_fields: # type: CommonRange + msg_without_dst = {i for i in src_address_field.message_indices if i not in msg_with_dst} + if len(msg_without_dst) == 0: + continue + try: + matching_dst = next(dst for dst in dst_address_fields + if all(i in dst.message_indices + for i in src_address_field.message_indices - msg_without_dst)) + except StopIteration: + continue + for msg in msg_without_dst: + broadcast_bag[matching_dst].append(msg) + + if len(broadcast_bag) == 0: + return + + broadcast_address = None + for dst, messages in broadcast_bag.items(): + for msg_index in messages: + value = self.msg_vectors[msg_index][dst.start:dst.end + 1] + if broadcast_address is None: + broadcast_address = value + elif value.tobytes() != broadcast_address.tobytes(): + # Address is not common across messages so it can't be a broadcast address + return + + addresses_by_participant[-1] = broadcast_address.tobytes() + for dst, messages in broadcast_bag.items(): + dst.values.append(broadcast_address) + dst.message_indices.update(messages) + + def find_addresses(self) -> dict: + already_assigned = list(self.known_addresses_by_participant.keys()) + if len(already_assigned) == len(self.message_indices_by_participant): + self._debug("Skipping find addresses as already known.") + return dict() + + common_ranges_by_participant = dict() + for participant, message_indices in self.message_indices_by_participant.items(): + # Cluster by length + length_clusters = defaultdict(list) + for i in message_indices: + length_clusters[len(self.msg_vectors[i])].append(i) + + common_ranges_by_length = self.find_common_ranges_by_cluster(self.msg_vectors, length_clusters, range_type="hex") + common_ranges_by_participant[participant] = [] + for ranges in common_ranges_by_length.values(): + common_ranges_by_participant[participant].extend(self.ignore_already_labeled(ranges, + self.already_labeled)) + + self._debug("Common ranges by participant:", common_ranges_by_participant) + + result = defaultdict(set) + participants = sorted(common_ranges_by_participant) # type: list[int] + + if len(participants) < 2: + return result + + # If we already know the address length we do not need to bother with other candidates + if len(already_assigned) > 0: + addr_len = len(self.known_addresses_by_participant[already_assigned[0]]) + if any(len(self.known_addresses_by_participant[i]) != addr_len for i in already_assigned): + logger.warning("Addresses do not have a common length. Assuming length of {}".format(addr_len)) + else: + addr_len = None + + for p1, p2 in itertools.combinations(participants, 2): + p1_already_assigned = p1 in already_assigned + p2_already_assigned = p2 in already_assigned + + if p1_already_assigned and p2_already_assigned: + continue + + # common ranges are not merged yet, so there is only one element in values + values1 = [cr.value for cr in common_ranges_by_participant[p1]] + values2 = [cr.value for cr in common_ranges_by_participant[p2]] + for seq1, seq2 in itertools.product(values1, values2): + lcs = self.find_longest_common_sub_sequences(seq1, seq2) + vals = lcs if len(lcs) > 0 else [seq1, seq2] + # Address candidate must be at least 2 values long + for val in filter(lambda v: len(v) >= 2, vals): + if addr_len is not None and len(val) != addr_len: + continue + if not p1_already_assigned and not p2_already_assigned: + result[p1].add(val.tostring()) + result[p2].add(val.tostring()) + elif p1_already_assigned and val.tostring() != self.known_addresses_by_participant[p1].tostring(): + result[p2].add(val.tostring()) + elif p2_already_assigned and val.tostring() != self.known_addresses_by_participant[p2].tostring(): + result[p1].add(val.tostring()) + return result diff --git a/src/urh/awre/engines/ChecksumEngine.py b/src/urh/awre/engines/ChecksumEngine.py new file mode 100644 index 0000000000..9ad5d7a3a7 --- /dev/null +++ b/src/urh/awre/engines/ChecksumEngine.py @@ -0,0 +1,121 @@ +import copy +import math +from collections import defaultdict + +import numpy as np +from urh.util.WSPChecksum import WSPChecksum + +from urh.awre.CommonRange import ChecksumRange +from urh.awre.engines.Engine import Engine +from urh.cythonext import awre_util +from urh.util.GenericCRC import GenericCRC + + +class ChecksumEngine(Engine): + def __init__(self, bitvectors, n_gram_length=8, minimum_score=0.9, already_labeled: list = None): + """ + :type bitvectors: list of np.ndarray + :param bitvectors: bitvectors behind the synchronization + """ + self.bitvectors = bitvectors + self.n_gram_length = n_gram_length + self.minimum_score = minimum_score + if already_labeled is None: + self.already_labeled_cols = set() + else: + self.already_labeled_cols = {e for rng in already_labeled for e in range(*rng)} + + def find(self): + result = list() + bitvectors_by_n_gram_length = defaultdict(list) + for i, bitvector in enumerate(self.bitvectors): + bin_num = int(math.ceil(len(bitvector) / self.n_gram_length)) + bitvectors_by_n_gram_length[bin_num].append(i) + + crc = GenericCRC() + for length, message_indices in bitvectors_by_n_gram_length.items(): + checksums_for_length = [] + for index in message_indices: + bits = self.bitvectors[index] + data_start, data_stop, crc_start, crc_stop = WSPChecksum.search_for_wsp_checksum(bits) + if (data_start, data_stop, crc_start, crc_stop) != (0, 0, 0, 0): + checksum_range = ChecksumRange(start=crc_start, length=crc_stop-crc_start, + data_range_start=data_start, data_range_end=data_stop, + crc=WSPChecksum(), score=1/len(message_indices), + field_type="checksum", message_indices={index}) + try: + present = next(c for c in checksums_for_length if c == checksum_range) + present.message_indices.add(index) + except StopIteration: + checksums_for_length.append(checksum_range) + continue + + crc_object, data_start, data_stop, crc_start, crc_stop = crc.guess_all(bits, + ignore_positions=self.already_labeled_cols) + + if (crc_object, data_start, data_stop, crc_start, crc_stop) != (0, 0, 0, 0, 0): + checksum_range = ChecksumRange(start=crc_start, length=crc_stop - crc_start, + data_range_start=data_start, data_range_end=data_stop, + crc=copy.copy(crc_object), score=1 / len(message_indices), + field_type="checksum", message_indices={index} + ) + + try: + present = next(rng for rng in checksums_for_length if rng == checksum_range) + present.message_indices.add(index) + continue + except StopIteration: + pass + + checksums_for_length.append(checksum_range) + + matching = awre_util.check_crc_for_messages(message_indices, self.bitvectors, + data_start, data_stop, + crc_start, crc_stop, + *crc_object.get_parameters()) + + checksum_range.message_indices.update(matching) + + # Score ranges + for rng in checksums_for_length: + rng.score = len(rng.message_indices) / len(message_indices) + + try: + result.append(max(checksums_for_length, key=lambda x: x.score)) + except ValueError: + pass # no checksums found for this length + + self._debug("Found Checksums", result) + try: + max_scored = max(filter(lambda x: len(x.message_indices) >= 2 and x.score >= self.minimum_score, result), + key=lambda x: x.score) + except ValueError: + return [] + + result = list(filter(lambda x: x.crc == max_scored.crc, result)) + self._debug("Filtered Checksums", result) + + return result + + @staticmethod + def calc_score(diff_frequencies: dict) -> float: + """ + Calculate the score based on the distribution of differences + 1. high if one constant (!= zero) dominates + 2. Other constants (!= zero) should lower the score, zero means sequence number stays same for some messages + + :param diff_frequencies: Frequencies of decimal differences between columns of subsequent messages + e.g. {-255: 3, 1: 1020} means -255 appeared 3 times and 1 appeared 1020 times + :return: a score between 0 and 1 + """ + total = sum(diff_frequencies.values()) + num_zeros = sum(v for k, v in diff_frequencies.items() if k == 0) + if num_zeros == total: + return 0 + + try: + most_frequent = ChecksumEngine.get_most_frequent(diff_frequencies) + except ValueError: + return 0 + + return diff_frequencies[most_frequent] / (total - num_zeros) diff --git a/src/urh/awre/engines/Engine.py b/src/urh/awre/engines/Engine.py new file mode 100644 index 0000000000..3198797d5c --- /dev/null +++ b/src/urh/awre/engines/Engine.py @@ -0,0 +1,85 @@ +from urh.awre.CommonRange import CommonRange +from urh.awre.Histogram import Histogram +import numpy as np +from urh.cythonext import awre_util +import itertools + + +class Engine(object): + _DEBUG_ = False + + def _debug(self, *args): + if self._DEBUG_: + print("[{}]".format(self.__class__.__name__), *args) + + @staticmethod + def find_common_ranges_by_cluster(msg_vectors, clustered_bitvectors, alpha=0.95, range_type="bit"): + """ + + :param alpha: How many percent of values must be equal per range? + :param range_type: Describes what kind of range this is: bit, hex or byte. + Needed for conversion of range start / end later + :type msg_vectors: list of np.ndarray + :type clustered_bitvectors: dict + :rtype: dict[int, list of CommonRange] + """ + histograms = { + cluster: Histogram(msg_vectors, message_indices) + for cluster, message_indices in clustered_bitvectors.items() + } + + common_ranges_by_cluster = { + cluster: histogram.find_common_ranges(alpha=alpha, range_type=range_type) + for cluster, histogram in histograms.items() + } + + return common_ranges_by_cluster + + @staticmethod + def find_common_ranges_exhaustive(msg_vectors, msg_indices, range_type="bit") -> list: + result = [] + + for i, j in itertools.combinations(msg_indices, 2): + for rng in Histogram(msg_vectors, indices=[i, j]).find_common_ranges(alpha=1, range_type=range_type): + try: + common_range = next(cr for cr in result if cr.start == rng.start and cr.value.tobytes() == rng.value.tobytes()) + common_range.message_indices.update({i, j}) + except StopIteration: + result.append(rng) + + return result + + @staticmethod + def ignore_already_labeled(common_ranges, already_labeled): + """ + Shrink the common ranges so that they not overlap with already labeled ranges. + Empty common ranges are removed after shrinking + + :type common_ranges: list of CommonRange + :type already_labeled: list of tuple + :return: list of CommonRange + """ + result = [] + for common_range in common_ranges: + range_result = [common_range] + for start, end in already_labeled: + for rng in range_result[:]: + range_result.remove(rng) + range_result.extend(rng.ensure_not_overlaps(start, end)) + result.extend(range_result) + + return result + + @staticmethod + def find_longest_common_sub_sequences(seq1, seq2) -> list: + result = [] + if seq1 is None or seq2 is None: + return result + + indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2) + for ind in indices: + s = seq1[slice(*ind)] + if len(s) > 0: + result.append(s) + + return result diff --git a/src/urh/awre/engines/LengthEngine.py b/src/urh/awre/engines/LengthEngine.py new file mode 100644 index 0000000000..413a6fa834 --- /dev/null +++ b/src/urh/awre/engines/LengthEngine.py @@ -0,0 +1,193 @@ +import math +from collections import defaultdict + +import numpy as np + +from urh.awre.CommonRange import CommonRange, EmptyCommonRange +from urh.awre.engines.Engine import Engine +from urh.cythonext import awre_util + + +class LengthEngine(Engine): + def __init__(self, bitvectors, already_labeled=None): + """ + + :type bitvectors: list of np.ndarray + :param bitvectors: bitvectors behind the synchronization + """ + self.bitvectors = bitvectors + self.already_labeled = [] if already_labeled is None else already_labeled + + def find(self, n_gram_length=8, minimum_score=0.1): + # Consider the n_gram_length + bitvectors_by_n_gram_length = defaultdict(list) + for i, bitvector in enumerate(self.bitvectors): + bin_num = int(math.ceil(len(bitvector) / n_gram_length)) + bitvectors_by_n_gram_length[bin_num].append(i) + + common_ranges_by_length = self.find_common_ranges_by_cluster(self.bitvectors, + bitvectors_by_n_gram_length, + alpha=0.7) + + for length, ranges in common_ranges_by_length.items(): + common_ranges_by_length[length] = self.ignore_already_labeled(ranges, self.already_labeled) + + self.filter_common_ranges(common_ranges_by_length) + self._debug("Common Ranges:", common_ranges_by_length) + + scored_ranges = self.score_ranges(common_ranges_by_length, n_gram_length) + self._debug("Scored Ranges", scored_ranges) + + # Take the ranges with highest score per cluster if it's score surpasses the minimum score + high_scores_by_length = self.choose_high_scored_ranges(scored_ranges, bitvectors_by_n_gram_length, + minimum_score) + self._debug("Highscored Ranges", high_scores_by_length) + return high_scores_by_length.values() + + @staticmethod + def filter_common_ranges(common_ranges_by_length: dict): + """ + Ranges must be common along length clusters + but their values must differ, so now we rule out all ranges that are + 1. common across clusters AND + 2. have same value + + :return: + """ + + ranges = [r for rng in common_ranges_by_length.values() for r in rng] + for rng in ranges: + count = len([r for r in ranges if rng.start == r.start + and rng.length == r.length + and rng.value.tobytes() == r.value.tobytes()] + ) + if count < 2: + continue + + for length in common_ranges_by_length: + try: + common_ranges_by_length[length].remove(rng) + except ValueError: + pass + + @staticmethod + def score_ranges(common_ranges_by_length: dict, n_gram_length: int): + """ + Calculate score for the common ranges + + :param common_ranges_by_length: + :param n_gram_length: + :return: + """ + + # The window length must be smaller than common range's length + # and is something like 8 in case of on 8 bit integer. + # We make this generic so e.g. 4 bit integers are supported as well + if n_gram_length == 8: + window_lengths = [8, 16, 32, 64] + else: + window_lengths = [n_gram_length * i for i in range(1, 5)] + + scored_ranges = dict() + for length in common_ranges_by_length: + scored_ranges[length] = dict() + for window_length in window_lengths: + scored_ranges[length][window_length] = [] + + byteorders = ["big", "little"] if n_gram_length == 8 else ["big"] + for window_length in window_lengths: + for length, common_ranges in common_ranges_by_length.items(): + for common_range in filter(lambda cr: cr.length >= window_length, common_ranges): + bits = common_range.value + rng_byte_order = "big" + + max_score = max_start = -1 + for start in range(0, len(bits) + 1 - window_length, n_gram_length): + for byteorder in byteorders: + score = LengthEngine.score_bits(bits[start:start + window_length], + length, position=start, byteorder=byteorder) + + if score > max_score: + max_score = score + max_start = start + rng_byte_order = byteorder + + rng = CommonRange(common_range.start + max_start, window_length, + common_range.value[max_start:max_start + window_length], + score=max_score, field_type="length", + message_indices=common_range.message_indices, + range_type=common_range.range_type, + byte_order=rng_byte_order) + scored_ranges[length][window_length].append(rng) + + return scored_ranges + + def choose_high_scored_ranges(self, scored_ranges: dict, bitvectors_by_n_gram_length: dict, minimum_score: float): + + # Set for every window length the highest scored range as candidate + possible_window_lengths = defaultdict(int) + for length, ranges_by_window_length in scored_ranges.items(): + for window_length, ranges in ranges_by_window_length.items(): + try: + ranges_by_window_length[window_length] = max(filter(lambda x: x.score >= minimum_score, ranges), + key=lambda x: x.score) + possible_window_lengths[window_length] += 1 + except ValueError: + ranges_by_window_length[window_length] = None + + try: + # Choose window length -> window length that has a result most often and choose greater on tie + chosen_window_length = max(possible_window_lengths, key=lambda x: (possible_window_lengths[x], x)) + except ValueError: + return dict() + + high_scores_by_length = dict() + + # Choose all ranges with highest score per cluster if score surpasses the minimum score + for length, ranges_by_window_length in scored_ranges.items(): + try: + if ranges_by_window_length[chosen_window_length]: + high_scores_by_length[length] = ranges_by_window_length[chosen_window_length] + except KeyError: + continue + + # If there are length clusters with only one message see if we can assign a range from other clusters + for length, msg_indices in bitvectors_by_n_gram_length.items(): + if len(msg_indices) != 1: + continue + + msg_index = msg_indices[0] + bitvector = self.bitvectors[msg_index] + max_score, best_match = 0, None + + for rng in high_scores_by_length.values(): + bits = bitvector[rng.start:rng.end + 1] + if len(bits) > 0: + score = self.score_bits(bits, length, rng.start) + if score > max_score: + best_match, max_score = rng, score + + if best_match is not None: + high_scores_by_length[length] = CommonRange(best_match.start, best_match.length, + value=bitvector[best_match.start:best_match.end + 1], + score=max_score, field_type="length", + message_indices={msg_index}, range_type="bit") + + return high_scores_by_length + + @staticmethod + def score_bits(bits: np.ndarray, target_length: int, position: int, byteorder="big"): + value = awre_util.bit_array_to_number(bits, len(bits)) + if byteorder == "little": + if len(bits) > 8 and len(bits) % 8 == 0: + n = len(bits) // 8 + value = int.from_bytes(value.to_bytes(n, byteorder="big"), byteorder="little", signed=False) + + # Length field should be at front, so we give lower scores for large starts + f = (1 / (1 + 0.25 * position)) + + return f * LengthEngine.gauss(value, target_length) + + @staticmethod + def gauss(x, mu, sigma=2): + return np.exp(-0.5 * np.power((x - mu) / sigma, 2)) diff --git a/src/urh/awre/engines/SequenceNumberEngine.py b/src/urh/awre/engines/SequenceNumberEngine.py new file mode 100644 index 0000000000..aa64406963 --- /dev/null +++ b/src/urh/awre/engines/SequenceNumberEngine.py @@ -0,0 +1,137 @@ +import numpy as np + +from urh.awre.CommonRange import CommonRange +from urh.awre.engines.Engine import Engine +from urh.cythonext import awre_util + + +class SequenceNumberEngine(Engine): + def __init__(self, bitvectors, n_gram_length=8, minimum_score=0.75, already_labeled: list = None): + """ + + :type bitvectors: list of np.ndarray + :param bitvectors: bitvectors behind the synchronization + """ + self.bitvectors = bitvectors + self.n_gram_length = n_gram_length + self.minimum_score = minimum_score + if already_labeled is None: + self.already_labeled_cols = set() + else: + self.already_labeled_cols = {e // n_gram_length for rng in already_labeled for e in range(*rng)} + + def find(self): + n = self.n_gram_length + + if len(self.bitvectors) < 3: + # We need at least 3 bitvectors to properly find a sequence number + return [] + + diff_matrix = self.create_difference_matrix(self.bitvectors, self.n_gram_length) + diff_frequencies_by_column = dict() + + for j in range(diff_matrix.shape[1]): + unique, counts = np.unique(diff_matrix[:, j], return_counts=True) + diff_frequencies_by_column[j] = dict(zip(unique, counts)) + + self._debug("Diff_frequencies_by_column", diff_frequencies_by_column) + scores_by_column = dict() + for column, frequencies in diff_frequencies_by_column.items(): + if column not in self.already_labeled_cols: + scores_by_column[column] = self.calc_score(frequencies) + else: + scores_by_column[column] = 0 + + self._debug("Scores by column", scores_by_column) + result = [] + for candidate_column in sorted(scores_by_column, key=scores_by_column.get, reverse=True): + score = scores_by_column[candidate_column] + if score < self.minimum_score: + continue + + most_common_diff = self.get_most_frequent(diff_frequencies_by_column[candidate_column]) + message_indices = np.flatnonzero( + # get all rows that have the most common difference or zero + (diff_matrix[:, candidate_column] == most_common_diff) | (diff_matrix[:, candidate_column] == 0) + ) + + # For example, index 1 in diff matrix corresponds to index 1 and 2 of messages + message_indices = set(message_indices) | set(message_indices + 1) + values = set() + for i in message_indices: + values.add(self.bitvectors[i][candidate_column * n:(candidate_column + 1) * n].tobytes()) + + matching_ranges = [r for r in result if r.message_indices == message_indices] + + try: + matching_range = next(r for r in matching_ranges if r.start == (candidate_column - 1) * n + and (r.byte_order_is_unknown or r.byte_order == "big")) + matching_range.length += n + matching_range.byte_order = "big" + matching_range.values.extend(list(values)) + continue + except StopIteration: + pass + + try: + matching_range = next(r for r in matching_ranges if r.start == (candidate_column + 1) * n + and (r.byte_order_is_unknown or r.byte_order == "little")) + matching_range.start -= n + matching_range.length += n + matching_range.byte_order = "little" + matching_range.values.extend(list(values)) + continue + except StopIteration: + pass + + new_range = CommonRange(start=candidate_column * n, length=n, score=score, + field_type="sequence number", message_indices=message_indices, + byte_order=None) + new_range.values.extend(list(values)) + result.append(new_range) + + # At least three different values needed to reliably identify a sequence number + return [rng for rng in result if len(set(rng.values)) > 2] + + @staticmethod + def get_most_frequent(diff_frequencies: dict): + return max(filter(lambda x: x not in (0, -1), diff_frequencies), key=diff_frequencies.get) + + @staticmethod + def calc_score(diff_frequencies: dict) -> float: + """ + Calculate the score based on the distribution of differences + 1. high if one constant (!= zero) dominates + 2. Other constants (!= zero) should lower the score, zero means sequence number stays same for some messages + + :param diff_frequencies: Frequencies of decimal differences between columns of subsequent messages + e.g. {0: 3, 1: 1020} means 0 appeared 3 times and 1 appeared 1020 times + :return: a score between 0 and 1 + """ + total = sum(diff_frequencies.values()) + num_zeros = sum(v for k, v in diff_frequencies.items() if k == 0) + if num_zeros == total: + return 0 + + try: + most_frequent = SequenceNumberEngine.get_most_frequent(diff_frequencies) + except ValueError: + return 0 + + return diff_frequencies[most_frequent] / (total - num_zeros) + + @staticmethod + def create_difference_matrix(bitvectors, n_gram_length: int): + """ + Create the difference matrix e.g. + 10 20 0 + 1 2 3 + 4 5 6 + + means first eight bits of messages 1 and 2 (row 1) differ by 10 if they are considered as decimal number + + :type bitvectors: list of np.ndarray + :type n_gram_length: int + :rtype: np.ndarray + """ + return awre_util.create_seq_number_difference_matrix(bitvectors, n_gram_length) diff --git a/src/urh/awre/components/__init__.py b/src/urh/awre/engines/__init__.py similarity index 100% rename from src/urh/awre/components/__init__.py rename to src/urh/awre/engines/__init__.py diff --git a/src/urh/controller/CompareFrameController.py b/src/urh/controller/CompareFrameController.py index 93769a416d..c7e9424f20 100644 --- a/src/urh/controller/CompareFrameController.py +++ b/src/urh/controller/CompareFrameController.py @@ -1,6 +1,7 @@ import locale import math import os +import traceback from collections import defaultdict from datetime import datetime @@ -9,8 +10,10 @@ QModelIndex from PyQt5.QtGui import QContextMenuEvent, QIcon from PyQt5.QtWidgets import QMessageBox, QAbstractItemView, QUndoStack, QMenu, QWidget, QHeaderView +from urh.util.Errors import Errors from urh import constants +from urh.awre import AutoAssigner from urh.controller.dialogs.MessageTypeDialog import MessageTypeDialog from urh.controller.dialogs.ProtocolLabelDialog import ProtocolLabelDialog from urh.models.LabelValueTableModel import LabelValueTableModel @@ -84,7 +87,7 @@ def __init__(self, plugin_manager: PluginManager, project_manager: ProjectManage self.assign_message_type_action.setChecked(True) self.assign_labels_action = self.analyze_menu.addAction(self.tr("Assign labels")) self.assign_labels_action.setCheckable(True) - self.assign_labels_action.setChecked(False) + self.assign_labels_action.setChecked(True) self.assign_participant_address_action = self.analyze_menu.addAction(self.tr("Assign participant addresses")) self.assign_participant_address_action.setCheckable(True) self.assign_participant_address_action.setChecked(True) @@ -440,6 +443,15 @@ def add_protocol_from_file(self, filename: str) -> ProtocolAnalyzer: messsage_type.name += " (" + os.path.split(filename)[1].rstrip(".xml").rstrip(".proto") + ")" self.proto_analyzer.message_types.append(messsage_type) + update_project = False + for msg in pa.messages: + if msg.participant is not None and msg.participant not in self.project_manager.participants: + self.project_manager.participants.append(msg.participant) + update_project = True + + if update_project: + self.project_manager.project_updated.emit() + self.message_type_table_model.update() self.add_protocol(protocol=pa) @@ -1011,7 +1023,7 @@ def on_btn_analyze_clicked(self): if self.assign_participants_action.isChecked(): for protocol in self.protocol_list: - protocol.auto_assign_participants(self.protocol_model.participants) + AutoAssigner.auto_assign_participants(protocol.messages, self.protocol_model.participants) self.refresh_assigned_participants_ui() self.ui.progressBarLogicAnalyzer.setFormat("%p% (Assign message type by rules)") @@ -1024,16 +1036,23 @@ def on_btn_analyze_clicked(self): self.ui.progressBarLogicAnalyzer.setValue(75) if self.assign_labels_action.isChecked(): - self.proto_analyzer.auto_assign_labels() - self.protocol_model.update() - self.label_value_model.update() - self.message_type_table_model.update() - self.ui.tblViewMessageTypes.clearSelection() + try: + self.proto_analyzer.auto_assign_labels() + self.protocol_model.update() + self.label_value_model.update() + self.message_type_table_model.update() + self.ui.tblViewMessageTypes.clearSelection() + except Exception as e: + logger.exception(e) + Errors.generic_error("Failed to assign labels", + "An error occurred during automatic label assignment", + traceback.format_exc()) self.ui.progressBarLogicAnalyzer.setValue(90) if self.assign_participant_address_action.isChecked(): - self.proto_analyzer.auto_assign_participant_addresses(self.protocol_model.participants) + AutoAssigner.auto_assign_participant_addresses(self.proto_analyzer.messages, + self.protocol_model.participants) self.ui.progressBarLogicAnalyzer.setValue(100) self.unsetCursor() diff --git a/src/urh/controller/MainController.py b/src/urh/controller/MainController.py index ead9d3d946..46ac56f9c3 100644 --- a/src/urh/controller/MainController.py +++ b/src/urh/controller/MainController.py @@ -208,6 +208,7 @@ def create_connects(self): self.on_show_field_types_config_action_triggered) self.compare_frame_controller.load_protocol_clicked.connect(self.on_compare_frame_controller_load_protocol_clicked) + self.compare_frame_controller.ui.listViewParticipants.doubleClicked.connect(self.on_project_settings_action_triggered) self.ui.lnEdtTreeFilter.textChanged.connect(self.on_file_tree_filter_text_changed) @@ -257,6 +258,7 @@ def add_protocol_file(self, filename): proto = self.compare_frame_controller.add_protocol_from_file(filename) if proto: self.__add_empty_frame_for_filename(proto, filename) + self.ui.tabWidget.setCurrentWidget(self.ui.tab_protocol) def add_fuzz_profile(self, filename): self.ui.tabWidget.setCurrentIndex(2) diff --git a/src/urh/controller/widgets/SignalFrame.py b/src/urh/controller/widgets/SignalFrame.py index 820128135f..db1c6bf5ab 100644 --- a/src/urh/controller/widgets/SignalFrame.py +++ b/src/urh/controller/widgets/SignalFrame.py @@ -266,6 +266,8 @@ def set_empty_frame_visibilities(self): self.ui.btnCloseSignal, self.ui.lineEditSignalName): w.hide() + self.adjustSize() + def cancel_filtering(self): self.filter_abort_wanted = True diff --git a/src/urh/cythonext/awre_util.pyx b/src/urh/cythonext/awre_util.pyx new file mode 100644 index 0000000000..833bfb7b9f --- /dev/null +++ b/src/urh/cythonext/awre_util.pyx @@ -0,0 +1,384 @@ +# noinspection PyUnresolvedReferences +cimport numpy as np +import numpy as np + + +from libc.math cimport floor, ceil, pow +from libc.stdlib cimport malloc, free + +from libcpp cimport bool +from libc.stdint cimport uint8_t, uint16_t, uint32_t, uint64_t, int32_t, int8_t, int64_t + +from array import array + + +from urh.cythonext.util import crc + +cpdef set find_longest_common_sub_sequence_indices(np.uint8_t[::1] seq1, np.uint8_t[::1] seq2): + cdef unsigned int i, j, longest = 0, counter = 0, len_bits1 = len(seq1), len_bits2 = len(seq2) + cdef unsigned short max_results = 10, current_result = 0 + + cdef unsigned int[:, ::1] m = np.zeros((len_bits1+1, len_bits2+1), dtype=np.uint32, order="C") + cdef unsigned int[:, ::1] result_indices = np.zeros((max_results, 2), dtype=np.uint32, order="C") + + for i in range(0, len_bits1): + for j in range(0, len_bits2): + if seq1[i] == seq2[j]: + counter = m[i, j] + 1 + m[i+1, j+1] = counter + + if counter > longest: + longest = counter + + current_result = 0 + result_indices[current_result, 0] = i - counter + 1 + result_indices[current_result, 1] = i + 1 + elif counter == longest: + if current_result < max_results - 1: + current_result += 1 + result_indices[current_result, 0] = i - counter + 1 + result_indices[current_result, 1] = i + 1 + + cdef set result = set() + for i in range(current_result+1): + result.add((result_indices[i, 0], result_indices[i, 1])) + + return result + +cpdef uint32_t find_first_difference(uint8_t[::1] bits1, uint8_t[::1] bits2, uint32_t len_bits1, uint32_t len_bits2) nogil: + cdef uint32_t i, smaller_len = min(len_bits1, len_bits2) + + for i in range(0, smaller_len): + if bits1[i] != bits2[i]: + return i + + return smaller_len + +cpdef np.ndarray[np.uint32_t, ndim=2, mode="c"] get_difference_matrix(list bitvectors): + cdef uint32_t i, j, N = len(bitvectors) + cdef np.ndarray[np.uint32_t, ndim=2, mode="c"] result = np.zeros((N, N), dtype=np.uint32, order="C") + + cdef uint8_t[::1] bitvector_i + cdef uint32_t len_bitvector_i + + for i in range(N): + bitvector_i = bitvectors[i] + len_bitvector_i = len(bitvector_i) + for j in range(i + 1, N): + result[i, j] = find_first_difference(bitvector_i, bitvectors[j], len_bitvector_i, len(bitvectors[j])) + + return result + +cpdef list get_hexvectors(list bitvectors): + cdef list result = [] + cdef uint8_t[::1] bitvector + cdef size_t i, j, M, N = len(bitvectors) + + cdef np.ndarray[np.uint8_t, mode="c"] hexvector + cdef size_t len_bitvector + + for i in range(0, N): + bitvector = bitvectors[i] + len_bitvector = len(bitvector) + + M = ceil(len_bitvector / 4) + hexvector = np.zeros(M, dtype=np.uint8, order="C") + + for j in range(0, M): + hexvector[j] = bit_array_to_number(bitvector, min(len_bitvector, 4*j+4), 4*j) + + result.append(hexvector) + + return result + + +cdef int lower_multiple_of_n(int number, int n) nogil: + return n * floor(number / n) + +cdef int64_t find(uint8_t[:] data, int64_t len_data, uint8_t element, int64_t start=0) nogil: + cdef int64_t i + for i in range(start, len_data): + if data[i] == element: + return i + return -1 + +cpdef tuple get_raw_preamble_position(uint8_t[:] bitvector): + cdef int64_t N = len(bitvector) + if N == 0: + return 0, 0 + + cdef int64_t i, j, n, m, start = -1 + cdef double k = 0 + + cdef int64_t lower = 0, upper = 0 + cdef uint8_t a, b + + cdef uint8_t* preamble_pattern = NULL + cdef int64_t len_preamble_pattern, preamble_end + + cdef bool preamble_end_reached + + while k < 2 and start < N: + start += 1 + + a = bitvector[start] + b = 1 if a == 0 else 0 + + # now we search for the pattern a^n b^m + n = find(bitvector, N, b, start) - start + + if n <= 0: + return 0, 0, 0 + + m = find(bitvector, N, a, start+n) - n - start + + if m <= 0: + return 0, 0, 0 + + #preamble_pattern = a * n + b * m + len_preamble_pattern = n + m + preamble_pattern = malloc(len_preamble_pattern * sizeof(uint8_t)) + + for j in range(0, n): + preamble_pattern[j] = a + for j in range(n, len_preamble_pattern): + preamble_pattern[j] = b + + preamble_end = start + preamble_end_reached = False + for i in range(start, N, len_preamble_pattern): + if preamble_end_reached: + break + for j in range(0, len_preamble_pattern): + if bitvector[i+j] != preamble_pattern[j]: + preamble_end_reached = True + preamble_end = i + break + + free(preamble_pattern) + + upper = start + lower_multiple_of_n(preamble_end + 1 - start, len_preamble_pattern) + lower = upper - len_preamble_pattern + + k = (upper - start) / len_preamble_pattern + + if k > 2: + return start, lower, upper + else: + # no preamble found + return 0, 0, 0 + + +cpdef dict find_possible_sync_words(np.ndarray[np.uint32_t, ndim=2, mode="c"] difference_matrix, + np.ndarray[np.uint32_t, ndim=2, mode="c"] raw_preamble_positions, + list bitvectors, int n_gram_length): + cdef dict possible_sync_words = dict() + + cdef uint32_t i, j, num_rows = difference_matrix.shape[0], num_cols = difference_matrix.shape[1] + cdef uint32_t sync_len, sync_end, start, index, k, n + + cdef bytes sync_word + + cdef np.ndarray[np.uint8_t, mode="c"] bitvector + + cdef uint8_t ij_ctr = 0 + cdef uint32_t* ij_arr = malloc(2 * sizeof(uint32_t)) + cdef uint8_t* temp = NULL + + for i in range(0, num_rows): + for j in range(i + 1, num_cols): + # position of first difference between message i and j + sync_end = difference_matrix[i, j] + + if sync_end == 0: + continue + + ij_arr[0] = i + ij_arr[1] = j + + for k in range(0, 2): + for ij_ctr in range(0, 2): + index = ij_arr[ij_ctr] + start = raw_preamble_positions[index, 0] + raw_preamble_positions[index, k + 1] + + # We take the next lower multiple of n for the sync len + # In doubt, it is better to under estimate the sync len to prevent it from + # taking needed values from other fields e.g. leading zeros for a length field + sync_len = max(0, lower_multiple_of_n(sync_end - start, n_gram_length)) + + if sync_len >= 2: + bitvector = bitvectors[index] + if sync_len == 2: + # Sync word must not be empty or just two bits long and "10" or "01" because + # that would be indistinguishable from the preamble + if bitvector[start] == 0 and bitvector[start+1] == 1: + continue + if bitvector[start] == 1 and bitvector[start+1] == 0: + continue + + temp = malloc(sync_len * sizeof(uint8_t)) + for n in range(0, sync_len): + temp[n] = bitvector[start+n] + sync_word = temp[:sync_len] + free(temp) + + possible_sync_words.setdefault(sync_word, 0) + if (start + sync_len) % n_gram_length == 0: + # if sync end aligns nicely at n gram length give it a larger score + possible_sync_words[sync_word] += 1 + else: + possible_sync_words[sync_word] += 0.5 + + free(ij_arr) + return possible_sync_words + +cpdef np.ndarray[np.float64_t] create_difference_histogram(list vectors, list active_indices): + """ + Return a histogram of common ranges. E.g. [1, 1, 0.75, 0.8] means 75% of values at third column are equal + + :param vectors: Vectors over which differences the histogram will be created + :param active_indices: Active indices of vectors. Vectors with index not in this list will be ignored + :return: + """ + cdef unsigned long i,j,k,index_i,index_j, L = len(active_indices) + cdef unsigned long longest = 0, len_vector, len_vector_i + for i in active_indices: + len_vector = len(vectors[i]) + if len_vector > longest: + longest = len_vector + + cdef np.ndarray[np.float64_t] histogram = np.zeros(longest, dtype=np.float64) + cdef double n = (len(active_indices) * (len(active_indices) - 1)) // 2 + + cdef np.ndarray[np.uint8_t] bitvector_i, bitvector_j + + for i in range(0, L - 1): + index_i = active_indices[i] + bitvector_i = vectors[index_i] + len_vector_i = len(bitvector_i) + for j in range(i+1, L): + index_j = active_indices[j] + bitvector_j = vectors[index_j] + for k in range(0, min(len_vector_i, len(bitvector_j))): + if bitvector_i[k] == bitvector_j[k]: + histogram[k] += 1 / n + return histogram + +cpdef list find_occurrences(np.uint8_t[::1] a, np.uint8_t[::1] b, + unsigned long[:] ignore_indices=None, bool return_after_first=False): + """ + Find the indices of occurrences of b in a. + + :param a: Larger array + :param b: Subarray to search for + :return: List of start indices of b in a + """ + cdef unsigned long i, j + cdef unsigned long len_a = len(a), len_b = len(b) + + cdef bool ignore_indices_present = ignore_indices is not None + + if len_b > len_a: + return [] + + cdef list result = [] + cdef bool found + for i in range(0, (len_a-len_b) + 1): + found = True + for j in range(0, len_b): + if ignore_indices_present: + if i+j in ignore_indices: + found = False + break + + if a[i+j] != b[j]: + found = False + break + if found: + if return_after_first: + return [i] + else: + result.append(i) + + return result + +cpdef unsigned long long bit_array_to_number(uint8_t[::1] bits, int64_t end, int64_t start=0) nogil: + if end < 1: + return 0 + + cdef long long i, acc = 1 + cdef unsigned long long result = 0 + + for i in range(start, end): + result += bits[end-1-i+start] * acc + acc *= 2 + + return result + +cpdef np.ndarray[np.int32_t, ndim=2, mode="c"] create_seq_number_difference_matrix(list bitvectors, int n_gram_length): + """ + Create the difference matrix e.g. + 10 20 0 + 1 2 3 + 4 5 6 + + means first eight bits of messages 1 and 2 (row 1) differ by 10 if they are considered as decimal number + + :type bitvectors: list of np.ndarray + :type n_gram_length: int + :rtype: np.ndarray + """ + cdef size_t max_len = len(max(bitvectors, key=len)) + cdef size_t i, j, k, index, N = len(bitvectors), M = ceil(max_len / n_gram_length) + cdef uint8_t[::1] bv1, bv2 + cdef size_t len_bv1, len_bv2 + cdef int32_t diff + cdef int32_t n_gram_power_two = pow(2, n_gram_length) + + cdef np.ndarray[np.int32_t, ndim=2, mode="c"] result = np.full((N - 1, M), -1, dtype=np.int32) + for i in range(1, N): + bv1 = bitvectors[i - 1] + bv2 = bitvectors[i] + len_bv1 = len(bv1) + len_bv2 = len(bv2) + k = min(len_bv1, len_bv2) + for j in range(0, k, n_gram_length): + index = j / n_gram_length + if index < M: + diff = bit_array_to_number(bv2, min(len_bv2, j + n_gram_length), j) -\ + bit_array_to_number(bv1, min(len_bv1, j+n_gram_length), j) + # add + n_gram_power_two because in C modulo can be negative + result[i - 1, index] = (diff + n_gram_power_two) % n_gram_power_two + + return result + +cpdef set check_crc_for_messages(list message_indices, list bitvectors, + unsigned long data_start, unsigned long data_stop, + unsigned long crc_start, unsigned long crc_stop, + unsigned char[:] crc_polynomial, unsigned char[:] crc_start_value, + unsigned char[:] crc_final_xor, + bool crc_lsb_first, bool crc_reverse_polynomial, + bool crc_reverse_all, bool crc_little_endian): + """ + Check a configurable subset of bitvectors for a matching CRC and return the indices of the + vectors who match the CRC with the given parameters + :return: + """ + cdef set result = set() + cdef unsigned long j, index, end = len(message_indices) + cdef np.ndarray[np.uint8_t] bits + cdef unsigned char[:] crc_input + cdef unsigned long long check + + for j in range(0, end): + index = message_indices[j] + bits = bitvectors[index] + crc_input = bits[data_start:data_stop] + #check = int("".join(map(str, bits[crc_start:crc_stop])), 2) + check = bit_array_to_number(bits[crc_start:crc_stop], crc_stop - crc_start) + if crc(crc_input, crc_polynomial, crc_start_value, crc_final_xor, + crc_lsb_first, crc_reverse_polynomial, + crc_reverse_all, crc_little_endian) == check: + result.add(index) + + return result diff --git a/src/urh/cythonext/util.pyx b/src/urh/cythonext/util.pyx index 714773132e..037010fd57 100644 --- a/src/urh/cythonext/util.pyx +++ b/src/urh/cythonext/util.pyx @@ -6,9 +6,11 @@ import numpy as np # because it can lead to OS X error: https://github.com/jopohl/urh/issues/273 # np.import_array() +from libc.stdint cimport uint8_t, uint16_t, uint32_t, uint64_t +from libc.stdlib cimport malloc, calloc, free cimport cython from cython.parallel import prange -from libc.math cimport log10 +from libc.math cimport log10,pow from libcpp cimport bool cpdef tuple minmax(float[:] arr): @@ -29,48 +31,6 @@ cpdef tuple minmax(float[:] arr): return minimum, maximum - -cpdef np.ndarray[np.int8_t, ndim=3] build_xor_matrix(list bitvectors): - cdef unsigned int maximum = 0 - cdef np.int8_t[:] bitvector_i, bitvector_j - cdef int i, j, l - for i in range(0, len(bitvectors)): - bitvector_i = bitvectors[i] - if maximum < len(bitvector_i): - maximum = len(bitvector_i) - - cdef np.ndarray[np.int8_t, ndim=3] result = np.full((len(bitvectors), len(bitvectors), maximum), -1, dtype=np.int8, order="C") - - for i in range(len(bitvectors)): - bitvector_i = bitvectors[i] - for j in range(i+1, len(bitvectors)): - bitvector_j = bitvectors[j] - l = min(len(bitvector_i), len(bitvector_j)) - for k in range(0, l): - result[i,j,k] = bitvector_i[k] ^ bitvector_j[k] - - return result - - -cpdef str longest_common_substring(str s1, str s2): - cdef int len_s1 = len(s1) - cdef int len_s2 = len(s2) - cdef np.int_t[:, ::1] m = np.zeros((len_s1+1, len_s2+1), dtype=np.int, order="C") - cdef int longest = 0 - cdef int x_longest = 0 - cdef int x, y - - for x in range(1, 1 + len_s1): - for y in range(1, 1 + len_s2): - if s1[x - 1] == s2[y - 1]: - m[x, y] = m[x - 1, y - 1] + 1 - if m[x, y] > longest: - longest = m[x, y] - x_longest = x - else: - m[x, y] = 0 - return s1[x_longest - longest: x_longest] - cpdef np.ndarray[np.float32_t, ndim=2] arr2decibel(np.ndarray[np.complex64_t, ndim=2] arr): cdef long long x = arr.shape[0] cdef long long y = arr.shape[1] @@ -83,11 +43,11 @@ cpdef np.ndarray[np.float32_t, ndim=2] arr2decibel(np.ndarray[np.complex64_t, nd result[i, j] = factor * log10(arr[i, j].real * arr[i, j].real + arr[i, j].imag * arr[i, j].imag) return result -cpdef unsigned long long arr_to_number(unsigned char[:] inpt, bool reverse, unsigned int start = 0): - cdef unsigned long long result = 0 +cpdef uint64_t arr_to_number(uint8_t[:] inpt, bool reverse = False, unsigned int start = 0): + cdef uint64_t result = 0 cdef unsigned int i, len_inpt = len(inpt) for i in range(start, len_inpt): - if reverse == False: + if not reverse: if inpt[len_inpt - 1 - i + start]: result |= (1 << (i-start)) else: @@ -95,16 +55,16 @@ cpdef unsigned long long arr_to_number(unsigned char[:] inpt, bool reverse, unsi result |= (1 << (i-start)) return result -cpdef unsigned long long crc(unsigned char[:] inpt, unsigned char[:] polynomial, unsigned char[:] start_value, unsigned char[:] final_xor, bool lsb_first, bool reverse_polynomial, bool reverse_all, bool little_endian): +cpdef uint64_t crc(uint8_t[:] inpt, uint8_t[:] polynomial, uint8_t[:] start_value, uint8_t[:] final_xor, bool lsb_first, bool reverse_polynomial, bool reverse_all, bool little_endian): cdef unsigned int len_inpt = len(inpt) cdef unsigned int i, idx, poly_order = len(polynomial) - cdef unsigned long long crc_mask = (2**(poly_order - 1) - 1) - cdef unsigned long long poly_mask = (crc_mask + 1) >> 1 - cdef unsigned long long poly_int = arr_to_number(polynomial, reverse_polynomial, 1) & crc_mask + cdef uint64_t crc_mask = pow(2, poly_order - 1) - 1 + cdef uint64_t poly_mask = (crc_mask + 1) >> 1 + cdef uint64_t poly_int = arr_to_number(polynomial, reverse_polynomial, 1) & crc_mask cdef unsigned short j, x # start value - cdef unsigned long long temp, crc = arr_to_number(start_value, False, 0) & crc_mask + cdef uint64_t temp, crc = arr_to_number(start_value, False, 0) & crc_mask for i in range(0, len_inpt+7, 8): for j in range(0, 8): @@ -131,54 +91,118 @@ cpdef unsigned long long crc(unsigned char[:] inpt, unsigned char[:] polynomial, temp = 0 for i in range(0, poly_order - 1): if crc & (1 << i): - temp |= (1 << (poly_order -2 -i)) + temp |= (1 << (poly_order - 2 - i)) crc = temp & crc_mask # little endian encoding, different for 16, 32, 64 bit if poly_order - 1 == 16 and little_endian: crc = ((crc << 8) & 0xFF00) | (crc >> 8) elif poly_order - 1 == 32 and little_endian: - crc = ((crc << 24) & 0xFF000000) | ((crc << 8) & 0x00FF0000) | ((crc >> 8) & 0x0000FF00) | (crc >> 24) + crc = ((crc << 24) & 0xFF000000) | ((crc << 8) & 0x00FF0000) | ((crc >> 8) & 0x0000FF00) | (crc >> 24) elif poly_order - 1 == 64 and little_endian: - crc = ((crc << 56) & 0xFF00000000000000) | (crc >> 56) \ - | ((crc >> 40) & 0x000000000000FF00) | ((crc << 40) & 0x00FF000000000000) \ - | ((crc << 24) & 0x0000FF0000000000) | ((crc >> 24) & 0x0000000000FF0000) \ - | ((crc << 8) & 0x000000FF00000000) | ((crc >> 8) & 0x00000000FF000000) + crc = ((crc << 56) & 0xFF00000000000000) | (crc >> 56) \ + | ((crc >> 40) & 0x000000000000FF00) | ((crc << 40) & 0x00FF000000000000) \ + | ((crc << 24) & 0x0000FF0000000000) | ((crc >> 24) & 0x0000000000FF0000) \ + | ((crc << 8) & 0x000000FF00000000) | ((crc >> 8) & 0x00000000FF000000) return crc & crc_mask -cpdef tuple get_crc_datarange(unsigned char[:] inpt, unsigned char[:] polynomial, unsigned char[:] vrfy_crc, unsigned char[:] start_value, unsigned char[:] final_xor, bool lsb_first, bool reverse_polynomial, bool reverse_all, bool little_endian): +cpdef np.ndarray[np.uint64_t, ndim=1] calculate_cache(uint8_t[:] polynomial, bool reverse_polynomial=False, uint8_t bits=8): + cdef uint8_t j, poly_order = len(polynomial) + cdef uint64_t crc_mask = pow(2, poly_order - 1) - 1 + cdef uint64_t poly_mask = (crc_mask + 1) >> 1 + cdef uint64_t poly_int = arr_to_number(polynomial, reverse_polynomial, 1) & crc_mask + cdef uint64_t crcv, i + cdef np.ndarray[np.uint64_t, ndim=1] cache = np.zeros( pow(2, bits), dtype = np.uint64) + # Caching + for i in range(0, len(cache)): + crcv = i << (poly_order - 1 - bits) + for _ in range(0, bits): + if (crcv & poly_mask) > 0: + crcv = (crcv << 1) & crc_mask + crcv ^= poly_int + else: + crcv = (crcv << 1) & crc_mask + cache[i] = crcv + return cache + +cpdef uint64_t cached_crc(uint64_t[:] cache, uint8_t bits, uint8_t[:] inpt, uint8_t[:] polynomial, uint8_t[:] start_value, uint8_t[:] final_xor, bool lsb_first, bool reverse_polynomial, bool reverse_all, bool little_endian): cdef unsigned int len_inpt = len(inpt) - cdef unsigned int i, idx, offset, data_end = 0, poly_order = len(polynomial) - cdef np.ndarray[np.uint64_t, ndim=1] steps = np.empty(len_inpt+2, dtype=np.uint64) - cdef unsigned long long temp - cdef unsigned long long crc_mask = (2**(poly_order - 1) - 1) - cdef unsigned long long poly_mask = (crc_mask + 1) >> 1 - cdef unsigned long long poly_int = arr_to_number(polynomial, reverse_polynomial, 1) & crc_mask - cdef unsigned long long final_xor_int = arr_to_number(final_xor, False, 0) & crc_mask - cdef unsigned long long vrfy_crc_int = arr_to_number(vrfy_crc, False, 0) & crc_mask - cdef unsigned long long crcvalue = arr_to_number(start_value, False, 0) & crc_mask - cdef unsigned short j = 0, len_crc = poly_order - 1 - cdef bool found + cdef unsigned int i, poly_order = len(polynomial) + cdef uint64_t crc_mask = pow(2, poly_order - 1) - 1 + cdef uint64_t poly_mask = (crc_mask + 1) >> 1 + cdef uint64_t poly_int = arr_to_number(polynomial, reverse_polynomial, 1) & crc_mask + cdef uint64_t temp, crcv, data, pos + cdef uint8_t j + + # For inputs smaller than 8 bits, call normal function + if len_inpt < bits: + return crc(inpt, polynomial, start_value, final_xor, lsb_first, reverse_polynomial, reverse_all, little_endian) + + # CRC + crcv = arr_to_number(start_value, False, 0) & crc_mask + for i in range(0, len_inpt - bits + 1, bits): + data = 0 + if lsb_first: + for j in range(0, bits): + if inpt[i + j]: + data |= (1 << j) + else: + for j in range(0, bits): + if inpt[i + bits - 1 - j]: + data |= (1 << j) + pos = (crcv >> (poly_order - bits - 1)) ^ data + crcv = ((crcv << bits) ^ cache[pos]) & crc_mask + + # Are we done? + if len_inpt % bits > 0: + # compute rest of crc inpt[-(len_inpt%8):] with normal function + # Set start_value to current crc value + for i in range(0, len(start_value)): + start_value[len(start_value) - 1 - i] = True if (crcv & (1 << i)) > 0 else False + crcv = crc(inpt[len_inpt-(len_inpt%bits):len_inpt], polynomial, start_value, final_xor, lsb_first, reverse_polynomial, reverse_all, little_endian) + else: + # final XOR + crcv ^= arr_to_number(final_xor, False, 0) & crc_mask - # Find data_end (beginning of crc) - if len_inpt <= len_crc or len_crc != len(vrfy_crc): - return 0, 0 - for data_end in range(len_inpt - len_crc, -1, -1): - i = 0 - for j in range(0, len_crc): - if vrfy_crc[j] == inpt[data_end+j]: - i += 1 - else: - continue - if i == len_crc: - break - if data_end <= 0: # Could not find crc position + # reverse all bits + if reverse_all: + temp = 0 + for i in range(0, poly_order - 1): + if crcv & (1 << i): + temp |= (1 << (poly_order - 2 - i)) + crcv = temp & crc_mask + + # little endian encoding, different for 16, 32, 64 bit + if poly_order - 1 == 16 and little_endian: + crcv = ((crcv << 8) & 0xFF00) | (crcv >> 8) + elif poly_order - 1 == 32 and little_endian: + crcv = ((crcv << 24) & 0xFF000000) | ((crcv << 8) & 0x00FF0000) | ((crcv >> 8) & 0x0000FF00) | (crcv >> 24) + elif poly_order - 1 == 64 and little_endian: + crcv = ((crcv << 56) & 0xFF00000000000000) | (crcv >> 56) \ + | ((crcv >> 40) & 0x000000000000FF00) | ((crcv << 40) & 0x00FF000000000000) \ + | ((crcv << 24) & 0x0000FF0000000000) | ((crcv >> 24) & 0x0000000000FF0000) \ + | ((crcv << 8) & 0x000000FF00000000) | ((crcv >> 8) & 0x00000000FF000000) + return crcv & crc_mask + +cpdef tuple get_crc_datarange(uint8_t[:] inpt, uint8_t[:] polynomial, uint64_t vrfy_crc_start, uint8_t[:] start_value, uint8_t[:] final_xor, bool lsb_first, bool reverse_polynomial, bool reverse_all, bool little_endian): + cdef uint32_t len_inpt = len(inpt), poly_order = len(polynomial) + cdef uint8_t j = 0, len_crc = poly_order - 1 + + if vrfy_crc_start-1+len_crc >= len_inpt or vrfy_crc_start < 2: return 0, 0 - # leads to https://github.com/jopohl/urh/issues/463 - #step = [1] + [0] * (len_inpt - 1) - step = [0] * len_inpt + cdef uint64_t* steps = calloc(len_inpt+2, sizeof(uint64_t)) + cdef uint64_t temp + cdef uint64_t crc_mask = pow(2, poly_order - 1) - 1 + cdef uint64_t poly_mask = (crc_mask + 1) >> 1 + cdef uint64_t poly_int = arr_to_number(polynomial, reverse_polynomial, 1) & crc_mask + cdef uint64_t final_xor_int = arr_to_number(final_xor, False, 0) & crc_mask + cdef uint64_t vrfy_crc_int = arr_to_number(inpt[vrfy_crc_start:vrfy_crc_start+len_crc], False, 0) & crc_mask + cdef uint64_t crcvalue = arr_to_number(start_value, False, 0) & crc_mask + cdef bool found + cdef uint32_t i, idx, offset, data_end = vrfy_crc_start + cdef uint8_t* step = calloc(len_inpt, sizeof(uint8_t)) step[0] = 1 # crcvalue is initialized with start_value @@ -201,45 +225,52 @@ cpdef tuple get_crc_datarange(unsigned char[:] inpt, unsigned char[:] polynomial # Save steps XORed with final_xor steps[idx] = crcvalue ^ final_xor_int - # Reverse and little endian - for i in range(0, data_end): - # reverse all bits - if reverse_all: - temp = 0 - for j in range(0, poly_order - 1): - if steps[i] & (1 << j): - temp |= (1 << (poly_order -2 - j)) - steps[j] = temp & crc_mask + free(step) - # little endian encoding, different for 16, 32, 64 bit - if poly_order - 1 == 16 and little_endian: - steps[i] = ((steps[i] << 8) & 0xFF00) | (steps[i] >> 8) - elif poly_order - 1 == 32 and little_endian: - steps[i] = ((steps[i] << 24) & 0xFF000000) | ((steps[i] << 8) & 0x00FF0000) | ((steps[i] >> 8) & 0x0000FF00) | (steps[i] >> 24) - elif poly_order - 1 == 64 and little_endian: - steps[i] = ((steps[i] << 56) & 0xFF00000000000000) | (steps[i] >> 56) \ - | ((steps[i] >> 40) & 0x000000000000FF00) | ((steps[i] << 40) & 0x00FF000000000000) \ - | ((steps[i] << 24) & 0x0000FF0000000000) | ((steps[i] >> 24) & 0x0000000000FF0000) \ - | ((steps[i] << 8) & 0x000000FF00000000) | ((steps[i] >> 8) & 0x00000000FF000000) + # Reverse and little endian + if reverse_all or little_endian: + for i in range(0, data_end): + # reverse all bits + if reverse_all: + temp = 0 + for j in range(0, poly_order - 1): + if steps[i] & (1 << j): + temp |= (1 << (poly_order -2 - j)) + steps[j] = temp & crc_mask + + # little endian encoding, different for 16, 32, 64 bit + if poly_order - 1 == 16 and little_endian: + steps[i] = ((steps[i] << 8) & 0xFF00) | (steps[i] >> 8) + elif poly_order - 1 == 32 and little_endian: + steps[i] = ((steps[i] << 24) & 0xFF000000) | ((steps[i] << 8) & 0x00FF0000) | ((steps[i] >> 8) & 0x0000FF00) | (steps[i] >> 24) + elif poly_order - 1 == 64 and little_endian: + steps[i] = ((steps[i] << 56) & 0xFF00000000000000) | (steps[i] >> 56) \ + | ((steps[i] >> 40) & 0x000000000000FF00) | ((steps[i] << 40) & 0x00FF000000000000) \ + | ((steps[i] << 24) & 0x0000FF0000000000) | ((steps[i] >> 24) & 0x0000000000FF0000) \ + | ((steps[i] << 8) & 0x000000FF00000000) | ((steps[i] >> 8) & 0x00000000FF000000) # Test data range from 0...start_crc until start_crc-1...start_crc # Compute start value crcvalue = crc(inpt[:data_end], polynomial, start_value, final_xor, lsb_first, reverse_polynomial, reverse_all, little_endian) - if vrfy_crc_int == crcvalue: - return 0, data_end - found = False - i = 0 - while i < data_end - 1: - offset = 0 - while (inpt[i + offset] == False and i+offset < data_end - 1): # skip leading 0s in data (doesn't change crc...) - offset += 1 - # XOR delta=crc(10000...) to last crc value to create next crc value - crcvalue ^= steps[data_end-i-offset-1] - if found: - return i, data_end # Return start_data, end_data + try: if vrfy_crc_int == crcvalue: - found = True - i += 1 + offset + return 0, data_end + found = False - # No beginning found - return 0, 0 \ No newline at end of file + i = 0 + while i < data_end - 1: + offset = 0 + while inpt[i + offset] == False and i+offset < data_end - 1: # skip leading 0s in data (doesn't change crc...) + offset += 1 + # XOR delta=crc(10000...) to last crc value to create next crc value + crcvalue ^= steps[data_end-i-offset-1] + if found: + return i, data_end # Return start_data, end_data + if vrfy_crc_int == crcvalue: + found = True + i += 1 + offset + + # No beginning found + return 0, 0 + finally: + free(steps) \ No newline at end of file diff --git a/src/urh/models/ParticipantTableModel.py b/src/urh/models/ParticipantTableModel.py index d403e14a4e..58919b95d5 100644 --- a/src/urh/models/ParticipantTableModel.py +++ b/src/urh/models/ParticipantTableModel.py @@ -37,7 +37,7 @@ def headerData(self, section, orientation, role=Qt.DisplayRole): return super().headerData(section, orientation, role) def data(self, index: QModelIndex, role=Qt.DisplayRole): - if role == Qt.DisplayRole: + if role == Qt.DisplayRole or role == Qt.EditRole: i = index.row() j = index.column() part = self.participants[i] diff --git a/src/urh/signalprocessing/FieldType.py b/src/urh/signalprocessing/FieldType.py index 41325837e7..dfce28b76c 100644 --- a/src/urh/signalprocessing/FieldType.py +++ b/src/urh/signalprocessing/FieldType.py @@ -16,10 +16,12 @@ class Function(Enum): SRC_ADDRESS = "source address" DST_ADDRESS = "destination address" SEQUENCE_NUMBER = "sequence number" + TYPE = "type" + DATA = "data" CHECKSUM = "checksum" CUSTOM = "custom" - def __init__(self, caption: str, function: Function, display_format_index:int = None): + def __init__(self, caption: str, function: Function, display_format_index: int = None): self.caption = caption self.function = function @@ -41,6 +43,14 @@ def __eq__(self, other): def __repr__(self): return "FieldType: {0} - {1} ({2})".format(self.function.name, self.caption, self.display_format_index) + @staticmethod + def from_caption(caption: str): + try: + ft_function = FieldType.Function(caption) + except ValueError: + return None + return FieldType(caption, ft_function) + @staticmethod def default_field_types(): """ diff --git a/src/urh/signalprocessing/Message.py b/src/urh/signalprocessing/Message.py index e22f36e00d..bf84d6f35e 100644 --- a/src/urh/signalprocessing/Message.py +++ b/src/urh/signalprocessing/Message.py @@ -466,6 +466,12 @@ def from_plain_bits_str(bits, pause=0): plain_bits = list(map(int, bits)) return Message(plain_bits=plain_bits, pause=pause, message_type=MessageType("none")) + @staticmethod + def from_plain_hex_str(hex_str, pause=0): + lut = {"{0:x}".format(i): "{0:04b}".format(i) for i in range(16)} + bits = "".join((lut[h] for h in hex_str)) + return Message.from_plain_bits_str(bits, pause) + def to_xml(self, decoders=None, include_message_type=False, write_bits=False) -> ET.Element: root = ET.Element("message") root.set("message_type_id", self.message_type.id) diff --git a/src/urh/signalprocessing/MessageType.py b/src/urh/signalprocessing/MessageType.py index 78c0bbe9b3..f7a93b47c8 100644 --- a/src/urh/signalprocessing/MessageType.py +++ b/src/urh/signalprocessing/MessageType.py @@ -67,6 +67,26 @@ def unlabeled_ranges(self): """ return self.__get_unlabeled_ranges_from_labels(self) + def __create_label(self, name: str, start: int, end: int, color_index: int, auto_created: bool, + field_type: FieldType): + if field_type is not None: + if field_type.function == FieldType.Function.CHECKSUM: + # If we have sync or preamble labels start behind last one: + pre_sync_label_ends = [lbl.end for lbl in self if lbl.is_preamble or lbl.is_sync] + if len(pre_sync_label_ends) > 0: + range_start = max(pre_sync_label_ends) + else: + range_start = 0 + + if range_start >= start: + range_start = 0 + + return ChecksumLabel(name=name, start=start, end=end, color_index=color_index, field_type=field_type, + auto_created=auto_created, data_range_start=range_start) + + return ProtocolLabel(name=name, start=start, end=end, color_index=color_index, field_type=field_type, + auto_created=auto_created) + @staticmethod def __get_unlabeled_ranges_from_labels(labels): """ @@ -93,10 +113,19 @@ def unlabeled_ranges_with_other_mt(self, other_message_type): labels.sort() return self.__get_unlabeled_ranges_from_labels(labels) + def get_first_label_with_type(self, field_type: FieldType.Function) -> ProtocolLabel: + return next((lbl for lbl in self if lbl.field_type and lbl.field_type.function == field_type), None) + + def num_labels_with_type(self, field_type: FieldType.Function) -> int: + return len([lbl for lbl in self if lbl.field_type and lbl.field_type.function == field_type]) + def append(self, lbl: ProtocolLabel): super().append(lbl) self.sort() + def give_new_id(self): + self.__id = str(uuid.uuid4()) + def add_protocol_label(self, start: int, end: int, name=None, color_ind=None, auto_created=False, type: FieldType = None) -> ProtocolLabel: @@ -119,6 +148,10 @@ def add_protocol_label(self, start: int, end: int, name=None, color_ind=None, return proto_label # Return label to set editor focus after adding + def add_protocol_label_start_length(self, start: int, length: int, name=None, color_ind=None, + auto_created=False, type: FieldType = None) -> ProtocolLabel: + return self.add_protocol_label(start, start + length - 1, name, color_ind, auto_created, type) + def add_label(self, lbl: ProtocolLabel, allow_overlapping=True): if allow_overlapping or not any(lbl.overlaps_with(l) for l in self): added = self.add_protocol_label(lbl.start, lbl.end - 1, name=lbl.name, color_ind=lbl.color_index) @@ -131,20 +164,6 @@ def remove(self, lbl: ProtocolLabel): else: logger.warning(lbl.name + " is not in set, so cant be removed") - def to_xml(self) -> ET.Element: - result = ET.Element("message_type", attrib={"name": self.name, "id": self.id, - "assigned_by_ruleset": "1" if self.assigned_by_ruleset else "0", - "assigned_by_logic_analyzer": "1" if self.assigned_by_logic_analyzer else "0"}) - for lbl in self: - try: - result.append(lbl.to_xml()) - except TypeError: - logger.error("Could not save label: " + str(lbl)) - - result.append(self.ruleset.to_xml()) - - return result - def change_field_type_of_label(self, label: ProtocolLabel, field_type: FieldType): if not isinstance(label, ProtocolLabel) and hasattr(label, "field_type"): # In case of SimulatorProtocolLabel @@ -158,25 +177,19 @@ def change_field_type_of_label(self, label: ProtocolLabel, field_type: FieldType else: label.field_type = field_type - def __create_label(self, name: str, start: int, end: int, color_index: int, auto_created: bool, - field_type: FieldType): - if field_type is not None: - if field_type.function == FieldType.Function.CHECKSUM: - # If we have sync or preamble labels start behind last one: - pre_sync_label_ends = [lbl.end for lbl in self if lbl.is_preamble or lbl.is_sync] - if len(pre_sync_label_ends) > 0: - range_start = max(pre_sync_label_ends) - else: - range_start = 0 - - if range_start >= start: - range_start = 0 + def to_xml(self) -> ET.Element: + result = ET.Element("message_type", attrib={"name": self.name, "id": self.id, + "assigned_by_ruleset": "1" if self.assigned_by_ruleset else "0", + "assigned_by_logic_analyzer": "1" if self.assigned_by_logic_analyzer else "0"}) + for lbl in self: + try: + result.append(lbl.to_xml()) + except TypeError: + logger.error("Could not save label: " + str(lbl)) - return ChecksumLabel(name=name, start=start, end=end, color_index=color_index, field_type=field_type, - auto_created=auto_created, data_range_start=range_start) + result.append(self.ruleset.to_xml()) - return ProtocolLabel(name=name, start=start, end=end, color_index=color_index, field_type=field_type, - auto_created=auto_created) + return result @staticmethod def from_xml(tag: ET.Element): diff --git a/src/urh/signalprocessing/ProtocoLabel.py b/src/urh/signalprocessing/ProtocoLabel.py index a4520f30c7..ab42e42a33 100644 --- a/src/urh/signalprocessing/ProtocoLabel.py +++ b/src/urh/signalprocessing/ProtocoLabel.py @@ -1,6 +1,6 @@ +import copy import xml.etree.ElementTree as ET -import copy from PyQt5.QtCore import Qt from urh.signalprocessing.FieldType import FieldType @@ -40,7 +40,10 @@ def __init__(self, name: str, start: int, end: int, color_index: int, fuzz_creat self.fuzz_created = fuzz_created - self.__field_type = field_type # type: FieldType + if field_type is None: + self.__field_type = FieldType.from_caption(name) + else: + self.__field_type = field_type # type: FieldType self.display_format_index = 0 if field_type is None else field_type.display_format_index self.display_bit_order_index = 0 @@ -68,6 +71,10 @@ def is_preamble(self) -> bool: def is_sync(self) -> bool: return self.field_type is not None and self.field_type.function == FieldType.Function.SYNC + @property + def length(self) -> int: + return self.end - self.start + @property def field_type(self) -> FieldType: return self.__field_type @@ -80,6 +87,13 @@ def field_type(self, value: FieldType): if hasattr(value, "display_format_index"): self.display_format_index = value.display_format_index + @property + def field_type_function(self): + if self.field_type is not None: + return self.field_type.function + else: + return None + @property def name(self): if not self.__name: @@ -149,10 +163,13 @@ def __lt__(self, other): return False def __eq__(self, other): - return self.start == other.start and self.end == other.end and self.name == other.name and self.field_type == other.field_type + return self.start == other.start and \ + self.end == other.end and \ + self.name == other.name and \ + self.field_type_function == other.field_type_function def __hash__(self): - return hash("{}/{}/{}".format(self.start, self.end, self.name)) + return hash((self.start, self.end, self.name, self.field_type_function)) def __repr__(self): return "Protocol Label - start: {0} end: {1} name: {2}".format(self.start, self.end, self.name) diff --git a/src/urh/signalprocessing/ProtocolAnalyzer.py b/src/urh/signalprocessing/ProtocolAnalyzer.py index 2a2010b272..17a2a693e6 100644 --- a/src/urh/signalprocessing/ProtocolAnalyzer.py +++ b/src/urh/signalprocessing/ProtocolAnalyzer.py @@ -1,6 +1,5 @@ import array import copy -import sys import xml.etree.ElementTree as ET from xml.dom import minidom @@ -8,14 +7,14 @@ from PyQt5.QtCore import QObject, pyqtSignal, Qt from urh import constants -from urh.awre.FormatFinder import FormatFinder -from urh.cythonext import signal_functions, util +from urh.cythonext import signal_functions from urh.signalprocessing.Encoding import Encoding from urh.signalprocessing.FieldType import FieldType from urh.signalprocessing.Message import Message from urh.signalprocessing.MessageType import MessageType from urh.signalprocessing.Modulator import Modulator from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocoLabel import ProtocolLabel from urh.signalprocessing.Signal import Signal from urh.util import util as urh_util from urh.util.Logger import logger @@ -120,6 +119,10 @@ def plain_hex_str(self): def plain_ascii_str(self): return [msg.plain_ascii_str for msg in self.messages] + @property + def decoded_bits(self): + return [msg.decoded_bits for msg in self.messages] + @property def decoded_proto_bits_str(self): """ @@ -223,7 +226,7 @@ def get_protocol_from_signal(self): bit_len = signal.bit_len ppseq = signal_functions.grab_pulse_lens(signal.qad, signal.qad_center, signal.tolerance, - signal.modulation_type, signal.bit_len) + signal.modulation_type, signal.bit_len) bit_data, pauses, bit_sample_pos = self._ppseq_to_bits(ppseq, bit_len, pause_threshold=signal.pause_threshold) if signal.message_length_divisor > 1 and signal.modulation_type_str == "ASK": @@ -658,89 +661,15 @@ def update_auto_message_types(self): message.message_type = message_type break - def auto_assign_participants(self, participants): - """ - - :type participants: list of Participant - :return: - """ - if len(participants) == 0: - return - - if len(participants) == 1: - for message in self.messages: - message.participant = participants[0] - return - - # Try to assign participants based on SRC_ADDRESS label and participant address - for msg in filter(lambda m: m.participant is None, self.messages): - src_address = msg.get_src_address_from_data() - if src_address: - try: - msg.participant = next(p for p in participants if p.address_hex == src_address) - except StopIteration: - pass - - # Assign remaining participants based on RSSI of messages - rssis = np.array([msg.rssi for msg in self.messages], dtype=np.float32) - min_rssi, max_rssi = util.minmax(rssis) - center_spacing = (max_rssi - min_rssi) / (len(participants) - 1) - centers = [min_rssi + i * center_spacing for i in range(0, len(participants))] - rssi_assigned_centers = [] - - for rssi in rssis: - center_index = np.argmin(np.abs(rssi - centers)) - rssi_assigned_centers.append(int(center_index)) - - participants.sort(key=lambda participant: participant.relative_rssi) - for message, center_index in zip(self.messages, rssi_assigned_centers): - if message.participant is None: - message.participant = participants[center_index] - - def auto_assign_participant_addresses(self, participants): - """ - - :type participants: list of Participant - :return: - """ - participants_without_address = [p for p in participants if not p.address_hex] - - if len(participants_without_address) == 0: - return - - for msg in self.messages: - if msg.participant in participants_without_address: - src_address = msg.get_src_address_from_data() - if src_address: - participants_without_address.remove(msg.participant) - msg.participant.address_hex = src_address - - def auto_assign_decodings(self, decodings): - """ - :type decodings: list of Encoding - """ - nrz_decodings = [decoding for decoding in decodings if decoding.is_nrz or decoding.is_nrzi] - fallback = nrz_decodings[0] if nrz_decodings else None - candidate_decodings = [decoding for decoding in decodings - if decoding not in nrz_decodings and not decoding.contains_cut] - - for message in self.messages: - decoder_found = False - - for decoder in candidate_decodings: - if decoder.applies_for_message(message.plain_bits): - message.decoder = decoder - decoder_found = True - break - - if not decoder_found and fallback: - message.decoder = fallback - def auto_assign_labels(self): - format_finder = FormatFinder(self) - - # OPEN: Perform multiple iterations with varying priorities later - format_finder.perform_iteration() + from urh.awre.FormatFinder import FormatFinder + format_finder = FormatFinder(self.messages) + format_finder.run(max_iterations=10) + + self.message_types[:] = format_finder.message_types + for msg_type, indices in format_finder.existing_message_types.items(): + for i in indices: + self.messages[i].message_type = msg_type @staticmethod def get_protocol_from_string(message_strings: list, is_hex=None, default_pause=0, sample_rate=1e6): diff --git a/src/urh/util/GenericCRC.py b/src/urh/util/GenericCRC.py index cc0aa5b98f..b39621b7a4 100755 --- a/src/urh/util/GenericCRC.py +++ b/src/urh/util/GenericCRC.py @@ -3,8 +3,8 @@ from collections import OrderedDict from xml.etree import ElementTree as ET -from urh.util import util from urh.cythonext import util as c_util +from urh.util import util class GenericCRC(object): @@ -25,10 +25,38 @@ class GenericCRC(object): # x^16+x^13+x^12+x^11+x^10+x^8+x^6+x^5+x^2+x^0 ("16_dnp", array.array("B", [1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1])), + + # x^8 + x^2 + x + 1 + ("8_ccitt", array.array("B", [1, + 0, 0, 0, 0, 0, 1, 1, 1])) + ]) + + STANDARD_CHECKSUMS = OrderedDict([ + # see method guess_standard_parameters_and_datarange for default parameters + # Links: + # - https://en.wikipedia.org/wiki/Cyclic_redundancy_check + # - http://reveng.sourceforge.net/crc-catalogue/1-15.htm + # - https://crccalc.com/ + ("CRC8 (default)", dict(polynomial="0xD5")), + ("CRC8 CCITT", dict(polynomial="0x07")), + ("CRC8 Bluetooth", dict(polynomial="0xA7", ref_in=True, ref_out=True)), + ("CRC8 DARC", dict(polynomial="0x39", ref_in=True, ref_out=True)), + ("CRC8 NRSC-5", dict(polynomial="0x31", start_value=1)), + ("CRC16 (default)", dict(polynomial="0x8005", ref_in=True, ref_out=True)), + ("CRC16 CCITT", dict(polynomial="0x1021", ref_in=True, ref_out=True)), + ("CRC16 NRSC-5", dict(polynomial="0x080B", start_value=1, ref_in=True, ref_out=True)), + ("CRC16 CC1101", dict(polynomial="0x8005", start_value=1)), + ("CRC16 CDMA2000", dict(polynomial="0xC867", start_value=1)), + ("CRC32 (default)", dict(polynomial="0x04C11DB7", start_value=1, final_xor=1, ref_in=True, ref_out=True)), ]) def __init__(self, polynomial="16_standard", start_value=False, final_xor=False, reverse_polynomial=False, reverse_all=False, little_endian=False, lsb_first=False): + if isinstance(polynomial, str): + self.caption = polynomial + else: + self.caption = "" + self.polynomial = self.choose_polynomial(polynomial) self.reverse_polynomial = reverse_polynomial self.reverse_all = reverse_all @@ -37,6 +65,8 @@ def __init__(self, polynomial="16_standard", start_value=False, final_xor=False, self.start_value = self.__read_parameter(start_value) self.final_xor = self.__read_parameter(final_xor) + self.cache = [] + self.__cache_bits = 8 def __read_parameter(self, value): if isinstance(value, bool) or isinstance(value, int): @@ -52,7 +82,12 @@ def __eq__(self, other): return False return all(getattr(self, attrib) == getattr(other, attrib) for attrib in ( - "polynomial", "reverse_polynomial", "reverse_all", "little_endian", "lsb_first", "start_value", "final_xor")) + "polynomial", "reverse_polynomial", "reverse_all", "little_endian", "lsb_first", "start_value", + "final_xor")) + + def __hash__(self): + return hash((self.polynomial.tobytes(), self.reverse_polynomial, self.reverse_all, self.little_endian, + self.lsb_first, self.start_value.tobytes(), self.final_xor.tobytes())) @property def poly_order(self): @@ -83,7 +118,11 @@ def polynomial_to_html(self) -> str: return result def set_polynomial_from_hex(self, hex_str: str): + old = self.polynomial self.polynomial = array.array("B", [1]) + util.hex2bit(hex_str) + if self.polynomial != old: + self.cache = [] + self.__cache_bits = 8 def choose_polynomial(self, polynomial): if isinstance(polynomial, str): @@ -93,6 +132,10 @@ def choose_polynomial(self, polynomial): else: return polynomial + def get_parameters(self): + return self.polynomial, self.start_value, self.final_xor, \ + self.lsb_first, self.reverse_polynomial, self.reverse_all, self.little_endian + def crc(self, inpt): result = c_util.crc(array.array("B", inpt), array.array("B", self.polynomial), @@ -101,21 +144,41 @@ def crc(self, inpt): self.lsb_first, self.reverse_polynomial, self.reverse_all, self.little_endian) return util.number_to_bits(result, self.poly_order - 1) - def get_crc_datarange(self, inpt, vrfy_crc): + def cached_crc(self, inpt, bits=8): + if len(self.cache) == 0: + self.calculate_cache(bits) + result = c_util.cached_crc(self.cache, + self.__cache_bits, + array.array("B", inpt), + array.array("B", self.polynomial), + array.array("B", self.start_value), + array.array("B", self.final_xor), + self.lsb_first, self.reverse_polynomial, self.reverse_all, self.little_endian) + return util.number_to_bits(result, self.poly_order - 1) + + def calculate_cache(self, bits=8): + if 0 < bits < self.poly_order: + self.__cache_bits = bits + else: + self.__cache_bits = 8 if self.poly_order > 8 else self.poly_order - 1 + self.cache = c_util.calculate_cache(array.array("B", self.polynomial), self.reverse_polynomial, + self.__cache_bits) + + def get_crc_datarange(self, inpt, vrfy_crc_start): return c_util.get_crc_datarange(array.array("B", inpt), - array.array("B", self.polynomial), - array.array("B", vrfy_crc), - array.array("B", self.start_value), - array.array("B", self.final_xor), - self.lsb_first, self.reverse_polynomial, self.reverse_all, self.little_endian) + array.array("B", self.polynomial), + vrfy_crc_start, + array.array("B", self.start_value), + array.array("B", self.final_xor), + self.lsb_first, self.reverse_polynomial, self.reverse_all, self.little_endian) def reference_crc(self, inpt): len_inpt = len(inpt) if len(self.start_value) < self.poly_order - 1: return False - crc = copy.copy(self.start_value[0:(self.poly_order-1)]) + crc = copy.copy(self.start_value[0:(self.poly_order - 1)]) - for i in range(0, len_inpt+7, 8): + for i in range(0, len_inpt + 7, 8): for j in range(0, 8): if self.lsb_first: @@ -156,7 +219,7 @@ def reference_crc(self, inpt): elif self.poly_order - 1 == 64 and self.little_endian: for pos1, pos2 in [(0, 7), (1, 6), (2, 5), (3, 4)]: self.__swap_bytes(crc, pos1, pos2) - #return crc + # return crc return array.array("B", crc) def calculate(self, bits: array.array): @@ -167,11 +230,62 @@ def __swap_bytes(array, pos1: int, pos2: int): array[pos1 * 8:pos1 * 8 + 8], array[pos2 * 8:pos2 * 8 + 8] = \ array[pos2 * 8: pos2 * 8 + 8], array[pos1 * 8:pos1 * 8 + 8] + @staticmethod + def from_standard_checksum(name: str): + result = GenericCRC() + result.set_individual_parameters(**GenericCRC.STANDARD_CHECKSUMS[name]) + result.caption = name + return result + + def set_individual_parameters(self, polynomial, start_value=0, final_xor=0, ref_in=False, ref_out=False, + little_endian=False, reverse_polynomial=False): + # Set polynomial from hex or bit array + old = self.polynomial + if isinstance(polynomial, str): + self.set_polynomial_from_hex(polynomial) + else: + self.polynomial = polynomial + # Clear cache if polynomial changes + if self.polynomial != old: + self.cache = [] + self.__cache_bits = 8 + + # Set start value completely or 0000/FFFF + if isinstance(start_value, int): + self.start_value = array.array("B", [start_value] * (self.poly_order - 1)) + elif isinstance(start_value, array.array) and len(start_value) == self.poly_order - 1: + self.start_value = start_value + else: + raise ValueError("Invalid start value length") + + # Set final xor completely or 0000/FFFF + if isinstance(final_xor, int): + self.final_xor = array.array("B", [final_xor] * (self.poly_order - 1)) + elif isinstance(final_xor, array.array) and len(final_xor) == self.poly_order - 1: + self.final_xor = final_xor + else: + raise ValueError("Invalid final xor length") + + # Set boolean parameters + old_reverse = self.reverse_polynomial + self.reverse_polynomial = reverse_polynomial + if self.reverse_polynomial != old_reverse: + self.cache = [] + self.__cache_bits = 8 + + self.reverse_all = ref_out + self.little_endian = little_endian + self.lsb_first = ref_in + def set_crc_parameters(self, i): # Bit 0,1 = Polynomial val = (i >> 0) & 3 + old = self.polynomial self.polynomial = self.choose_polynomial(val) poly_order = len(self.polynomial) + if (self.polynomial != old): + self.cache = [] + self.__cache_bits = 8 # Bit 2 = Start Value val = (i >> 2) & 1 @@ -183,10 +297,14 @@ def set_crc_parameters(self, i): # Bit 4 = Reverse Polynomial val = (i >> 4) & 1 + old_reverse = self.reverse_polynomial if val == 0: self.reverse_polynomial = False else: self.reverse_polynomial = True + if (self.reverse_polynomial != old_reverse): + self.cache = [] + self.__cache_bits = 8 # Bit 5 = Reverse (all) Result val = (i >> 5) & 1 @@ -209,22 +327,104 @@ def set_crc_parameters(self, i): else: self.lsb_first = True + @classmethod + def __initialize_standard_checksums(cls): + for name in cls.STANDARD_CHECKSUMS: + polynomial = cls.STANDARD_CHECKSUMS[name]["polynomial"] + if isinstance(polynomial, str): + polynomial = array.array("B", [1]) + util.hex2bit(polynomial) + cls.STANDARD_CHECKSUMS[name]["polynomial"] = polynomial + + n = len(polynomial) - 1 + try: + start_val = cls.STANDARD_CHECKSUMS[name]["start_value"] + except KeyError: + start_val = 0 + + if isinstance(start_val, int): + cls.STANDARD_CHECKSUMS[name]["start_value"] = array.array("B", [start_val] * n) + + try: + final_xor = cls.STANDARD_CHECKSUMS[name]["final_xor"] + except KeyError: + final_xor = 0 + + if isinstance(final_xor, int): + cls.STANDARD_CHECKSUMS[name]["final_xor"] = array.array("B", [final_xor] * n) + + def guess_all(self, bits, trash_max=7, ignore_positions: set = None): + """ + + :param bits: + :param trash_max: + :param ignore_positions: columns to ignore (e.g. if already another label on them) + :return: a CRC object, data_range_start, data_range_end, crc_start, crc_end + """ + self.__initialize_standard_checksums() + + ignore_positions = set() if ignore_positions is None else ignore_positions + for i in range(0, trash_max): + ret = self.guess_standard_parameters_and_datarange(bits, i) + if ret == (0, 0, 0): + continue # nothing found + + crc_start, crc_end = len(bits) - i - ret[0].poly_order + 1, len(bits) - i + if not any(i in ignore_positions for i in range(crc_start, crc_end)): + return ret[0], ret[1], ret[2], crc_start, crc_end + return 0, 0, 0, 0, 0 + + def bruteforce_all(self, inpt, trash_max=7): + polynomial_sizes = [16, 8] + len_input = len(inpt) + for s in polynomial_sizes: + for i in range(len_input - s - trash_max, len_input - s): + ret = self.bruteforce_parameters_and_data_range(inpt, i) + if ret != (0, 0, 0): + return ret[0], ret[1], ret[2], i, i + s + return 0, 0, 0, 0, 0 + def guess_standard_parameters(self, inpt, vrfy_crc): # Tests all standard parameters and return parameter_value (else False), if a valid CRC could be computed. # Note: vfry_crc is included inpt! for i in range(0, 2 ** 8): self.set_crc_parameters(i) - if self.crc(inpt) == vrfy_crc: + if len(vrfy_crc) == self.poly_order and self.crc(inpt) == vrfy_crc: return i return False - def guess_standard_parameters_and_datarange(self, inpt, vrfy_crc): + def guess_standard_parameters_and_datarange(self, inpt, trash): + """ + Tests standard parameters from dict and return polynomial object, if a valid CRC could be computed + and determines start and end of crc datarange (end is set before crc) + Note: vfry_crc is included inpt! + """ + # Test longer polynomials first, because smaller polynomials have higher risk of false positive + for name, parameters in sorted(self.STANDARD_CHECKSUMS.items(), + key=lambda x: len(x[1]["polynomial"]), + reverse=True): + self.caption = name + data_begin, data_end = c_util.get_crc_datarange(inpt, + parameters["polynomial"], + max(0, + len(inpt) - trash - len(parameters["polynomial"])) + 1, + parameters["start_value"], + parameters["final_xor"], + parameters.get("ref_in", False), + parameters.get("reverse_polynomial", False), + parameters.get("ref_out", False), + parameters.get("little_endian", False)) + if (data_begin, data_end) != (0, 0): + self.set_individual_parameters(**parameters) + return self, data_begin, data_end + return 0, 0, 0 + + def bruteforce_parameters_and_data_range(self, inpt, vrfy_crc_start): # Tests all standard parameters and return parameter_value (else False), if a valid CRC could be computed # and determines start and end of crc datarange (end is set before crc) # Note: vfry_crc is included inpt! for i in range(0, 2 ** 8): self.set_crc_parameters(i) - data_begin, data_end = self.get_crc_datarange(inpt, vrfy_crc) + data_begin, data_end = self.get_crc_datarange(inpt, vrfy_crc_start) if (data_begin, data_end) != (0, 0): return i, data_begin, data_end return 0, 0, 0 @@ -273,6 +473,8 @@ def to_xml(self): root.set("polynomial", util.convert_bits_to_string(self.polynomial, 0)) root.set("start_value", util.convert_bits_to_string(self.start_value, 0)) root.set("final_xor", util.convert_bits_to_string(self.final_xor, 0)) + root.set("ref_in", str(int(self.lsb_first))) + root.set("ref_out", str(int(self.reverse_all))) return root @classmethod @@ -280,8 +482,11 @@ def from_xml(cls, tag: ET.Element): polynomial = tag.get("polynomial", "1010") start_value = tag.get("start_value", "0000") final_xor = tag.get("final_xor", "0000") + ref_in = bool(int(tag.get("ref_in", "0"))) + ref_out = bool(int(tag.get("ref_out", "0"))) return GenericCRC(polynomial=util.string2bits(polynomial), - start_value=util.string2bits(start_value), final_xor=util.string2bits(final_xor)) + start_value=util.string2bits(start_value), final_xor=util.string2bits(final_xor), + lsb_first=ref_in, reverse_all=ref_out) @staticmethod def bit2str(inpt): @@ -291,6 +496,10 @@ def bit2str(inpt): def str2bit(inpt): return [True if x == "1" else False for x in inpt] + @staticmethod + def int2bit(inpt): + return [True if x == "1" else False for x in '{0:08b}'.format(inpt)] + @staticmethod def str2arr(inpt): return array.array("B", GenericCRC.str2bit(inpt)) diff --git a/src/urh/util/WSPChecksum.py b/src/urh/util/WSPChecksum.py index 52b68ae35b..700e53e95e 100644 --- a/src/urh/util/WSPChecksum.py +++ b/src/urh/util/WSPChecksum.py @@ -1,10 +1,11 @@ import array import copy +from enum import Enum +from xml.etree import ElementTree as ET from urh.util import util from urh.util.GenericCRC import GenericCRC -from enum import Enum -from xml.etree import ElementTree as ET + class WSPChecksum(object): """ @@ -23,6 +24,16 @@ class ChecksumMode(Enum): def __init__(self, mode=ChecksumMode.auto): self.mode = mode + self.caption = str(mode) + + def __eq__(self, other): + if not isinstance(other, WSPChecksum): + return False + + return self.mode == other.mode + + def __hash__(self): + return hash(self.mode) def calculate(self, msg: array.array) -> array.array: """ @@ -55,6 +66,26 @@ def calculate(self, msg: array.array) -> array.array: except IndexError: return None + @classmethod + def search_for_wsp_checksum(cls, bits_behind_sync): + data_start, data_stop, crc_start, crc_stop = 0, 0, 0, 0 + + if bits_behind_sync[-4:].tobytes() != array.array("B", [1, 0, 1, 1]).tobytes(): + return 0, 0, 0, 0 # Check for EOF + + rorg = bits_behind_sync[0:4].tobytes() + if rorg == array.array("B", [0, 1, 0, 1]).tobytes() or rorg == array.array("B", [0, 1, 1, 0]).tobytes(): + # Switch telegram + if cls.checksum4(bits_behind_sync[-8:]).tobytes() == bits_behind_sync[-8:-4].tobytes(): + crc_start = len(bits_behind_sync) - 8 + crc_stop = len(bits_behind_sync) - 4 + data_stop = crc_start + return data_start, data_stop, crc_start, crc_stop + + # todo: Find crc8 and checksum8 + + return 0, 0, 0, 0 + @classmethod def checksum4(cls, bits: array.array) -> array.array: hash = 0 @@ -82,5 +113,5 @@ def to_xml(self) -> ET.Element: return root @classmethod - def from_xml(cls, tag: ET.Element): + def from_xml(cls, tag: ET.Element): return WSPChecksum(mode=WSPChecksum.ChecksumMode[tag.get("mode", "auto")]) diff --git a/src/urh/util/util.py b/src/urh/util/util.py index 28d1fd0645..fb29120b92 100644 --- a/src/urh/util/util.py +++ b/src/urh/util/util.py @@ -8,15 +8,17 @@ from xml.dom import minidom from xml.etree import ElementTree as ET +import numpy as np from PyQt5.QtCore import Qt from PyQt5.QtGui import QFontDatabase, QFont from PyQt5.QtGui import QIcon from PyQt5.QtWidgets import QApplication, QSplitter from PyQt5.QtWidgets import QDialog, QVBoxLayout, QPlainTextEdit, QTableWidgetItem + from urh import constants from urh.util.Logger import logger -PROJECT_PATH = None # for referencing in external program calls +PROJECT_PATH = None # for referencing in external program calls BCD_ERROR_SYMBOL = "?" BCD_LUT = {"{0:04b}".format(i): str(i) if i < 10 else BCD_ERROR_SYMBOL for i in range(16)} @@ -51,7 +53,7 @@ def set_shared_library_path(): if shared_lib_dir: if sys.platform == "win32": - current_path = os.environ.get("PATH", '') + current_path = os.environ.get("PATH", '') if not current_path.startswith(shared_lib_dir): os.environ["PATH"] = shared_lib_dir + os.pathsep + current_path else: @@ -211,7 +213,7 @@ def convert_string_to_bits(value: str, display_format: int, target_num_bits: int if len(result) < target_num_bits: # pad with zeros - return result + array.array("B", [0] * (target_num_bits-len(result))) + return result + array.array("B", [0] * (target_num_bits - len(result))) else: return result[:target_num_bits] @@ -241,6 +243,10 @@ def number_to_bits(n: int, length: int) -> array.array: return array.array("B", map(int, fmt.format(n))) +def bits_to_number(bits: array.array) -> int: + return int("".join(map(str, bits)), 2) + + def aggregate_bits(bits: array.array, size=4) -> array.array: result = array.array("B", []) @@ -257,6 +263,17 @@ def aggregate_bits(bits: array.array, size=4) -> array.array: return result +def convert_numbers_to_hex_string(arr: np.ndarray): + """ + Convert an array like [0, 1, 10, 2] to string 012a2 + + :param arr: + :return: + """ + lut = {i: "{0:x}".format(i) for i in range(16)} + return "".join(lut[x] if x in lut else " {} ".format(x) for x in arr) + + def clip(value, minimum, maximum): return max(minimum, min(value, maximum)) diff --git a/tests/auto_interpretation/test_estimate_tolerance.py b/tests/auto_interpretation/test_estimate_tolerance.py index d2762d85e0..3d3578cc85 100644 --- a/tests/auto_interpretation/test_estimate_tolerance.py +++ b/tests/auto_interpretation/test_estimate_tolerance.py @@ -33,10 +33,8 @@ def test_tolerance_estimation(self): [1, 9, 3, 3, 2, 9, 1, 4, 2, 4, 2, 8, 1, 4, 2, 4, 2, 4, 2, 2, 1, 5, 2, 3, 3, 3, 2, 2, 106, 104, 104, 103, 105, 104, 104, 105, 104, 104, 104, 104, 105, 103, 104, 106, 103, 105, 104, 103, 105, 103, 105, 105, 104, 104, 104, 103, 105, 104, 104, 105, 104, 104, 104, 104, 105, 104, 103, 106, 207, 104, 105],] found_tolerances = [] - print() for i, plateau_lengths in enumerate(data): found_tolerances.append(AutoInterpretation.estimate_tolerance_from_plateau_lengths(plateau_lengths)) - print(found_tolerances) estimated_tolerance = AutoInterpretation.get_most_frequent_value(found_tolerances) self.assertIn(estimated_tolerance, range(4, 7)) diff --git a/tests/auto_interpretation/test_message_segmentation.py b/tests/auto_interpretation/test_message_segmentation.py index dc10aca85c..00c14c68e8 100644 --- a/tests/auto_interpretation/test_message_segmentation.py +++ b/tests/auto_interpretation/test_message_segmentation.py @@ -58,12 +58,9 @@ def test_segmentation_ask_50(self): data = np.concatenate((msg1, msg2, msg3)) segments = segment_messages_from_magnitudes(np.abs(data), noise_threshold=0) - print(segments) self.assertEqual(len(segments), 3) self.assertEqual(segments, [(0, 999), (10999, 12599), (32599, 34199)]) - print(merge_message_segments_for_ook(segments)) - def test_segmentation_elektromaten(self): signal = Signal(get_path_for_data_file("elektromaten.coco"), "") segments = segment_messages_from_magnitudes(np.abs(signal.data), noise_threshold=0.0167) diff --git a/tests/awre/AWRETestCase.py b/tests/awre/AWRETestCase.py new file mode 100644 index 0000000000..24f2ef74e7 --- /dev/null +++ b/tests/awre/AWRETestCase.py @@ -0,0 +1,65 @@ +import os +import tempfile +import unittest + +import numpy +from urh.awre.FormatFinder import FormatFinder + +from tests.utils_testing import get_path_for_data_file +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer + +from urh.signalprocessing.MessageType import MessageType + +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.signalprocessing.FieldType import FieldType + + +class AWRETestCase(unittest.TestCase): + def setUp(self): + numpy.set_printoptions(linewidth=80) + self.field_types = self.__init_field_types() + + def get_format_finder_from_protocol_file(self, filename: str, clear_participant_addresses=True, return_messages=False): + proto_file = get_path_for_data_file(filename) + protocol = ProtocolAnalyzer(signal=None, filename=proto_file) + protocol.from_xml_file(filename=proto_file, read_bits=True) + + self.clear_message_types(protocol.messages) + + ff = FormatFinder(protocol.messages) + if clear_participant_addresses: + ff.known_participant_addresses.clear() + + if return_messages: + return ff, protocol.messages + else: + return ff + + @staticmethod + def __init_field_types(): + result = [] + for field_type_function in FieldType.Function: + result.append(FieldType(field_type_function.value, field_type_function)) + return result + + @staticmethod + def clear_message_types(messages: list): + mt = MessageType("empty") + for msg in messages: + msg.message_type = mt + + @staticmethod + def save_protocol(name, protocol_generator, silent=False): + filename = os.path.join(tempfile.gettempdir(), name + ".proto") + if isinstance(protocol_generator, ProtocolGenerator): + protocol_generator.to_file(filename) + elif isinstance(protocol_generator, ProtocolAnalyzer): + participants = list(set(msg.participant for msg in protocol_generator.messages)) + protocol_generator.to_xml_file(filename, [], participants=participants, write_bits=True) + info = "Protocol written to " + filename + if not silent: + print() + print("-" * len(info)) + print(info) + print("-" * len(info)) diff --git a/tests/awre/AWRExperiments.py b/tests/awre/AWRExperiments.py new file mode 100644 index 0000000000..b4bd4a833d --- /dev/null +++ b/tests/awre/AWRExperiments.py @@ -0,0 +1,791 @@ +import array +import multiprocessing +import os +import random +import time +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np + +from tests.awre.AWRETestCase import AWRETestCase +from tests.utils_testing import get_path_for_data_file +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.Preprocessor import Preprocessor +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.awre.engines.Engine import Engine +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Message import Message +from urh.signalprocessing.MessageType import MessageType +from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer +from urh.util.GenericCRC import GenericCRC + + +def run_for_num_broken(protocol_nr, num_broken: list, num_messages: int, num_runs: int) -> list: + random.seed(0) + np.random.seed(0) + + result = [] + for broken in num_broken: + tmp_accuracies = np.empty(num_runs, dtype=np.float64) + tmp_accuracies_without_broken = np.empty(num_runs, dtype=np.float64) + for i in range(num_runs): + protocol, expected_labels = AWRExperiments.get_protocol(protocol_nr, + num_messages=num_messages, + num_broken_messages=broken, + silent=True) + + AWRExperiments.run_format_finder_for_protocol(protocol) + accuracy = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels) + accuracy_without_broken = AWRExperiments.calculate_accuracy(protocol.messages, expected_labels, broken) + tmp_accuracies[i] = accuracy + tmp_accuracies_without_broken[i] = accuracy_without_broken + + avg_accuracy = np.mean(tmp_accuracies) + avg_accuracy_without_broken = np.mean(tmp_accuracies_without_broken) + + result.append((avg_accuracy, avg_accuracy_without_broken)) + print("Protocol {} with {} broken: {:>3}% {:>3}%".format(protocol_nr, broken, int(avg_accuracy), + int(avg_accuracy_without_broken))) + + return result + + +class AWRExperiments(AWRETestCase): + @staticmethod + def _prepare_protocol_1() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="dead") + bob = Participant("Bob", address_hex="beef") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x1337"}, + participants=[alice, bob]) + return pg + + @staticmethod + def _prepare_protocol_2() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="dead01") + bob = Participant("Bob", address_hex="beef24") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 72) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 24) + mb.add_label(FieldType.Function.DST_ADDRESS, 24) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x1337"}, + preambles_by_mt={mb.message_type: "10" * 36}, + sequence_number_increment=32, + participants=[alice, bob]) + + return pg + + @staticmethod + def _prepare_protocol_3() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="1337") + bob = Participant("Bob", address_hex="beef") + + checksum = GenericCRC.from_standard_checksum("CRC8 CCITT") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 16) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + mb.add_label(FieldType.Function.DATA, 10 * 8) + mb.add_checksum_label(8, checksum) + + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 16) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + mb_ack.add_checksum_label(8, checksum) + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"}, + preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8}, + participants=[alice, bob]) + + return pg + + @staticmethod + def _prepare_protocol_4() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="1337") + bob = Participant("Bob", address_hex="beef") + + checksum = GenericCRC.from_standard_checksum("CRC16 CCITT") + + mb = MessageTypeBuilder("data1") + mb.add_label(FieldType.Function.PREAMBLE, 16) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.DATA, 8 * 8) + mb.add_checksum_label(16, checksum) + + mb2 = MessageTypeBuilder("data2") + mb2.add_label(FieldType.Function.PREAMBLE, 16) + mb2.add_label(FieldType.Function.SYNC, 16) + mb2.add_label(FieldType.Function.LENGTH, 8) + mb2.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb2.add_label(FieldType.Function.DST_ADDRESS, 16) + mb2.add_label(FieldType.Function.DATA, 64 * 8) + mb2.add_checksum_label(16, checksum) + + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 16) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + mb_ack.add_checksum_label(16, checksum) + + mt1, mt2, mt3 = mb.message_type, mb2.message_type, mb_ack.message_type + + preamble = "10001000" * 2 + + pg = ProtocolGenerator([mt1, mt2, mt3], + syncs_by_mt={mt1: "0x9a7d", mt2: "0x9a7d", mt3: "0x9a7d"}, + preambles_by_mt={mt1: preamble, mt2: preamble, mt3: preamble}, + participants=[alice, bob]) + + return pg + + @staticmethod + def _prepare_protocol_5() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="1337") + bob = Participant("Bob", address_hex="beef") + carl = Participant("Carl", address_hex="cafe") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 16) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 16) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"}, + preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8}, + participants=[alice, bob, carl]) + + return pg + + @staticmethod + def _prepare_protocol_6() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="24") + broadcast = Participant("Bob", address_hex="ff") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 8) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x8e88"}, + preambles_by_mt={mb.message_type: "10" * 8}, + participants=[alice, broadcast]) + + return pg + + @staticmethod + def _prepare_protocol_7() -> ProtocolGenerator: + alice = Participant("Alice", address_hex="313370") + bob = Participant("Bob", address_hex="031337") + charly = Participant("Charly", address_hex="110000") + daniel = Participant("Daniel", address_hex="001100") + # broadcast = Participant("Broadcast", address_hex="ff") #TODO: Sometimes messages to broadcast + + checksum = GenericCRC.from_standard_checksum("CRC16 CC1101") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 16) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.DST_ADDRESS, 24) + mb.add_label(FieldType.Function.SRC_ADDRESS, 24) + mb.add_label(FieldType.Function.DATA, 8 * 8) + mb.add_checksum_label(16, checksum) + + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 8) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 24) + mb_ack.add_checksum_label(16, checksum) + + mb_kex = MessageTypeBuilder("kex") + mb_kex.add_label(FieldType.Function.PREAMBLE, 24) + mb_kex.add_label(FieldType.Function.SYNC, 16) + mb_kex.add_label(FieldType.Function.DST_ADDRESS, 24) + mb_kex.add_label(FieldType.Function.SRC_ADDRESS, 24) + mb_kex.add_label(FieldType.Function.DATA, 64 * 8) + mb_kex.add_checksum_label(16, checksum) + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type, mb_kex.message_type], + syncs_by_mt={mb.message_type: "0x0420", mb_ack.message_type: "0x2222", + mb_kex.message_type: "0x6767"}, + preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 4, + mb_kex.message_type: "10" * 12}, + participants=[alice, bob, charly, daniel]) + + return pg + + @staticmethod + def _prepare_protocol_8() -> ProtocolGenerator: + alice = Participant("Alice") + + mb = MessageTypeBuilder("data1") + mb.add_label(FieldType.Function.PREAMBLE, 4) + mb.add_label(FieldType.Function.SYNC, 4) + mb.add_label(FieldType.Function.LENGTH, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + mb.add_label(FieldType.Function.DATA, 8 * 542) + + mb2 = MessageTypeBuilder("data2") + mb2.add_label(FieldType.Function.PREAMBLE, 4) + mb2.add_label(FieldType.Function.SYNC, 4) + mb2.add_label(FieldType.Function.LENGTH, 16) + mb2.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + mb2.add_label(FieldType.Function.DATA, 8 * 260) + + pg = ProtocolGenerator([mb.message_type, mb2.message_type], + syncs_by_mt={mb.message_type: "0x9", mb2.message_type: "0x9"}, + preambles_by_mt={mb.message_type: "10" * 2, mb2.message_type: "10" * 2}, + sequence_number_increment=32, + participants=[alice], + little_endian=True) + + return pg + + def test_export_to_latex(self): + filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocols.tex") + if os.path.isfile(filename): + os.remove(filename) + + for i in range(1, 9): + pg = getattr(self, "_prepare_protocol_" + str(i))() + pg.export_to_latex(filename, i) + + @classmethod + def get_protocol(cls, protocol_number: int, num_messages, num_broken_messages=0, silent=False): + if protocol_number == 1: + pg = cls._prepare_protocol_1() + elif protocol_number == 2: + pg = cls._prepare_protocol_2() + elif protocol_number == 3: + pg = cls._prepare_protocol_3() + elif protocol_number == 4: + pg = cls._prepare_protocol_4() + elif protocol_number == 5: + pg = cls._prepare_protocol_5() + elif protocol_number == 6: + pg = cls._prepare_protocol_6() + elif protocol_number == 7: + pg = cls._prepare_protocol_7() + elif protocol_number == 8: + pg = cls._prepare_protocol_8() + else: + raise ValueError("Unknown protocol number") + + messages_types_with_data_field = [mt for mt in pg.protocol.message_types + if mt.get_first_label_with_type(FieldType.Function.DATA)] + i = -1 + while len(pg.protocol.messages) < num_messages: + i += 1 + source = pg.participants[i % len(pg.participants)] + destination = pg.participants[(i + 1) % len(pg.participants)] + if i % 2 == 0: + data_bytes = 8 + else: + # data_bytes = 16 + data_bytes = 64 + + if len(messages_types_with_data_field) == 0: + # set data automatically + data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8)) + pg.generate_message(data=data, source=source, destination=destination) + else: + # search for message type with right data length + mt = messages_types_with_data_field[i % len(messages_types_with_data_field)] + data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length + data = "".join(random.choice(["0", "1"]) for _ in range(data_length)) + pg.generate_message(message_type=mt, data=data, source=source, destination=destination) + + ack_message_type = next((mt for mt in pg.protocol.message_types if "ack" in mt.name), None) + if ack_message_type: + pg.generate_message(message_type=ack_message_type, data="", source=destination, destination=source) + + for i in range(num_broken_messages): + msg = pg.protocol.messages[i] + pos = random.randint(0, len(msg.plain_bits) // 2) + msg.plain_bits[pos:] = array.array("B", + [random.randint(0, 1) for _ in range(len(msg.plain_bits) - pos)]) + + if num_broken_messages == 0: + cls.save_protocol("protocol{}_{}_messages".format(protocol_number, num_messages), pg, silent=silent) + else: + cls.save_protocol("protocol{}_{}_broken".format(protocol_number, num_broken_messages), pg, silent=silent) + + expected_message_types = [msg.message_type for msg in pg.protocol.messages] + + # Delete message type information -> no prior knowledge + cls.clear_message_types(pg.protocol.messages) + + # Delete data labels if present + for mt in expected_message_types: + data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA) + if data_lbl: + mt.remove(data_lbl) + + return pg.protocol, expected_message_types + + @staticmethod + def calculate_accuracy(messages, expected_labels, num_broken_messages=0): + """ + Calculate the accuracy of labels compared to expected labels + Accuracy is 100% when labels == expected labels + Accuracy drops by 1 / len(expected_labels) for every expected label not present in labels + + :type messages: list of Message + :type expected_labels: list of MessageType + :return: + """ + accuracy = sum(len(set(expected_labels[i]) & set(messages[i].message_type)) / len(expected_labels[i]) + for i in range(num_broken_messages, len(messages))) + try: + accuracy /= (len(messages) - num_broken_messages) + except ZeroDivisionError: + accuracy = 0 + + return accuracy * 100 + + def test_against_num_messages(self): + num_messages = list(range(1, 24, 1)) + accuracies = defaultdict(list) + + protocols = [1, 2, 3, 4, 5, 6, 7, 8] + + random.seed(0) + np.random.seed(0) + for protocol_nr in protocols: + for n in num_messages: + protocol, expected_labels = self.get_protocol(protocol_nr, num_messages=n) + self.run_format_finder_for_protocol(protocol) + + accuracy = self.calculate_accuracy(protocol.messages, expected_labels) + accuracies["protocol {}".format(protocol_nr)].append(accuracy) + + self.__plot(num_messages, accuracies, xlabel="Number of messages", ylabel="Accuracy in %", grid=True) + self.__export_to_csv("/tmp/accuray-vs-messages", num_messages, accuracies) + + def test_against_error(self): + Engine._DEBUG_ = False + Preprocessor._DEBUG_ = False + + num_runs = 100 + + num_messages = 30 + num_broken_messages = list(range(0, num_messages + 1)) + accuracies = defaultdict(list) + accuracies_without_broken = defaultdict(list) + + protocols = [1, 2, 3, 4, 5, 6, 7, 8] + + random.seed(0) + np.random.seed(0) + + with multiprocessing.Pool() as p: + result = p.starmap(run_for_num_broken, + [(i, num_broken_messages, num_messages, num_runs) for i in protocols]) + for i, acc in enumerate(result): + accuracies["protocol {}".format(i + 1)] = [a[0] for a in acc] + accuracies_without_broken["protocol {}".format(i + 1)] = [a[1] for a in acc] + + self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies, + title="Overall Accuracy vs percentage of broken messages", + xlabel="Broken messages in %", + ylabel="Accuracy in %", grid=True) + self.__plot(100 * np.array(num_broken_messages) / num_messages, accuracies_without_broken, + title=" Accuracy of unbroken vs percentage of broken messages", + xlabel="Broken messages in %", + ylabel="Accuracy in %", grid=True) + self.__export_to_csv("/tmp/accuray-vs-error", num_broken_messages, accuracies, relative=num_messages) + self.__export_to_csv("/tmp/accuray-vs-error-without-broken", num_broken_messages, accuracies_without_broken, + relative=num_messages) + + def test_performance(self): + Engine._DEBUG_ = False + Preprocessor._DEBUG_ = False + + num_messages = list(range(200, 205, 5)) + protocols = [1] + + random.seed(0) + np.random.seed(0) + + performances = defaultdict(list) + + for protocol_nr in protocols: + print("Running for protocol", protocol_nr) + for messages in num_messages: + protocol, _ = self.get_protocol(protocol_nr, messages, silent=True) + + t = time.time() + self.run_format_finder_for_protocol(protocol) + performances["protocol {}".format(protocol_nr)].append(time.time() - t) + + # self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True) + + def test_performance_real_protocols(self): + Engine._DEBUG_ = False + Preprocessor._DEBUG_ = False + + num_runs = 100 + + num_messages = list(range(8, 512, 4)) + protocol_names = ["enocean", "homematic", "rwe"] + + random.seed(0) + np.random.seed(0) + + performances = defaultdict(list) + + for protocol_name in protocol_names: + for messages in num_messages: + if protocol_name == "homematic": + protocol = self.generate_homematic(messages, save_protocol=False) + elif protocol_name == "enocean": + protocol = self.generate_enocean(messages, save_protocol=False) + elif protocol_name == "rwe": + protocol = self.generate_rwe(messages, save_protocol=False) + else: + raise ValueError("Unknown protocol name") + + tmp_performances = np.empty(num_runs, dtype=np.float64) + for i in range(num_runs): + print("\r{0} with {1:02d} messages ({2}/{3} runs)".format(protocol_name, messages, i + 1, num_runs), + flush=True, end="") + + t = time.time() + self.run_format_finder_for_protocol(protocol) + tmp_performances[i] = time.time() - t + self.clear_message_types(protocol.messages) + + mean_performance = tmp_performances.mean() + print(" {:.2f}s".format(mean_performance)) + performances["{}".format(protocol_name)].append(mean_performance) + + self.__plot(num_messages, performances, xlabel="Number of messages", ylabel="Time in seconds", grid=True) + self.__export_to_csv("/tmp/performance.csv", num_messages, performances) + + @staticmethod + def __export_to_csv(filename: str, x: list, y: dict, relative=None): + if not filename.endswith(".csv"): + filename += ".csv" + + with open(filename, "w") as f: + f.write("N,") + if relative is not None: + f.write("NRel,") + for y_cap in sorted(y): + f.write(y_cap + ",") + f.write("\n") + + for i, x_val in enumerate(x): + f.write("{},".format(x_val)) + if relative is not None: + f.write("{},".format(100 * x_val / relative)) + + for y_cap in sorted(y): + f.write("{},".format(y[y_cap][i])) + f.write("\n") + + @staticmethod + def __plot(x: list, y: dict, xlabel: str, ylabel: str, grid=False, title=None): + plt.xlabel(xlabel) + plt.ylabel(ylabel) + + for y_cap, y_values in sorted(y.items()): + plt.plot(x, y_values, label=y_cap) + + if grid: + plt.grid() + + if title: + plt.title(title) + + plt.legend() + plt.show() + + @staticmethod + def run_format_finder_for_protocol(protocol: ProtocolAnalyzer): + ff = FormatFinder(protocol.messages) + ff.known_participant_addresses.clear() + ff.run() + + for msg_type, indices in ff.existing_message_types.items(): + for i in indices: + protocol.messages[i].message_type = msg_type + + @classmethod + def generate_homematic(cls, num_messages: int, save_protocol=True): + mb_m_frame = MessageTypeBuilder("mframe") + mb_c_frame = MessageTypeBuilder("cframe") + mb_r_frame = MessageTypeBuilder("rframe") + mb_a_frame = MessageTypeBuilder("aframe") + + participants = [Participant("CCU", address_hex="3927cc"), Participant("Switch", address_hex="3101cc")] + + checksum = GenericCRC.from_standard_checksum("CRC16 CC1101") + for mb_builder in [mb_m_frame, mb_c_frame, mb_r_frame, mb_a_frame]: + mb_builder.add_label(FieldType.Function.PREAMBLE, 32) + mb_builder.add_label(FieldType.Function.SYNC, 32) + mb_builder.add_label(FieldType.Function.LENGTH, 8) + mb_builder.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + mb_builder.add_label(FieldType.Function.TYPE, 16) + mb_builder.add_label(FieldType.Function.SRC_ADDRESS, 24) + mb_builder.add_label(FieldType.Function.DST_ADDRESS, 24) + if mb_builder.name == "mframe": + mb_builder.add_label(FieldType.Function.DATA, 16, name="command") + elif mb_builder.name == "cframe": + mb_builder.add_label(FieldType.Function.DATA, 16 * 4, name="command+challenge+magic") + elif mb_builder.name == "rframe": + mb_builder.add_label(FieldType.Function.DATA, 32 * 4, name="cipher") + elif mb_builder.name == "aframe": + mb_builder.add_label(FieldType.Function.DATA, 10 * 4, name="command + auth") + mb_builder.add_checksum_label(16, checksum) + + message_types = [mb_m_frame.message_type, mb_c_frame.message_type, mb_r_frame.message_type, + mb_a_frame.message_type] + preamble = "0xaaaaaaaa" + sync = "0xe9cae9ca" + initial_sequence_number = 36 + pg = ProtocolGenerator(message_types, participants, + preambles_by_mt={mt: preamble for mt in message_types}, + syncs_by_mt={mt: sync for mt in message_types}, + sequence_numbers={mt: initial_sequence_number for mt in message_types}, + message_type_codes={mb_m_frame.message_type: 42560, + mb_c_frame.message_type: 40962, + mb_r_frame.message_type: 40963, + mb_a_frame.message_type: 32770}) + + for i in range(num_messages): + mt = pg.message_types[i % 4] + data_length = mt.get_first_label_with_type(FieldType.Function.DATA).length + data = "".join(random.choice(["0", "1"]) for _ in range(data_length)) + pg.generate_message(mt, data, source=pg.participants[i % 2], destination=pg.participants[(i + 1) % 2]) + + if save_protocol: + cls.save_protocol("homematic", pg) + + cls.clear_message_types(pg.messages) + return pg.protocol + + @classmethod + def generate_enocean(cls, num_messages: int, save_protocol=True): + filename = get_path_for_data_file("enocean_bits.txt") + enocean_bits = [] + with open(filename, "r") as f: + for line in map(str.strip, f): + enocean_bits.append(line) + + protocol = ProtocolAnalyzer(None) + message_type = MessageType("empty") + for i in range(num_messages): + msg = Message.from_plain_bits_str(enocean_bits[i % len(enocean_bits)]) + msg.message_type = message_type + protocol.messages.append(msg) + + if save_protocol: + cls.save_protocol("enocean", protocol) + + return protocol + + @classmethod + def generate_rwe(cls, num_messages: int, save_protocol=True): + proto_file = get_path_for_data_file("rwe.proto.xml") + protocol = ProtocolAnalyzer(signal=None, filename=proto_file) + protocol.from_xml_file(filename=proto_file, read_bits=True) + messages = protocol.messages + + result = ProtocolAnalyzer(None) + message_type = MessageType("empty") + for i in range(num_messages): + msg = messages[i % len(messages)] # type: Message + msg.message_type = message_type + result.messages.append(msg) + + if save_protocol: + cls.save_protocol("rwe", result) + + return result + + def test_export_latex_table(self): + def bold_latex(s): + return r"\textbf{" + str(s) + r"}" + + comments = { + 1: "common protocol", + 2: "unusual field sizes", + 3: "contains ack and CRC8 CCITT", + 4: "contains ack and CRC16 CCITT", + 5: "three participants with ack frame", + 6: "short address", + 7: "four participants, varying preamble size, varying sync words", + 8: "nibble fields + LE" + } + + bold = {i: defaultdict(bool) for i in range(1, 9)} + bold[2][FieldType.Function.PREAMBLE] = True + bold[2][FieldType.Function.SRC_ADDRESS] = True + bold[2][FieldType.Function.DST_ADDRESS] = True + + bold[3][FieldType.Function.CHECKSUM] = True + + bold[4][FieldType.Function.CHECKSUM] = True + + bold[6][FieldType.Function.SRC_ADDRESS] = True + + bold[7][FieldType.Function.PREAMBLE] = True + bold[7][FieldType.Function.SYNC] = True + bold[7][FieldType.Function.SRC_ADDRESS] = True + bold[7][FieldType.Function.DST_ADDRESS] = True + + bold[8][FieldType.Function.PREAMBLE] = True + bold[8][FieldType.Function.SYNC] = True + + filename = os.path.expanduser("~/GIT/publications/awre/USENIX/protocol_table.tex") + rowcolors = [r"\rowcolor{black!10}", r"\rowcolor{black!20}"] + + with open(filename, "w") as f: + f.write(r"\begin{table*}[!h]" + "\n") + f.write( + "\t" + r"\caption{Properties of tested protocols whereby $\times$ means field is not present and $N_P$ is the number of participants.}" + "\n") + f.write("\t" + r"\label{tab:protocols}" + "\n") + f.write("\t" + r"\centering" + "\n") + f.write("\t" + r"\begin{tabularx}{\linewidth}{cp{2.5cm}llcccccccc}" + "\n") + f.write("\t\t" + r"\hline" + "\n") + f.write("\t\t" + r"\rowcolor{black!90}" + "\n") + f.write("\t\t" + r"\textcolor{white}{\textbf{\#}} & " + r"\textcolor{white}{\textbf{Comment}} & " + r"\textcolor{white}{$\mathbf{ N_P }$} & " + r"\textcolor{white}{\textbf{Message}} & " + r"\textcolor{white}{\textbf{Even/odd}} & " + r"\multicolumn{7}{c}{\textcolor{white}{\textbf{Size of field in bit (BE=Big Endian, LE=Little Endian)}}}\\" + "\n\t\t" + r"\rowcolor{black!90}" + "\n\t\t" + r"& & & \textcolor{white}{\textbf{Type}} & \textcolor{white}{\textbf{message data}} &" + r"\textcolor{white}{Preamble} & " + r"\textcolor{white}{Sync} & " + r"\textcolor{white}{Length} & " + r"\textcolor{white}{SRC} & " + r"\textcolor{white}{DST} & " + r"\textcolor{white}{SEQ Nr} & " + r"\textcolor{white}{CRC} \\" + "\n") + f.write("\t\t" + r"\hline" + "\n") + + rowcolor_index = 0 + for i in range(1, 9): + pg = getattr(self, "_prepare_protocol_" + str(i))() + assert isinstance(pg, ProtocolGenerator) + + try: + data1 = next(mt for mt in pg.message_types if mt.name == "data1") + data2 = next(mt for mt in pg.message_types if mt.name == "data2") + + data1_len = data1.get_first_label_with_type(FieldType.Function.DATA).length // 8 + data2_len = data2.get_first_label_with_type(FieldType.Function.DATA).length // 8 + + except StopIteration: + data1_len, data2_len = 8, 64 + + rowcolor = rowcolors[rowcolor_index % len(rowcolors)] + rowcount = 0 + for j, mt in enumerate(pg.message_types): + if mt.name == "data2": + continue + + rowcount += 1 + if j == 0: + protocol_nr, participants = str(i), len(pg.participants) + if participants > 2: + participants = bold_latex(participants) + else: + protocol_nr, participants = " ", " " + + f.write("\t\t" + rowcolor + "\n") + + if len(pg.message_types) == 1 or ( + mt.name == "data1" and "ack" not in {m.name for m in pg.message_types}): + f.write("\t\t{} & {} & {} & {} &".format(protocol_nr, comments[i], participants, + mt.name.replace("1", ""))) + elif j == len(pg.message_types) - 1: + f.write( + "\t\t{} & \\multirow{{{}}}{{\\linewidth}}{{{}}} & {} & {} &".format(protocol_nr, -rowcount, + comments[i], + participants, + mt.name.replace("1", + ""))) + else: + f.write("\t\t{} & & {} & {} &".format(protocol_nr, participants, mt.name.replace("1", ""))) + data_lbl = mt.get_first_label_with_type(FieldType.Function.DATA) + + if mt.name == "data1" or mt.name == "data2": + f.write("{}/{} byte &".format(data1_len, data2_len)) + elif mt.name == "data" and data_lbl is None: + f.write("{}/{} byte &".format(data1_len, data2_len)) + elif data_lbl is not None: + f.write("{0}/{0} byte & ".format(data_lbl.length // 8)) + else: + f.write(r"$ \times $ & ") + + for t in (FieldType.Function.PREAMBLE, FieldType.Function.SYNC, FieldType.Function.LENGTH, + FieldType.Function.SRC_ADDRESS, FieldType.Function.DST_ADDRESS, + FieldType.Function.SEQUENCE_NUMBER, + FieldType.Function.CHECKSUM): + lbl = mt.get_first_label_with_type(t) + if lbl is not None: + if bold[i][lbl.field_type.function]: + f.write(bold_latex(lbl.length)) + else: + f.write(str(lbl.length)) + if lbl.length > 8 and t in (FieldType.Function.LENGTH, FieldType.Function.SEQUENCE_NUMBER): + f.write(" ({})".format(bold_latex("LE") if pg.little_endian else "BE")) + else: + f.write(r"$ \times $") + + if t != FieldType.Function.CHECKSUM: + f.write(" & ") + else: + f.write(r"\\" + "\n") + + rowcolor_index += 1 + + f.write("\t" + r"\end{tabularx}" + "\n") + + f.write(r"\end{table*}" + "\n") diff --git a/tests/awre/TestAWREHistograms.py b/tests/awre/TestAWREHistograms.py new file mode 100644 index 0000000000..3a6a4341f2 --- /dev/null +++ b/tests/awre/TestAWREHistograms.py @@ -0,0 +1,179 @@ +import random +from collections import defaultdict + +import matplotlib.pyplot as plt + +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.FormatFinder import FormatFinder +from urh.awre.Histogram import Histogram +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Participant import Participant + +SHOW_PLOTS = True + +class TestAWREHistograms(AWRETestCase): + def test_very_simple_protocol(self): + """ + Test a very simple protocol consisting just of a preamble, sync and some random data + :return: + """ + mb = MessageTypeBuilder("very_simple_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 8) + + num_messages = 10 + + pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a"}) + for _ in range(num_messages): + pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 255), 8)) + + self.save_protocol("very_simple", pg) + + h = Histogram(FormatFinder.get_bitvectors_from_messages(pg.protocol.messages)) + if SHOW_PLOTS: + h.plot() + + def test_simple_protocol(self): + """ + Test a simple protocol with preamble, sync and length field and some random data + :return: + """ + mb = MessageTypeBuilder("simple_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + + num_messages_by_data_length = {8: 5, 16: 10, 32: 15} + pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x9a9d"}) + for data_length, num_messages in num_messages_by_data_length.items(): + for _ in range(num_messages): + pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** data_length - 1), data_length)) + + self.save_protocol("simple", pg) + + plt.subplot("221") + plt.title("All messages") + format_finder = FormatFinder(pg.protocol.messages) + + for i, sync_end in enumerate(format_finder.sync_ends): + self.assertEqual(sync_end, 24, msg=str(i)) + + h = Histogram(format_finder.bitvectors) + h.subplot_on(plt) + + bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages) + bitvectors_by_length = defaultdict(list) + for bitvector in bitvectors: + bitvectors_by_length[len(bitvector)].append(bitvector) + + for i, (message_length, bitvectors) in enumerate(bitvectors_by_length.items()): + plt.subplot(2, 2, i + 2) + plt.title("Messages with length {} ({})".format(message_length, len(bitvectors))) + Histogram(bitvectors).subplot_on(plt) + + if SHOW_PLOTS: + plt.show() + + def test_medium_protocol(self): + """ + Test a protocol with preamble, sync, length field, 2 participants and addresses and seq nr and random data + :return: + """ + mb = MessageTypeBuilder("medium_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 8) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + alice = Participant("Alice", "A", "1234", color_index=0) + bob = Participant("Bob", "B", "5a9d", color_index=1) + + num_messages = 100 + pg = ProtocolGenerator([mb.message_type], syncs_by_mt={mb.message_type: "0x1c"}, little_endian=False) + for i in range(num_messages): + len_data = random.randint(1, 5) + data = "".join(pg.decimal_to_bits(random.randint(0, 2 ** 8 - 1), 8) for _ in range(len_data)) + if i % 2 == 0: + source, dest = alice, bob + else: + source, dest = bob, alice + pg.generate_message(data=data, source=source, destination=dest) + + self.save_protocol("medium", pg) + + plt.subplot(2, 2, 1) + plt.title("All messages") + bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages) + h = Histogram(bitvectors) + h.subplot_on(plt) + + for i, (participant, bitvectors) in enumerate( + sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())): + plt.subplot(2, 2, i + 3) + plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors))) + Histogram(bitvectors).subplot_on(plt) + + if SHOW_PLOTS: + plt.show() + + def get_bitvectors_by_participant(self, messages): + import numpy as np + result = defaultdict(list) + for msg in messages: # type: Message + result[msg.participant].append(np.array(msg.decoded_bits, dtype=np.uint8, order="C")) + return result + + def test_ack_protocol(self): + """ + Test a protocol with acks + :return: + """ + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 8) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 8) + mb_ack.add_label(FieldType.Function.SYNC, 8) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + + alice = Participant("Alice", "A", "1234", color_index=0) + bob = Participant("Bob", "B", "5a9d", color_index=1) + + num_messages = 50 + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0xbf", mb_ack.message_type: "0xbf"}, + little_endian=False) + for i in range(num_messages): + if i % 2 == 0: + source, dest = alice, bob + else: + source, dest = bob, alice + pg.generate_message(data="0xffff", source=source, destination=dest) + pg.generate_message(data="", source=dest, destination=source, message_type=mb_ack.message_type) + + self.save_protocol("proto_with_acks", pg) + + plt.subplot(2, 2, 1) + plt.title("All messages") + bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages) + h = Histogram(bitvectors) + h.subplot_on(plt) + + for i, (participant, bitvectors) in enumerate( + sorted(self.get_bitvectors_by_participant(pg.protocol.messages).items())): + plt.subplot(2, 2, i + 3) + plt.title("Messages with participant {} ({})".format(participant.shortname, len(bitvectors))) + Histogram(bitvectors).subplot_on(plt) + + if SHOW_PLOTS: + plt.show() diff --git a/tests/awre/__init__.py b/tests/awre/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/awre/test_address_engine.py b/tests/awre/test_address_engine.py new file mode 100644 index 0000000000..1bb9f2a767 --- /dev/null +++ b/tests/awre/test_address_engine.py @@ -0,0 +1,386 @@ +import random +from array import array + +import numpy as np + +from tests.awre.AWRETestCase import AWRETestCase +from tests.utils_testing import get_path_for_data_file +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.awre.engines.AddressEngine import AddressEngine +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Message import Message +from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer +from urh.util import util + + +class TestAddressEngine(AWRETestCase): + def setUp(self): + super().setUp() + self.alice = Participant("Alice", "A", address_hex="1234") + self.bob = Participant("Bob", "B", address_hex="cafe") + + def test_one_participant(self): + """ + Test a simple protocol with + preamble, sync and length field (8 bit) and some random data + + :return: + """ + mb = MessageTypeBuilder("simple_address_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + + num_messages_by_data_length = {8: 5, 16: 10, 32: 15} + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}, + participants=[self.alice]) + for data_length, num_messages in num_messages_by_data_length.items(): + for i in range(num_messages): + pg.generate_message(data=pg.decimal_to_bits(22 * i, data_length), source=self.alice) + + #self.save_protocol("address_one_participant", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + address_engine = AddressEngine(ff.hexvectors, ff.participant_indices) + address_dict = address_engine.find_addresses() + + self.assertEqual(len(address_dict), 0) + + def test_two_participants(self): + mb = MessageTypeBuilder("address_two_participants") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + + num_messages = 50 + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}, + participants=[self.alice, self.bob]) + + for i in range(num_messages): + if i % 2 == 0: + source, destination = self.alice, self.bob + data_length = 8 + else: + source, destination = self.bob, self.alice + data_length = 16 + pg.generate_message(data=pg.decimal_to_bits(4 * i, data_length), source=source, destination=destination) + + #self.save_protocol("address_two_participants", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + address_engine = AddressEngine(ff.hexvectors, ff.participant_indices) + address_dict = address_engine.find_addresses() + self.assertEqual(len(address_dict), 2) + addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0])) + addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1])) + self.assertIn(self.alice.address_hex, addresses_1) + self.assertIn(self.alice.address_hex, addresses_2) + self.assertIn(self.bob.address_hex, addresses_1) + self.assertIn(self.bob.address_hex, addresses_2) + + ff.known_participant_addresses.clear() + self.assertEqual(len(ff.known_participant_addresses), 0) + + ff.perform_iteration() + + self.assertEqual(len(ff.known_participant_addresses), 2) + self.assertIn(bytes([int(h, 16) for h in self.alice.address_hex]), + map(bytes, ff.known_participant_addresses.values())) + self.assertIn(bytes([int(h, 16) for h in self.bob.address_hex]), + map(bytes, ff.known_participant_addresses.values())) + + self.assertEqual(len(ff.message_types), 1) + mt = ff.message_types[0] + dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertIsNotNone(dst_addr) + self.assertEqual(dst_addr.start, 32) + self.assertEqual(dst_addr.length, 16) + src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertIsNotNone(src_addr) + self.assertEqual(src_addr.start, 48) + self.assertEqual(src_addr.length, 16) + + def test_two_participants_with_ack_messages(self): + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 8) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + + num_messages = 50 + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"}, + participants=[self.alice, self.bob]) + + random.seed(0) + for i in range(num_messages): + if i % 2 == 0: + source, destination = self.alice, self.bob + data_length = 8 + else: + source, destination = self.bob, self.alice + data_length = 16 + pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length), + source=source, destination=destination) + pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination) + + #self.save_protocol("address_two_participants_with_acks", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + address_engine = AddressEngine(ff.hexvectors, ff.participant_indices) + address_dict = address_engine.find_addresses() + self.assertEqual(len(address_dict), 2) + addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0])) + addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1])) + self.assertIn(self.alice.address_hex, addresses_1) + self.assertIn(self.alice.address_hex, addresses_2) + self.assertIn(self.bob.address_hex, addresses_1) + self.assertIn(self.bob.address_hex, addresses_2) + + ff.known_participant_addresses.clear() + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 2) + mt = ff.message_types[1] + dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertIsNotNone(dst_addr) + self.assertEqual(dst_addr.start, 32) + self.assertEqual(dst_addr.length, 16) + src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertIsNotNone(src_addr) + self.assertEqual(src_addr.start, 48) + self.assertEqual(src_addr.length, 16) + + mt = ff.message_types[0] + dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertIsNotNone(dst_addr) + self.assertEqual(dst_addr.start, 32) + self.assertEqual(dst_addr.length, 16) + + def test_two_participants_with_ack_messages_and_type(self): + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.TYPE, 8) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 8) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + + num_messages = 50 + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"}, + participants=[self.alice, self.bob]) + + random.seed(0) + for i in range(num_messages): + if i % 2 == 0: + source, destination = self.alice, self.bob + data_length = 8 + else: + source, destination = self.bob, self.alice + data_length = 16 + pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length), + source=source, destination=destination) + pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination) + + #self.save_protocol("address_two_participants_with_acks_and_types", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + address_engine = AddressEngine(ff.hexvectors, ff.participant_indices) + address_dict = address_engine.find_addresses() + self.assertEqual(len(address_dict), 2) + addresses_1 = list(map(util.convert_numbers_to_hex_string, address_dict[0])) + addresses_2 = list(map(util.convert_numbers_to_hex_string, address_dict[1])) + self.assertIn(self.alice.address_hex, addresses_1) + self.assertIn(self.alice.address_hex, addresses_2) + self.assertIn(self.bob.address_hex, addresses_1) + self.assertIn(self.bob.address_hex, addresses_2) + + ff.known_participant_addresses.clear() + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 2) + mt = ff.message_types[1] + dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertIsNotNone(dst_addr) + self.assertEqual(dst_addr.start, 40) + self.assertEqual(dst_addr.length, 16) + src_addr = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertIsNotNone(src_addr) + self.assertEqual(src_addr.start, 56) + self.assertEqual(src_addr.length, 16) + + mt = ff.message_types[0] + dst_addr = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertIsNotNone(dst_addr) + self.assertEqual(dst_addr.start, 32) + self.assertEqual(dst_addr.length, 16) + + def test_three_participants_with_ack(self): + alice = Participant("Alice", address_hex="1337") + bob = Participant("Bob", address_hex="4711") + carl = Participant("Carl", address_hex="cafe") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 16) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 16) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0x9a7d", mb_ack.message_type: "0x9a7d"}, + preambles_by_mt={mb.message_type: "10" * 8, mb_ack.message_type: "10" * 8}, + participants=[alice, bob, carl]) + + i = -1 + while len(pg.protocol.messages) < 20: + i += 1 + source = pg.participants[i % len(pg.participants)] + destination = pg.participants[(i + 1) % len(pg.participants)] + if i % 2 == 0: + data_bytes = 8 + else: + data_bytes = 16 + + data = "".join(random.choice(["0", "1"]) for _ in range(data_bytes * 8)) + pg.generate_message(data=data, source=source, destination=destination) + + if "ack" in (msg_type.name for msg_type in pg.protocol.message_types): + pg.generate_message(message_type=1, data="", source=destination, destination=source) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + ff.known_participant_addresses.clear() + self.assertEqual(len(ff.known_participant_addresses), 0) + ff.run() + + # Since there are ACKS in this protocol, the engine must be able to assign the correct participant addresses + # IN CORRECT ORDER! + self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337") + self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711") + self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[2]), "cafe") + + def test_protocol_with_acks_and_checksum(self): + proto_file = get_path_for_data_file("ack_frames_with_crc.proto.xml") + protocol = ProtocolAnalyzer(signal=None, filename=proto_file) + protocol.from_xml_file(filename=proto_file, read_bits=True) + + self.clear_message_types(protocol.messages) + + ff = FormatFinder(protocol.messages) + ff.known_participant_addresses.clear() + + ff.run() + self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[0]), "1337") + self.assertEqual(util.convert_numbers_to_hex_string(ff.known_participant_addresses[1]), "4711") + + for mt in ff.message_types: + preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE) + self.assertEqual(preamble.start, 0) + self.assertEqual(preamble.length, 16) + sync = mt.get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 16) + self.assertEqual(sync.length, 16) + length = mt.get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(length.start, 32) + self.assertEqual(length.length, 8) + + def test_address_engine_performance(self): + ff, messages = self.get_format_finder_from_protocol_file("35_messages.proto.xml", return_messages=True) + + engine = AddressEngine(ff.hexvectors, ff.participant_indices) + engine.find() + + def test_paper_example(self): + alice = Participant("Alice", "A") + bob = Participant("Bob", "B") + participants = [alice, bob] + msg1 = Message.from_plain_hex_str("aabb1234") + msg1.participant = alice + msg2 = Message.from_plain_hex_str("aabb6789") + msg2.participant = alice + msg3 = Message.from_plain_hex_str("bbaa4711") + msg3.participant = bob + msg4 = Message.from_plain_hex_str("bbaa1337") + msg4.participant = bob + + protocol = ProtocolAnalyzer(None) + protocol.messages.extend([msg1, msg2, msg3, msg4]) + #self.save_protocol("paper_example", protocol) + + bitvectors = FormatFinder.get_bitvectors_from_messages(protocol.messages) + hexvectors = FormatFinder.get_hexvectors(bitvectors) + address_engine = AddressEngine(hexvectors, participant_indices=[participants.index(msg.participant) for msg in + protocol.messages]) + + def test_find_common_sub_sequence(self): + from urh.cythonext import awre_util + str1 = "0612345678" + str2 = "0756781234" + + seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C") + seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C") + + indices = awre_util.find_longest_common_sub_sequence_indices(seq1, seq2) + self.assertEqual(len(indices), 2) + for ind in indices: + s = str1[slice(*ind)] + self.assertIn(s, ("5678", "1234")) + self.assertIn(s, str1) + self.assertIn(s, str2) + + def test_find_first_occurrence(self): + from urh.cythonext import awre_util + str1 = "00" * 100 + "1234500012345" + "00" * 100 + str2 = "12345" + + seq1 = np.array(list(map(int, str1)), dtype=np.uint8, order="C") + seq2 = np.array(list(map(int, str2)), dtype=np.uint8, order="C") + indices = awre_util.find_occurrences(seq1, seq2) + self.assertEqual(len(indices), 2) + index = indices[0] + self.assertEqual(str1[index:index + len(str2)], str2) + + # Test with ignoring indices + indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 205)))) + self.assertEqual(len(indices), 1) + + # Test with ignoring indices + indices = awre_util.find_occurrences(seq1, seq2, array("L", list(range(0, 210)))) + self.assertEqual(len(indices), 0) + + self.assertEqual(awre_util.find_occurrences(seq1, np.ones(10, dtype=np.uint8)), []) diff --git a/tests/awre/test_awre_preprocessing.py b/tests/awre/test_awre_preprocessing.py new file mode 100644 index 0000000000..45ad40b4b8 --- /dev/null +++ b/tests/awre/test_awre_preprocessing.py @@ -0,0 +1,256 @@ +import random + +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.Preprocessor import Preprocessor +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Message import Message +from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer +import numpy as np + + +class TestAWREPreprocessing(AWRETestCase): + def test_very_simple_sync_word_finding(self): + preamble = "10101010" + sync = "1101" + + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)], + num_messages=(20,), + data=(lambda i: 10 * i,)) + + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + + possible_syncs = preprocessor.find_possible_syncs() + #self.save_protocol("very_simple_sync_test", pg) + self.assertGreaterEqual(len(possible_syncs), 1) + self.assertEqual(preprocessor.find_possible_syncs()[0], sync) + + def test_simple_sync_word_finding(self): + preamble = "10101010" + sync = "1001" + + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "1010", sync)], + num_messages=(20, 5), + data=(lambda i: 10 * i, lambda i: 22 * i)) + + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + + possible_syncs = preprocessor.find_possible_syncs() + #self.save_protocol("simple_sync_test", pg) + self.assertGreaterEqual(len(possible_syncs), 1) + self.assertEqual(preprocessor.find_possible_syncs()[0], sync) + + def test_sync_word_finding_odd_preamble(self): + preamble = "0101010" + sync = "1101" + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)], + num_messages=(20, 5), + data=(lambda i: 10 * i, lambda i: i)) + + # If we have a odd preamble length, the last bit of the preamble is counted to the sync + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + possible_syncs = preprocessor.find_possible_syncs() + + #self.save_protocol("odd_preamble", pg) + self.assertEqual(preamble[-1] + sync[:-1], possible_syncs[0]) + + def test_sync_word_finding_special_preamble(self): + preamble = "111001110011100" + sync = "0110" + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)], + num_messages=(20, 5), + data=(lambda i: 10 * i, lambda i: i)) + + # If we have a odd preamble length, the last bit of the preamble is counted to the sync + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + possible_syncs = preprocessor.find_possible_syncs() + + #self.save_protocol("special_preamble", pg) + self.assertEqual(sync, possible_syncs[0]) + + def test_sync_word_finding_errored_preamble(self): + preamble = "00010101010" # first bits are wrong + sync = "0110" + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync), (preamble + "10", sync)], + num_messages=(20, 5), + data=(lambda i: 10 * i, lambda i: i)) + + # If we have a odd preamble length, the last bit of the preamble is counted to the sync + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + possible_syncs = preprocessor.find_possible_syncs() + + #self.save_protocol("errored_preamble", pg) + self.assertIn(preamble[-1] + sync[:-1], possible_syncs) + + def test_sync_word_finding_with_two_sync_words(self): + preamble = "0xaaaa" + sync1, sync2 = "0x1234", "0xcafe" + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync1), (preamble, sync2)], + num_messages=(15, 10), + data=(lambda i: 12 * i, lambda i: 16 * i)) + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + possible_syncs = preprocessor.find_possible_syncs() + #self.save_protocol("two_syncs", pg) + self.assertGreaterEqual(len(possible_syncs), 2) + self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs) + self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs) + + def test_multiple_sync_words(self): + hex_messages = [ + "aaS1234", + "aaScafe", + "aaSdead", + "aaSbeef", + ] + + for i in range(1, 256): + messages = [] + sync = "{0:02x}".format(i) + if sync.startswith("a"): + continue + + for msg in hex_messages: + messages.append(Message.from_plain_hex_str(msg.replace("S", sync))) + + for i in range(1, len(messages)): + messages[i].message_type = messages[0].message_type + + ff = FormatFinder(messages) + ff.run() + + self.assertEqual(len(ff.message_types), 1, msg=sync) + + preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE) + self.assertEqual(preamble.start, 0, msg=sync) + self.assertEqual(preamble.length, 8, msg=sync) + + sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 8, msg=sync) + self.assertEqual(sync.length, 8, msg=sync) + + def test_sync_word_finding_varying_message_length(self): + hex_messages = [ + "aaaa9a7d0f1337471100009a44ebdd13517bf9", + "aaaa9a7d4747111337000134a4473c002b909630b11df37e34728c79c60396176aff2b5384e82f31511581d0cbb4822ad1b6734e2372ad5cf4af4c9d6b067e5f7ec359ec443c3b5ddc7a9e", + "aaaa9a7d0f13374711000205ee081d26c86b8c", + "aaaa9a7d474711133700037cae4cda789885f88f5fb29adc9acf954cb2850b9d94e7f3b009347c466790e89f2bcd728987d4670690861bbaa120f71f14d4ef8dc738a6d7c30e7d2143c267", + "aaaa9a7d0f133747110004c2906142300427f3" + ] + + messages = [Message.from_plain_hex_str(hex_msg) for hex_msg in hex_messages] + for i in range(1, len(messages)): + messages[i].message_type = messages[0].message_type + + ff = FormatFinder(messages) + ff.run() + + self.assertEqual(len(ff.message_types), 1) + preamble = ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE) + self.assertEqual(preamble.start, 0) + self.assertEqual(preamble.length, 16) + + sync = ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 16) + self.assertEqual(sync.length, 16) + + def test_sync_word_finding_common_prefix(self): + """ + Messages are very similiar (odd and even ones are the same) + However, they do not have two different sync words! + The algorithm needs to check for a common prefix of the two found sync words + + :return: + """ + sync = "0x1337" + num_messages = 10 + + alice = Participant("Alice", address_hex="dead01") + bob = Participant("Bob", address_hex="beef24") + + mb = MessageTypeBuilder("protocol_with_one_message_type") + mb.add_label(FieldType.Function.PREAMBLE, 72) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 24) + mb.add_label(FieldType.Function.DST_ADDRESS, 24) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x1337"}, + preambles_by_mt={mb.message_type: "10" * 36}, + participants=[alice, bob]) + + random.seed(0) + for i in range(num_messages): + if i % 2 == 0: + source, destination = alice, bob + data_length = 8 + else: + source, destination = bob, alice + data_length = 16 + pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length), + source=source, destination=destination) + + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages]) + possible_syncs = preprocessor.find_possible_syncs() + #self.save_protocol("sync_by_common_prefix", pg) + self.assertEqual(len(possible_syncs), 1) + + # +0000 is okay, because this will get fixed by correction in FormatFinder + self.assertIn(possible_syncs[0], [ProtocolGenerator.to_bits(sync), ProtocolGenerator.to_bits(sync) + "0000"]) + + def test_with_given_preamble_and_sync(self): + preamble = "10101010" + sync = "10011" + pg = self.build_protocol_generator(preamble_syncs=[(preamble, sync)], + num_messages=(20,), + data=(lambda i: 10 * i,)) + + # If we have a odd preamble length, the last bit of the preamble is counted to the sync + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in pg.protocol.messages], + existing_message_types={i: msg.message_type for i, msg in + enumerate(pg.protocol.messages)}) + preamble_starts, preamble_lengths, sync_len = preprocessor.preprocess() + + #self.save_protocol("given_preamble", pg) + + self.assertTrue(all(preamble_start == 0 for preamble_start in preamble_starts)) + self.assertTrue(all(preamble_length == len(preamble) for preamble_length in preamble_lengths)) + self.assertEqual(sync_len, len(sync)) + + @staticmethod + def build_protocol_generator(preamble_syncs: list, num_messages: tuple, data: tuple) -> ProtocolGenerator: + message_types = [] + preambles_by_mt = dict() + syncs_by_mt = dict() + + assert len(preamble_syncs) == len(num_messages) == len(data) + + for i, (preamble, sync_word) in enumerate(preamble_syncs): + assert isinstance(preamble, str) + assert isinstance(sync_word, str) + + preamble, sync_word = map(ProtocolGenerator.to_bits, (preamble, sync_word)) + + mb = MessageTypeBuilder("message type #{0}".format(i)) + mb.add_label(FieldType.Function.PREAMBLE, len(preamble)) + mb.add_label(FieldType.Function.SYNC, len(sync_word)) + + message_types.append(mb.message_type) + preambles_by_mt[mb.message_type] = preamble + syncs_by_mt[mb.message_type] = sync_word + + pg = ProtocolGenerator(message_types, preambles_by_mt=preambles_by_mt, syncs_by_mt=syncs_by_mt) + for i, msg_type in enumerate(message_types): + for j in range(num_messages[i]): + if callable(data[i]): + msg_data = pg.decimal_to_bits(data[i](j), num_bits=8) + else: + msg_data = data[i] + + pg.generate_message(message_type=msg_type, data=msg_data) + + return pg diff --git a/tests/awre/test_awre_real_protocols.py b/tests/awre/test_awre_real_protocols.py new file mode 100644 index 0000000000..944c5bc752 --- /dev/null +++ b/tests/awre/test_awre_real_protocols.py @@ -0,0 +1,149 @@ +from tests.awre.AWRETestCase import AWRETestCase +from tests.utils_testing import get_path_for_data_file +from urh.awre.CommonRange import CommonRange +from urh.awre.FormatFinder import FormatFinder +from urh.awre.Preprocessor import Preprocessor +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Message import Message +from urh.signalprocessing.MessageType import MessageType +from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer +import numpy as np + +class TestAWRERealProtocols(AWRETestCase): + def setUp(self): + super().setUp() + alice = Participant("Alice", "A") + bob = Participant("Bob", "B") + self.participants = [alice, bob] + + def test_format_finding_enocean(self): + enocean_protocol = ProtocolAnalyzer(None) + with open(get_path_for_data_file("enocean_bits.txt")) as f: + for line in f: + enocean_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", ""))) + enocean_protocol.messages[-1].message_type = enocean_protocol.default_message_type + + ff = FormatFinder(enocean_protocol.messages) + ff.perform_iteration() + + message_types = ff.message_types + self.assertEqual(len(message_types), 1) + + preamble = message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE) + self.assertEqual(preamble.start, 0) + self.assertEqual(preamble.length, 8) + + sync = message_types[0].get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 8) + self.assertEqual(sync.length, 4) + + checksum = message_types[0].get_first_label_with_type(FieldType.Function.CHECKSUM) + self.assertEqual(checksum.start, 56) + self.assertEqual(checksum.length, 4) + + self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS)) + self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS)) + self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)) + self.assertIsNone(message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)) + + def test_format_finding_rwe(self): + ff, messages = self.get_format_finder_from_protocol_file("rwe.proto.xml", return_messages=True) + ff.run() + + sync1, sync2 = "0x9a7d9a7d", "0x67686768" + + preprocessor = Preprocessor([np.array(msg.plain_bits, dtype=np.uint8) for msg in messages]) + possible_syncs = preprocessor.find_possible_syncs() + self.assertIn(ProtocolGenerator.to_bits(sync1), possible_syncs) + self.assertIn(ProtocolGenerator.to_bits(sync2), possible_syncs) + + ack_messages = (3, 5, 7, 9, 11, 13, 15, 17, 20) + ack_message_type = next(mt for mt, messages in ff.existing_message_types.items() if ack_messages[0] in messages) + self.assertTrue(all(ack_msg in ff.existing_message_types[ack_message_type] for ack_msg in ack_messages)) + + for mt in ff.message_types: + preamble = mt.get_first_label_with_type(FieldType.Function.PREAMBLE) + self.assertEqual(preamble.start, 0) + self.assertEqual(preamble.length, 32) + + sync = mt.get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 32) + self.assertEqual(sync.length, 32) + + length = mt.get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(length.start, 64) + self.assertEqual(length.length, 8) + + dst = mt.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertEqual(dst.length, 24) + + if mt == ack_message_type or 1 in ff.existing_message_types[mt]: + self.assertEqual(dst.start, 72) + else: + self.assertEqual(dst.start, 88) + + if mt != ack_message_type and 1 not in ff.existing_message_types[mt]: + src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertEqual(src.start, 112) + self.assertEqual(src.length, 24) + elif 1 in ff.existing_message_types[mt]: + # long ack + src = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertEqual(src.start, 96) + self.assertEqual(src.length, 24) + + crc = mt.get_first_label_with_type(FieldType.Function.CHECKSUM) + self.assertIsNotNone(crc) + + def test_homematic(self): + proto_file = get_path_for_data_file("homematic.proto.xml") + protocol = ProtocolAnalyzer(signal=None, filename=proto_file) + protocol.message_types = [] + protocol.from_xml_file(filename=proto_file, read_bits=True) + # prevent interfering with preassinged labels + protocol.message_types = [MessageType("Default")] + + participants = sorted({msg.participant for msg in protocol.messages}) + + self.clear_message_types(protocol.messages) + ff = FormatFinder(protocol.messages, participants=participants) + ff.known_participant_addresses.clear() + ff.perform_iteration() + + self.assertGreater(len(ff.message_types), 0) + + for i, message_type in enumerate(ff.message_types): + preamble = message_type.get_first_label_with_type(FieldType.Function.PREAMBLE) + self.assertEqual(preamble.start, 0) + self.assertEqual(preamble.length, 32) + + sync = message_type.get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 32) + self.assertEqual(sync.length, 32) + + length = message_type.get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(length.start, 64) + self.assertEqual(length.length, 8) + + seq = message_type.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(seq.start, 72) + self.assertEqual(seq.length, 8) + + src = message_type.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertEqual(src.start, 96) + self.assertEqual(src.length, 24) + + dst = message_type.get_first_label_with_type(FieldType.Function.DST_ADDRESS) + self.assertEqual(dst.start, 120) + self.assertEqual(dst.length, 24) + + checksum = message_type.get_first_label_with_type(FieldType.Function.CHECKSUM) + self.assertEqual(checksum.length, 16) + self.assertIn("CC1101", checksum.checksum.caption) + + for msg_index in ff.existing_message_types[message_type]: + msg_len = len(protocol.messages[msg_index]) + self.assertEqual(checksum.start, msg_len-16) + self.assertEqual(checksum.end, msg_len) diff --git a/tests/awre/test_checksum_engine.py b/tests/awre/test_checksum_engine.py new file mode 100644 index 0000000000..97cd31ebcd --- /dev/null +++ b/tests/awre/test_checksum_engine.py @@ -0,0 +1,102 @@ +import array + +import numpy as np + +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.CommonRange import ChecksumRange +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.awre.engines.ChecksumEngine import ChecksumEngine +from urh.signalprocessing.FieldType import FieldType +from urh.util import util +from urh.util.GenericCRC import GenericCRC +from urh.cythonext import util as c_util + +class TestChecksumEngine(AWRETestCase): + def test_find_crc8(self): + messages = ["aabbcc7d", "abcdee24", "dacafe33"] + message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)] + + checksum_engine = ChecksumEngine(message_bits, n_gram_length=8) + result = checksum_engine.find() + self.assertEqual(len(result), 1) + checksum_range = result[0] # type: ChecksumRange + self.assertEqual(checksum_range.length, 8) + self.assertEqual(checksum_range.start, 24) + + reference = GenericCRC() + reference.set_polynomial_from_hex("0x07") + self.assertEqual(checksum_range.crc.polynomial, reference.polynomial) + + self.assertEqual(checksum_range.message_indices, {0, 1, 2}) + + def test_find_crc16(self): + messages = ["12345678347B", "abcdefffABBD", "cafe1337CE12"] + message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)] + + checksum_engine = ChecksumEngine(message_bits, n_gram_length=8) + result = checksum_engine.find() + self.assertEqual(len(result), 1) + checksum_range = result[0] # type: ChecksumRange + self.assertEqual(checksum_range.start, 32) + self.assertEqual(checksum_range.length, 16) + + reference = GenericCRC() + reference.set_polynomial_from_hex("0x8005") + self.assertEqual(checksum_range.crc.polynomial, reference.polynomial) + + self.assertEqual(checksum_range.message_indices, {0, 1, 2}) + + def test_find_crc32(self): + messages = ["deadcafe5D7F3F5A", "47111337E3319242", "beefaffe0DCD0E15"] + message_bits = [np.array(msg, dtype=np.uint8) for msg in map(util.hex2bit, messages)] + + checksum_engine = ChecksumEngine(message_bits, n_gram_length=8) + result = checksum_engine.find() + self.assertEqual(len(result), 1) + checksum_range = result[0] # type: ChecksumRange + self.assertEqual(checksum_range.start, 32) + self.assertEqual(checksum_range.length, 32) + + reference = GenericCRC() + reference.set_polynomial_from_hex("0x04C11DB7") + self.assertEqual(checksum_range.crc.polynomial, reference.polynomial) + + self.assertEqual(checksum_range.message_indices, {0, 1, 2}) + + def test_find_generated_crc16(self): + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.DATA, 32) + mb.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT")) + + mb2 = MessageTypeBuilder("data2") + mb2.add_label(FieldType.Function.PREAMBLE, 8) + mb2.add_label(FieldType.Function.SYNC, 16) + mb2.add_label(FieldType.Function.LENGTH, 8) + mb2.add_label(FieldType.Function.DATA, 16) + + mb2.add_checksum_label(16, GenericCRC.from_standard_checksum("CRC16 CCITT")) + + pg = ProtocolGenerator([mb.message_type, mb2.message_type], syncs_by_mt={mb.message_type: "0x1234", mb2.message_type: "0x1234"}) + + num_messages = 5 + + for i in range(num_messages): + pg.generate_message(data="{0:032b}".format(i), message_type=mb.message_type) + pg.generate_message(data="{0:016b}".format(i), message_type=mb2.message_type) + + #self.save_protocol("crc16_test", pg) + self.clear_message_types(pg.protocol.messages) + + ff = FormatFinder(pg.protocol.messages) + ff.run() + + self.assertEqual(len(ff.message_types), 2) + for mt in ff.message_types: + checksum_label = mt.get_first_label_with_type(FieldType.Function.CHECKSUM) + self.assertEqual(checksum_label.length, 16) + self.assertEqual(checksum_label.checksum.caption, "CRC16 CCITT") diff --git a/tests/awre/test_common_range.py b/tests/awre/test_common_range.py new file mode 100644 index 0000000000..2d6517861b --- /dev/null +++ b/tests/awre/test_common_range.py @@ -0,0 +1,35 @@ +import unittest + +from urh.awre.CommonRange import CommonRange + + +class TestCommonRange(unittest.TestCase): + def test_ensure_not_overlaps(self): + test_range = CommonRange(start=4, length=8, value="12345678") + self.assertEqual(test_range.end, 11) + + # no overlapping + self.assertEqual(test_range, test_range.ensure_not_overlaps(0, 3)[0]) + self.assertEqual(test_range, test_range.ensure_not_overlaps(20, 24)[0]) + + # overlapping on left + result = test_range.ensure_not_overlaps(2, 6)[0] + self.assertEqual(result.start, 6) + self.assertEqual(result.end, 11) + + # overlapping on right + result = test_range.ensure_not_overlaps(6, 14)[0] + self.assertEqual(result.start, 4) + self.assertEqual(result.end, 5) + + # full overlapping + self.assertEqual(len(test_range.ensure_not_overlaps(3, 14)), 0) + + # overlapping in the middle + result = test_range.ensure_not_overlaps(6, 9) + self.assertEqual(len(result), 2) + left, right = result[0], result[1] + self.assertEqual(left.start, 4) + self.assertEqual(left.end, 5) + self.assertEqual(right.start, 10) + self.assertEqual(right.end, 11) diff --git a/tests/awre/test_format_finder.py b/tests/awre/test_format_finder.py new file mode 100644 index 0000000000..191bc74186 --- /dev/null +++ b/tests/awre/test_format_finder.py @@ -0,0 +1,102 @@ +import numpy as np + +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.CommonRange import CommonRange, CommonRangeContainer +from urh.awre.FormatFinder import FormatFinder + + +class TestFormatFinder(AWRETestCase): + def test_create_message_types_1(self): + rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length") + rng1.message_indices = {0, 1, 2} + rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address") + rng2.message_indices = {0, 1, 2} + + message_types = FormatFinder.create_common_range_containers({rng1, rng2}) + self.assertEqual(len(message_types), 1) + + expected = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2}) + self.assertEqual(message_types[0], expected) + + def test_create_message_types_2(self): + rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length") + rng1.message_indices = {0, 2, 4, 6, 8, 12} + rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address") + rng2.message_indices = {1, 2, 3, 4, 5, 12} + rng3 = CommonRange(16, 8, "1" * 8, score=1, field_type="Seq") + rng3.message_indices = {1, 3, 5, 7, 12} + + message_types = FormatFinder.create_common_range_containers({rng1, rng2, rng3}) + expected1 = CommonRangeContainer([rng1], message_indices={0, 6, 8}) + expected2 = CommonRangeContainer([rng1, rng2], message_indices={2, 4}) + expected3 = CommonRangeContainer([rng1, rng2, rng3], message_indices={12}) + expected4 = CommonRangeContainer([rng2, rng3], message_indices={1, 3, 5}) + expected5 = CommonRangeContainer([rng3], message_indices={7}) + + self.assertEqual(len(message_types), 5) + + self.assertIn(expected1, message_types) + self.assertIn(expected2, message_types) + self.assertIn(expected3, message_types) + self.assertIn(expected4, message_types) + self.assertIn(expected5, message_types) + + def test_retransform_message_indices(self): + sync_ends = np.array([12, 12, 12, 14, 14]) + + rng = CommonRange(0, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2, 3, 4}) + retransformed_ranges = FormatFinder.retransform_message_indices([rng], [0, 1, 2, 3, 4], sync_ends) + + # two different sync ends + self.assertEqual(len(retransformed_ranges), 2) + + expected1 = CommonRange(12, 8, "1" * 8, score=1, field_type="length", message_indices={0, 1, 2}) + expected2 = CommonRange(14, 8, "1" * 8, score=1, field_type="length", message_indices={3, 4}) + + self.assertIn(expected1, retransformed_ranges) + self.assertIn(expected2, retransformed_ranges) + + def test_handle_no_overlapping_conflict(self): + rng1 = CommonRange(0, 8, "1" * 8, score=1, field_type="Length") + rng1.message_indices = {0, 1, 2} + rng2 = CommonRange(8, 8, "1" * 8, score=1, field_type="Address") + rng2.message_indices = {0, 1, 2} + + container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2}) + + # no conflict + result = FormatFinder.handle_overlapping_conflict([container]) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 2) + self.assertIn(rng1, result[0]) + self.assertEqual(result[0].message_indices, {0, 1, 2}) + self.assertIn(rng2, result[0]) + + def test_handle_easy_overlapping_conflict(self): + # Easy conflict: First Label has higher score + rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length") + rng1.message_indices = {0, 1, 2} + rng2 = CommonRange(8, 8, "1" * 8, score=0.8, field_type="Address") + rng2.message_indices = {0, 1, 2} + + container = CommonRangeContainer([rng1, rng2], message_indices={0, 1, 2}) + result = FormatFinder.handle_overlapping_conflict([container]) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 1) + self.assertIn(rng1, result[0]) + self.assertEqual(result[0].message_indices, {0, 1, 2}) + + def test_handle_medium_overlapping_conflict(self): + rng1 = CommonRange(8, 8, "1" * 8, score=1, field_type="Length") + rng2 = CommonRange(4, 10, "1" * 8, score=0.8, field_type="Address") + rng3 = CommonRange(15, 20, "1" * 8, score=1, field_type="Seq") + rng4 = CommonRange(60, 80, "1" * 8, score=0.8, field_type="Type") + rng5 = CommonRange(70, 90, "1" * 8, score=0.9, field_type="Data") + + container = CommonRangeContainer([rng1, rng2, rng3, rng4, rng5]) + result = FormatFinder.handle_overlapping_conflict([container]) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 3) + self.assertIn(rng1, result[0]) + self.assertIn(rng3, result[0]) + self.assertIn(rng5, result[0]) diff --git a/tests/awre/test_generated_protocols.py b/tests/awre/test_generated_protocols.py new file mode 100644 index 0000000000..0aa055b747 --- /dev/null +++ b/tests/awre/test_generated_protocols.py @@ -0,0 +1,236 @@ +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre import AutoAssigner +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.Preprocessor import Preprocessor +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Participant import Participant +from urh.util import util + + +class TestGeneratedProtocols(AWRETestCase): + def __check_addresses(self, messages, format_finder, known_participant_addresses): + """ + Use the AutoAssigner used also in main GUI to test assigned participant addresses to get same results + as in main program and not rely on cache of FormatFinder, because values there might be false + but SRC address labels still on right position which is the basis for Auto Assigner + + :param messages: + :param format_finder: + :param known_participant_addresses: + :return: + """ + + for msg_type, indices in format_finder.existing_message_types.items(): + for i in indices: + messages[i].message_type = msg_type + + participants = list(set(m.participant for m in messages)) + for p in participants: + p.address_hex = "" + AutoAssigner.auto_assign_participant_addresses(messages, participants) + + for i in range(len(participants)): + self.assertIn(participants[i].address_hex, + list(map(util.convert_numbers_to_hex_string, known_participant_addresses.values())), + msg=" [ " + " ".join(p.address_hex for p in participants) + " ]") + + def test_without_preamble(self): + alice = Participant("Alice", address_hex="24") + broadcast = Participant("Broadcast", address_hex="ff") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 8) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x8e88"}, + preambles_by_mt={mb.message_type: "10" * 8}, + participants=[alice, broadcast]) + + for i in range(20): + data_bits = 16 if i % 2 == 0 else 32 + source = pg.participants[i % 2] + destination = pg.participants[(i + 1) % 2] + pg.generate_message(data="1010" * (data_bits // 4), source=source, destination=destination) + + #self.save_protocol("without_preamble", pg) + self.clear_message_types(pg.messages) + ff = FormatFinder(pg.messages) + ff.known_participant_addresses.clear() + + ff.run() + self.assertEqual(len(ff.message_types), 1) + + mt = ff.message_types[0] + sync = mt.get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 0) + self.assertEqual(sync.length, 16) + + length = mt.get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(length.start, 16) + self.assertEqual(length.length, 8) + + dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertEqual(dst.start, 24) + self.assertEqual(dst.length, 8) + + seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(seq.start, 32) + self.assertEqual(seq.length, 8) + + def test_without_preamble_random_data(self): + ff = self.get_format_finder_from_protocol_file("without_ack_random_data.proto.xml") + ff.run() + + self.assertEqual(len(ff.message_types), 1) + + mt = ff.message_types[0] + sync = mt.get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 0) + self.assertEqual(sync.length, 16) + + length = mt.get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(length.start, 16) + self.assertEqual(length.length, 8) + + dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertEqual(dst.start, 24) + self.assertEqual(dst.length, 8) + + seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(seq.start, 32) + self.assertEqual(seq.length, 8) + + def test_without_preamble_random_data2(self): + ff = self.get_format_finder_from_protocol_file("without_ack_random_data2.proto.xml") + ff.run() + + self.assertEqual(len(ff.message_types), 1) + + mt = ff.message_types[0] + sync = mt.get_first_label_with_type(FieldType.Function.SYNC) + self.assertEqual(sync.start, 0) + self.assertEqual(sync.length, 16) + + length = mt.get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(length.start, 16) + self.assertEqual(length.length, 8) + + dst = mt.get_first_label_with_type(FieldType.Function.SRC_ADDRESS) + self.assertEqual(dst.start, 24) + self.assertEqual(dst.length, 8) + + seq = mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(seq.start, 32) + self.assertEqual(seq.length, 8) + + def test_with_checksum(self): + ff = self.get_format_finder_from_protocol_file("with_checksum.proto.xml", clear_participant_addresses=False) + known_participant_addresses = ff.known_participant_addresses.copy() + ff.known_participant_addresses.clear() + ff.run() + + self.assertIn(known_participant_addresses[0].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + self.assertIn(known_participant_addresses[1].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + + self.assertEqual(len(ff.message_types), 3) + + def test_with_only_one_address(self): + ff = self.get_format_finder_from_protocol_file("only_one_address.proto.xml", clear_participant_addresses=False) + known_participant_addresses = ff.known_participant_addresses.copy() + ff.known_participant_addresses.clear() + + ff.run() + + self.assertIn(known_participant_addresses[0].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + self.assertIn(known_participant_addresses[1].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + + def test_with_four_broken(self): + ff, messages = self.get_format_finder_from_protocol_file("four_broken.proto.xml", + clear_participant_addresses=False, + return_messages=True) + + assert isinstance(ff, FormatFinder) + known_participant_addresses = ff.known_participant_addresses.copy() + ff.known_participant_addresses.clear() + + ff.run() + + self.__check_addresses(messages, ff, known_participant_addresses) + + for i in range(4, len(messages)): + mt = next(mt for mt, indices in ff.existing_message_types.items() if i in indices) + self.assertIsNotNone(mt.get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER)) + + def test_with_one_address_one_message_type(self): + ff, messages = self.get_format_finder_from_protocol_file("one_address_one_mt.proto.xml", + clear_participant_addresses=False, + return_messages=True) + + self.assertEqual(len(messages), 17) + self.assertEqual(len(ff.hexvectors), 17) + + known_participant_addresses = ff.known_participant_addresses.copy() + ff.known_participant_addresses.clear() + + ff.run() + + self.assertEqual(len(ff.message_types), 1) + + self.assertIn(known_participant_addresses[0].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + self.assertIn(known_participant_addresses[1].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + + def test_without_preamble_24_messages(self): + ff, messages = self.get_format_finder_from_protocol_file("no_preamble24.proto.xml", + clear_participant_addresses=False, + return_messages=True) + + known_participant_addresses = ff.known_participant_addresses.copy() + ff.known_participant_addresses.clear() + + ff.run() + + self.assertEqual(len(ff.message_types), 1) + + self.assertIn(known_participant_addresses[0].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + self.assertIn(known_participant_addresses[1].tostring(), + list(map(bytes, ff.known_participant_addresses.values()))) + + def test_with_three_syncs_different_preamble_lengths(self): + ff, messages = self.get_format_finder_from_protocol_file("three_syncs.proto.xml", return_messages=True) + preprocessor = Preprocessor(ff.get_bitvectors_from_messages(messages)) + sync_words = preprocessor.find_possible_syncs() + self.assertIn("0000010000100000", sync_words, msg="Sync 1") + self.assertIn("0010001000100010", sync_words, msg="Sync 2") + self.assertIn("0110011101100111", sync_words, msg="Sync 3") + + ff.run() + + expected_sync_ends = [32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24, 32, 24, 40, 24] + + for i, (s1, s2) in enumerate(zip(expected_sync_ends, ff.sync_ends)): + self.assertEqual(s1, s2, msg=str(i)) + + def test_with_four_participants(self): + ff, messages = self.get_format_finder_from_protocol_file("four_participants.proto.xml", + clear_participant_addresses=False, + return_messages=True) + + known_participant_addresses = ff.known_participant_addresses.copy() + ff.known_participant_addresses.clear() + + ff.run() + + self.__check_addresses(messages, ff, known_participant_addresses) + self.assertEqual(len(ff.message_types), 3) diff --git a/tests/awre/test_length_engine.py b/tests/awre/test_length_engine.py new file mode 100644 index 0000000000..435a19e87f --- /dev/null +++ b/tests/awre/test_length_engine.py @@ -0,0 +1,167 @@ +import random + +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.awre.engines.LengthEngine import LengthEngine +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.ProtocoLabel import ProtocolLabel + + +class TestLengthEngine(AWRETestCase): + def test_simple_protocol(self): + """ + Test a simple protocol with + preamble, sync and length field (8 bit) and some random data + + :return: + """ + mb = MessageTypeBuilder("simple_length_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + + num_messages_by_data_length = {8: 5, 16: 10, 32: 15} + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}) + random.seed(0) + for data_length, num_messages in num_messages_by_data_length.items(): + for i in range(num_messages): + pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)])) + + #self.save_protocol("simple_length", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + length_engine = LengthEngine(ff.bitvectors) + highscored_ranges = length_engine.find(n_gram_length=8) + self.assertEqual(len(highscored_ranges), 3) + + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 1) + self.assertGreater(len(ff.message_types[0]), 0) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(label.start, 24) + self.assertEqual(label.length, 8) + + def test_easy_protocol(self): + """ + preamble, sync, sequence number, length field (8 bit) and some random data + + :return: + """ + mb = MessageTypeBuilder("easy_length_test") + mb.add_label(FieldType.Function.PREAMBLE, 16) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + num_messages_by_data_length = {32: 10, 64: 15, 16: 5, 24: 7} + pg = ProtocolGenerator([mb.message_type], + preambles_by_mt={mb.message_type: "10" * 8}, + syncs_by_mt={mb.message_type: "0xcafe"}) + for data_length, num_messages in num_messages_by_data_length.items(): + for i in range(num_messages): + if i % 4 == 0: + data = "1" * data_length + elif i % 4 == 1: + data = "0" * data_length + elif i % 4 == 2: + data = "10" * (data_length // 2) + else: + data = "01" * (data_length // 2) + + pg.generate_message(data=data) + + #self.save_protocol("easy_length", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + length_engine = LengthEngine(ff.bitvectors) + highscored_ranges = length_engine.find(n_gram_length=8) + self.assertEqual(len(highscored_ranges), 4) + + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 1) + self.assertGreater(len(ff.message_types[0]), 0) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH) + self.assertIsInstance(label, ProtocolLabel) + self.assertEqual(label.start, 32) + self.assertEqual(label.length, 8) + + def test_medium_protocol(self): + """ + Protocol with two message types. Length field only present in one of them + + :return: + """ + mb1 = MessageTypeBuilder("data") + mb1.add_label(FieldType.Function.PREAMBLE, 8) + mb1.add_label(FieldType.Function.SYNC, 8) + mb1.add_label(FieldType.Function.LENGTH, 8) + mb1.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + mb2 = MessageTypeBuilder("ack") + mb2.add_label(FieldType.Function.PREAMBLE, 8) + mb2.add_label(FieldType.Function.SYNC, 8) + + pg = ProtocolGenerator([mb1.message_type, mb2.message_type], + syncs_by_mt={mb1.message_type: "11110011", + mb2.message_type: "11110011"}) + num_messages_by_data_length = {8: 5, 16: 10, 32: 5} + for data_length, num_messages in num_messages_by_data_length.items(): + for i in range(num_messages): + pg.generate_message(data=pg.decimal_to_bits(10 * i, data_length), message_type=mb1.message_type) + pg.generate_message(message_type=mb2.message_type, data="0xaf") + + #self.save_protocol("medium_length", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 2) + length_mt = next( + mt for mt in ff.message_types if mt.get_first_label_with_type(FieldType.Function.LENGTH) is not None) + length_label = length_mt.get_first_label_with_type(FieldType.Function.LENGTH) + + for i, sync_end in enumerate(ff.sync_ends): + self.assertEqual(sync_end, 16, msg=str(i)) + + self.assertEqual(16, length_label.start) + self.assertEqual(8, length_label.length) + + def test_little_endian_16_bit(self): + mb = MessageTypeBuilder("little_endian_16_length_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 16) + + num_messages_by_data_length = {256*8: 5, 16: 4, 512: 2} + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}, + little_endian=True) + + random.seed(0) + for data_length, num_messages in num_messages_by_data_length.items(): + for i in range(num_messages): + pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(data_length)])) + + #self.save_protocol("little_endian_16_length_test", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + length_engine = LengthEngine(ff.bitvectors) + highscored_ranges = length_engine.find(n_gram_length=8) + self.assertEqual(len(highscored_ranges), 3) + + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 1) + self.assertGreater(len(ff.message_types[0]), 0) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH) + self.assertEqual(label.start, 24) + self.assertEqual(label.length, 16) diff --git a/tests/awre/test_partially_labeled.py b/tests/awre/test_partially_labeled.py new file mode 100644 index 0000000000..cf8c459bff --- /dev/null +++ b/tests/awre/test_partially_labeled.py @@ -0,0 +1,198 @@ +import copy +import random + +from urh.signalprocessing.MessageType import MessageType + +from urh.awre.FormatFinder import FormatFinder + +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.signalprocessing.FieldType import FieldType + +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.signalprocessing.Participant import Participant +from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer + + +class TestPartiallyLabeled(AWRETestCase): + """ + Some tests if there are already information about the message types present + + """ + def test_fully_labeled(self): + """ + For fully labeled protocol, nothing should be done + + :return: + """ + protocol = self.__prepare_example_protocol() + message_types = sorted(copy.deepcopy(protocol.message_types), key=lambda x: x.name) + ff = FormatFinder(protocol.messages) + ff.perform_iteration() + self.assertEqual(len(message_types), len(ff.message_types)) + + for mt1, mt2 in zip(message_types, ff.message_types): + self.assertTrue(self.__message_types_have_same_labels(mt1, mt2)) + + def test_one_message_type_empty(self): + """ + Empty the "ACK" message type, the labels should be find by FormatFinder + + :return: + """ + protocol = self.__prepare_example_protocol() + n_message_types = len(protocol.message_types) + ack_mt = next(mt for mt in protocol.message_types if mt.name == "ack") + ack_mt.clear() + self.assertEqual(len(ack_mt), 0) + + ff = FormatFinder(protocol.messages) + ff.perform_iteration() + self.assertEqual(n_message_types, len(ff.message_types)) + + self.assertEqual(len(ack_mt), 4, msg=str(ack_mt)) + + def test_given_address_information(self): + """ + Empty both message types and see if addresses are found, when information of participant addresses is given + + :return: + """ + protocol = self.__prepare_example_protocol() + self.clear_message_types(protocol.messages) + + ff = FormatFinder(protocol.messages) + ff.perform_iteration() + self.assertEqual(2, len(ff.message_types)) + + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)) + self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.PREAMBLE)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)) + self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SYNC)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)) + self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.LENGTH)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS)) + self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.DST_ADDRESS)) + self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS)) + self.assertIsNotNone(ff.message_types[1].get_first_label_with_type(FieldType.Function.SRC_ADDRESS)) + + def test_type_part_already_labeled(self): + protocol = self.__prepare_simple_example_protocol() + self.clear_message_types(protocol.messages) + ff = FormatFinder(protocol.messages) + + # overlaps type + ff.message_types[0].add_protocol_label_start_length(32, 8) + ff.perform_iteration() + self.assertEqual(1, len(ff.message_types)) + + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS)) + + def test_length_part_already_labeled(self): + protocol = self.__prepare_simple_example_protocol() + self.clear_message_types(protocol.messages) + ff = FormatFinder(protocol.messages) + + # overlaps length + ff.message_types[0].add_protocol_label_start_length(24, 8) + ff.perform_iteration() + self.assertEqual(1, len(ff.message_types)) + + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)) + self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS)) + + def test_address_part_already_labeled(self): + protocol = self.__prepare_simple_example_protocol() + self.clear_message_types(protocol.messages) + ff = FormatFinder(protocol.messages) + + # overlaps dst address + ff.message_types[0].add_protocol_label_start_length(40, 16) + ff.perform_iteration() + self.assertEqual(1, len(ff.message_types)) + + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.PREAMBLE)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SYNC)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.LENGTH)) + self.assertIsNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.DST_ADDRESS)) + self.assertIsNotNone(ff.message_types[0].get_first_label_with_type(FieldType.Function.SRC_ADDRESS)) + + @staticmethod + def __message_types_have_same_labels(mt1: MessageType, mt2: MessageType): + if len(mt1) != len(mt2): + return False + + for i, lbl in enumerate(mt1): + if lbl != mt2[i]: + return False + + return True + + def __prepare_example_protocol(self) -> ProtocolAnalyzer: + alice = Participant("Alice", "A", address_hex="1234") + bob = Participant("Bob", "B", address_hex="cafe") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.TYPE, 8) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb_ack = MessageTypeBuilder("ack") + mb_ack.add_label(FieldType.Function.PREAMBLE, 8) + mb_ack.add_label(FieldType.Function.SYNC, 16) + mb_ack.add_label(FieldType.Function.LENGTH, 8) + mb_ack.add_label(FieldType.Function.DST_ADDRESS, 16) + + num_messages = 50 + + pg = ProtocolGenerator([mb.message_type, mb_ack.message_type], + syncs_by_mt={mb.message_type: "0x6768", mb_ack.message_type: "0x6768"}, + participants=[alice, bob]) + + random.seed(0) + for i in range(num_messages): + if i % 2 == 0: + source, destination = alice, bob + data_length = 8 + else: + source, destination = bob, alice + data_length = 16 + pg.generate_message(data=pg.decimal_to_bits(random.randint(0, 2 ** (data_length - 1)), data_length), + source=source, destination=destination) + pg.generate_message(data="", message_type=mb_ack.message_type, destination=source, source=destination) + + #self.save_protocol("labeled_protocol", pg) + + return pg.protocol + + def __prepare_simple_example_protocol(self): + random.seed(0) + alice = Participant("Alice", "A", address_hex="1234") + bob = Participant("Bob", "B", address_hex="cafe") + + mb = MessageTypeBuilder("data") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.TYPE, 8) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x6768"}, + participants=[alice, bob]) + + for i in range(10): + pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(16)]), source=alice, destination=bob) + pg.generate_message(data="".join([random.choice(["0", "1"]) for _ in range(8)]), source=bob, destination=alice) + + return pg.protocol diff --git a/tests/awre/test_sequence_number_engine.py b/tests/awre/test_sequence_number_engine.py new file mode 100644 index 0000000000..6b02e9535b --- /dev/null +++ b/tests/awre/test_sequence_number_engine.py @@ -0,0 +1,182 @@ +from tests.awre.AWRETestCase import AWRETestCase +from urh.awre.CommonRange import CommonRange +from urh.awre.FormatFinder import FormatFinder +from urh.awre.MessageTypeBuilder import MessageTypeBuilder +from urh.awre.ProtocolGenerator import ProtocolGenerator +from urh.awre.engines.SequenceNumberEngine import SequenceNumberEngine +from urh.signalprocessing.FieldType import FieldType +from urh.signalprocessing.Participant import Participant + + +class TestSequenceNumberEngine(AWRETestCase): + def test_simple_protocol(self): + """ + Test a simple protocol with + preamble, sync and increasing sequence number (8 bit) and some constant data + + :return: + """ + mb = MessageTypeBuilder("simple_seq_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 8) + + num_messages = 20 + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}) + + for i in range(num_messages): + pg.generate_message(data="0xcafe") + + #self.save_protocol("simple_sequence_number", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + + seq_engine = SequenceNumberEngine(ff.bitvectors, n_gram_length=8) + highscored_ranges = seq_engine.find() + self.assertEqual(len(highscored_ranges), 1) + + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 1) + self.assertGreater(len(ff.message_types[0]), 0) + self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(label.start, 24) + self.assertEqual(label.length, 8) + + def test_16bit_seq_nr(self): + mb = MessageTypeBuilder("16bit_seq_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + num_messages = 10 + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=64) + + for i in range(num_messages): + pg.generate_message(data="0xcafe") + + #self.save_protocol("16bit_seq", pg) + + bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages) + seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8) + highscored_ranges = seq_engine.find() + self.assertEqual(len(highscored_ranges), 1) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + ff.perform_iteration() + + self.assertEqual(len(ff.message_types), 1) + self.assertGreater(len(ff.message_types[0]), 0) + self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(label.start, 24) + self.assertEqual(label.length, 16) + + def test_16bit_seq_nr_with_zeros_in_first_part(self): + mb = MessageTypeBuilder("16bit_seq_first_byte_zero_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + num_messages = 10 + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}, sequence_number_increment=1) + + for i in range(num_messages): + pg.generate_message(data="0xcafe" + "abc" * i) + + #self.save_protocol("16bit_seq_first_byte_zero_test", pg) + + bitvectors = FormatFinder.get_bitvectors_from_messages(pg.protocol.messages, sync_ends=[24]*num_messages) + seq_engine = SequenceNumberEngine(bitvectors, n_gram_length=8) + highscored_ranges = seq_engine.find() + self.assertEqual(len(highscored_ranges), 1) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + ff.perform_iteration() + self.assertEqual(len(ff.message_types), 1) + self.assertGreater(len(ff.message_types[0]), 0) + self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + + # Not consider constants as part of SEQ Nr! + self.assertEqual(label.start, 40) + self.assertEqual(label.length, 8) + + def test_no_sequence_number(self): + """ + Ensure no sequence number is labeled, when it cannot be found + + :return: + """ + alice = Participant("Alice", address_hex="dead") + bob = Participant("Bob", address_hex="beef") + + mb = MessageTypeBuilder("protocol_with_one_message_type") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.LENGTH, 8) + mb.add_label(FieldType.Function.SRC_ADDRESS, 16) + mb.add_label(FieldType.Function.DST_ADDRESS, 16) + + num_messages = 3 + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x1337"}, + participants=[alice, bob]) + + for i in range(num_messages): + if i % 2 == 0: + source, destination = alice, bob + else: + source, destination = bob, alice + pg.generate_message(data="", source=source, destination=destination) + + #self.save_protocol("protocol_1", pg) + + # Delete message type information -> no prior knowledge + self.clear_message_types(pg.protocol.messages) + + ff = FormatFinder(pg.protocol.messages) + ff.known_participant_addresses.clear() + ff.perform_iteration() + + self.assertEqual(len(ff.message_types), 1) + + self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 0) + + def test_sequence_number_little_endian_16_bit(self): + mb = MessageTypeBuilder("16bit_seq_test") + mb.add_label(FieldType.Function.PREAMBLE, 8) + mb.add_label(FieldType.Function.SYNC, 16) + mb.add_label(FieldType.Function.SEQUENCE_NUMBER, 16) + + num_messages = 8 + + pg = ProtocolGenerator([mb.message_type], + syncs_by_mt={mb.message_type: "0x9a9d"}, + little_endian=True, sequence_number_increment=64) + + for i in range(num_messages): + pg.generate_message(data="0xcafe") + + #self.save_protocol("16bit_litte_endian_seq", pg) + + self.clear_message_types(pg.protocol.messages) + ff = FormatFinder(pg.protocol.messages) + ff.perform_iteration() + + self.assertEqual(len(ff.message_types), 1) + self.assertEqual(ff.message_types[0].num_labels_with_type(FieldType.Function.SEQUENCE_NUMBER), 1) + label = ff.message_types[0].get_first_label_with_type(FieldType.Function.SEQUENCE_NUMBER) + self.assertEqual(label.start, 24) + self.assertEqual(label.length, 16) diff --git a/tests/data/35_messages.proto.xml b/tests/data/35_messages.proto.xml new file mode 100644 index 0000000000..730aa210c0 --- /dev/null +++ b/tests/data/35_messages.proto.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/ack_frames_with_crc.proto.xml b/tests/data/ack_frames_with_crc.proto.xml new file mode 100644 index 0000000000..c90be7b58b --- /dev/null +++ b/tests/data/ack_frames_with_crc.proto.xml @@ -0,0 +1,77 @@ + + + + 'Non Return To Zero (NRZ)', + 'Non Return To Zero Inverted (NRZ-I)', 'Invert', + 'Manchester I', 'Edge Trigger', + 'Manchester II', 'Edge Trigger', 'Invert', + 'Differential Manchester', 'Edge Trigger', 'Differential Encoding', + 'WSP', 'Wireless Short Packet (WSP)', + 'Nexa', 'Substitution', '100000:0;', 'Substitution', '10:1;', + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/awre_consistent_addresses.txt b/tests/data/awre_consistent_addresses.txt deleted file mode 100644 index 15a1a22834..0000000000 --- a/tests/data/awre_consistent_addresses.txt +++ /dev/null @@ -1,21 +0,0 @@ -1010101010101010101010101010101010011010011111011001101001111101001011011000000001110000111000000000000000000011000110110110000000110011000000000000000100000000010000000011000000010100111101110100100010011100010110010100010100000000000000000000111111001001101101100000000000000001000000000000000000000000000000100000000100000100000101110000000000000010000000010000001010001100001111010111011100000000110110111011101101111011100110001100111010001111 -1010101010101010101010101010101001100111011010000110011101101000000001100001101101100000001100110111100011100010100010010011001000101011 -10101010101010101010101010101010100110100111110110011010011111010011001000000000011100000001101101100000001100110111100011100010100010011100000000000001000001000000001000000110011100100100100110010111111110101010000011011101011011110000111101010111101010001001010010000111111101010111100101011010101001110111011011111110010101111100111000100100110100001111101111001011111001011000100010101011011001110001111010110101111100100110011001010011111011111011010101100111011010000011101010011101 -1010101010101010101010101010101010011010011111011001101001111101000000110111100011100010100010010111010101111110 -101010101010101010101010101010100110011101101000011001110110100000010011001000000111000001111000111000101000100100011011011000000011001100000000000000000000000000000000101001000011011110100110001001011001101110000110011000001101110100000111 -1010101010101010101010101010101001100111011010000110011101101000000000110001101101100000001100111111110101010111 -10101010101010101010101010101010011001110110100001100111011010000001101000100000011100000001101101100000001100110111100011100010100010010000000000001100011000100000111001100110101001001000110001101101000011000111110011111110110100011111001110111100100011101110101011100001011000011011000001010111 -1010101010101010101010101010101001100111011010000110011101101000000000110111100011100010100010010111010101111110 -101010101010101010101010101010100110011101101000011001110110100000010011001000000111000001111000111000101000100100011011011000000011001100000000000000000000000000000001010110110100111000100100010001011010101101010011100001100100011011110101 -1010101010101010101010101010101001100111011010000110011101101000000000110001101101100000001100111111110101010111 -10101010101010101010101010101010011001110110100001100111011010000001101000100000011100000001101101100000001100110111100011100010100010010000000000001100011000100000111100111010000110001010100101011000010001101001000000010100101100011110011100010001000011000001001010000010000100111010010010110101 -1010101010101010101010101010101001100111011010000110011101101000000000110111100011100010100010010111010101111110 -101010101010101010101010101010100110011101101000011001110110100000010011001000000111000001111000111000101000100100011011011000000011001100000000000000000000000000000010101011111011101010110000010001110000110000011110000001110101101011001001 -1010101010101010101010101010101001100111011010000110011101101000000000110001101101100000001100111111110101010111 -10101010101010101010101010101010011001110110100001100111011010000001010000100000011100000001101101100000001100110111100011100010100010010000000000001100011000100001000010011100100000101001100001010000001111001101001011110010111001000100001010011101 -1010101010101010101010101010101001100111011010000110011101101000000000110111100011100010100010010111010101111110 -101010101010101010101010101010100110011101101000011001110110100000010011001000000111000001111000111000101000100100011011011000000011001100000000000000000000000000000011110000010100111010111110000000011011011111011111010011010011011011110100 -1010101010101010101010101010101001100111011010000110011101101000000000110001101101100000001100111111110101010111 -10101010101010101010101010101010011001110110100001100111011010000001010000100000011100000001101101100000001100110111100011100010100010010000000000001100011000100001000110010011100100111011000011111110001110100011111011110100000000010011011000011001 -10101010101010101010101010101010011001110110100001100111011010000001011101100000011100000111100011100010100010010001101101100000001100110000000000000000000000000000010010011111101101101111001010101100111100000101101000011101001010110110000111100110000000111100010000100111 -1010101010101010101010101010101001100111011010000110011101101000000000110001101101100000001100111111110101010111 \ No newline at end of file diff --git a/tests/data/enocean_bits.txt b/tests/data/enocean_bits.txt index 052d708a9c..278e75dbcb 100644 --- a/tests/data/enocean_bits.txt +++ b/tests/data/enocean_bits.txt @@ -1,12 +1,12 @@ -11110101010100101100001000000000000001011000001110000000010010010111 -11110101010100101100001000000000000001011000001110000000010010010111 -11110101010100101100001000000000000001011000001110000000010010010111 -11110101010100101010000000000000000001011000001110000000010001010111 -11110101010100101010000000000000000001011000001110000000010001010111 -11110101010100101010000000000000000001011000001110000000010001010111 -11110101010100101100011000000000000001011000001110000000010011010111 -11110101010100101100011000000000000001011000001110000000010011010111 -11110101010100101100011000000000000001011000001110000000010011010111 -11110101010100101010000000000000000001011000001110000000010001010111 -11110101010100101010000000000000000001011000001110000000010001010111 -11110101010100101010000000000000000001011000001110000000010001010111 \ No newline at end of file +1010101010010110000101010000000000101100000111000000001010011011 +1010101010010110000101010000000000101100000111000000001010011011 +1010101010010110000101010000000000101100000111000000001010011011 +1010101010010101000000000000000000101100000111000000001000101011 +1010101010010101000000000000000000101100000111000000001000101011 +1010101010010101000000000000000000101100000111000000001000101011 +1010101010010110000100000000000000101100000111000000001001001011 +1010101010010110000100000000000000101100000111000000001001001011 +1010101010010110000100000000000000101100000111000000001001001011 +1010101010010101000000000000000000101100000111000000001000101011 +1010101010010101000000000000000000101100000111000000001000101011 +1010101010010101000000000000000000101100000111000000001000101011 \ No newline at end of file diff --git a/tests/data/four_broken.proto.xml b/tests/data/four_broken.proto.xml new file mode 100644 index 0000000000..367e83f5d0 --- /dev/null +++ b/tests/data/four_broken.proto.xml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/four_participants.proto.xml b/tests/data/four_participants.proto.xml new file mode 100644 index 0000000000..eeef293456 --- /dev/null +++ b/tests/data/four_participants.proto.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/homematic.proto.xml b/tests/data/homematic.proto.xml new file mode 100644 index 0000000000..f7c8093111 --- /dev/null +++ b/tests/data/homematic.proto.xml @@ -0,0 +1,153 @@ + + + + 'Non Return To Zero (NRZ)', + 'Non Return To Zero Inverted (NRZ-I)', 'Invert', + 'Manchester I', 'Edge Trigger', + 'Manchester II', 'Edge Trigger', 'Invert', + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/no_preamble24.proto.xml b/tests/data/no_preamble24.proto.xml new file mode 100644 index 0000000000..9ab8c9ffd3 --- /dev/null +++ b/tests/data/no_preamble24.proto.xml @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/one_address_one_mt.proto.xml b/tests/data/one_address_one_mt.proto.xml new file mode 100644 index 0000000000..0e3217c481 --- /dev/null +++ b/tests/data/one_address_one_mt.proto.xml @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/only_one_address.proto.xml b/tests/data/only_one_address.proto.xml new file mode 100644 index 0000000000..1e386efc87 --- /dev/null +++ b/tests/data/only_one_address.proto.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/rwe.proto.xml b/tests/data/rwe.proto.xml new file mode 100644 index 0000000000..bb3504dd9b --- /dev/null +++ b/tests/data/rwe.proto.xml @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/three_syncs.proto.xml b/tests/data/three_syncs.proto.xml new file mode 100644 index 0000000000..255939e68a --- /dev/null +++ b/tests/data/three_syncs.proto.xml @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/with_checksum.proto.xml b/tests/data/with_checksum.proto.xml new file mode 100644 index 0000000000..d99bdd69a7 --- /dev/null +++ b/tests/data/with_checksum.proto.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/without_ack_random_data.proto.xml b/tests/data/without_ack_random_data.proto.xml new file mode 100644 index 0000000000..5af005724c --- /dev/null +++ b/tests/data/without_ack_random_data.proto.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/without_ack_random_data2.proto.xml b/tests/data/without_ack_random_data2.proto.xml new file mode 100644 index 0000000000..bada9632df --- /dev/null +++ b/tests/data/without_ack_random_data2.proto.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_CRC.py b/tests/test_CRC.py index 26b9e8ca92..7d53723f54 100644 --- a/tests/test_CRC.py +++ b/tests/test_CRC.py @@ -1,10 +1,12 @@ -import unittest, time +import time +import unittest from urh.signalprocessing.Encoding import Encoding from urh.util import util from urh.util.GenericCRC import GenericCRC from urh.util.WSPChecksum import WSPChecksum + class TestCRC(unittest.TestCase): def test_crc(self): # http://depa.usst.edu.cn/chenjq/www2/software/crc/CRC_Javascript/CRCcalculation.htm @@ -32,6 +34,16 @@ def test_crc(self): self.assertEqual(util.bit2hex(c.crc(e.str2bit(value[4:-8]))), expect) + def test_crc8(self): + messages = ["aabbcc", "abcdee", "dacafe"] + + expected = ["7d", "24", "33"] + crc = GenericCRC(polynomial=GenericCRC.DEFAULT_POLYNOMIALS["8_ccitt"]) + + for msg, expect in zip(messages, expected): + bits = util.hex2bit(msg) + self.assertEqual(util.bit2hex(crc.crc(bits)), expect) + def test_different_crcs(self): c = GenericCRC(polynomial="16_standard", start_value=False, final_xor=False, reverse_polynomial=False, reverse_all=False, lsb_first=False, little_endian=False) @@ -92,6 +104,86 @@ def test_different_crcs(self): self.assertEqual(crc_new, crc_old) c.reverse_all = False + def test_cache(self): + c = GenericCRC(polynomial="16_standard", start_value=False, final_xor=False, + reverse_polynomial=False, reverse_all=False, lsb_first=False, little_endian=False) + c.calculate_cache(8) + self.assertEqual(len(c.cache), 256) + + def test_different_crcs_fast(self): + c = GenericCRC(polynomial="16_standard", start_value=False, final_xor=False, + reverse_polynomial=False, reverse_all=False, lsb_first=False, little_endian=False) + bitstring_set = [ + "10101010", + "00000001", + "000000010", + "000000011", + "0000000100000001", + "101001001010101010101011101111111000000000000111101010011101011", + "101001001010101101111010110111101010010110111010", + "00000000000000000000000000000000100000000000000000000000000000000001111111111111", + "1111111111111111111111111111111110111111111111111111110111111111111111110000000000" + "1"] + + for j in c.DEFAULT_POLYNOMIALS: + c.polynomial = c.choose_polynomial(j) + for i in bitstring_set: + for cache in [8, 4, 7, 12, 16]: + c.calculate_cache(cache) + # Standard + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + + # Special final xor + c.final_xor = c.str2bit("0000111100001111") + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + c.final_xor = [False] * 16 + + # Special start value + c.start_value = c.str2bit("1010101010101010") + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + c.start_value = [False] * 16 + + # little_endian + c.little_endian = True + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + c.little_endian = False + + # reverse all + c.reverse_all = True + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + c.reverse_all = False + + # reverse_polynomial + # We need to clear the cache before and after + c.cache = [] + # + c.reverse_polynomial = True + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + c.reverse_polynomial = False + # + c.cache = [] + + # TODO: Does only work for cachesize = 8 + # lsb_first + c.calculate_cache(8) + c.lsb_first = True + crc_new = c.cached_crc(c.str2bit(i)) + crc_old = c.reference_crc(c.str2bit(i)) + self.assertEqual(crc_old, crc_new) + c.lsb_first = False + def test_reverse_engineering(self): c = GenericCRC(polynomial="16_standard", start_value=False, final_xor=False, reverse_polynomial=False, reverse_all=False, lsb_first=False, little_endian=False) @@ -126,17 +218,17 @@ def test_not_aligned_data_len(self): self.assertEqual(val, crcs[j]) inpt = "0" + inpt - def test_guess_standard_parameters_and_datarange(self): + def test_bruteforce_parameters_and_data_range(self): c = GenericCRC(polynomial="16_ccitt", start_value=False, final_xor=False, reverse_polynomial=False, reverse_all=False, lsb_first=False, little_endian=False) inpt = "101010101010101010000000111000000000000011100000001011010010110100000000111000000101001010000100000000000100111001111110010000000011011111111001001101100001100010100000000000111011110100010" vrfy_crc = "0011101111010001" - result = c.guess_standard_parameters_and_datarange(c.str2arr(inpt), c.str2arr(vrfy_crc)) + result = c.bruteforce_parameters_and_data_range(c.str2arr(inpt), len(inpt)-len(vrfy_crc)-1) self.assertEqual(result, (2, 84, 172)) self.assertEqual(vrfy_crc, c.bit2str(c.crc(c.str2arr(inpt[result[1]:result[2]])))) - def test_guess_standard_parameters_and_datarange_improved(self): + def test_bruteforce_parameters_and_data_range_improved(self): c = GenericCRC(polynomial="16_ccitt", start_value=False, final_xor=False, reverse_polynomial=False, reverse_all=False, lsb_first=False, little_endian=False) inpt = "101010101010101010000000111000000000000011100000001011010010110100000000111000000101001010000100000000000100111001111110010000000011011111111001001101100001100010100000000000111011110100010" @@ -146,15 +238,15 @@ def test_guess_standard_parameters_and_datarange_improved(self): runs = 100 for i in range(0, runs): t = time.time() - result = c.guess_standard_parameters_and_datarange(c.str2arr(inpt), c.str2arr(vrfy_crc)) + result = c.bruteforce_parameters_and_data_range(c.str2arr(inpt), len(inpt)-len(vrfy_crc)-1) t1 += time.time() - t - #print(result, c.bit2str(c.crc(c.str2arr(inpt[result[1]:result[2]])))) - self.assertEqual(result[0], 2) # Parameters = 2 - self.assertEqual(result[1], len(inpt) - 1 - 16 - 88) # start of datarange - self.assertEqual(result[2], len(inpt) - 1 - 16) # end of datarange - inpt = "0"+inpt if i%2 == 0 else "1"+inpt - #print("Performance:", t1/runs) - self.assertLess(t1/runs, 0.1) # Should be faster than 100ms in average + # print(result, c.bit2str(c.crc(c.str2arr(inpt[result[1]:result[2]])))) + self.assertEqual(result[0], 2) # Parameters = 2 + self.assertEqual(result[1], len(inpt) - 1 - 16 - 88) # start of datarange + self.assertEqual(result[2], len(inpt) - 1 - 16) # end of datarange + inpt = "0" + inpt if i % 2 == 0 else "1" + inpt + # print("Performance:", t1/runs) + self.assertLess(t1 / runs, 0.1) # Should be faster than 100ms in average def test_adaptive_crc_calculation(self): c = GenericCRC(polynomial="16_ccitt", start_value=False, final_xor=False, @@ -171,4 +263,4 @@ def test_adaptive_crc_calculation(self): c.start_value = crc1 crcx = c.crc(c.str2arr(delta)) - self.assertEqual(crcx, crc2) \ No newline at end of file + self.assertEqual(crcx, crc2) diff --git a/tests/test_auto_assignments.py b/tests/test_auto_assignments.py index 26eeccd385..1be3c1b51d 100644 --- a/tests/test_auto_assignments.py +++ b/tests/test_auto_assignments.py @@ -3,6 +3,7 @@ from tests.utils_testing import get_path_for_data_file from urh import constants +from urh.awre import AutoAssigner from urh.signalprocessing.Encoding import Encoding from urh.signalprocessing.Message import Message from urh.signalprocessing.MessageType import MessageType @@ -97,24 +98,10 @@ def test_two_assign_participants_by_rssi(self): alice, alice, bob, bob, alice, alice, bob, bob, alice, alice, bob, bob, alice, bob]] - proto1.auto_assign_participants([alice, bob]) + AutoAssigner.auto_assign_participants(proto1.messages, [alice, bob]) for i, message in enumerate(proto1.messages): self.assertEqual(message.participant, excpected_partis[0][i]) - proto2.auto_assign_participants([alice, bob]) + AutoAssigner.auto_assign_participants(proto2.messages, [alice, bob]) for i, message in enumerate(proto2.messages): self.assertEqual(message.participant, excpected_partis[1][i]) - - def test_assign_decodings(self): - self.undecoded_protocol = ProtocolAnalyzer(None) - with open(get_path_for_data_file("undecoded.txt")) as f: - for line in f: - self.undecoded_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", ""))) - - self.undecoded_protocol.auto_assign_decodings(self.decodings) - - for i, message in enumerate(self.undecoded_protocol.messages): - if message.plain_hex_str[8:16] == "9a7d9a7d": - self.assertEqual(message.decoder.name, "DeWhitening Special", msg=str(i)) - elif message.plain_hex_str[8:16] == "67686768": - self.assertEqual(message.decoder.name, "DeWhitening", msg=str(i)) diff --git a/tests/test_awre.py b/tests/test_awre.py deleted file mode 100644 index 7a87999362..0000000000 --- a/tests/test_awre.py +++ /dev/null @@ -1,226 +0,0 @@ -import unittest - -from tests.utils_testing import get_path_for_data_file -from urh.awre.CommonRange import CommonRange -from urh.awre.FormatFinder import FormatFinder -from urh.awre.components.Address import Address -from urh.awre.components.Component import Component -from urh.awre.components.Flags import Flags -from urh.awre.components.Length import Length -from urh.awre.components.Preamble import Preamble -from urh.awre.components.SequenceNumber import SequenceNumber -from urh.awre.components.Type import Type -from urh.signalprocessing.FieldType import FieldType -from urh.signalprocessing.Message import Message -from urh.signalprocessing.Participant import Participant -from urh.signalprocessing.ProtocoLabel import ProtocolLabel -from urh.signalprocessing.ProtocolAnalyzer import ProtocolAnalyzer - - -class TestAWRE(unittest.TestCase): - def setUp(self): - self.field_types = FieldType.default_field_types() - - self.preamble_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.PREAMBLE) - self.sync_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SYNC) - self.length_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.LENGTH) - self.sequence_number_field_type = self.__field_type_with_function(self.field_types, - FieldType.Function.SEQUENCE_NUMBER) - self.dst_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.DST_ADDRESS) - self.src_address_field_type = self.__field_type_with_function(self.field_types, FieldType.Function.SRC_ADDRESS) - - self.protocol = ProtocolAnalyzer(None) - with open(get_path_for_data_file("awre_consistent_addresses.txt")) as f: - for line in f: - self.protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", ""))) - self.protocol.messages[-1].message_type = self.protocol.default_message_type - - # Assign participants - alice = Participant("Alice", "A") - bob = Participant("Bob", "B") - alice_indices = {1, 2, 5, 6, 9, 10, 13, 14, 17, 18, 20, 22, 23, 26, 27, 30, 31, 34, 35, 38, 39, 41} - for i, message in enumerate(self.protocol.messages): - message.participant = alice if i in alice_indices else bob - - self.participants = [alice, bob] - - self.zero_crc_protocol = ProtocolAnalyzer(None) - with open(get_path_for_data_file("awre_zeroed_crc.txt")) as f: - for line in f: - self.zero_crc_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", ""))) - self.zero_crc_protocol.messages[-1].message_type = self.protocol.default_message_type - - for i, message in enumerate(self.zero_crc_protocol.messages): - message.participant = alice if i in alice_indices else bob - - @staticmethod - def __field_type_with_function(field_types, function) -> FieldType: - return next(ft for ft in field_types if ft.function == function) - - def test_build_component_order(self): - expected_default = [Preamble(fieldtypes=[]), Length(fieldtypes=[], length_cluster=None), - Address(fieldtypes=[], xor_matrix=None), SequenceNumber(fieldtypes=[]), Type(), Flags()] - - format_finder = FormatFinder(self.protocol) - - for expected, actual in zip(expected_default, format_finder.build_component_order()): - assert type(expected) == type(actual) - - expected_swapped = [Preamble(fieldtypes=[]), Address(fieldtypes=[], xor_matrix=None), - Length(fieldtypes=[], length_cluster=None), SequenceNumber(fieldtypes=[]), Type(), Flags()] - format_finder.length_component.priority = 2 - format_finder.address_component.priority = 1 - - for expected, actual in zip(expected_swapped, format_finder.build_component_order()): - assert type(expected) == type(actual) - - # Test duplicate Priority - format_finder.sequence_number_component.priority = 4 - with self.assertRaises(ValueError) as context: - format_finder.build_component_order() - self.assertTrue('Duplicate priority' in context.exception) - format_finder.sequence_number_component.priority = 3 - self.assertTrue(format_finder.build_component_order()) - - def test_format_finding_rwe(self): - preamble_start, preamble_end = 0, 31 - sync_start, sync_end = 32, 63 - length_start, length_end = 64, 71 - ack_address_start, ack_address_end = 72, 95 - dst_address_start, dst_address_end = 88, 111 - src_address_start, src_address_end = 112, 135 - - preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, field_type=self.preamble_field_type, - start=preamble_start, end=preamble_end, color_index=0) - sync_label = ProtocolLabel(name=self.sync_field_type.caption, field_type=self.sync_field_type, - start=sync_start, end=sync_end, color_index=1) - length_label = ProtocolLabel(name=self.length_field_type.caption, field_type=self.length_field_type, - start=length_start, end=length_end, color_index=2) - ack_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, - field_type=self.dst_address_field_type, - start=ack_address_start, end=ack_address_end, color_index=3) - dst_address_label = ProtocolLabel(name=self.dst_address_field_type.caption, - field_type=self.dst_address_field_type, - start=dst_address_start, end=dst_address_end, color_index=4) - src_address_label = ProtocolLabel(name=self.src_address_field_type.caption, - field_type=self.src_address_field_type, - start=src_address_start, end=src_address_end, color_index=5) - - ff = FormatFinder(protocol=self.protocol, participants=self.participants, field_types=self.field_types) - ff.perform_iteration() - - self.assertIn(preamble_label, self.protocol.default_message_type) - self.assertIn(sync_label, self.protocol.default_message_type) - self.assertIn(length_label, self.protocol.default_message_type) - self.assertIn(dst_address_label, self.protocol.default_message_type) - self.assertIn(src_address_label, self.protocol.default_message_type) - - self.assertEqual(len(self.protocol.message_types), 2) - self.assertEqual(self.protocol.message_types[1].name, "ack") - self.assertIn(ack_address_label, self.protocol.message_types[1]) - - ack_messages = (1, 3, 5, 7, 9, 11, 13, 15, 17, 20) - for i, msg in enumerate(self.protocol.messages): - if i in ack_messages: - self.assertEqual(msg.message_type.name, "ack", msg=i) - else: - self.assertEqual(msg.message_type.name.lower(), "default", msg=i) - - def test_format_finding_rwe_zeroed_crc(self): - ff = FormatFinder(self.zero_crc_protocol, self.participants) - ff.perform_iteration() - - def test_format_finding_enocean(self): - enocean_protocol = ProtocolAnalyzer(None) - with open(get_path_for_data_file("enocean_bits.txt")) as f: - for line in f: - enocean_protocol.messages.append(Message.from_plain_bits_str(line.replace("\n", ""))) - enocean_protocol.messages[-1].message_type = enocean_protocol.default_message_type - - preamble_start = 3 - preamble_end = 10 - sof_start = 11 - sof_end = 14 - - preamble_label = ProtocolLabel(name=self.preamble_field_type.caption, field_type=self.preamble_field_type, - start=preamble_start, end=preamble_end, color_index=0) - sync_label = ProtocolLabel(name=self.sync_field_type.caption, field_type=self.sync_field_type, - start=sof_start, end=sof_end, color_index=1) - - ff = FormatFinder(enocean_protocol, self.participants, field_types=self.field_types) - ff.perform_iteration() - - self.assertEqual(len(enocean_protocol.message_types), 1) - - self.assertIn(preamble_label, enocean_protocol.default_message_type) - self.assertIn(sync_label, enocean_protocol.default_message_type) - self.assertTrue( - not any(lbl.name == self.length_field_type.caption for lbl in enocean_protocol.default_message_type)) - self.assertTrue(not any("address" in lbl.name.lower() for lbl in enocean_protocol.default_message_type)) - - def test_address_candidate_finding(self): - fh = CommonRange.from_hex - - candidates_participant_1 = [fh('1b6033'), fh('1b6033fd57'), fh('701b603378e289'), fh('20701b603378e289000c62')] - candidates_participant_2 = [fh('1b603300'), fh('78e289757e'), fh('7078e2891b6033000000'), - fh('207078e2891b6033000000')] - - expected_address1 = '1b6033' - expected_address2 = '78e289' - - # print(Address.find_candidates(candidates_participant_1)) - # print(Address.find_candidates(candidates_participant_2)) - combined = candidates_participant_1 + candidates_participant_2 - combined.sort(key=len) - score = Address.find_candidates(combined) - # print(score) - # print("=================") - # print(sorted(score, key=lambda k: score[k], reverse=True)) - # print() - - highscored = sorted(score, key=lambda k: score[k], reverse=True)[:2] - self.assertIn(expected_address1, highscored) - self.assertIn(expected_address2, highscored) - - def test_message_type_assign(self): - clusters = {"ack": {1, 17, 3, 20, 5, 7, 9, 11, 13, 15}, "Default": {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19}} - com = Component(messagetypes=self.protocol.message_types) - com.assign_messagetypes(self.protocol.messages, clusters) - - for clustername, msg_indices in clusters.items(): - for msg in msg_indices: - self.assertEqual(self.protocol.messages[msg].message_type.name, clustername, msg=str(msg)) - - # do it again and ensure nothing changes - com.assign_messagetypes(self.protocol.messages, clusters) - for clustername, msg_indices in clusters.items(): - for msg in msg_indices: - self.assertEqual(self.protocol.messages[msg].message_type.name, clustername, msg=str(msg)) - - def test_choose_candidate(self): - - candidates1 = {'78e289': 8, '207078e2891b6033000000': 1, '57': 1, '20701b603378e289000c62': 1, '1b6033fd57': 1, - '1b603300': 3, '7078e2891b6033000000': 2, '78e289757e': 1, '1b6033': 14, '701b603378e289': 2} - candidates2 = {'1b603300': 4, '701b603378e289': 2, '20701b603378e289000c62': 1, '000': 3, '0000': 19, - '1b6033': 11, '78e2890000': 1, '00': 4, '7078e2891b6033000000': 2, '207078e2891b6033000000': 1, - '78e289000': 1, '78e289': 7, '0': 7, '1b60330000': 3} - - self.assertEqual(next(Address.choose_candidate_pair(candidates1)), ("1b6033", "78e289")) - self.assertEqual(next(Address.choose_candidate_pair(candidates2)), ("1b6033", "78e289")) - - def test_format_finding_without_participants(self): - for msg in self.zero_crc_protocol.messages: - msg.participant = None - - ff = FormatFinder(self.zero_crc_protocol, []) - ff.perform_iteration() - - def test_assign_participant_addresses(self): - clusters = {"ack": {1, 17, 3, 20, 5, 7, 9, 11, 13, 15}, "default": {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19}} - com = Component(messagetypes=self.protocol.message_types) - com.assign_messagetypes(self.protocol.messages, clusters) - - Address.assign_participant_addresses(self.protocol.messages, self.participants, ("78e289", "1b6033")) - - self.assertEqual(self.participants[0].address_hex, "78e289") - self.assertEqual(self.participants[1].address_hex, "1b6033") diff --git a/tests/test_simulator.py b/tests/test_simulator.py index cc286ce5bf..ff7756a05a 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -261,7 +261,7 @@ def test_external_program_simulator(self): self.assertTrue(os.path.isfile(file_name)) def __demodulate(self, connection: socket.socket): - connection.settimeout(0.1) + connection.settimeout(0.5) time.sleep(self.TIMEOUT) total_data = [] diff --git a/tests/test_simulator_dialog.py b/tests/test_simulator_dialog.py index 97f7b168e5..2d8591aaf4 100644 --- a/tests/test_simulator_dialog.py +++ b/tests/test_simulator_dialog.py @@ -26,8 +26,6 @@ def setUp(self): simulator_manager.add_items([msg1, msg2], 0, simulator_manager.rootItem) simulator_manager.add_label(5, 15, "test", parent_item=simulator_manager.rootItem.children[0]) - print(self.form.simulator_tab_controller.simulator_config.tx_needed) - self.dialog = SimulatorDialog(self.form.simulator_tab_controller.simulator_config, self.form.generator_tab_controller.modulators, self.form.simulator_tab_controller.sim_expression_parser,