diff --git a/hotdata_marimo/databases.py b/hotdata_marimo/databases.py index 540b1f4..13e8c0a 100644 --- a/hotdata_marimo/databases.py +++ b/hotdata_marimo/databases.py @@ -47,7 +47,7 @@ def databases_panel(client: HotdataClient): gap=1, ) rows: list[dict[str, object]] = [ - {"name": db.name, "id": db.id, "sql_prefix": f"{db.name}.{{schema}}.{{table}}"} + {"description": db.description or db.id, "id": db.id, "sql_prefix": f"{db.id}.{{schema}}.{{table}}"} for db in dbs ] return mo.vstack( @@ -127,13 +127,16 @@ def _rebuild_database_pick(self) -> None: message="(create one first)", ) return - options = {db.name: db.name for db in dbs} - value = current if current in options else next(iter(options)) + options = {db.description or db.id: db.id for db in dbs} + # current holds the previously selected database ID (.value returns the dict value). + # mo.ui.dropdown validates value= against option keys (labels), not values. + default_key = next(iter(options)) + selected_key = next((k for k, v in options.items() if v == current), default_key) self.database = mo.ui.dropdown( options=options, label="Database", full_width=True, - value=value, + value=selected_key, ) def _maybe_create(self) -> None: @@ -153,7 +156,7 @@ def _maybe_create(self) -> None: tables = _parse_table_names(self.tables.value) try: self._create_result = self._client.create_managed_database( - db_name, + description=db_name, schema=schema, tables=tables or None, ) @@ -209,7 +212,7 @@ def result_panel(self): db = self._create_result return mo.callout( mo.md( - f"Created **{db.name}** (`{db.id}`). " + f"Created **{db.description or db.id}** (`{db.id}`). " "Load parquet into a declared table below." ), kind="success", diff --git a/hotdata_marimo/sql_editor.py b/hotdata_marimo/sql_editor.py index ead9d9a..e04f1b0 100644 --- a/hotdata_marimo/sql_editor.py +++ b/hotdata_marimo/sql_editor.py @@ -20,8 +20,10 @@ def __init__( default_sql: str = "", label: str = "SQL", run_label: str = "Run on Hotdata", + database: str | None = None, ) -> None: self._client = client + self._database = database self.sql = mo.ui.text_area(default_sql, label=label) self.run = mo.ui.button( value=0, @@ -103,7 +105,7 @@ def _execute_or_cached(self) -> QueryResult | None: title="Running on Hotdata", subtitle="Re-running last query and waiting for results…", ): - result = self._client.execute_sql(self._cached_sql or "") + result = self._client.execute_sql(self._cached_sql, database=self._database) self._result_cache = result self._last_rerun_n = rerun_n return result @@ -113,7 +115,7 @@ def _execute_or_cached(self) -> QueryResult | None: title="Running on Hotdata", subtitle="Executing query and waiting for results…", ): - result = self._client.execute_sql(sql_text) + result = self._client.execute_sql(sql_text, database=self._database) self._result_cache = result self._cached_sql = sql_text self._last_run_n = run_n @@ -195,7 +197,8 @@ def sql_editor( default_sql: str = "", label: str = "SQL", run_label: str = "Run on Hotdata", + database: str | None = None, ) -> SqlEditor: return SqlEditor( - client, default_sql=default_sql, label=label, run_label=run_label + client, default_sql=default_sql, label=label, run_label=run_label, database=database ) diff --git a/hotdata_marimo/sql_engine.py b/hotdata_marimo/sql_engine.py index 5b7a6e9..33b2ce9 100644 --- a/hotdata_marimo/sql_engine.py +++ b/hotdata_marimo/sql_engine.py @@ -37,9 +37,12 @@ def __init__( self, connection: HotdataClient, engine_name: VariableName | None = None, + *, + default_database: str | None = None, ) -> None: super().__init__(connection, engine_name) self._connections_cache: list[Any] | None = None + self._default_database = default_database @property def source(self) -> str: @@ -291,7 +294,7 @@ def get_table_details( ) def execute(self, query: str) -> Any: - qr = self._connection.execute_sql(query) + qr = self._connection.execute_sql(query, database=self._default_database) fmt = self.sql_output_format() def to_polars() -> Any: @@ -365,7 +368,11 @@ def register_hotdata_sql_engine() -> None: def unregister_hotdata_sql_engine() -> None: """Remove :class:`HotdataMarimoEngine` from Marimo's registry (mostly for tests).""" + global _ORIGINAL_ENGINE_TO_CONNECTION from marimo._sql.get_engines import SUPPORTED_ENGINES while HotdataMarimoEngine in SUPPORTED_ENGINES: SUPPORTED_ENGINES.remove(HotdataMarimoEngine) + if _ORIGINAL_ENGINE_TO_CONNECTION is not None: + _set_engine_to_data_source_connection(_ORIGINAL_ENGINE_TO_CONNECTION) + _ORIGINAL_ENGINE_TO_CONNECTION = None diff --git a/hotdata_marimo/table_browser.py b/hotdata_marimo/table_browser.py index e25d6f4..390b879 100644 --- a/hotdata_marimo/table_browser.py +++ b/hotdata_marimo/table_browser.py @@ -12,6 +12,8 @@ resolve_connection_picker, ) +__all__ = ["TableBrowser", "connection_picker", "table_browser"] + class TableBrowser: """Pick a fully qualified `connection.schema.table` and inspect columns. @@ -43,58 +45,26 @@ def __init__( ) self._table_pick_ctx: str | None = None - self._rebuilt_table_pick_this_run = False self._init_table_pick() - def _init_table_pick(self) -> None: - if self._conn_pick is not None: - self.table_pick = empty_dropdown( - label="Table", - message="(select connection above)", - ) - self._empty_catalog = True - self._all_names = [] - self._table_pick_ctx = "" - return - - names = self._names_for_active_connection() - self._all_names = names - if not names: - self.table_pick = empty_dropdown( - label="Table", - message="(no tables in catalog)", - ) - self._empty_catalog = True - else: - self._empty_catalog = False - self.table_pick = mo.ui.dropdown( - options={n: n for n in names}, - label="Table", - full_width=True, - searchable=True, - ) - self._table_pick_ctx = self._active_connection_id() - def _active_connection_id(self) -> str | None: if self._override_connection_id is not None: return self._override_connection_id or None if self._conn_pick is not None: - v = self._conn_pick.value # type: ignore[attr-defined] - return v if v else None - if self._implicit_connection_id is None: - return None + return self._conn_pick.value or None # type: ignore[attr-defined] return self._implicit_connection_id or None def _names_for_active_connection(self) -> list[str]: cid = self._active_connection_id() - if cid is None or cid == "": + if not cid: return [] return self._client.list_qualified_table_names( limit=self._table_limit, connection_id=cid, ) - def _rebuild_table_pick(self, names: list[str]) -> None: + def _set_table_pick(self, names: list[str]) -> None: + """Create or replace the table dropdown for the given names list.""" self._all_names = names if not names: self._empty_catalog = True @@ -111,7 +81,32 @@ def _rebuild_table_pick(self, names: list[str]) -> None: searchable=True, ) self._table_pick_ctx = self._active_connection_id() - self._rebuilt_table_pick_this_run = True + + def _init_table_pick(self) -> None: + if self._conn_pick is not None: + self._all_names = [] + self._empty_catalog = True + self.table_pick = empty_dropdown( + label="Table", + message="(select connection above)", + ) + self._table_pick_ctx = "" + return + self._set_table_pick(self._names_for_active_connection()) + + def _sync_table_catalog(self) -> bool: + """Refresh the table dropdown when the active connection changes. + + Returns True if the dropdown was rebuilt (so the caller knows not to + read ``.value`` on the new widget in the same Marimo run). + """ + if self._conn_pick is not None: + _ = self._conn_pick.value # type: ignore[attr-defined] + cid = self._active_connection_id() + if not cid or cid == self._table_pick_ctx: + return False + self._set_table_pick(self._names_for_active_connection()) + return True @property def selected_connection_id(self) -> str | None: @@ -122,30 +117,17 @@ def selected_table(self) -> str | None: v = self.table_pick.value return v if v else None - def _sync_table_catalog(self) -> None: - """Refresh the table dropdown when the active connection changes.""" - if self._conn_pick is not None: - _ = self._conn_pick.value # type: ignore[attr-defined] - cid = self._active_connection_id() - if not cid: - return - if cid == self._table_pick_ctx: - return - self._rebuild_table_pick(self._names_for_active_connection()) - @property def ui(self): - self._rebuilt_table_pick_this_run = False - self._sync_table_catalog() - - if not self._rebuilt_table_pick_this_run: + rebuilt = self._sync_table_catalog() + if not rebuilt: _ = self.table_pick.value - sel = None if self._rebuilt_table_pick_this_run else self.selected_table + sel = None if rebuilt else self.selected_table cid = self._active_connection_id() conn_header = ( - mo.md(f"**Connection** `{self._active_connection_id()}`") - if self._active_connection_id() + mo.md(f"**Connection** `{cid}`") + if cid else None ) if not sel: diff --git a/hotdata_marimo/workspace_selector.py b/hotdata_marimo/workspace_selector.py index 0b9425d..5393558 100644 --- a/hotdata_marimo/workspace_selector.py +++ b/hotdata_marimo/workspace_selector.py @@ -26,6 +26,8 @@ def __init__( self._api_key = api_key self._host = host or default_host() self._session_id = session_id + self._client_cache: HotdataClient | None = None + self._client_cache_wid: str | None = None selection = resolve_workspace_selection(api_key, self._host, session_id) self._explicit = selection.source == "explicit_env" if self._explicit: @@ -64,12 +66,16 @@ def workspace_id(self) -> str: @property def client(self) -> HotdataClient: - return HotdataClient( - self._api_key, - self.workspace_id, - host=self._host, - session_id=self._session_id, - ) + wid = self.workspace_id + if self._client_cache is None or self._client_cache_wid != wid: + self._client_cache = HotdataClient( + self._api_key, + wid, + host=self._host, + session_id=self._session_id, + ) + self._client_cache_wid = wid + return self._client_cache @property def ui(self): diff --git a/pyproject.toml b/pyproject.toml index d3b9482..42870ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" requires-python = ">=3.10" license = { text = "MIT" } dependencies = [ - "hotdata-runtime>=0.1.1", + "hotdata-runtime>=0.2.1", "hotdata>=0.2.0", "marimo>=0.10.0", ] diff --git a/tests/test_databases_marimo.py b/tests/test_databases_marimo.py index c4b7c0f..5a20d26 100644 --- a/tests/test_databases_marimo.py +++ b/tests/test_databases_marimo.py @@ -17,7 +17,7 @@ def test_databases_panel_empty_state(mock_client): def test_databases_panel_lists_managed_databases(mock_client): mock_client.list_managed_databases.return_value = [ - ManagedDatabase(id="c1", name="sales", source_type="managed"), + ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"), ] with patch("hotdata_marimo.databases.mo.vstack", return_value="panel"), patch( "hotdata_marimo.databases.mo.md", side_effect=lambda x: x @@ -30,8 +30,8 @@ def test_managed_database_writer_creates_database(mock_client): mock_client.list_managed_databases.return_value = [] mock_client.create_managed_database.return_value = ManagedDatabase( id="conn_new", - name="sales", - source_type="managed", + description="sales", + default_connection_id="conn_c1", ) create = MagicMock() create.value = 1 @@ -71,7 +71,7 @@ def test_managed_database_writer_creates_database(mock_client): panel = writer.result_panel mock_client.create_managed_database.assert_called_once_with( - "sales", + description="sales", schema="public", tables=["orders", "customers"], ) @@ -80,7 +80,7 @@ def test_managed_database_writer_creates_database(mock_client): def test_managed_database_writer_loads_parquet(mock_client): mock_client.list_managed_databases.return_value = [ - ManagedDatabase(id="c1", name="sales", source_type="managed"), + ManagedDatabase(id="c1", description="sales", default_connection_id="conn_c1"), ] mock_client.upload_parquet.return_value = "upl_1" mock_client.load_managed_table.return_value = LoadManagedTableResult( diff --git a/uv.lock b/uv.lock index 932a1ed..201dca8 100644 --- a/uv.lock +++ b/uv.lock @@ -169,7 +169,7 @@ wheels = [ [[package]] name = "hotdata" -version = "0.2.0" +version = "0.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, @@ -177,9 +177,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ce/0f/1e9e024aa13f8d4bf8f9fb1bce777da6ca19da05b8435f2ba5cd5f87ec80/hotdata-0.2.0.tar.gz", hash = "sha256:e1131c05ed34d2f39ddee84930eb6694ed46971d7a442df5932689b28a6c9b4f", size = 108780, upload-time = "2026-05-19T04:01:38.345Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/20/5d016d4aec39fe04eb77a6394651e3b18f6ecc701dc678563889debd79ed/hotdata-0.2.3.tar.gz", hash = "sha256:bc415af4ac475e5bd5fe3320d1c14aaac92942462a0ef9dac22b89bcc120ad55", size = 118187, upload-time = "2026-05-23T04:41:10.835Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/e7/63b4820963ec475fe16403d363e5ddec237cfe01a39c2d7aff6a6d48d720/hotdata-0.2.0-py3-none-any.whl", hash = "sha256:d3d644a3b607f4891a784b8d5afa30a00bd9e437db013fd0581bf8bca501ac0d", size = 256603, upload-time = "2026-05-19T04:01:36.253Z" }, + { url = "https://files.pythonhosted.org/packages/09/87/d3cb845ba01e5b4e9bfb1e59d0032a246b94497e470d171f2ee2a56bd850/hotdata-0.2.3-py3-none-any.whl", hash = "sha256:aed2ae884d184cf143572c84d068a9ceedbe021a6d14005332647a46aa7be11c", size = 275718, upload-time = "2026-05-23T04:41:09.355Z" }, ] [[package]] @@ -200,7 +200,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "hotdata", specifier = ">=0.2.0" }, - { name = "hotdata-runtime", specifier = ">=0.1.1" }, + { name = "hotdata-runtime", specifier = ">=0.2.1" }, { name = "marimo", specifier = ">=0.10.0" }, ] @@ -209,16 +209,16 @@ dev = [{ name = "pytest", specifier = ">=8.0" }] [[package]] name = "hotdata-runtime" -version = "0.1.1" +version = "0.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "hotdata" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "pandas", version = "3.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/86/0b/b2889774abaa555be7625999c8730361d86f588aec7219918c616817cdb1/hotdata_runtime-0.1.1.tar.gz", hash = "sha256:3ed64b430f258b3505cf2d1f6635069fc1afef6df6fc3fca5e52ac578e69ead7", size = 57795, upload-time = "2026-05-19T05:13:15.451Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/ab/6629c4555d02ef68712391acf1e5b11ce90dcb278462ada503b05d76f74a/hotdata_runtime-0.2.1.tar.gz", hash = "sha256:ef4901e9f3ea01fee087e6b452e1b9b5c14f3b1192ca8ef08ce34fe9b003e5e2", size = 65672, upload-time = "2026-05-25T00:38:40.168Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/7b/98cf841d7900e4eb198d1a393828a0999d9b4d54ef792cec5fa3eb4c5a01/hotdata_runtime-0.1.1-py3-none-any.whl", hash = "sha256:51da53100329fbf634abbe95073b2edbbdad174886263b40652091a88f41f0ad", size = 10210, upload-time = "2026-05-19T05:13:14.28Z" }, + { url = "https://files.pythonhosted.org/packages/8f/3b/93d09fda8dd064150aeed7f0581a28c8d24d3410d0007986fafffb56737b/hotdata_runtime-0.2.1-py3-none-any.whl", hash = "sha256:bb9765059fe26ca407ce6f08ae471473a463cae5bcb1ab46d07353863bca1e1e", size = 10411, upload-time = "2026-05-25T00:38:38.731Z" }, ] [[package]]