Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions sqlit/domains/explorer/app/schema_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ def _run_with_retry(self, fn: Callable[[], ResultT], database: str | None) -> Re
raise
return self._run(fn)

def execute_cursor_query(
self,
query: str,
database: str | None = None,
) -> tuple[list[str], list[tuple[Any, ...]]]:
"""Execute a cursor-returning metadata query on the serialized DB executor."""

def fetch() -> tuple[list[str], list[tuple[Any, ...]]]:
cursor = self.session.connection.cursor()
try:
cursor.execute(query)
description = cursor.description or []
result_columns = [column[0] for column in description]
rows = [tuple(row) for row in cursor.fetchall()]
return result_columns, rows
finally:
close = getattr(cursor, "close", None)
if callable(close):
close()

return self._run_with_retry(fetch, database)

def list_databases(self) -> list[str]:
inspector = self.session.provider.schema_inspector
return self._run_with_retry(
Expand Down
20 changes: 16 additions & 4 deletions sqlit/domains/explorer/ui/mixins/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,20 @@ def _show_table_metadata_result(
self.query_input.text = query_text
self.notify(f"{message_prefix}: {table_name} ({len(rows)})")

def _fetch_cursor_result(self: TreeMixinHost, query: str) -> tuple[list[str], list[tuple[Any, ...]]]:
"""Execute a metadata query and return cursor columns/rows."""
def _fetch_cursor_result(
self: TreeMixinHost,
query: str,
database: str | None = None,
) -> tuple[list[str], list[tuple[Any, ...]]]:
"""Execute a metadata query and return cursor columns/rows.

Prefer the schema service so raw connection access stays serialized through
the session executor, matching other explorer metadata operations.
"""
schema_service = self._get_schema_service()
if schema_service is not None and hasattr(schema_service, "execute_cursor_query"):
return schema_service.execute_cursor_query(query, database)

if self.current_connection is None:
raise RuntimeError("No active database connection")
cursor = self.current_connection.cursor()
Expand Down Expand Up @@ -454,7 +466,7 @@ def _show_mysql_full_columns(self: TreeMixinHost, data: Any) -> bool:
table_ref = self._format_mysql_table_reference(data.database, data.name)
query = f"SHOW FULL COLUMNS FROM {table_ref}"
try:
result_columns, rows = self._fetch_cursor_result(query)
result_columns, rows = self._fetch_cursor_result(query, data.database)
except Exception as error:
self.notify(f"Error getting table columns: {error}", severity="error")
return True
Expand Down Expand Up @@ -527,7 +539,7 @@ def _show_mysql_indexes(self: TreeMixinHost, data: Any) -> bool:
table_ref = self._format_mysql_table_reference(data.database, data.name)
query = f"SHOW INDEX FROM {table_ref}"
try:
result_columns, rows = self._fetch_cursor_result(query)
result_columns, rows = self._fetch_cursor_result(query, data.database)
except Exception as error:
self.notify(f"Error getting table indexes: {error}", severity="error")
return True
Expand Down
72 changes: 53 additions & 19 deletions sqlit/domains/explorer/ui/mixins/tree_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,9 @@ class TreeFilterMixin:
_tree_original_labels: dict[int, str] = {}
_tree_filter_applied: bool = False
_tree_filter_scope_path: str | None = None
_tree_filter_scope_node: Any | None = None

_TREE_FILTER_LOADABLE_FOLDERS = {
"databases",
"tables",
"views",
"indexes",
"triggers",
"sequences",
"procedures",
}
_TREE_FILTER_LOADABLE_FOLDERS = {"tables"}

def action_tree_filter(self: TreeFilterMixinHost) -> None:
"""Open the tree filter."""
Expand Down Expand Up @@ -70,13 +63,15 @@ def action_table_filter(self: TreeFilterMixinHost) -> None:
pass

scope_path = expansion_state.get_node_path(cast(Any, self), tables_node) or None
self._begin_tree_filter_session(scope_path=scope_path)
self._begin_tree_filter_session(scope_path=scope_path, scope_node=tables_node)
self._remember_tree_filter_path(scope_path, include_self=True)
self._ensure_tree_filter_search_nodes_loaded()
self._update_tree_filter()
self._update_footer_bindings()

def _begin_tree_filter_session(self: TreeFilterMixinHost, *, scope_path: str | None) -> None:
def _begin_tree_filter_session(
self: TreeFilterMixinHost, *, scope_path: str | None, scope_node: Any | None = None
) -> None:
"""Reset transient filter state and show the filter input for a new session."""
self._tree_filter_visible = True
self._tree_filter_text = ""
Expand All @@ -91,6 +86,7 @@ def _begin_tree_filter_session(self: TreeFilterMixinHost, *, scope_path: str | N
self._tree_original_labels = {}
self._tree_filter_applied = False
self._tree_filter_scope_path = scope_path
self._tree_filter_scope_node = scope_node
self.tree_filter_input.show()

def action_tree_filter_close(self: TreeFilterMixinHost) -> None:
Expand All @@ -108,6 +104,7 @@ def _close_tree_filter_state(self: TreeFilterMixinHost, *, restore_tree: bool) -
self._tree_filter_regex_error = None
self._tree_filter_typing = False
self._tree_filter_scope_path = None
self._tree_filter_scope_node = None
self.tree_filter_input.hide()
self._restore_tree_labels()
if restore_tree:
Expand Down Expand Up @@ -295,7 +292,11 @@ def _update_tree_filter(self: TreeFilterMixinHost) -> None:
self._restore_tree_labels()
search_root = self._get_tree_filter_search_root()
is_scoped_filter = bool(getattr(self, "_tree_filter_scope_path", None))
total = self._count_all_nodes(search_root if is_scoped_filter else None)
total = (
self._count_all_nodes(search_root if is_scoped_filter else None)
if search_root is not None
else 0
)
raw_text = self._tree_filter_text
self._tree_filter_fuzzy = raw_text.startswith("~")
self._tree_filter_regex_mode = False
Expand Down Expand Up @@ -330,6 +331,12 @@ def _update_tree_filter(self: TreeFilterMixinHost) -> None:
if is_scoped_filter:
self._ensure_tree_filter_search_nodes_loaded()

if search_root is None:
self._tree_filter_matches = []
self._tree_filter_match_index = 0
self.tree_filter_input.set_filter(self._tree_filter_text, 0, 0)
return

# Find all matching nodes. The default Explorer filter keeps main's
# connection/database-only behavior; Table Filter searches inside its scoped subtree.
matches: list[Any] = []
Expand Down Expand Up @@ -387,14 +394,35 @@ def _remember_tree_filter_path(self: TreeFilterMixinHost, path: str | None, *, i
cast(Any, self)._pending_tree_cursor_path = path
cast(Any, self)._pending_tree_cursor_connection = ""

def _get_tree_filter_search_root(self: TreeFilterMixinHost) -> Any:
def _get_tree_filter_search_root(self: TreeFilterMixinHost) -> Any | None:
"""Return the subtree that should be searched by the active explorer filter."""
path = getattr(self, "_tree_filter_scope_path", None)
if path:
scoped_node = expansion_state.find_node_by_path(cast(Any, self), self.object_tree.root, path)
if scoped_node is not None:
return scoped_node
return self.object_tree.root
if not path:
return self.object_tree.root

scoped_node = getattr(self, "_tree_filter_scope_node", None)
if scoped_node is not None and self._node_is_attached_to_tree(scoped_node):
return scoped_node

scoped_node = expansion_state.find_node_by_path(cast(Any, self), self.object_tree.root, path)
if scoped_node is not None:
self._tree_filter_scope_node = scoped_node
return scoped_node

# A scoped Table Filter must never fall back to the whole tree. If the
# Tables folder is temporarily unavailable during async refresh/filter
# updates, searching from root can select similarly named system tables
# such as INFORMATION_SCHEMA.ROUTINES and make metadata shortcuts show
# columns like ROUTINE_NAME.
return None

def _node_is_attached_to_tree(self: TreeFilterMixinHost, node: Any) -> bool:
current = node
while current is not None:
if current is self.object_tree.root:
return True
current = getattr(current, "parent", None)
return False

def _extract_tree_filter_regex_query(self: TreeFilterMixinHost, raw_text: str) -> str | None:
"""Return regex pattern when the filter text uses a regex prefix."""
Expand Down Expand Up @@ -493,7 +521,11 @@ def _ensure_tree_filter_search_nodes_loaded(self: TreeFilterMixinHost) -> bool:
return False

started = False
stack = [self._get_tree_filter_search_root()]
search_root = self._get_tree_filter_search_root()
if search_root is None:
return False

stack = [search_root]
while stack:
node = stack.pop()
if self._tree_filter_should_load_node(node):
Expand Down Expand Up @@ -575,6 +607,8 @@ def _apply_filter_to_tree(self: TreeFilterMixinHost) -> None:
ancestor_ids = set()
pending_ids = set()
scope_node = self._get_tree_filter_search_root()
if scope_node is None:
return

def collect_pending(node: Any) -> None:
if self._tree_filter_node_has_pending_load(node):
Expand Down
96 changes: 56 additions & 40 deletions tests/ui/explorer/test_table_metadata_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,35 +96,25 @@ def test_show_table_indexes_displays_matching_indexes_in_results(self):
assert mixin._last_result_row_count == 1
assert mixin.query_input.text == "-- Indexes for users"


def test_show_table_columns_uses_mysql_full_columns(self):
def test_show_table_columns_uses_serialized_mysql_full_columns(self):
mixin = self._create_mixin()
mixin.current_provider.metadata.db_type = "mysql"
mixin.current_provider.dialect.quote_identifier.side_effect = lambda name: f"`{name}`"
cursor = MagicMock()
cursor.description = [
("Field",),
("Type",),
("Collation",),
("Null",),
("Key",),
("Default",),
("Extra",),
("Privileges",),
("Comment",),
]
cursor.fetchall.return_value = [
("id", "int", None, "NO", "PRI", None, "auto_increment", "select,insert", ""),
("email", "varchar(255)", "utf8mb4_0900_ai_ci", "YES", "UNI", None, "", "select,insert", "login email"),
]
mixin.current_connection = MagicMock()
mixin.current_connection.cursor.return_value = cursor
mixin._get_schema_service = MagicMock()
schema_service = MagicMock()
schema_service.execute_cursor_query.return_value = (
["Field", "Type", "Collation", "Null", "Key", "Default", "Extra", "Privileges", "Comment"],
[
("id", "int", None, "NO", "PRI", None, "auto_increment", "select,insert", ""),
("email", "varchar(255)", "utf8mb4_0900_ai_ci", "YES", "UNI", None, "", "select,insert", "login email"),
],
)
mixin._get_schema_service = MagicMock(return_value=schema_service)

mixin.action_show_table_columns()

cursor.execute.assert_called_once_with("SHOW FULL COLUMNS FROM `app_db`.`users`")
mixin._get_schema_service.assert_not_called()
schema_service.execute_cursor_query.assert_called_once_with("SHOW FULL COLUMNS FROM `app_db`.`users`", "app_db")
mixin.current_connection.cursor.assert_not_called()
mixin._replace_results_table.assert_called_once_with(
["Field", "Type", "Collation", "Null", "Key", "Default", "Extra", "Privileges", "Comment"],
[
Expand All @@ -134,31 +124,41 @@ def test_show_table_columns_uses_mysql_full_columns(self):
)
assert mixin.query_input.text == "SHOW FULL COLUMNS FROM `app_db`.`users`"

def test_show_table_indexes_uses_mysql_show_index(self):
def test_show_table_columns_reports_mysql_show_error_after_serialized_retry(self):
mixin = self._create_mixin()
mixin.current_provider.metadata.db_type = "mysql"
mixin.current_provider.dialect.quote_identifier.side_effect = lambda name: f"`{name}`"
cursor = MagicMock()
cursor.description = [
("Table",),
("Non_unique",),
("Key_name",),
("Seq_in_index",),
("Column_name",),
("Index_type",),
]
cursor.fetchall.return_value = [
("users", 0, "PRIMARY", 1, "id", "BTREE"),
("users", 1, "idx_users_email", 1, "email", "BTREE"),
]
mixin.current_connection = MagicMock()
mixin.current_connection.cursor.return_value = cursor
mixin._get_schema_service = MagicMock()
schema_service = MagicMock()
schema_service.execute_cursor_query.side_effect = RuntimeError((0, ""))
mixin._get_schema_service = MagicMock(return_value=schema_service)

mixin.action_show_table_columns()

schema_service.execute_cursor_query.assert_called_once_with("SHOW FULL COLUMNS FROM `app_db`.`users`", "app_db")
mixin.current_connection.cursor.assert_not_called()
mixin.notify.assert_called_once_with("Error getting table columns: (0, '')", severity="error")
mixin._replace_results_table.assert_not_called()

def test_show_table_indexes_uses_serialized_mysql_show_index(self):
mixin = self._create_mixin()
mixin.current_provider.metadata.db_type = "mysql"
mixin.current_provider.dialect.quote_identifier.side_effect = lambda name: f"`{name}`"
mixin.current_connection = MagicMock()
schema_service = MagicMock()
schema_service.execute_cursor_query.return_value = (
["Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", "Index_type"],
[
("users", 0, "PRIMARY", 1, "id", "BTREE"),
("users", 1, "idx_users_email", 1, "email", "BTREE"),
],
)
mixin._get_schema_service = MagicMock(return_value=schema_service)

mixin.action_show_table_indexes()

cursor.execute.assert_called_once_with("SHOW INDEX FROM `app_db`.`users`")
mixin._get_schema_service.assert_not_called()
schema_service.execute_cursor_query.assert_called_once_with("SHOW INDEX FROM `app_db`.`users`", "app_db")
mixin.current_connection.cursor.assert_not_called()
mixin._replace_results_table.assert_called_once_with(
["Table", "Non_unique", "Key_name", "Seq_in_index", "Column_name", "Index_type"],
[
Expand All @@ -168,6 +168,22 @@ def test_show_table_indexes_uses_mysql_show_index(self):
)
assert mixin.query_input.text == "SHOW INDEX FROM `app_db`.`users`"

def test_show_table_indexes_reports_mysql_show_error_after_serialized_retry(self):
mixin = self._create_mixin()
mixin.current_provider.metadata.db_type = "mysql"
mixin.current_provider.dialect.quote_identifier.side_effect = lambda name: f"`{name}`"
mixin.current_connection = MagicMock()
schema_service = MagicMock()
schema_service.execute_cursor_query.side_effect = RuntimeError((0, ""))
mixin._get_schema_service = MagicMock(return_value=schema_service)

mixin.action_show_table_indexes()

schema_service.execute_cursor_query.assert_called_once_with("SHOW INDEX FROM `app_db`.`users`", "app_db")
mixin.current_connection.cursor.assert_not_called()
mixin.notify.assert_called_once_with("Error getting table indexes: (0, '')", severity="error")
mixin._replace_results_table.assert_not_called()

def test_show_table_indexes_matches_table_case_insensitively(self):
mixin = self._create_mixin()
mixin.object_tree.cursor_node.data = TableNode(database="app_db", schema="public", name="User")
Expand Down
23 changes: 23 additions & 0 deletions tests/ui/test_tree_filter_tables_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,29 @@ def test_table_filter_from_table_uses_ancestor_tables_subtree() -> None:
assert views.parent is None


def test_table_filter_does_not_fall_back_to_root_when_scope_is_missing() -> None:
host, _database, tables, _users, _views = build_host()
info_schema = host.object_tree.root.add("information_schema")
info_schema.data = DatabaseNode(name="information_schema")
info_tables = info_schema.add("Tables")
info_tables.data = FolderNode(folder_type="tables", database="information_schema")
routines = info_tables.add("ROUTINES")
routines.data = TableNode(database="information_schema", schema="", name="ROUTINES")

host._tree_filter_visible = True
host._tree_filter_scope_path = "db:main/folder:tables"
host._tree_filter_scope_node = None
tables.remove()
host._tree_filter_text = "routine"

host._update_tree_filter()

assert host._tree_filter_matches == []
assert host.object_tree.cursor_node is None
assert routines.parent is info_tables
assert host.tree_filter_input.last_filter == ("routine", 0, 0)


def test_table_filter_accept_moves_cursor_to_matched_table() -> None:
host, database, _tables, users, _views = build_host()
host.object_tree.cursor_node = database
Expand Down
Loading