Enhance SQLAlchemy engine with safe_execute and inspector methods for SnowFlake#8920
Conversation
|
The latest updates on your projects. Learn more about Vercel for GitHub.
|
marimo/_sql/engines/sqlalchemy.py
Outdated
| from sqlalchemy import inspect, text | ||
|
|
||
| _use_database_dialect_command: dict[str, str] = { | ||
| "snowflake": f"""USE DATABASE {f'"{database}"' if re.search(r"[^A-Za-z0-9_]", database) else database}""", |
There was a problem hiding this comment.
[^A-Za-z0-9_] looks like this is a common constant, let's extract it out and put at the top
There was a problem hiding this comment.
moved and improved the indentifier quoting for snowflake
feat: implement identifier quoting for Snowflake in SQLAlchemyEnginefeat: implement identifier quoting for Snowflake in SQLAlchemyEngine
marimo/_sql/engines/sqlalchemy.py
Outdated
| def decorator(func: Callable[..., Any]) -> Callable[..., Any]: | ||
| @functools.wraps(func) | ||
| def wrapper(*args: Any, **kwargs: Any) -> Any: | ||
| try: | ||
| return func(*args, **kwargs) | ||
| except silent_exceptions: | ||
| return fallback | ||
| except Exception: | ||
| getattr(LOGGER, log_level)(message, exc_info=True) | ||
| return fallback | ||
|
|
||
| return wrapper |
There was a problem hiding this comment.
Could we make this strongly typed, maybe with Generics.
Else, we are overriding the function return types. I get this typecheck error
Function decorator obscures type of function because its type is Any basedpyright[reportAny]
There was a problem hiding this comment.
There was a problem hiding this comment.
Pull request overview
Adds multi-database catalog discovery to the SQLAlchemy engine (starting with Snowflake) and introduces a safe_execute helper to standardize exception handling/logging during introspection.
Changes:
- Added
safe_executedecorator and applied it across schema/table/detail discovery helpers. - Implemented per-database inspection via
_get_inspector, including SnowflakeUSE DATABASEswitching andSHOW DATABASESdiscovery. - Expanded unit tests to cover
safe_execute, inspector helpers, and Snowflake database listing behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
marimo/_sql/engines/sqlalchemy.py |
Adds safe_execute, Snowflake database discovery, and per-database inspector context manager used across introspection. |
tests/_sql/test_sqlalchemy.py |
Adds tests for safe_execute, _get_inspector, per-database schema/table/detail helpers, and Snowflake database-name discovery. |
|
|
||
|
|
||
| # ------------------------ | ||
| # Decoracotors |
There was a problem hiding this comment.
Typo in section header: "Decoracotors" should be "Decorators".
| # Decoracotors | |
| # Decorators |
marimo/_sql/engines/sqlalchemy.py
Outdated
| database_names = [ | ||
| str(row[name_col_index]) | ||
| if re.search(r"[^A-Za-z0-9_]", str(row[name_col_index])) | ||
| else str(row[name_col_index]).lower() | ||
| for row in result.fetchall() | ||
| ] |
There was a problem hiding this comment.
This normalization lowercases any database name that doesn’t contain “special” characters, but Snowflake quoted identifiers can be mixed-case (e.g. MyDb) without special characters. Lowercasing those will change the identifier and can make subsequent USE DATABASE / inspection fail. Consider preserving the exact name returned by SHOW DATABASES and only applying case-insensitive matching for display/default selection, or detect mixed-case and avoid lowercasing/ensure quoting.
There was a problem hiding this comment.
Good point, but I think the current behavior is actually correct for Snowflake's semantics.
Snowflake stores unquoted identifiers internally as UPPERCASE. So when SHOW DATABASES returns DB_A, it means the database was created with an unquoted identifier (e.g., CREATE DATABASE db_a or CREATE DATABASE DB_A) — Snowflake normalized it to DB_A. Lowercasing it to db_a is safe because Snowflake resolves unquoted identifiers case-insensitively, so USE DATABASE db_a and USE DATABASE DB_A are equivalent.
marimo/_sql/engines/sqlalchemy.py
Outdated
| import re | ||
|
|
||
| from sqlalchemy import inspect, text | ||
|
|
||
| _use_database_dialect_command: dict[str, str] = { | ||
| "snowflake": f"""USE DATABASE {f'"{database}"' if re.search(r"[^A-Za-z0-9_]", database) else database}""", |
There was a problem hiding this comment.
Snowflake identifier quoting here is not robust: it doesn’t escape embedded double-quotes and it also treats mixed-case identifiers (e.g. a quoted Snowflake database like "MyDb") as safe to use unquoted, which can resolve to a different database (Snowflake uppercases unquoted identifiers). Prefer using SQLAlchemy’s identifier preparer (or a dedicated quoting helper) to quote/escape identifiers correctly rather than interpolating into the SQL string.
| import re | |
| from sqlalchemy import inspect, text | |
| _use_database_dialect_command: dict[str, str] = { | |
| "snowflake": f"""USE DATABASE {f'"{database}"' if re.search(r"[^A-Za-z0-9_]", database) else database}""", | |
| from sqlalchemy import inspect, text | |
| from sqlalchemy.sql.compiler import IdentifierPreparer | |
| # Robustly quote the database identifier for Snowflake, preserving case | |
| # and correctly escaping embedded double-quotes. | |
| if self.dialect == "snowflake": | |
| # Prefer SQLAlchemy's identifier preparer when available. | |
| try: | |
| preparer: IdentifierPreparer = self._connection.dialect.identifier_preparer # type: ignore[attr-defined] | |
| quoted_database = preparer.quote(database) | |
| except Exception: | |
| # Fallback: minimal, but correct, SQL identifier quoting. | |
| escaped_database = database.replace('"', '""') | |
| quoted_database = f'"{escaped_database}"' | |
| else: | |
| quoted_database = database | |
| _use_database_dialect_command: dict[str, str] = { | |
| "snowflake": f"USE DATABASE {quoted_database}", |
There was a problem hiding this comment.
same as befpre
marimo/_sql/engines/sqlalchemy.py
Outdated
| schema_names = self._get_schema_names( | ||
| database if database is not None else "" | ||
| ) |
There was a problem hiding this comment.
get_schemas(database=None) (used by get_default_database() for connection validation) currently calls _get_schema_names(""), which for Snowflake will try to execute USE DATABASE with an empty identifier via _get_inspector and fail. This can cause Snowflake connections without a URL database to be treated as invalid. Consider letting _get_inspector accept Optional[str] and short-circuit to self.inspector (or skip USE DATABASE) when database is falsy, or adjust get_default_database()’s validity check for Snowflake (e.g., SHOW DATABASES).
| schema_names = self._get_schema_names( | |
| database if database is not None else "" | |
| ) | |
| if database is None: | |
| # When no database is specified, avoid issuing a "USE DATABASE" | |
| # statement with an empty identifier (e.g., on Snowflake). | |
| # In this context (e.g., connection validation), we treat the | |
| # absence of an explicit database as having no enumerated schemas. | |
| schema_names: list[str] = [] | |
| else: | |
| schema_names = self._get_schema_names(database) |
There was a problem hiding this comment.
marimo/_sql/engines/sqlalchemy.py
Outdated
| default_lower = self.default_database.lower() | ||
| for db in database_names: | ||
| if db.lower() == default_lower: | ||
| return [str(db).lower()] |
There was a problem hiding this comment.
When the default database is present, this returns [str(db).lower()] even if the matched database name was preserved as case-sensitive (e.g., quoted/mixed-case or containing special characters). That lowercasing can make subsequent USE DATABASE fail for case-sensitive identifiers and is inconsistent with the earlier normalization rules. Return the matched db value as-is (or apply the same normalization logic used when building database_names).
| return [str(db).lower()] | |
| return [db] |
There was a problem hiding this comment.
…for better type safety
…in SQLAlchemyEngine
|
@Light2Dark I’ve refactored and improved the code in the latest commit, focusing on identifier quoting, normalization, and handling edge cases for database/schema retrieval in SQLAlchemyEngine. All recent changes are now ready for a new review. |
| except silent_exceptions: | ||
| return fallback | ||
| except Exception: | ||
| getattr(LOGGER, log_level)(message, exc_info=True) |
There was a problem hiding this comment.
safe_execute can raise unexpectedly if log_level is not a valid logger attribute: getattr(LOGGER, log_level)(...) will raise AttributeError, which bypasses the intended fallback behavior. Consider validating log_level (e.g., map allowed strings to bound methods, or default to LOGGER.warning) so the decorator remains safe even with unexpected input.
| getattr(LOGGER, log_level)(message, exc_info=True) | |
| log_method = getattr(LOGGER, log_level, LOGGER.warning) | |
| log_method(message, exc_info=True) |
| def _quote_identifier(self, identifier: str) -> str: | ||
| """Quote an identifier based on the SQL dialect's quoting rules.""" | ||
| dialect_quoting: dict[str, tuple[re.Pattern[str], str, str]] = { | ||
| "snowflake": (_SNOWFLAKE_NEEDS_QUOTING_RE, '"', '"'), | ||
| } | ||
|
|
||
| if self.dialect not in dialect_quoting: | ||
| return identifier | ||
|
|
||
| pattern, open_quote, close_quote = dialect_quoting[self.dialect] | ||
| if pattern.search(identifier) or identifier != identifier.lower(): | ||
| escaped = identifier.replace( | ||
| close_quote, close_quote + close_quote | ||
| ) | ||
| return f"{open_quote}{escaped}{close_quote}" | ||
| return identifier |
There was a problem hiding this comment.
Snowflake quoting logic is inconsistent with _get_snowflake_database_names(): that method preserves case-sensitive identifiers as returned by SHOW DATABASES, including identifiers that may be all-lowercase (created via quoted identifiers). For such names, _quote_identifier currently returns the identifier unquoted (because it only quotes when identifier != identifier.lower()), which can make USE DATABASE <name> resolve to a different (uppercased) identifier or fail. Consider adjusting the normalization/quoting strategy so any case-sensitive Snowflake database name (including all-lowercase) is quoted when issuing USE DATABASE (e.g., keep raw SHOW names for execution, or track whether a name should be quoted).
| def _get_database_names(self) -> list[str]: | ||
| """Get database names using dialect-specific queries. | ||
|
|
||
| Returns a single-element list with the default database when | ||
| the dialect has no dedicated discovery mechanism. | ||
| """ | ||
| if self.dialect.lower() == "snowflake": | ||
| return self._get_snowflake_database_names() | ||
|
|
||
| return [self.default_database] if self.default_database else [] |
There was a problem hiding this comment.
_get_database_names() treats self.default_database == "" as absent because it uses a truthiness check. get_default_database() explicitly uses empty-string to represent a detached-but-valid connection, so this change can cause get_databases() to return an empty list for valid connections (e.g., sqlite:///). Consider checking is not None instead of truthiness so the empty-string sentinel still yields a single database entry.
|
🚀 Development release published. You may be able to view the changes at https://marimo.app?v=0.21.2-dev100 |
📝 Summary
This pull request adds multi-database engine support, starting with Snowflake, and improves type safety and exception handling in the SQLAlchemy engine integration.
Continues the work from #8824.
Closes
🔍 Description of Changes
Multi-Database Support
_get_inspectorcontext manager that yields an appropriate SQLAlchemyInspectorfor a given database. For dialects that require aUSE DATABASEcommand (e.g., Snowflake), it opens a connection, executes the command, and yields an inspector bound to that connection. For all other dialects, it falls back to the existingself.inspector._get_snowflake_database_namesto fetch all available databases when the target database is not specified or not found._get_schema_names,_get_table_names,_get_columns,_fetch_primary_keys,_fetch_indexes) to accept adatabaseparameter and use_get_inspector, enabling per-database inspection across all operations.Type Safety & Code Quality
ReflectedColumnandReflectedIndextypes to resolve mypy errors._get_snowflake_database_namesby chaining exceptions withraise ... from err(B904 compliance).📋 Checklist