diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 433599f99f..281e99c0d7 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -15,6 +15,7 @@ # limitations under the License. import math +import threading import time import pytest @@ -371,6 +372,99 @@ def test_time_tolerance(self) -> None: result = ttbuffer.get("world", "robot", time_point=base_time + 0.5, time_tolerance=0.1) assert result is None + def test_forward_tolerance_returns_when_buffer_fills(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + def publish_after_delay() -> None: + time.sleep(0.05) + ttbuffer.receive_transform( + Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ) + + publisher = threading.Thread(target=publish_after_delay) + publisher.start() + + t0 = time.monotonic() + result = ttbuffer.get( + "world", "robot", time_point=base_time, time_tolerance=0.1, forward_tolerance=1.0 + ) + elapsed = time.monotonic() - t0 + publisher.join() + + assert result is not None + assert result.translation.x == 1.0 + assert elapsed < 0.5 + + def test_forward_tolerance_times_out(self) -> None: + ttbuffer = MultiTBuffer() + t0 = time.monotonic() + result = ttbuffer.get("world", "robot", time_point=time.time(), forward_tolerance=0.1) + elapsed = time.monotonic() - t0 + assert result is None + assert 0.08 < elapsed < 1.0 + + def test_forward_tolerance_fast_path_when_already_available(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + ttbuffer.receive_transform( + Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ) + + t0 = time.monotonic() + result = ttbuffer.get("world", "robot", time_point=base_time, forward_tolerance=10.0) + elapsed = time.monotonic() - t0 + + assert result is not None + assert result.translation.x == 2.0 + assert elapsed < 0.05 + + def test_forward_tolerance_wakes_on_chain_completion(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + ttbuffer.receive_transform( + Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ) + + def publish_after_delay() -> None: + time.sleep(0.05) + ttbuffer.receive_transform( + Transform( + translation=Vector3(0.0, 2.0, 0.0), + frame_id="robot", + child_frame_id="sensor", + ts=base_time, + ) + ) + + publisher = threading.Thread(target=publish_after_delay) + publisher.start() + + result = ttbuffer.get( + "world", "sensor", time_point=base_time, time_tolerance=0.1, forward_tolerance=1.0 + ) + publisher.join() + + assert result is not None + assert result.translation.x == 1.0 + assert result.translation.y == 2.0 + def test_same_frame_returns_identity(self) -> None: ttbuffer = MultiTBuffer() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index a182464ca7..05c600d733 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -18,6 +18,7 @@ from collections import deque from dataclasses import field from functools import reduce +import threading import time from dimos.memory.timeseries.inmemory import InMemoryStore @@ -59,6 +60,8 @@ def get( child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, ) -> Transform | None: ... def receive_transform(self, *args: Transform) -> None: ... @@ -112,29 +115,34 @@ class MultiTBuffer: def __init__(self, buffer_size: float = 10.0) -> None: self.buffers: dict[tuple[str, str], TBuffer] = {} self.buffer_size = buffer_size + self._cv = threading.Condition() def receive_transform(self, *args: Transform) -> None: - for transform in args: - key = (transform.frame_id, transform.child_frame_id) - if key not in self.buffers: - self.buffers[key] = TBuffer(self.buffer_size) - self.buffers[key].add(transform) + with self._cv: + for transform in args: + key = (transform.frame_id, transform.child_frame_id) + if key not in self.buffers: + self.buffers[key] = TBuffer(self.buffer_size) + self.buffers[key].add(transform) + self._cv.notify_all() def get_frames(self) -> set[str]: frames = set() - for parent, child in self.buffers: - frames.add(parent) - frames.add(child) + with self._cv: + for parent, child in self.buffers: + frames.add(parent) + frames.add(child) return frames def get_connections(self, frame_id: str) -> set[str]: """Get all frames connected to the given frame (both as parent and child).""" connections = set() - for parent, child in self.buffers: - if parent == frame_id: - connections.add(child) - if child == frame_id: - connections.add(parent) + with self._cv: + for parent, child in self.buffers: + if parent == frame_id: + connections.add(child) + if child == frame_id: + connections.add(parent) return connections def get_transform( @@ -151,40 +159,80 @@ def get_transform( ts=time_point if time_point is not None else time.time(), ) - # Check forward direction - key = (parent_frame, child_frame) - if key in self.buffers: - return self.buffers[key].get(time_point, time_tolerance) # type: ignore[arg-type] + with self._cv: + # Check forward direction + key = (parent_frame, child_frame) + if key in self.buffers: + return self.buffers[key].get(time_point, time_tolerance) # type: ignore[arg-type] - # Check reverse direction and return inverse - reverse_key = (child_frame, parent_frame) - if reverse_key in self.buffers: - transform = self.buffers[reverse_key].get(time_point, time_tolerance) # type: ignore[arg-type] - return transform.inverse() if transform else None + # Check reverse direction and return inverse + reverse_key = (child_frame, parent_frame) + if reverse_key in self.buffers: + transform = self.buffers[reverse_key].get(time_point, time_tolerance) # type: ignore[arg-type] + return transform.inverse() if transform else None - return None + return None - def get( + def _get( self, parent_frame: str, child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, ) -> Transform | None: - simple = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + with self._cv: + simple = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + + if simple is not None: + return simple - if simple is not None: - return simple + complex = self.get_transform_search( + parent_frame, child_frame, time_point, time_tolerance + ) + + if complex is None: + return None - complex = self.get_transform_search(parent_frame, child_frame, time_point, time_tolerance) + return reduce(lambda t1, t2: t1 + t2, complex) - if complex is None: + def _wait_get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None, + time_tolerance: float | None, + forward_tolerance: float, + ) -> Transform | None: + deadline = time.monotonic() + forward_tolerance + with self._cv: + while True: + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is not None: + return result + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + self._cv.wait(timeout=remaining) + + def get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, + ) -> Transform | None: + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is None and forward_tolerance > 0: + result = self._wait_get( + parent_frame, child_frame, time_point, time_tolerance, forward_tolerance + ) + if result is None: logger.warning( - f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}', {self}" + f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}'" ) - return None - - return reduce(lambda t1, t2: t1 + t2, complex) + return result def get_transform_search( self, @@ -194,36 +242,37 @@ def get_transform_search( time_tolerance: float | None = None, ) -> list[Transform] | None: """Search for shortest transform chain between parent and child frames using BFS.""" - # Check if direct transform exists (already checked in get_transform, but for clarity) - direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) - if direct is not None: - return [direct] + with self._cv: + # Check if direct transform exists (already checked in get_transform, but for clarity) + direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + if direct is not None: + return [direct] - # BFS to find shortest path - queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) - visited = {parent_frame} + # BFS to find shortest path + queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) + visited = {parent_frame} - while queue: - current_frame, path = queue.popleft() + while queue: + current_frame, path = queue.popleft() - if current_frame == child_frame: - return path + if current_frame == child_frame: + return path - # Get all connections for current frame - connections = self.get_connections(current_frame) + # Get all connections for current frame + connections = self.get_connections(current_frame) - for next_frame in connections: - if next_frame not in visited: - visited.add(next_frame) + for next_frame in connections: + if next_frame not in visited: + visited.add(next_frame) - # Get the transform between current and next frame - transform = self.get_transform( - current_frame, next_frame, time_point, time_tolerance - ) - if transform: - queue.append((next_frame, [*path, transform])) + # Get the transform between current and next frame + transform = self.get_transform( + current_frame, next_frame, time_point, time_tolerance + ) + if transform: + queue.append((next_frame, [*path, transform])) - return None + return None def graph(self) -> str: import subprocess @@ -232,7 +281,9 @@ def connection_str(connection: tuple[str, str]) -> str: (frame_from, frame_to) = connection return f"{frame_from} -> {frame_to}" - graph_str = "\n".join(map(connection_str, self.buffers.keys())) + with self._cv: + keys = list(self.buffers.keys()) + graph_str = "\n".join(map(connection_str, keys)) try: result = subprocess.run( @@ -246,11 +297,14 @@ def connection_str(connection: tuple[str, str]) -> str: return "no diagon installed" def __str__(self) -> str: - if not self.buffers: + with self._cv: + buffers = list(self.buffers.values()) + + if not buffers: return f"{self.__class__.__name__}(empty)" - lines = [f"{self.__class__.__name__}({len(self.buffers)} buffers):"] - for buffer in self.buffers.values(): + lines = [f"{self.__class__.__name__}({len(buffers)} buffers):"] + for buffer in buffers: lines.append(f" {buffer}") return "\n".join(lines) @@ -307,11 +361,12 @@ def publish_static(self, *args: Transform) -> None: def publish_all(self) -> None: """Publish all transforms currently stored in all buffers.""" all_transforms = [] - for buffer in self.buffers.values(): - # Get the latest transform from each buffer - latest = buffer.get() # get() with no args returns latest - if latest: - all_transforms.append(latest) + with self._cv: + for buffer in self.buffers.values(): + # Get the latest transform from each buffer + latest = buffer.get() # get() with no args returns latest + if latest: + all_transforms.append(latest) if all_transforms: self.publish(*all_transforms) @@ -322,8 +377,16 @@ def get( child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, ) -> Transform | None: - return super().get(parent_frame, child_frame, time_point, time_tolerance) + return super().get( + parent_frame, + child_frame, + time_point, + time_tolerance, + forward_tolerance=forward_tolerance, + ) def get_pose( self, @@ -331,8 +394,16 @@ def get_pose( child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, ) -> PoseStamped | None: - tf = self.get(parent_frame, child_frame, time_point, time_tolerance) + tf = self.get( + parent_frame, + child_frame, + time_point, + time_tolerance, + forward_tolerance=forward_tolerance, + ) if not tf: return None return tf.to_pose()