Skip to content

Commit

Permalink
Fix proto2ros.dependencies.fix_dependency_cycles (#107)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
  • Loading branch information
mhidalgo-bdai committed Jun 13, 2024
1 parent dbb1bab commit 41b3177
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 40 deletions.
56 changes: 38 additions & 18 deletions proto2ros/proto2ros/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

"""This module provides APIs to manipulate dependencies between Protobuf <-> ROS message equivalences."""

import collections
import itertools
import warnings
from typing import List

from rosidl_adapter.parser import MessageSpecification

from proto2ros.compatibility import networkx as nx
from proto2ros.utilities import pairwise, to_ros_base_type
from proto2ros.utilities import to_ros_base_type


def message_dependency_graph(message_specs: List[MessageSpecification]) -> nx.DiGraph:
Expand All @@ -32,28 +34,46 @@ def message_dependency_graph(message_specs: List[MessageSpecification]) -> nx.Di
def fix_dependency_cycles(message_specs: List[MessageSpecification], quiet: bool = True) -> None:
"""Fixes dependency cycles among ROS message specifications.
ROS messages do not support recursive definitions, this functions works around this
limitation by type erasing the thinnest link (least number of offending fields) for
each cycle.
ROS messages do not support recursive definitions, so this function works around this
limitation by type erasing the least amount of offending fields.
"""
dependency_graph = message_dependency_graph(message_specs)

cycles = []
for cycle in nx.simple_cycles(dependency_graph):
cycle = [*cycle, cycle[0]] # close the loop
if not quiet:
message_types = [dependency_graph.nodes[node]["message"].base_type for node in cycle]
dependency_cycle_depiction = " -> ".join(str(type_) for type_ in message_types)
dependency_cycle_depiction += " -> " + str(message_types[0]) # close the loop
warnings.warn("Dependency cycle found: " + dependency_cycle_depiction, stacklevel=1)
cycles.append(cycle)

explicit_edges = []
for parent, child in pairwise(cycle):
message = dependency_graph.nodes[child]["message"]
if message.annotations["proto-class"] == "message":
explicit_edges.append((parent, child))

parent, child = min(explicit_edges, key=lambda edge: dependency_graph.number_of_edges(*edge))
for data in dependency_graph[parent][child].values():
field = data["field"]
if not quiet:
message_type = dependency_graph.nodes[parent]["message"].base_type
warnings.warn(f"Type erasing {field.name} member in {message_type} to break recursion", stacklevel=1)
field.annotations["type-erased"] = True
counter = collections.Counter(
sorted(
sorted(itertools.chain(*cycles)), # ensures an stable order (maintained by sorted)
key=dependency_graph.in_degree, # implicitly breaks ties by prioritizing the least common messages
),
)
while counter.total() > 0:
for node, _ in counter.most_common(): # greedily break cycles
message = dependency_graph.nodes[node]["message"]
if message.annotations["proto-class"] != "message":
continue
break
else:
raise RuntimeError("no candidate for type erasure found")
for cycle in list(cycles):
if node not in cycle:
continue
parent = cycle[cycle.index(node) - 1]
for data in dependency_graph[parent][node].values():
field = data["field"]
if not quiet:
message_type = dependency_graph.nodes[parent]["message"].base_type
warnings.warn(
f"Type erasing {field.type} {field.name} member in {message_type} to break recursion",
stacklevel=1,
)
field.annotations["type-erased"] = True
counter.subtract(cycle)
cycles.remove(cycle)
19 changes: 17 additions & 2 deletions proto2ros_tests/proto/test.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,23 @@ message Fragment {

// Heterogeneous pair.
message Pair {
// Sequence first values.
// First value.
Value first = 1;
// Sequence second values.
// Second value.
Value second = 2;
}

// Heterogeneous list.
message List {
// Listed values.
repeated Value values = 1;
}

// Heterogeneous dict.
message Dict {
map<string, Value> items = 1;
}

// Heterogeneous value.
message Value {
oneof data {
Expand All @@ -62,6 +73,10 @@ message Value {
string text = 2;
// Pair value.
Pair pair = 3;
// List value.
List list = 4;
// Dict value.
Dict dict = 5;
}
}

Expand Down
3 changes: 3 additions & 0 deletions proto2ros_tests/test/generated/Dict.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Heterogeneous dict.

proto2ros_tests/DictItemsEntry[] items
3 changes: 3 additions & 0 deletions proto2ros_tests/test/generated/DictItemsEntry.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

string key
proto2ros/Any value # is proto2ros_tests/Value (type-erased)
7 changes: 7 additions & 0 deletions proto2ros_tests/test/generated/HVACControlRequest.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# A control request for an HVAC system.

uint8 TEMPERATURE_SETPOINT_FIELD_SET=2

float64 air_flow_rate
sensor_msgs/Temperature temperature_setpoint
uint8 has_field 255
4 changes: 4 additions & 0 deletions proto2ros_tests/test/generated/List.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Heterogeneous list.

# Listed values.
proto2ros/Any[] values # is proto2ros_tests/Value[] (type-erased)
8 changes: 4 additions & 4 deletions proto2ros_tests/test/generated/Pair.msg
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
uint8 FIRST_FIELD_SET=1
uint8 SECOND_FIELD_SET=2

# Sequence first values.
proto2ros_tests/Value first
# Sequence second values.
proto2ros_tests/Value second
# First value.
proto2ros/Any first # is proto2ros_tests/Value (type-erased)
# Second value.
proto2ros/Any second # is proto2ros_tests/Value (type-erased)
uint8 has_field 255
8 changes: 7 additions & 1 deletion proto2ros_tests/test/generated/ValueOneOfData.msg
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@ int8 DATA_NOT_SET=0
int8 DATA_NUMBER_SET=1
int8 DATA_TEXT_SET=2
int8 DATA_PAIR_SET=3
int8 DATA_LIST_SET=4
int8 DATA_DICT_SET=5

# Numeric value.
float32 number
# Text value.
string text
# Pair value.
proto2ros/Any pair # is proto2ros_tests/Pair (type-erased)
proto2ros_tests/Pair pair
# List value.
proto2ros_tests/List list
# Dict value.
proto2ros_tests/Dict dict
int8 data_choice # deprecated
int8 which
33 changes: 18 additions & 15 deletions proto2ros_tests/test/test_proto2ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,24 @@ def test_recursive_messages() -> None:


def test_circularly_dependent_messages() -> None:
proto_pair = test_pb2.Pair()
proto_pair.first.text = "interval"
proto_pair.second.pair.first.number = -0.5
proto_pair.second.pair.second.number = 0.5
ros_pair = proto2ros_tests.msg.Pair()
convert(proto_pair, ros_pair)
assert ros_pair.first.data.which == ros_pair.first.data.DATA_TEXT_SET
assert ros_pair.first.data.text == proto_pair.first.text
assert ros_pair.second.data.which == ros_pair.second.data.DATA_PAIR_SET
assert ros_pair.second.data.pair.type_name == "proto2ros_tests/Pair"
other_proto_pair = test_pb2.Pair()
convert(ros_pair, other_proto_pair)
assert other_proto_pair.first.text == proto_pair.first.text
assert other_proto_pair.second.pair.first.number == proto_pair.second.pair.first.number
assert other_proto_pair.second.pair.second.number == proto_pair.second.pair.second.number
proto_value = test_pb2.Value()
proto_pair_value = proto_value.dict.items["interval"]
proto_pair_value.pair.first.number = -0.5
proto_pair_value.pair.second.number = 0.5
proto_list_value = proto_value.dict.items["range"]
for number in (-0.1, 0.0, 0.1, -0.7, 0.3, 0.4):
value = proto_list_value.list.values.add()
value.number = number
ros_value = proto2ros_tests.msg.Value()
convert(proto_value, ros_value)
other_proto_value = test_pb2.Value()
convert(ros_value, other_proto_value)
assert "interval" in other_proto_value.dict.items
other_proto_pair_value = proto_value.dict.items["interval"]
assert other_proto_pair_value.pair.first.number == proto_pair_value.pair.first.number
assert other_proto_pair_value.pair.second.number == proto_pair_value.pair.second.number
other_proto_list_value = proto_value.dict.items["range"]
assert [v.number for v in other_proto_list_value.list.values] == [v.number for v in proto_list_value.list.values]


def test_messages_with_enums() -> None:
Expand Down

0 comments on commit 41b3177

Please sign in to comment.