Skip to content

Commit

Permalink
tui: Allow merging threads in live mode TUI
Browse files Browse the repository at this point in the history
This allows users to choose between seeing allocations for a single
thread at a time, or seeing allocations across all threads.

Signed-off-by: Marta Gomez Macias <mgmacias@google.com>
  • Loading branch information
mgmacias95 authored and godlygeek committed May 30, 2024
1 parent 71c729a commit d0e5866
Show file tree
Hide file tree
Showing 4 changed files with 1,141 additions and 625 deletions.
1 change: 1 addition & 0 deletions news/589.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a button in the live-mode TUI to show allocations from all threads at once.
54 changes: 41 additions & 13 deletions src/memray/reporters/tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class AllocationTable(Widget):
sort_column_id = reactive(default_sort_column_id)
snapshot = reactive(_EMPTY_SNAPSHOT)
current_thread = reactive(0)
merge_threads = reactive(False, init=False)

columns = [
"Location",
Expand Down Expand Up @@ -371,6 +372,10 @@ def watch_current_thread(self) -> None:
"""Called when the current_thread attribute changes."""
self.populate_table()

def watch_merge_threads(self) -> None:
"""Called when the merge_threads attribute changes."""
self.populate_table()

def watch_snapshot(self) -> None:
"""Called when the snapshot attribute changes."""
self.populate_table()
Expand Down Expand Up @@ -413,7 +418,9 @@ def populate_table(self) -> None:
new_locations = set()

for location, result in sorted_allocations:
if self.current_thread not in result.thread_ids:
if not self.merge_threads and (
self.current_thread not in result.thread_ids
):
continue

total_color = self._get_color(result.total_memory, total_allocations)
Expand Down Expand Up @@ -514,6 +521,7 @@ class TUI(Screen[None]):
Binding("q,esc", "app.quit", "Quit"),
Binding("<,left", "previous_thread", "Previous Thread"),
Binding(">,right", "next_thread", "Next Thread"),
Binding("m", "toggle_merge_threads", "Merge Threads"),
Binding("t", "sort(1)", "Sort by Total"),
Binding("o", "sort(3)", "Sort by Own"),
Binding("a", "sort(5)", "Sort by Allocations"),
Expand All @@ -539,6 +547,7 @@ def __init__(self, pid: Optional[int], cmd_line: Optional[str], native: bool):
self.native = native
self._seen_threads: Set[int] = set()
self._max_memory_seen = 0
self._merge_threads = False
super().__init__()

@property
Expand All @@ -547,16 +556,36 @@ def current_thread(self) -> int:

def action_previous_thread(self) -> None:
"""An action to switch to previous thread."""
self.thread_idx = (self.thread_idx - 1) % len(self.threads)
if not self._merge_threads:
self.thread_idx = (self.thread_idx - 1) % len(self.threads)

def action_next_thread(self) -> None:
"""An action to switch to next thread."""
self.thread_idx = (self.thread_idx + 1) % len(self.threads)
if not self._merge_threads:
self.thread_idx = (self.thread_idx + 1) % len(self.threads)

def action_sort(self, col_number: int) -> None:
"""An action to sort the table rows based on a given column attribute."""
self.update_sort_key(col_number)

def _populate_header_thread_labels(self, thread_idx: int) -> None:
if self._merge_threads:
tid_label = "[b]TID[/]: *"
thread_label = "[b]All threads[/]"
else:
tid_label = f"[b]TID[/]: {hex(self.current_thread)}"
thread_label = f"[b]Thread[/] {thread_idx + 1} of {len(self.threads)}"

self.query_one("#tid", Label).update(tid_label)
self.query_one("#thread", Label).update(thread_label)

def action_toggle_merge_threads(self) -> None:
"""An action to toggle showing allocations from all threads together."""
self._merge_threads = not self._merge_threads
redraw_footer(self.app)
self.app.query_one(AllocationTable).merge_threads = self._merge_threads
self._populate_header_thread_labels(self.thread_idx)

def action_toggle_pause(self) -> None:
"""Toggle pause on keypress"""
if self.paused or not self.disconnected:
Expand All @@ -572,18 +601,12 @@ def action_scroll_grid(self, direction: str) -> None:

def watch_thread_idx(self, thread_idx: int) -> None:
"""Called when the thread_idx attribute changes."""
self.query_one("#tid", Label).update(f"[b]TID[/]: {hex(self.current_thread)}")
self.query_one("#thread", Label).update(
f"[b]Thread[/] {thread_idx + 1} of {len(self.threads)}"
)
self._populate_header_thread_labels(thread_idx)
self.query_one(AllocationTable).current_thread = self.current_thread

def watch_threads(self, threads: List[int]) -> None:
def watch_threads(self) -> None:
"""Called when the threads attribute changes."""
self.query_one("#tid", Label).update(f"[b]TID[/]: {hex(self.current_thread)}")
self.query_one("#thread", Label).update(
f"[b]Thread[/] {self.thread_idx + 1} of {len(threads)}"
)
self._populate_header_thread_labels(self.thread_idx)

def watch_disconnected(self) -> None:
self.update_label()
Expand Down Expand Up @@ -666,6 +689,11 @@ def rewrite_bindings(self, bindings: Bindings) -> None:
elif self.disconnected:
del bindings["space"]

if self._merge_threads:
bindings.pop("less_than_sign")
bindings.pop("greater_than_sign")
update_key_description(bindings, "m", "Unmerge Threads")

@property
def active_bindings(self) -> Dict[str, Any]:
bindings = super().active_bindings.copy()
Expand All @@ -674,7 +702,7 @@ def active_bindings(self) -> Dict[str, Any]:


class UpdateThread(threading.Thread):
def __init__(self, app: App[None], reader: SocketReader) -> None:
def __init__(self, app: "TUIApp", reader: SocketReader) -> None:
self._app = app
self._reader = reader
self._update_requested = threading.Event()
Expand Down
Loading

0 comments on commit d0e5866

Please sign in to comment.