From 6c3b398a204cccb857c5f44b16415f4d5c4a58e2 Mon Sep 17 00:00:00 2001 From: arkluc Date: Sat, 23 May 2026 19:08:12 +0800 Subject: [PATCH 1/7] tf get takes forward_tolerance --- dimos/protocol/tf/test_tf.py | 98 ++++++++++++++++++++++++++++++++++++ dimos/protocol/tf/tf.py | 59 ++++++++++++++++++---- 2 files changed, 146 insertions(+), 11 deletions(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 433599f99f..2cf370b072 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -15,6 +15,7 @@ # limitations under the License. import math +import threading import time import pytest @@ -371,6 +372,103 @@ def test_time_tolerance(self) -> None: result = ttbuffer.get("world", "robot", time_point=base_time + 0.5, time_tolerance=0.1) assert result is None + def test_forward_tolerance_returns_when_buffer_fills(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + def publish_after_delay() -> None: + time.sleep(0.05) + ttbuffer.receive_transform( + Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ) + + publisher = threading.Thread(target=publish_after_delay) + publisher.start() + + t0 = time.monotonic() + result = ttbuffer.get( + "world", "robot", time_point=base_time, time_tolerance=0.1, forward_tolerance=1.0 + ) + elapsed = time.monotonic() - t0 + publisher.join() + + assert result is not None + assert result.translation.x == 1.0 + assert elapsed < 0.5 + + def test_forward_tolerance_times_out(self) -> None: + ttbuffer = MultiTBuffer() + t0 = time.monotonic() + result = ttbuffer.get( + "world", "robot", time_point=time.time(), forward_tolerance=0.1 + ) + elapsed = time.monotonic() - t0 + assert result is None + assert 0.05 < elapsed < 1.0 + + def test_forward_tolerance_fast_path_when_already_available(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + ttbuffer.receive_transform( + Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ) + + t0 = time.monotonic() + result = ttbuffer.get( + "world", "robot", time_point=base_time, forward_tolerance=10.0 + ) + elapsed = time.monotonic() - t0 + + assert result is not None + assert result.translation.x == 2.0 + assert elapsed < 0.05 + + def test_forward_tolerance_wakes_on_chain_completion(self) -> None: + ttbuffer = MultiTBuffer() + base_time = time.time() + + ttbuffer.receive_transform( + Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ) + + def publish_after_delay() -> None: + time.sleep(0.05) + ttbuffer.receive_transform( + Transform( + translation=Vector3(0.0, 2.0, 0.0), + frame_id="robot", + child_frame_id="sensor", + ts=base_time, + ) + ) + + publisher = threading.Thread(target=publish_after_delay) + publisher.start() + + result = ttbuffer.get( + "world", "sensor", time_point=base_time, time_tolerance=0.1, forward_tolerance=1.0 + ) + publisher.join() + + assert result is not None + assert result.translation.x == 1.0 + assert result.translation.y == 2.0 + def test_same_frame_returns_identity(self) -> None: ttbuffer = MultiTBuffer() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index a182464ca7..40dfd225f6 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -18,6 +18,7 @@ from collections import deque from dataclasses import field from functools import reduce +import threading import time from dimos.memory.timeseries.inmemory import InMemoryStore @@ -59,6 +60,8 @@ def get( child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, ) -> Transform | None: ... def receive_transform(self, *args: Transform) -> None: ... @@ -112,13 +115,16 @@ class MultiTBuffer: def __init__(self, buffer_size: float = 10.0) -> None: self.buffers: dict[tuple[str, str], TBuffer] = {} self.buffer_size = buffer_size + self._cv = threading.Condition() def receive_transform(self, *args: Transform) -> None: - for transform in args: - key = (transform.frame_id, transform.child_frame_id) - if key not in self.buffers: - self.buffers[key] = TBuffer(self.buffer_size) - self.buffers[key].add(transform) + with self._cv: + for transform in args: + key = (transform.frame_id, transform.child_frame_id) + if key not in self.buffers: + self.buffers[key] = TBuffer(self.buffer_size) + self.buffers[key].add(transform) + self._cv.notify_all() def get_frames(self) -> set[str]: frames = set() @@ -164,7 +170,7 @@ def get_transform( return None - def get( + def _get( self, parent_frame: str, child_frame: str, @@ -179,13 +185,40 @@ def get( complex = self.get_transform_search(parent_frame, child_frame, time_point, time_tolerance) if complex is None: - logger.warning( - f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}', {self}" - ) return None return reduce(lambda t1, t2: t1 + t2, complex) + def get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, + ) -> Transform | None: + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is not None: + return result + + if forward_tolerance > 0: + deadline = time.monotonic() + forward_tolerance + with self._cv: + while True: + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is not None: + return result + remaining = deadline - time.monotonic() + if remaining <= 0: + break + self._cv.wait(timeout=remaining) + + logger.warning( + f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}', {self}" + ) + return None + def get_transform_search( self, parent_frame: str, @@ -322,8 +355,10 @@ def get( child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, ) -> Transform | None: - return super().get(parent_frame, child_frame, time_point, time_tolerance) + return super().get(parent_frame, child_frame, time_point, time_tolerance, forward_tolerance=forward_tolerance) def get_pose( self, @@ -331,8 +366,10 @@ def get_pose( child_frame: str, time_point: float | None = None, time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, ) -> PoseStamped | None: - tf = self.get(parent_frame, child_frame, time_point, time_tolerance) + tf = self.get(parent_frame, child_frame, time_point, time_tolerance, forward_tolerance=forward_tolerance) if not tf: return None return tf.to_pose() From 25dd354a3fccc2b5ca6ccd8aeed52a6038f5126f Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sat, 23 May 2026 11:34:44 +0000 Subject: [PATCH 2/7] [autofix.ci] apply automated fixes --- dimos/protocol/tf/test_tf.py | 8 ++------ dimos/protocol/tf/tf.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 2cf370b072..a3e70de646 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -404,9 +404,7 @@ def publish_after_delay() -> None: def test_forward_tolerance_times_out(self) -> None: ttbuffer = MultiTBuffer() t0 = time.monotonic() - result = ttbuffer.get( - "world", "robot", time_point=time.time(), forward_tolerance=0.1 - ) + result = ttbuffer.get("world", "robot", time_point=time.time(), forward_tolerance=0.1) elapsed = time.monotonic() - t0 assert result is None assert 0.05 < elapsed < 1.0 @@ -424,9 +422,7 @@ def test_forward_tolerance_fast_path_when_already_available(self) -> None: ) t0 = time.monotonic() - result = ttbuffer.get( - "world", "robot", time_point=base_time, forward_tolerance=10.0 - ) + result = ttbuffer.get("world", "robot", time_point=base_time, forward_tolerance=10.0) elapsed = time.monotonic() - t0 assert result is not None diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 40dfd225f6..3440d4c87a 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -358,7 +358,13 @@ def get( *, forward_tolerance: float = 0.0, ) -> Transform | None: - return super().get(parent_frame, child_frame, time_point, time_tolerance, forward_tolerance=forward_tolerance) + return super().get( + parent_frame, + child_frame, + time_point, + time_tolerance, + forward_tolerance=forward_tolerance, + ) def get_pose( self, @@ -369,7 +375,13 @@ def get_pose( *, forward_tolerance: float = 0.0, ) -> PoseStamped | None: - tf = self.get(parent_frame, child_frame, time_point, time_tolerance, forward_tolerance=forward_tolerance) + tf = self.get( + parent_frame, + child_frame, + time_point, + time_tolerance, + forward_tolerance=forward_tolerance, + ) if not tf: return None return tf.to_pose() From 3d2e0d6d40f68036aba43deb155975dfe4dc91e1 Mon Sep 17 00:00:00 2001 From: arkluc Date: Sat, 23 May 2026 19:39:35 +0800 Subject: [PATCH 3/7] be cautious --- dimos/protocol/tf/test_tf.py | 2 +- dimos/protocol/tf/tf.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index a3e70de646..281e99c0d7 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -407,7 +407,7 @@ def test_forward_tolerance_times_out(self) -> None: result = ttbuffer.get("world", "robot", time_point=time.time(), forward_tolerance=0.1) elapsed = time.monotonic() - t0 assert result is None - assert 0.05 < elapsed < 1.0 + assert 0.08 < elapsed < 1.0 def test_forward_tolerance_fast_path_when_already_available(self) -> None: ttbuffer = MultiTBuffer() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 3440d4c87a..dbfe4cabda 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -128,7 +128,7 @@ def receive_transform(self, *args: Transform) -> None: def get_frames(self) -> set[str]: frames = set() - for parent, child in self.buffers: + for parent, child in list(self.buffers): frames.add(parent) frames.add(child) return frames @@ -136,7 +136,7 @@ def get_frames(self) -> set[str]: def get_connections(self, frame_id: str) -> set[str]: """Get all frames connected to the given frame (both as parent and child).""" connections = set() - for parent, child in self.buffers: + for parent, child in list(self.buffers): if parent == frame_id: connections.add(child) if child == frame_id: From 1c1bb6c45c56e88e4454b1d43ddc4caa874a8177 Mon Sep 17 00:00:00 2001 From: arkluc Date: Sun, 24 May 2026 16:37:11 +0800 Subject: [PATCH 4/7] condition for read; drop self from log --- dimos/protocol/tf/tf.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index dbfe4cabda..f8777dc5ff 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -128,19 +128,21 @@ def receive_transform(self, *args: Transform) -> None: def get_frames(self) -> set[str]: frames = set() - for parent, child in list(self.buffers): - frames.add(parent) - frames.add(child) + with self._cv: + for parent, child in self.buffers: + frames.add(parent) + frames.add(child) return frames def get_connections(self, frame_id: str) -> set[str]: """Get all frames connected to the given frame (both as parent and child).""" connections = set() - for parent, child in list(self.buffers): - if parent == frame_id: - connections.add(child) - if child == frame_id: - connections.add(parent) + with self._cv: + for parent, child in self.buffers: + if parent == frame_id: + connections.add(child) + if child == frame_id: + connections.add(parent) return connections def get_transform( @@ -198,13 +200,13 @@ def get( *, forward_tolerance: float = 0.0, ) -> Transform | None: - result = self._get(parent_frame, child_frame, time_point, time_tolerance) - if result is not None: - return result + with self._cv: + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is not None: + return result - if forward_tolerance > 0: - deadline = time.monotonic() + forward_tolerance - with self._cv: + if forward_tolerance > 0: + deadline = time.monotonic() + forward_tolerance while True: result = self._get(parent_frame, child_frame, time_point, time_tolerance) if result is not None: @@ -214,10 +216,10 @@ def get( break self._cv.wait(timeout=remaining) - logger.warning( - f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}', {self}" - ) - return None + logger.warning( + f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}'" + ) + return None def get_transform_search( self, From f3cc37094fa2667bf1fb63c923ece03d8904f98e Mon Sep 17 00:00:00 2001 From: arkluc Date: Sun, 24 May 2026 16:52:22 +0800 Subject: [PATCH 5/7] seperate _wait_get --- dimos/protocol/tf/tf.py | 60 ++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index f8777dc5ff..87e03cf49d 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -179,17 +179,39 @@ def _get( time_point: float | None = None, time_tolerance: float | None = None, ) -> Transform | None: - simple = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + with self._cv: + simple = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) - if simple is not None: - return simple + if simple is not None: + return simple - complex = self.get_transform_search(parent_frame, child_frame, time_point, time_tolerance) + complex = self.get_transform_search( + parent_frame, child_frame, time_point, time_tolerance + ) - if complex is None: - return None + if complex is None: + return None - return reduce(lambda t1, t2: t1 + t2, complex) + return reduce(lambda t1, t2: t1 + t2, complex) + + def _wait_get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None, + time_tolerance: float | None, + forward_tolerance: float, + ) -> Transform | None: + deadline = time.monotonic() + forward_tolerance + with self._cv: + while True: + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is not None: + return result + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + self._cv.wait(timeout=remaining) def get( self, @@ -200,26 +222,16 @@ def get( *, forward_tolerance: float = 0.0, ) -> Transform | None: - with self._cv: - result = self._get(parent_frame, child_frame, time_point, time_tolerance) - if result is not None: - return result - - if forward_tolerance > 0: - deadline = time.monotonic() + forward_tolerance - while True: - result = self._get(parent_frame, child_frame, time_point, time_tolerance) - if result is not None: - return result - remaining = deadline - time.monotonic() - if remaining <= 0: - break - self._cv.wait(timeout=remaining) - + result = self._get(parent_frame, child_frame, time_point, time_tolerance) + if result is None and forward_tolerance > 0: + result = self._wait_get( + parent_frame, child_frame, time_point, time_tolerance, forward_tolerance + ) + if result is None: logger.warning( f"No direct transform found between '{parent_frame}' and '{child_frame}' at '{to_human_readable(time_point or time.time())}'" ) - return None + return result def get_transform_search( self, From 13d463aae43af70229ab7db0d38edb87d84ca102 Mon Sep 17 00:00:00 2001 From: arkluc Date: Sun, 24 May 2026 17:46:00 +0800 Subject: [PATCH 6/7] buffer fix --- dimos/protocol/tf/tf.py | 92 ++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 87e03cf49d..3e0ab8267b 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -159,18 +159,19 @@ def get_transform( ts=time_point if time_point is not None else time.time(), ) - # Check forward direction - key = (parent_frame, child_frame) - if key in self.buffers: - return self.buffers[key].get(time_point, time_tolerance) # type: ignore[arg-type] + with self._cv: + # Check forward direction + key = (parent_frame, child_frame) + if key in self.buffers: + return self.buffers[key].get(time_point, time_tolerance) # type: ignore[arg-type] - # Check reverse direction and return inverse - reverse_key = (child_frame, parent_frame) - if reverse_key in self.buffers: - transform = self.buffers[reverse_key].get(time_point, time_tolerance) # type: ignore[arg-type] - return transform.inverse() if transform else None + # Check reverse direction and return inverse + reverse_key = (child_frame, parent_frame) + if reverse_key in self.buffers: + transform = self.buffers[reverse_key].get(time_point, time_tolerance) # type: ignore[arg-type] + return transform.inverse() if transform else None - return None + return None def _get( self, @@ -241,36 +242,37 @@ def get_transform_search( time_tolerance: float | None = None, ) -> list[Transform] | None: """Search for shortest transform chain between parent and child frames using BFS.""" - # Check if direct transform exists (already checked in get_transform, but for clarity) - direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) - if direct is not None: - return [direct] + with self._cv: + # Check if direct transform exists (already checked in get_transform, but for clarity) + direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + if direct is not None: + return [direct] - # BFS to find shortest path - queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) - visited = {parent_frame} + # BFS to find shortest path + queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) + visited = {parent_frame} - while queue: - current_frame, path = queue.popleft() + while queue: + current_frame, path = queue.popleft() - if current_frame == child_frame: - return path + if current_frame == child_frame: + return path - # Get all connections for current frame - connections = self.get_connections(current_frame) + # Get all connections for current frame + connections = self.get_connections(current_frame) - for next_frame in connections: - if next_frame not in visited: - visited.add(next_frame) + for next_frame in connections: + if next_frame not in visited: + visited.add(next_frame) - # Get the transform between current and next frame - transform = self.get_transform( - current_frame, next_frame, time_point, time_tolerance - ) - if transform: - queue.append((next_frame, [*path, transform])) + # Get the transform between current and next frame + transform = self.get_transform( + current_frame, next_frame, time_point, time_tolerance + ) + if transform: + queue.append((next_frame, [*path, transform])) - return None + return None def graph(self) -> str: import subprocess @@ -279,7 +281,9 @@ def connection_str(connection: tuple[str, str]) -> str: (frame_from, frame_to) = connection return f"{frame_from} -> {frame_to}" - graph_str = "\n".join(map(connection_str, self.buffers.keys())) + with self._cv: + keys = list(self.buffers.keys()) + graph_str = "\n".join(map(connection_str, keys)) try: result = subprocess.run( @@ -293,11 +297,14 @@ def connection_str(connection: tuple[str, str]) -> str: return "no diagon installed" def __str__(self) -> str: - if not self.buffers: + with self._cv: + buffers = list(self.buffers.values()) + + if not buffers: return f"{self.__class__.__name__}(empty)" - lines = [f"{self.__class__.__name__}({len(self.buffers)} buffers):"] - for buffer in self.buffers.values(): + lines = [f"{self.__class__.__name__}({len(buffers)} buffers):"] + for buffer in buffers: lines.append(f" {buffer}") return "\n".join(lines) @@ -354,11 +361,12 @@ def publish_static(self, *args: Transform) -> None: def publish_all(self) -> None: """Publish all transforms currently stored in all buffers.""" all_transforms = [] - for buffer in self.buffers.values(): - # Get the latest transform from each buffer - latest = buffer.get() # get() with no args returns latest - if latest: - all_transforms.append(latest) + with self._cv: + for buffer in self.buffers.values(): + # Get the latest transform from each buffer + latest = buffer.get() # get() with no args returns latest + if latest: + all_transforms.append(latest) if all_transforms: self.publish(*all_transforms) From d6578267066ecb78c12f91e981fffea5430d72d2 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sun, 24 May 2026 09:46:37 +0000 Subject: [PATCH 7/7] [autofix.ci] apply automated fixes --- dimos/protocol/tf/tf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 3e0ab8267b..05c600d733 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -364,7 +364,7 @@ def publish_all(self) -> None: with self._cv: for buffer in self.buffers.values(): # Get the latest transform from each buffer - latest = buffer.get() # get() with no args returns latest + latest = buffer.get() # get() with no args returns latest if latest: all_transforms.append(latest)