From c0138ccaae013ca214d72d25fc847c50ff12c8f5 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 17:36:17 +0000 Subject: [PATCH 1/7] feat: refactor SQL file loading and storage utilities to improve PyArrow integration --- sqlspec/loader.py | 13 ++++++--- sqlspec/storage/_utils.py | 42 +++++++++++++++++++++++++---- sqlspec/storage/backends/fsspec.py | 17 +++++------- sqlspec/storage/backends/local.py | 20 +++++--------- sqlspec/storage/backends/obstore.py | 31 +++++++-------------- 5 files changed, 69 insertions(+), 54 deletions(-) diff --git a/sqlspec/loader.py b/sqlspec/loader.py index 8a1fee96..2d679f2a 100644 --- a/sqlspec/loader.py +++ b/sqlspec/loader.py @@ -13,7 +13,12 @@ from urllib.parse import unquote, urlparse from sqlspec.core import SQL, get_cache, get_cache_config -from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError +from sqlspec.exceptions import ( + FileNotFoundInStorageError, + SQLFileNotFoundError, + SQLFileParseError, + StorageOperationFailedError, +) from sqlspec.storage.registry import storage_registry as default_storage_registry from sqlspec.utils.correlation import CorrelationContext from sqlspec.utils.logging import get_logger @@ -259,9 +264,11 @@ def _read_file_content(self, path: str | Path) -> str: return backend.read_text(path_str, encoding=self.encoding) except KeyError as e: raise SQLFileNotFoundError(path_str) from e + except FileNotFoundInStorageError as e: + raise SQLFileNotFoundError(path_str) from e + except FileNotFoundError as e: + raise SQLFileNotFoundError(path_str) from e except StorageOperationFailedError as e: - if "not found" in str(e).lower() or "no such file" in str(e).lower(): - raise SQLFileNotFoundError(path_str) from e raise SQLFileParseError(path_str, path_str, e) from e except Exception as e: raise SQLFileParseError(path_str, path_str, e) from e diff --git a/sqlspec/storage/_utils.py b/sqlspec/storage/_utils.py index 02359fea..a98905a5 100644 --- a/sqlspec/storage/_utils.py +++ b/sqlspec/storage/_utils.py @@ -1,12 +1,44 @@ """Shared utilities for storage backends.""" from pathlib import Path +from typing import Any, Final -__all__ = ("resolve_storage_path",) +from sqlspec.utils.module_loader import ensure_pyarrow + +FILE_PROTOCOL: Final[str] = "file" +FILE_SCHEME_PREFIX: Final[str] = "file://" + +__all__ = ("FILE_PROTOCOL", "FILE_SCHEME_PREFIX", "import_pyarrow", "import_pyarrow_parquet", "resolve_storage_path") + + +def import_pyarrow() -> "Any": + """Import PyArrow with optional dependency guard. + + Returns: + PyArrow module. + """ + + ensure_pyarrow() + import pyarrow as pa + + return pa + + +def import_pyarrow_parquet() -> "Any": + """Import PyArrow parquet module with optional dependency guard. + + Returns: + PyArrow parquet module. + """ + + ensure_pyarrow() + import pyarrow.parquet as pq + + return pq def resolve_storage_path( - path: "str | Path", base_path: str = "", protocol: str = "file", strip_file_scheme: bool = True + path: "str | Path", base_path: str = "", protocol: str = FILE_PROTOCOL, strip_file_scheme: bool = True ) -> str: """Resolve path relative to base_path with protocol-specific handling. @@ -43,10 +75,10 @@ def resolve_storage_path( path_str = str(path) - if strip_file_scheme and path_str.startswith("file://"): - path_str = path_str.removeprefix("file://") + if strip_file_scheme and path_str.startswith(FILE_SCHEME_PREFIX): + path_str = path_str.removeprefix(FILE_SCHEME_PREFIX) - if protocol == "file": + if protocol == FILE_PROTOCOL: path_obj = Path(path_str) if path_obj.is_absolute(): diff --git a/sqlspec/storage/backends/fsspec.py b/sqlspec/storage/backends/fsspec.py index 21df5cee..b96280f2 100644 --- a/sqlspec/storage/backends/fsspec.py +++ b/sqlspec/storage/backends/fsspec.py @@ -5,8 +5,8 @@ from mypy_extensions import mypyc_attr -from sqlspec.storage._utils import resolve_storage_path -from sqlspec.utils.module_loader import ensure_fsspec, ensure_pyarrow +from sqlspec.storage._utils import import_pyarrow_parquet, resolve_storage_path +from sqlspec.utils.module_loader import ensure_fsspec from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: @@ -205,8 +205,7 @@ def move(self, source: str | Path, destination: str | Path, **kwargs: Any) -> No def read_arrow(self, path: str | Path, **kwargs: Any) -> "ArrowTable": """Read an Arrow table from storage.""" - ensure_pyarrow() - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) with self.fs.open(resolved_path, mode="rb", **kwargs) as f: @@ -214,8 +213,7 @@ def read_arrow(self, path: str | Path, **kwargs: Any) -> "ArrowTable": def write_arrow(self, path: str | Path, table: "ArrowTable", **kwargs: Any) -> None: """Write an Arrow table to storage.""" - ensure_pyarrow() - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) with self.fs.open(resolved_path, mode="wb") as f: @@ -273,15 +271,14 @@ def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> s return f"{self._fs_uri}{resolved_path}" def _stream_file_batches(self, obj_path: str | Path) -> "Iterator[ArrowRecordBatch]": - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() with self.fs.open(obj_path, mode="rb") as f: parquet_file = pq.ParquetFile(f) # pyright: ignore[reportArgumentType] yield from parquet_file.iter_batches() def stream_arrow(self, pattern: str, **kwargs: Any) -> "Iterator[ArrowRecordBatch]": - ensure_pyarrow() - + import_pyarrow_parquet() for obj_path in self.glob(pattern, **kwargs): yield from self._stream_file_batches(obj_path) @@ -303,8 +300,6 @@ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[Arro Returns: AsyncIterator of Arrow record batches """ - ensure_pyarrow() - return _ArrowStreamer(self, pattern, **kwargs) async def read_text_async(self, path: str | Path, encoding: str = "utf-8", **kwargs: Any) -> str: diff --git a/sqlspec/storage/backends/local.py b/sqlspec/storage/backends/local.py index 6ffc0411..78a2c88c 100644 --- a/sqlspec/storage/backends/local.py +++ b/sqlspec/storage/backends/local.py @@ -12,7 +12,7 @@ from mypy_extensions import mypyc_attr -from sqlspec.utils.module_loader import ensure_pyarrow +from sqlspec.storage._utils import import_pyarrow_parquet from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: @@ -233,19 +233,15 @@ def is_path(self, path: "str | Path") -> bool: def read_arrow(self, path: "str | Path", **kwargs: Any) -> "ArrowTable": """Read Arrow table from file.""" - ensure_pyarrow() - import pyarrow.parquet as pq - - return pq.read_table(str(self._resolve_path(path))) # pyright: ignore + pq = import_pyarrow_parquet() + return pq.read_table(str(self._resolve_path(path)), **kwargs) # pyright: ignore def write_arrow(self, path: "str | Path", table: "ArrowTable", **kwargs: Any) -> None: """Write Arrow table to file.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved = self._resolve_path(path) resolved.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(table, str(resolved)) # pyright: ignore + pq.write_table(table, str(resolved), **kwargs) # pyright: ignore def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatch"]: """Stream Arrow record batches from files matching pattern. @@ -253,9 +249,7 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatc Yields: Arrow record batches from matching files. """ - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() files = self.glob(pattern) for file_path in files: resolved = self._resolve_path(file_path) @@ -264,8 +258,6 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatc def sign(self, path: "str | Path", expires_in: int = 3600, for_upload: bool = False) -> str: """Generate a signed URL (returns file:// URI for local files).""" - # For local files, just return a file:// URI - # No actual signing needed for local files return self._resolve_path(path).as_uri() # Async methods using sync_tools.async_ diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 375010e2..979f7663 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -7,6 +7,7 @@ import fnmatch import io import logging +import re from collections.abc import AsyncIterator, Iterator from pathlib import Path, PurePosixPath from typing import Any, Final, cast @@ -15,9 +16,9 @@ from mypy_extensions import mypyc_attr from sqlspec.exceptions import StorageOperationFailedError -from sqlspec.storage._utils import resolve_storage_path +from sqlspec.storage._utils import import_pyarrow, import_pyarrow_parquet, resolve_storage_path from sqlspec.typing import ArrowRecordBatch, ArrowTable -from sqlspec.utils.module_loader import ensure_obstore, ensure_pyarrow +from sqlspec.utils.module_loader import ensure_obstore from sqlspec.utils.sync_tools import async_ __all__ = ("ObStoreBackend",) @@ -46,7 +47,7 @@ def __aiter__(self) -> "_AsyncArrowIterator": return self async def __anext__(self) -> ArrowRecordBatch: - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() if self._files_iterator is None: files = self.backend.glob(self.pattern, **self.kwargs) @@ -338,19 +339,15 @@ def is_path(self, path: "str | Path") -> bool: def read_arrow(self, path: "str | Path", **kwargs: Any) -> ArrowTable: """Read Arrow table using obstore.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = self.read_bytes(resolved_path) return pq.read_table(io.BytesIO(data), **kwargs) def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table using obstore.""" - ensure_pyarrow() - import pyarrow as pa - import pyarrow.parquet as pq - + pa = import_pyarrow() + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) schema = table.schema @@ -358,8 +355,6 @@ def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> N new_fields = [] for field in schema: if str(field.type).startswith("decimal64"): - import re - match = re.match(r"decimal64\((\d+),\s*(\d+)\)", str(field.type)) if match: precision, scale = int(match.group(1)), int(match.group(2)) @@ -381,9 +376,7 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator[ArrowRecordBatch Yields: Iterator of Arrow record batches from matching objects. """ - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() for obj_path in self.glob(pattern, **kwargs): resolved_path = resolve_storage_path(obj_path, self.base_path, self.protocol, strip_file_scheme=True) result = self.store.get(resolved_path) @@ -518,18 +511,14 @@ async def get_metadata_async(self, path: "str | Path", **kwargs: Any) -> dict[st async def read_arrow_async(self, path: "str | Path", **kwargs: Any) -> ArrowTable: """Read Arrow table from storage asynchronously.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = await self.read_bytes_async(resolved_path) return pq.read_table(io.BytesIO(data), **kwargs) async def write_arrow_async(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table to storage asynchronously.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) buffer = io.BytesIO() pq.write_table(table, buffer, **kwargs) From d88f770a41b34bd4c13ba06076a9bdb8fca9cb60 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 17:43:03 +0000 Subject: [PATCH 2/7] feat: add type casting for ArrowTable in read_arrow methods across storage backends --- sqlspec/storage/backends/fsspec.py | 4 ++-- sqlspec/storage/backends/local.py | 4 ++-- sqlspec/storage/backends/obstore.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sqlspec/storage/backends/fsspec.py b/sqlspec/storage/backends/fsspec.py index b96280f2..58dda5d6 100644 --- a/sqlspec/storage/backends/fsspec.py +++ b/sqlspec/storage/backends/fsspec.py @@ -1,7 +1,7 @@ # pyright: reportPrivateUsage=false import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from mypy_extensions import mypyc_attr @@ -209,7 +209,7 @@ def read_arrow(self, path: str | Path, **kwargs: Any) -> "ArrowTable": resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) with self.fs.open(resolved_path, mode="rb", **kwargs) as f: - return pq.read_table(f) + return cast("ArrowTable", pq.read_table(f)) def write_arrow(self, path: str | Path, table: "ArrowTable", **kwargs: Any) -> None: """Write an Arrow table to storage.""" diff --git a/sqlspec/storage/backends/local.py b/sqlspec/storage/backends/local.py index 78a2c88c..8bd0344d 100644 --- a/sqlspec/storage/backends/local.py +++ b/sqlspec/storage/backends/local.py @@ -7,7 +7,7 @@ import shutil from collections.abc import AsyncIterator, Iterator from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import unquote, urlparse from mypy_extensions import mypyc_attr @@ -234,7 +234,7 @@ def is_path(self, path: "str | Path") -> bool: def read_arrow(self, path: "str | Path", **kwargs: Any) -> "ArrowTable": """Read Arrow table from file.""" pq = import_pyarrow_parquet() - return pq.read_table(str(self._resolve_path(path)), **kwargs) # pyright: ignore + return cast("ArrowTable", pq.read_table(str(self._resolve_path(path)), **kwargs)) def write_arrow(self, path: "str | Path", table: "ArrowTable", **kwargs: Any) -> None: """Write Arrow table to file.""" diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 979f7663..887f7f60 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -342,7 +342,7 @@ def read_arrow(self, path: "str | Path", **kwargs: Any) -> ArrowTable: pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = self.read_bytes(resolved_path) - return pq.read_table(io.BytesIO(data), **kwargs) + return cast("ArrowTable", pq.read_table(io.BytesIO(data), **kwargs)) def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table using obstore.""" @@ -514,7 +514,7 @@ async def read_arrow_async(self, path: "str | Path", **kwargs: Any) -> ArrowTabl pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = await self.read_bytes_async(resolved_path) - return pq.read_table(io.BytesIO(data), **kwargs) + return cast("ArrowTable", pq.read_table(io.BytesIO(data), **kwargs)) async def write_arrow_async(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table to storage asynchronously.""" From d8e5e782d36f9d7c8e52a475c0249cfa4715d4cf Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 21:08:54 +0000 Subject: [PATCH 3/7] feat: Add COPY statement builders and error handling for storage operations - Implemented `build_copy_from_statement`, `build_copy_to_statement`, and `build_copy_statement` functions for generating SQL COPY statements. - Enhanced `SQLFactory` with methods for `copy_from`, `copy_to`, and `copy` to facilitate easier usage of COPY operations. - Introduced `execute_sync_storage_operation` and `raise_storage_error` for consistent error handling across storage backends. - Updated storage backends (FSSpec, Local, ObStore) to utilize the new error handling mechanism. - Added tests for COPY statement builders and storage error handling to ensure correctness and reliability. - Refactored serializer utilities to improve schema dumping and caching mechanisms. --- sqlspec/builder/__init__.py | 11 +- sqlspec/builder/_factory.py | 148 ++++++++++++- sqlspec/storage/backends/fsspec.py | 75 +++++-- sqlspec/storage/backends/local.py | 88 ++++++-- sqlspec/storage/backends/obstore.py | 65 +++++- sqlspec/storage/errors.py | 104 +++++++++ sqlspec/utils/fixtures.py | 5 +- sqlspec/utils/serializers.py | 209 ++++++++++++++++++- sqlspec/utils/type_guards.py | 32 --- tests/unit/test_builder/test_copy_helpers.py | 43 ++++ tests/unit/test_storage/test_errors.py | 32 +++ tests/unit/test_utils/test_type_guards.py | 44 +++- 12 files changed, 768 insertions(+), 88 deletions(-) create mode 100644 sqlspec/storage/errors.py create mode 100644 tests/unit/test_builder/test_copy_helpers.py create mode 100644 tests/unit/test_storage/test_errors.py diff --git a/sqlspec/builder/__init__.py b/sqlspec/builder/__init__.py index 866158a2..8293087b 100644 --- a/sqlspec/builder/__init__.py +++ b/sqlspec/builder/__init__.py @@ -40,7 +40,13 @@ MathExpression, StringExpression, ) -from sqlspec.builder._factory import SQLFactory, sql +from sqlspec.builder._factory import ( + SQLFactory, + build_copy_from_statement, + build_copy_statement, + build_copy_to_statement, + sql, +) from sqlspec.builder._insert import Insert from sqlspec.builder._join import JoinBuilder from sqlspec.builder._merge import Merge @@ -127,6 +133,9 @@ "UpdateTableClauseMixin", "WhereClauseMixin", "WindowFunctionBuilder", + "build_copy_from_statement", + "build_copy_statement", + "build_copy_to_statement", "extract_expression", "parse_column_expression", "parse_condition_expression", diff --git a/sqlspec/builder/_factory.py b/sqlspec/builder/_factory.py index d33b4748..50cebedf 100644 --- a/sqlspec/builder/_factory.py +++ b/sqlspec/builder/_factory.py @@ -4,7 +4,8 @@ """ import logging -from typing import TYPE_CHECKING, Any, Union +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Union, cast import sqlglot from sqlglot import exp @@ -46,6 +47,8 @@ from sqlspec.exceptions import SQLBuilderError if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from sqlspec.builder._expression_wrappers import ExpressionWrapper @@ -73,6 +76,9 @@ "Truncate", "Update", "WindowFunctionBuilder", + "build_copy_from_statement", + "build_copy_statement", + "build_copy_to_statement", "sql", ) @@ -108,6 +114,96 @@ } +def _normalize_copy_dialect(dialect: DialectType | None) -> str: + if dialect is None: + return "postgres" + if isinstance(dialect, str): + return dialect + return str(dialect) + + +def _to_copy_schema(table: str, columns: "Sequence[str] | None") -> exp.Expression: + base = exp.table_(table) + if not columns: + return base + column_nodes = [exp.column(column_name) for column_name in columns] + return exp.Schema(this=base, expressions=column_nodes) + + +def _build_copy_expression( + *, direction: str, table: str, location: str, columns: "Sequence[str] | None", options: "Mapping[str, Any] | None" +) -> exp.Copy: + copy_args: dict[str, Any] = {"this": _to_copy_schema(table, columns), "files": [exp.Literal.string(location)]} + + if direction == "from": + copy_args["kind"] = True + elif direction == "to": + copy_args["kind"] = False + + if options: + params: list[exp.CopyParameter] = [] + for key, value in options.items(): + identifier = exp.Var(this=str(key).upper()) + value_expression: exp.Expression + if isinstance(value, bool): + value_expression = exp.Boolean(this=value) + elif value is None: + value_expression = exp.null() + elif isinstance(value, (int, float)): + value_expression = exp.Literal.number(value) + elif isinstance(value, (list, tuple)): + elements = [exp.Literal.string(str(item)) for item in value] + value_expression = exp.Array(expressions=elements) + else: + value_expression = exp.Literal.string(str(value)) + params.append(exp.CopyParameter(this=identifier, expression=value_expression)) + copy_args["params"] = params + + return exp.Copy(**copy_args) + + +def build_copy_statement( + *, + direction: str, + table: str, + location: str, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, +) -> SQL: + expression = _build_copy_expression( + direction=direction, table=table, location=location, columns=columns, options=options + ) + rendered = expression.sql(dialect=_normalize_copy_dialect(dialect)) + return SQL(rendered) + + +def build_copy_from_statement( + table: str, + source: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, +) -> SQL: + return build_copy_statement( + direction="from", table=table, location=source, columns=columns, options=options, dialect=dialect + ) + + +def build_copy_to_statement( + table: str, + target: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, +) -> SQL: + return build_copy_statement( + direction="to", table=table, location=target, columns=columns, options=options, dialect=dialect + ) + + class SQLFactory: """Factory for creating SQL builders and column expressions.""" @@ -479,6 +575,56 @@ def comment_on(self, dialect: DialectType = None) -> "CommentOn": """ return CommentOn(dialect=dialect or self.dialect) + def copy_from( + self, + table: str, + source: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, + ) -> SQL: + """Build a COPY ... FROM statement.""" + + effective_dialect = dialect or self.dialect + return build_copy_from_statement(table, source, columns=columns, options=options, dialect=effective_dialect) + + def copy_to( + self, + table: str, + target: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, + ) -> SQL: + """Build a COPY ... TO statement.""" + + effective_dialect = dialect or self.dialect + return build_copy_to_statement(table, target, columns=columns, options=options, dialect=effective_dialect) + + def copy( + self, + table: str, + *, + source: str | None = None, + target: str | None = None, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, + ) -> SQL: + """Build a COPY statement, inferring direction from provided arguments.""" + + if (source is None and target is None) or (source is not None and target is not None): + msg = "Provide either 'source' or 'target' (but not both) to sql.copy()." + raise SQLBuilderError(msg) + + if source is not None: + return self.copy_from(table, source, columns=columns, options=options, dialect=dialect) + + target_value = cast("str", target) + return self.copy_to(table, target_value, columns=columns, options=options, dialect=dialect) + @staticmethod def _looks_like_sql(candidate: str, expected_type: str | None = None) -> bool: """Determine if a string looks like SQL. diff --git a/sqlspec/storage/backends/fsspec.py b/sqlspec/storage/backends/fsspec.py index 58dda5d6..3f0b0fae 100644 --- a/sqlspec/storage/backends/fsspec.py +++ b/sqlspec/storage/backends/fsspec.py @@ -6,6 +6,7 @@ from mypy_extensions import mypyc_attr from sqlspec.storage._utils import import_pyarrow_parquet, resolve_storage_path +from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.utils.module_loader import ensure_fsspec from sqlspec.utils.sync_tools import async_ @@ -158,7 +159,15 @@ def base_uri(self) -> str: def read_bytes(self, path: str | Path, **kwargs: Any) -> bytes: """Read bytes from an object.""" resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - return self.fs.cat(resolved_path, **kwargs) # type: ignore[no-any-return] # pyright: ignore + return cast( + "bytes", + execute_sync_storage_operation( + lambda: self.fs.cat(resolved_path, **kwargs), + backend=self.backend_type, + operation="read_bytes", + path=resolved_path, + ), + ) def write_bytes(self, path: str | Path, data: bytes, **kwargs: Any) -> None: """Write bytes to an object.""" @@ -169,8 +178,11 @@ def write_bytes(self, path: str | Path, data: bytes, **kwargs: Any) -> None: if parent_dir and not self.fs.exists(parent_dir): self.fs.makedirs(parent_dir, exist_ok=True) - with self.fs.open(resolved_path, mode="wb", **kwargs) as f: - f.write(data) # pyright: ignore + def _action() -> None: + with self.fs.open(resolved_path, mode="wb", **kwargs) as file_obj: + file_obj.write(data) # pyright: ignore + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_bytes", path=resolved_path) def read_text(self, path: str | Path, encoding: str = "utf-8", **kwargs: Any) -> str: """Read text from an object.""" @@ -189,35 +201,65 @@ def exists(self, path: str | Path, **kwargs: Any) -> bool: def delete(self, path: str | Path, **kwargs: Any) -> None: """Delete an object.""" resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - self.fs.rm(resolved_path, **kwargs) + execute_sync_storage_operation( + lambda: self.fs.rm(resolved_path, **kwargs), + backend=self.backend_type, + operation="delete", + path=resolved_path, + ) def copy(self, source: str | Path, destination: str | Path, **kwargs: Any) -> None: """Copy an object.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=False) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=False) - self.fs.copy(source_path, dest_path, **kwargs) + execute_sync_storage_operation( + lambda: self.fs.copy(source_path, dest_path, **kwargs), + backend=self.backend_type, + operation="copy", + path=f"{source_path}->{dest_path}", + ) def move(self, source: str | Path, destination: str | Path, **kwargs: Any) -> None: """Move an object.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=False) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=False) - self.fs.mv(source_path, dest_path, **kwargs) + execute_sync_storage_operation( + lambda: self.fs.mv(source_path, dest_path, **kwargs), + backend=self.backend_type, + operation="move", + path=f"{source_path}->{dest_path}", + ) def read_arrow(self, path: str | Path, **kwargs: Any) -> "ArrowTable": """Read an Arrow table from storage.""" pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - with self.fs.open(resolved_path, mode="rb", **kwargs) as f: - return cast("ArrowTable", pq.read_table(f)) + return cast( + "ArrowTable", + execute_sync_storage_operation( + lambda: self._read_arrow(resolved_path, pq, kwargs), + backend=self.backend_type, + operation="read_arrow", + path=resolved_path, + ), + ) def write_arrow(self, path: str | Path, table: "ArrowTable", **kwargs: Any) -> None: """Write an Arrow table to storage.""" pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - with self.fs.open(resolved_path, mode="wb") as f: - pq.write_table(table, f, **kwargs) # pyright: ignore + + def _action() -> None: + with self.fs.open(resolved_path, mode="wb") as file_obj: + pq.write_table(table, file_obj, **kwargs) # pyright: ignore + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_arrow", path=resolved_path) + + def _read_arrow(self, resolved_path: str, pq: Any, options: "dict[str, Any]") -> Any: + with self.fs.open(resolved_path, mode="rb", **options) as file_obj: + return pq.read_table(file_obj) def list_objects(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> list[str]: """List objects with optional prefix.""" @@ -273,8 +315,17 @@ def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> s def _stream_file_batches(self, obj_path: str | Path) -> "Iterator[ArrowRecordBatch]": pq = import_pyarrow_parquet() - with self.fs.open(obj_path, mode="rb") as f: - parquet_file = pq.ParquetFile(f) # pyright: ignore[reportArgumentType] + file_handle = execute_sync_storage_operation( + lambda: self.fs.open(obj_path, mode="rb"), + backend=self.backend_type, + operation="stream_open", + path=str(obj_path), + ) + + with file_handle as stream: + parquet_file = execute_sync_storage_operation( + lambda: pq.ParquetFile(stream), backend=self.backend_type, operation="stream_arrow", path=str(obj_path) + ) yield from parquet_file.iter_batches() def stream_arrow(self, pattern: str, **kwargs: Any) -> "Iterator[ArrowRecordBatch]": diff --git a/sqlspec/storage/backends/local.py b/sqlspec/storage/backends/local.py index 8bd0344d..88567e40 100644 --- a/sqlspec/storage/backends/local.py +++ b/sqlspec/storage/backends/local.py @@ -6,13 +6,16 @@ import shutil from collections.abc import AsyncIterator, Iterator +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, cast from urllib.parse import unquote, urlparse from mypy_extensions import mypyc_attr +from sqlspec.exceptions import FileNotFoundInStorageError from sqlspec.storage._utils import import_pyarrow_parquet +from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: @@ -107,23 +110,32 @@ def _resolve_path(self, path: "str | Path") -> Path: def read_bytes(self, path: "str | Path", **kwargs: Any) -> bytes: """Read bytes from file.""" resolved = self._resolve_path(path) - return resolved.read_bytes() + try: + return execute_sync_storage_operation( + lambda: resolved.read_bytes(), backend=self.backend_type, operation="read_bytes", path=str(resolved) + ) + except FileNotFoundInStorageError as error: + raise FileNotFoundError(str(resolved)) from error def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: """Write bytes to file.""" resolved = self._resolve_path(path) - resolved.parent.mkdir(parents=True, exist_ok=True) - resolved.write_bytes(data) + + def _action() -> None: + resolved.parent.mkdir(parents=True, exist_ok=True) + resolved.write_bytes(data) + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_bytes", path=str(resolved)) def read_text(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str: """Read text from file.""" - return self._resolve_path(path).read_text(encoding=encoding) + data = self.read_bytes(path, **kwargs) + return data.decode(encoding) def write_text(self, path: "str | Path", data: str, encoding: str = "utf-8", **kwargs: Any) -> None: """Write text to file.""" - resolved = self._resolve_path(path) - resolved.parent.mkdir(parents=True, exist_ok=True) - resolved.write_text(data, encoding=encoding) + encoded = data.encode(encoding) + self.write_bytes(path, encoded, **kwargs) def list_objects(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> list[str]: """List objects in directory.""" @@ -163,10 +175,14 @@ def exists(self, path: "str | Path", **kwargs: Any) -> bool: def delete(self, path: "str | Path", **kwargs: Any) -> None: """Delete file or directory.""" resolved = self._resolve_path(path) - if resolved.is_dir(): - shutil.rmtree(resolved) - elif resolved.exists(): - resolved.unlink() + + def _action() -> None: + if resolved.is_dir(): + shutil.rmtree(resolved) + elif resolved.exists(): + resolved.unlink() + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="delete", path=str(resolved)) def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: """Copy file or directory.""" @@ -174,17 +190,22 @@ def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) - dst = self._resolve_path(destination) dst.parent.mkdir(parents=True, exist_ok=True) - if src.is_dir(): - shutil.copytree(src, dst, dirs_exist_ok=True) - else: - shutil.copy2(src, dst) + def _action() -> None: + if src.is_dir(): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy2(src, dst) + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="copy", path=f"{src}->{dst}") def move(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: """Move file or directory.""" src = self._resolve_path(source) dst = self._resolve_path(destination) dst.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(src), str(dst)) + execute_sync_storage_operation( + lambda: shutil.move(str(src), str(dst)), backend=self.backend_type, operation="move", path=f"{src}->{dst}" + ) def glob(self, pattern: str, **kwargs: Any) -> list[str]: """Find files matching pattern.""" @@ -210,6 +231,14 @@ def glob(self, pattern: str, **kwargs: Any) -> list[str]: def get_metadata(self, path: "str | Path", **kwargs: Any) -> dict[str, Any]: """Get file metadata.""" resolved = self._resolve_path(path) + return execute_sync_storage_operation( + lambda: self._collect_metadata(resolved), + backend=self.backend_type, + operation="get_metadata", + path=str(resolved), + ) + + def _collect_metadata(self, resolved: "Path") -> dict[str, Any]: if not resolved.exists(): return {} @@ -234,14 +263,27 @@ def is_path(self, path: "str | Path") -> bool: def read_arrow(self, path: "str | Path", **kwargs: Any) -> "ArrowTable": """Read Arrow table from file.""" pq = import_pyarrow_parquet() - return cast("ArrowTable", pq.read_table(str(self._resolve_path(path)), **kwargs)) + resolved = self._resolve_path(path) + return cast( + "ArrowTable", + execute_sync_storage_operation( + lambda: pq.read_table(str(resolved), **kwargs), + backend=self.backend_type, + operation="read_arrow", + path=str(resolved), + ), + ) def write_arrow(self, path: "str | Path", table: "ArrowTable", **kwargs: Any) -> None: """Write Arrow table to file.""" pq = import_pyarrow_parquet() resolved = self._resolve_path(path) - resolved.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(table, str(resolved), **kwargs) # pyright: ignore + + def _action() -> None: + resolved.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, str(resolved), **kwargs) # pyright: ignore + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_arrow", path=str(resolved)) def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatch"]: """Stream Arrow record batches from files matching pattern. @@ -253,7 +295,13 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatc files = self.glob(pattern) for file_path in files: resolved = self._resolve_path(file_path) - parquet_file = pq.ParquetFile(str(resolved)) + resolved_str = str(resolved) + parquet_file = execute_sync_storage_operation( + partial(pq.ParquetFile, resolved_str), + backend=self.backend_type, + operation="stream_arrow", + path=resolved_str, + ) yield from parquet_file.iter_batches() # pyright: ignore def sign(self, path: "str | Path", expires_in: int = 3600, for_upload: bool = False) -> str: diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 887f7f60..0e825a79 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -9,6 +9,7 @@ import logging import re from collections.abc import AsyncIterator, Iterator +from functools import partial from pathlib import Path, PurePosixPath from typing import Any, Final, cast from urllib.parse import urlparse @@ -17,6 +18,7 @@ from sqlspec.exceptions import StorageOperationFailedError from sqlspec.storage._utils import import_pyarrow, import_pyarrow_parquet, resolve_storage_path +from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.typing import ArrowRecordBatch, ArrowTable from sqlspec.utils.module_loader import ensure_obstore from sqlspec.utils.sync_tools import async_ @@ -197,8 +199,13 @@ def read_bytes(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ig else: resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - result = self.store.get(resolved_path) - return cast("bytes", result.bytes().to_bytes()) + def _action() -> bytes: + result = self.store.get(resolved_path) + return cast("bytes", result.bytes().to_bytes()) + + return execute_sync_storage_operation( + _action, backend=self.backend_type, operation="read_bytes", path=resolved_path + ) def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Write bytes using obstore.""" @@ -206,7 +213,13 @@ def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: resolved_path = self._resolve_path_for_local_store(path) else: resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - self.store.put(resolved_path, data) + + execute_sync_storage_operation( + lambda: self.store.put(resolved_path, data), + backend=self.backend_type, + operation="write_bytes", + path=resolved_path, + ) def read_text(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str: """Read text using obstore.""" @@ -241,19 +254,31 @@ def exists(self, path: "str | Path", **kwargs: Any) -> bool: # pyright: ignore[ def delete(self, path: "str | Path", **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Delete object using obstore.""" resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - self.store.delete(resolved_path) + execute_sync_storage_operation( + lambda: self.store.delete(resolved_path), backend=self.backend_type, operation="delete", path=resolved_path + ) def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Copy object using obstore.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=True) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=True) - self.store.copy(source_path, dest_path) + execute_sync_storage_operation( + lambda: self.store.copy(source_path, dest_path), + backend=self.backend_type, + operation="copy", + path=f"{source_path}->{dest_path}", + ) def move(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Move object using obstore.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=True) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=True) - self.store.rename(source_path, dest_path) + execute_sync_storage_operation( + lambda: self.store.rename(source_path, dest_path), + backend=self.backend_type, + operation="move", + path=f"{source_path}->{dest_path}", + ) def glob(self, pattern: str, **kwargs: Any) -> list[str]: """Find objects matching pattern. @@ -342,7 +367,15 @@ def read_arrow(self, path: "str | Path", **kwargs: Any) -> ArrowTable: pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = self.read_bytes(resolved_path) - return cast("ArrowTable", pq.read_table(io.BytesIO(data), **kwargs)) + return cast( + "ArrowTable", + execute_sync_storage_operation( + lambda: pq.read_table(io.BytesIO(data), **kwargs), + backend=self.backend_type, + operation="read_arrow", + path=resolved_path, + ), + ) def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table using obstore.""" @@ -366,7 +399,12 @@ def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> N table = table.cast(pa.schema(new_fields)) buffer = io.BytesIO() - pq.write_table(table, buffer, **kwargs) + execute_sync_storage_operation( + lambda: pq.write_table(table, buffer, **kwargs), + backend=self.backend_type, + operation="write_arrow", + path=resolved_path, + ) buffer.seek(0) self.write_bytes(resolved_path, buffer.read()) @@ -379,11 +417,18 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator[ArrowRecordBatch pq = import_pyarrow_parquet() for obj_path in self.glob(pattern, **kwargs): resolved_path = resolve_storage_path(obj_path, self.base_path, self.protocol, strip_file_scheme=True) - result = self.store.get(resolved_path) + result = execute_sync_storage_operation( + partial(self.store.get, resolved_path), + backend=self.backend_type, + operation="stream_read", + path=resolved_path, + ) bytes_obj = result.bytes() data = bytes_obj.to_bytes() buffer = io.BytesIO(data) - parquet_file = pq.ParquetFile(buffer) + parquet_file = execute_sync_storage_operation( + partial(pq.ParquetFile, buffer), backend=self.backend_type, operation="stream_arrow", path=resolved_path + ) yield from parquet_file.iter_batches() def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> str: diff --git a/sqlspec/storage/errors.py b/sqlspec/storage/errors.py new file mode 100644 index 00000000..c1b81bd2 --- /dev/null +++ b/sqlspec/storage/errors.py @@ -0,0 +1,104 @@ +"""Storage error normalization helpers.""" + +import errno +import logging +from typing import TYPE_CHECKING, Any, TypeVar + +from sqlspec.exceptions import FileNotFoundInStorageError, StorageOperationFailedError + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Mapping + +__all__ = ("StorageError", "execute_async_storage_operation", "execute_sync_storage_operation", "raise_storage_error") + + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +_NOT_FOUND_NAMES = {"NotFoundError", "ObjectNotFound", "NoSuchKey", "NoSuchBucket", "NoSuchFile"} + + +class StorageError: + """Normalized view of a storage backend exception.""" + + __slots__ = ("backend", "message", "operation", "original", "path", "retryable") + + def __init__( + self, message: str, backend: str, operation: str, path: str | None, retryable: bool, original: Exception + ) -> None: + self.message = message + self.backend = backend + self.operation = operation + self.path = path + self.retryable = retryable + self.original = original + + +def _is_missing_error(error: Exception) -> bool: + if isinstance(error, FileNotFoundError): + return True + + number = getattr(error, "errno", None) + if number in {errno.ENOENT, errno.ENOTDIR}: + return True + + name = error.__class__.__name__ + return name in _NOT_FOUND_NAMES + + +def _is_retryable(error: Exception) -> bool: + if isinstance(error, (ConnectionError, TimeoutError)): + return True + + name = error.__class__.__name__ + return bool("Timeout" in name or "Temporary" in name) + + +def _normalize_storage_error(error: Exception, *, backend: str, operation: str, path: str | None) -> "StorageError": + message = f"{backend} {operation} failed" + if path: + message = f"{message} for {path}" + + return StorageError( + message=message, backend=backend, operation=operation, path=path, retryable=_is_retryable(error), original=error + ) + + +def raise_storage_error(error: Exception, *, backend: str, operation: str, path: str | None) -> None: + is_missing = _is_missing_error(error) + normalized = _normalize_storage_error(error, backend=backend, operation=operation, path=path) + + log_extra: Mapping[str, Any] = { + "storage_backend": backend, + "storage_operation": operation, + "storage_path": path, + "exception_type": error.__class__.__name__, + "retryable": normalized.retryable, + } + + if is_missing: + logger.debug("Storage object missing during %s", operation, extra=log_extra) + raise FileNotFoundInStorageError(normalized.message) from error + + logger.warning("Storage operation %s failed", operation, extra=log_extra, exc_info=error) + raise StorageOperationFailedError(normalized.message) from error + + +def execute_sync_storage_operation(func: "Callable[[], T]", *, backend: str, operation: str, path: str | None) -> T: + try: + return func() + except Exception as error: + raise_storage_error(error, backend=backend, operation=operation, path=path) + raise + + +async def execute_async_storage_operation( + func: "Callable[[], Awaitable[T]]", *, backend: str, operation: str, path: str | None +) -> T: + try: + return await func() + except Exception as error: + raise_storage_error(error, backend=backend, operation=operation, path=path) + raise diff --git a/sqlspec/utils/fixtures.py b/sqlspec/utils/fixtures.py index 628303af..60d54446 100644 --- a/sqlspec/utils/fixtures.py +++ b/sqlspec/utils/fixtures.py @@ -11,9 +11,9 @@ from sqlspec.storage import storage_registry from sqlspec.utils.serializers import from_json as decode_json +from sqlspec.utils.serializers import schema_dump from sqlspec.utils.serializers import to_json as encode_json from sqlspec.utils.sync_tools import async_ -from sqlspec.utils.type_guards import schema_dump if TYPE_CHECKING: from sqlspec.typing import SupportedSchemaModel @@ -146,14 +146,17 @@ def _serialize_data(data: Any) -> str: """ if isinstance(data, (list, tuple)): serialized_items: list[Any] = [] + for item in data: if isinstance(item, (str, int, float, bool, type(None))): serialized_items.append(item) else: serialized_items.append(schema_dump(item)) + return encode_json(serialized_items) if isinstance(data, (str, int, float, bool, type(None))): return encode_json(data) + return encode_json(schema_dump(data)) diff --git a/sqlspec/utils/serializers.py b/sqlspec/utils/serializers.py index 35fac632..19bfa832 100644 --- a/sqlspec/utils/serializers.py +++ b/sqlspec/utils/serializers.py @@ -1,16 +1,44 @@ -"""JSON serialization utilities for SQLSpec. +"""Serialization utilities for SQLSpec. -Re-exports common JSON encoding and decoding functions from the core -serialization module for convenient access. - -Provides NumPy array serialization hooks for framework integrations -that support custom type encoders and decoders (e.g., Litestar). +Provides JSON helpers, serializer pipelines, and optional dependency hooks. """ -from typing import Any, Literal, overload +from __future__ import annotations + +import dataclasses +from threading import RLock +from typing import TYPE_CHECKING, Any, Literal, cast, overload from sqlspec._serialization import decode_json, encode_json -from sqlspec.typing import NUMPY_INSTALLED +from sqlspec.typing import ( + ATTRS_INSTALLED, + MSGSPEC_INSTALLED, + NUMPY_INSTALLED, + PYDANTIC_INSTALLED, + UNSET, + ArrowReturnFormat, + attrs_asdict, + attrs_has, +) +from sqlspec.utils.arrow_helpers import convert_dict_to_arrow +from sqlspec.utils.type_guards import has_dict_attribute + +if TYPE_CHECKING: + from collections.abc import Callable, Hashable, Iterable + +__all__ = ( + "SchemaSerializer", + "from_json", + "get_collection_serializer", + "get_serializer_metrics", + "numpy_array_dec_hook", + "numpy_array_enc_hook", + "numpy_array_predicate", + "reset_serializer_cache", + "schema_dump", + "serialize_collection", + "to_json", +) @overload @@ -93,7 +121,7 @@ def numpy_array_enc_hook(value: Any) -> Any: return value -def numpy_array_dec_hook(value: Any) -> "Any": +def numpy_array_dec_hook(value: Any) -> Any: """Decode list to NumPy array. Converts Python lists to NumPy arrays when appropriate. @@ -164,4 +192,165 @@ def numpy_array_predicate(value: Any) -> bool: return isinstance(value, np.ndarray) -__all__ = ("from_json", "numpy_array_dec_hook", "numpy_array_enc_hook", "numpy_array_predicate", "to_json") +class SchemaSerializer: + """Serializer pipeline that caches conversions for repeated schema dumps.""" + + __slots__ = ("_dump", "_key") + + def __init__(self, key: tuple[type[Any] | None, bool], dump: Callable[[Any], dict[str, Any]]) -> None: + self._key = key + self._dump = dump + + @property + def key(self) -> tuple[type[Any] | None, bool]: + return self._key + + def dump_one(self, item: Any) -> dict[str, Any]: + return self._dump(item) + + def dump_many(self, items: Iterable[Any]) -> list[dict[str, Any]]: + return [self._dump(item) for item in items] + + def to_arrow( + self, items: Iterable[Any], *, return_format: ArrowReturnFormat = "table", batch_size: int | None = None + ) -> Any: + payload = self.dump_many(items) + return convert_dict_to_arrow(payload, return_format=return_format, batch_size=batch_size) + + +_SERIALIZER_LOCK: RLock = RLock() +_SCHEMA_SERIALIZERS: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} + + +def _is_dataclass_instance(value: Any) -> bool: + return dataclasses.is_dataclass(value) and not isinstance(value, type) + + +def _is_pydantic_model(value: Any) -> bool: + if not PYDANTIC_INSTALLED: + return False + return hasattr(value, "model_dump") + + +def _is_msgspec_struct(value: Any) -> bool: + if not MSGSPEC_INSTALLED: + return False + return hasattr(value, "__struct_fields__") + + +def _is_attrs_instance(value: Any) -> bool: + if not ATTRS_INSTALLED: + return False + return attrs_has(type(value)) + + +def _make_serializer_key(sample: Any, exclude_unset: bool) -> tuple[type[Any] | None, bool]: + if sample is None or isinstance(sample, dict): + return (None, exclude_unset) + return (type(sample), exclude_unset) + + +def _build_dump_function(sample: Any, exclude_unset: bool) -> Callable[[Any], dict[str, Any]]: + if sample is None or isinstance(sample, dict): + return lambda value: cast("dict[str, Any]", value) + + if _is_dataclass_instance(sample): + + def _dump_dataclass(value: Any) -> dict[str, Any]: + return dataclasses.asdict(value) + + return _dump_dataclass + if _is_pydantic_model(sample): + + def _dump_pydantic(value: Any) -> dict[str, Any]: + return cast("dict[str, Any]", value.model_dump(exclude_unset=exclude_unset)) + + return _dump_pydantic + if _is_msgspec_struct(sample): + if exclude_unset: + + def _dump(value: Any) -> dict[str, Any]: + return {f: val for f in value.__struct_fields__ if (val := getattr(value, f, None)) != UNSET} + + return _dump + + return lambda value: {f: getattr(value, f, None) for f in value.__struct_fields__} + + if _is_attrs_instance(sample): + + def _dump_attrs(value: Any) -> dict[str, Any]: + return attrs_asdict(value, recurse=True) + + return _dump_attrs + + if has_dict_attribute(sample): + + def _dump_dict_attr(value: Any) -> dict[str, Any]: + return dict(value.__dict__) + + return _dump_dict_attr + + return lambda value: dict(value) + + +def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> SchemaSerializer: + """Return cached serializer pipeline for the provided sample object.""" + + key = _make_serializer_key(sample, exclude_unset) + with _SERIALIZER_LOCK: + pipeline = _SCHEMA_SERIALIZERS.get(key) + if pipeline is not None: + return pipeline + + dump = _build_dump_function(sample, exclude_unset) + pipeline = SchemaSerializer(key, dump) + _SCHEMA_SERIALIZERS[key] = pipeline + return pipeline + + +def serialize_collection(items: Iterable[Any], *, exclude_unset: bool = True) -> list[Any]: + """Serialize a collection using cached pipelines keyed by item type.""" + + serialized: list[Any] = [] + cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} + + for item in items: + if isinstance(item, (str, bytes, int, float, bool)) or item is None or isinstance(item, dict): + serialized.append(item) + continue + + key = _make_serializer_key(item, exclude_unset) + pipeline = cache.get(key) + if pipeline is None: + pipeline = get_collection_serializer(item, exclude_unset=exclude_unset) + cache[key] = pipeline + serialized.append(pipeline.dump_one(item)) + return serialized + + +def reset_serializer_cache() -> None: + """Clear cached serializer pipelines.""" + + with _SERIALIZER_LOCK: + _SCHEMA_SERIALIZERS.clear() + + +def get_serializer_metrics() -> dict[str, Hashable]: + """Return metrics describing the serializer pipeline cache.""" + + with _SERIALIZER_LOCK: + return {"size": len(_SCHEMA_SERIALIZERS)} + + +def schema_dump(data: Any, *, exclude_unset: bool = True) -> dict[str, Any]: + """Dump a schema model or dict to a plain dictionary. + + Args: + data: Schema model instance or dictionary to dump. + exclude_unset: Whether to exclude unset fields (for models that support it). + + Returns: + A plain dictionary representation of the schema model. + """ + serializer = get_collection_serializer(data, exclude_unset=exclude_unset) + return serializer.dump_one(data) diff --git a/sqlspec/utils/type_guards.py b/sqlspec/utils/type_guards.py index 81c4094e..eaaf9627 100644 --- a/sqlspec/utils/type_guards.py +++ b/sqlspec/utils/type_guards.py @@ -20,7 +20,6 @@ DataclassProtocol, DTOData, Struct, - attrs_asdict, attrs_has, ) @@ -123,7 +122,6 @@ "is_string_literal", "is_typed_dict", "is_typed_parameter", - "schema_dump", "supports_arrow_native", "supports_arrow_results", "supports_limit", @@ -828,36 +826,6 @@ def dataclass_to_dict( return cast("dict[str, Any]", ret) -def schema_dump(data: Any, exclude_unset: bool = True) -> "dict[str, Any]": - """Dump a data object to a dictionary. - - Args: - data: :type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`AttrsInstance` - exclude_unset: :type:`bool` Whether to exclude unset values. - - Returns: - :type:`dict[str, Any]` - """ - from sqlspec._typing import UNSET - - if is_dict(data): - return data - if is_dataclass(data): - return dataclass_to_dict(data, exclude_empty=exclude_unset) - if is_pydantic_model(data): - return data.model_dump(exclude_unset=exclude_unset) # type: ignore[no-any-return] - if is_msgspec_struct(data): - if exclude_unset: - return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET} - return {f: getattr(data, f, None) for f in data.__struct_fields__} - if is_attrs_instance(data): - return attrs_asdict(data) - - if has_dict_attribute(data): - return data.__dict__ - return cast("dict[str, Any]", data) - - def can_extract_parameters(obj: Any) -> "TypeGuard[FilterParameterProtocol]": """Check if an object can extract parameters.""" from sqlspec.protocols import FilterParameterProtocol diff --git a/tests/unit/test_builder/test_copy_helpers.py b/tests/unit/test_builder/test_copy_helpers.py new file mode 100644 index 00000000..11d0edd2 --- /dev/null +++ b/tests/unit/test_builder/test_copy_helpers.py @@ -0,0 +1,43 @@ +from sqlglot import parse_one + +from sqlspec.builder import build_copy_from_statement, build_copy_to_statement, sql +from sqlspec.core import SQL + + +def test_build_copy_from_statement_generates_expected_sql() -> None: + statement = build_copy_from_statement( + "public.users", "s3://bucket/data.parquet", columns=("id", "name"), options={"format": "parquet"} + ) + + assert isinstance(statement, SQL) + rendered = statement.sql + assert rendered == "COPY \"public.users\" (id, name) FROM 's3://bucket/data.parquet' WITH (FORMAT 'parquet')" + + expression = parse_one(rendered, read="postgres") + assert expression.args["kind"] is True + + +def test_build_copy_to_statement_generates_expected_sql() -> None: + statement = build_copy_to_statement( + "public.users", "s3://bucket/output.parquet", options={"format": "parquet", "compression": "gzip"} + ) + + assert isinstance(statement, SQL) + rendered = statement.sql + assert rendered == ( + "COPY \"public.users\" TO 's3://bucket/output.parquet' WITH (FORMAT 'parquet', COMPRESSION 'gzip')" + ) + + expression = parse_one(rendered, read="postgres") + assert expression.args["kind"] is False + + +def test_sql_factory_copy_helpers() -> None: + statement = sql.copy_from("users", "s3://bucket/in.csv", columns=("id", "name"), options={"format": "csv"}) + assert isinstance(statement, SQL) + assert statement.sql.startswith("COPY users") + + to_statement = sql.copy("users", target="s3://bucket/out.csv", options={"format": "csv", "header": True}) + assert isinstance(to_statement, SQL) + parsed = parse_one(to_statement.sql, read="postgres") + assert parsed.args["kind"] is False diff --git a/tests/unit/test_storage/test_errors.py b/tests/unit/test_storage/test_errors.py new file mode 100644 index 00000000..2fcc5ec9 --- /dev/null +++ b/tests/unit/test_storage/test_errors.py @@ -0,0 +1,32 @@ +import pytest + +from sqlspec.exceptions import FileNotFoundInStorageError, StorageOperationFailedError +from sqlspec.storage.errors import ( + _normalize_storage_error, # pyright: ignore + execute_sync_storage_operation, + raise_storage_error, +) + + +def test_raise_normalized_storage_error_for_missing_file() -> None: + with pytest.raises(FileNotFoundInStorageError) as excinfo: + raise_storage_error(FileNotFoundError("missing"), backend="local", operation="read_bytes", path="file.txt") + + assert "local read_bytes failed" in str(excinfo.value) + + +def test_normalize_storage_error_marks_retryable() -> None: + normalized = _normalize_storage_error( + TimeoutError("timed out"), backend="fsspec", operation="write_bytes", path="s3://bucket/key" + ) + assert normalized.retryable is True + + +def test_execute_with_normalized_errors_wraps_generic_failure() -> None: + def _boom() -> None: + raise RuntimeError("boom") + + with pytest.raises(StorageOperationFailedError) as excinfo: + execute_sync_storage_operation(_boom, backend="obstore", operation="write_bytes", path="object") + + assert "obstore write_bytes failed" in str(excinfo.value) diff --git a/tests/unit/test_utils/test_type_guards.py b/tests/unit/test_utils/test_type_guards.py index 6ad9ee97..a592058c 100644 --- a/tests/unit/test_utils/test_type_guards.py +++ b/tests/unit/test_utils/test_type_guards.py @@ -12,6 +12,14 @@ from sqlglot import exp from typing_extensions import TypedDict +from sqlspec.typing import PYARROW_INSTALLED +from sqlspec.utils.serializers import ( + get_collection_serializer, + get_serializer_metrics, + reset_serializer_cache, + schema_dump, + serialize_collection, +) from sqlspec.utils.type_guards import ( dataclass_to_dict, expression_has_limit, @@ -59,7 +67,6 @@ is_schema_without_field, is_string_literal, is_typed_dict, - schema_dump, ) pytestmark = pytest.mark.xdist_group("utils") @@ -84,6 +91,12 @@ class SampleTypedDict(TypedDict): optional_field: "str | None" +@dataclass +class _SerializerRecord: + identifier: int + name: str + + class MockSQLGlotExpression: """Mock SQLGlot expression for testing type guard functions. @@ -814,6 +827,35 @@ def __init__(self) -> None: assert result == expected +def test_serializer_pipeline_reuses_entry() -> None: + reset_serializer_cache() + metrics = get_serializer_metrics() + assert metrics["size"] == 0 + + sample = _SerializerRecord(identifier=1, name="first") + pipeline = get_collection_serializer(sample) + metrics = get_serializer_metrics() + assert metrics["size"] == 1 + + same_pipeline = get_collection_serializer(_SerializerRecord(identifier=2, name="second")) + assert pipeline is same_pipeline + + +def test_serialize_collection_mixed_models() -> None: + items = [_SerializerRecord(identifier=1, name="alpha"), {"identifier": 2, "name": "beta"}] + serialized = serialize_collection(items) + assert serialized == [{"identifier": 1, "name": "alpha"}, {"identifier": 2, "name": "beta"}] + + +@pytest.mark.skipif(not PYARROW_INSTALLED, reason="PyArrow not installed") +def test_serializer_pipeline_arrow_conversion() -> None: + sample = _SerializerRecord(identifier=1, name="alpha") + pipeline = get_collection_serializer(sample) + table = pipeline.to_arrow([sample, _SerializerRecord(identifier=2, name="beta")]) + assert table.num_rows == 2 + assert table.column(0).to_pylist() == [1, 2] + + @pytest.mark.parametrize( "guard_func,test_obj,expected", [ From e6ab25f50aea11e762dbd8882a57b95286f4e5d2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 21:30:00 +0000 Subject: [PATCH 4/7] feat: enhance schema_dump and serializer metrics tracking --- sqlspec/utils/serializers.py | 150 ++++++++++++++-------- tests/unit/test_utils/test_type_guards.py | 21 +++ 2 files changed, 116 insertions(+), 55 deletions(-) diff --git a/sqlspec/utils/serializers.py b/sqlspec/utils/serializers.py index 19bfa832..7cb13826 100644 --- a/sqlspec/utils/serializers.py +++ b/sqlspec/utils/serializers.py @@ -1,30 +1,28 @@ """Serialization utilities for SQLSpec. -Provides JSON helpers, serializer pipelines, and optional dependency hooks. +Provides JSON helpers, serializer pipelines, optional dependency hooks, +and cache instrumentation aligned with the core pipeline counters. """ -from __future__ import annotations - -import dataclasses +import os from threading import RLock -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, Final, Literal, cast, overload from sqlspec._serialization import decode_json, encode_json -from sqlspec.typing import ( - ATTRS_INSTALLED, - MSGSPEC_INSTALLED, - NUMPY_INSTALLED, - PYDANTIC_INSTALLED, - UNSET, - ArrowReturnFormat, - attrs_asdict, - attrs_has, -) +from sqlspec.typing import NUMPY_INSTALLED, UNSET, ArrowReturnFormat, attrs_asdict from sqlspec.utils.arrow_helpers import convert_dict_to_arrow -from sqlspec.utils.type_guards import has_dict_attribute +from sqlspec.utils.type_guards import ( + dataclass_to_dict, + has_dict_attribute, + is_attrs_instance, + is_dataclass_instance, + is_dict, + is_msgspec_struct, + is_pydantic_model, +) if TYPE_CHECKING: - from collections.abc import Callable, Hashable, Iterable + from collections.abc import Callable, Iterable __all__ = ( "SchemaSerializer", @@ -40,6 +38,58 @@ "to_json", ) +DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" +_PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) + + +def _is_truthy(value: "str | None") -> bool: + if value is None: + return False + normalized = value.strip().lower() + return normalized in {"1", "true", "yes", "on"} + + +def _metrics_enabled() -> bool: + return _is_truthy(os.getenv(DEBUG_ENV_FLAG)) + + +class _SerializerCacheMetrics: + __slots__ = ("hits", "max_size", "misses", "size") + + def __init__(self) -> None: + self.hits = 0 + self.misses = 0 + self.size = 0 + self.max_size = 0 + + def record_hit(self, cache_size: int) -> None: + if not _metrics_enabled(): + return + self.hits += 1 + self.size = cache_size + self.max_size = max(self.max_size, cache_size) + + def record_miss(self, cache_size: int) -> None: + if not _metrics_enabled(): + return + self.misses += 1 + self.size = cache_size + self.max_size = max(self.max_size, cache_size) + + def reset(self) -> None: + self.hits = 0 + self.misses = 0 + self.size = 0 + self.max_size = 0 + + def snapshot(self) -> "dict[str, int]": + return { + "hits": self.hits if _metrics_enabled() else 0, + "misses": self.misses if _metrics_enabled() else 0, + "max_size": self.max_size if _metrics_enabled() else 0, + "size": self.size if _metrics_enabled() else 0, + } + @overload def to_json(data: Any, *, as_bytes: Literal[False] = ...) -> str: ... @@ -197,22 +247,22 @@ class SchemaSerializer: __slots__ = ("_dump", "_key") - def __init__(self, key: tuple[type[Any] | None, bool], dump: Callable[[Any], dict[str, Any]]) -> None: + def __init__(self, key: "tuple[type[Any] | None, bool]", dump: "Callable[[Any], dict[str, Any]]") -> None: self._key = key self._dump = dump @property - def key(self) -> tuple[type[Any] | None, bool]: + def key(self) -> "tuple[type[Any] | None, bool]": return self._key def dump_one(self, item: Any) -> dict[str, Any]: return self._dump(item) - def dump_many(self, items: Iterable[Any]) -> list[dict[str, Any]]: + def dump_many(self, items: "Iterable[Any]") -> list[dict[str, Any]]: return [self._dump(item) for item in items] def to_arrow( - self, items: Iterable[Any], *, return_format: ArrowReturnFormat = "table", batch_size: int | None = None + self, items: "Iterable[Any]", *, return_format: "ArrowReturnFormat" = "table", batch_size: int | None = None ) -> Any: payload = self.dump_many(items) return convert_dict_to_arrow(payload, return_format=return_format, batch_size=batch_size) @@ -220,28 +270,7 @@ def to_arrow( _SERIALIZER_LOCK: RLock = RLock() _SCHEMA_SERIALIZERS: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} - - -def _is_dataclass_instance(value: Any) -> bool: - return dataclasses.is_dataclass(value) and not isinstance(value, type) - - -def _is_pydantic_model(value: Any) -> bool: - if not PYDANTIC_INSTALLED: - return False - return hasattr(value, "model_dump") - - -def _is_msgspec_struct(value: Any) -> bool: - if not MSGSPEC_INSTALLED: - return False - return hasattr(value, "__struct_fields__") - - -def _is_attrs_instance(value: Any) -> bool: - if not ATTRS_INSTALLED: - return False - return attrs_has(type(value)) +_SERIALIZER_METRICS = _SerializerCacheMetrics() def _make_serializer_key(sample: Any, exclude_unset: bool) -> tuple[type[Any] | None, bool]: @@ -250,23 +279,23 @@ def _make_serializer_key(sample: Any, exclude_unset: bool) -> tuple[type[Any] | return (type(sample), exclude_unset) -def _build_dump_function(sample: Any, exclude_unset: bool) -> Callable[[Any], dict[str, Any]]: +def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], dict[str, Any]]": if sample is None or isinstance(sample, dict): return lambda value: cast("dict[str, Any]", value) - if _is_dataclass_instance(sample): + if is_dataclass_instance(sample): def _dump_dataclass(value: Any) -> dict[str, Any]: - return dataclasses.asdict(value) + return dataclass_to_dict(value, exclude_empty=exclude_unset) return _dump_dataclass - if _is_pydantic_model(sample): + if is_pydantic_model(sample): def _dump_pydantic(value: Any) -> dict[str, Any]: return cast("dict[str, Any]", value.model_dump(exclude_unset=exclude_unset)) return _dump_pydantic - if _is_msgspec_struct(sample): + if is_msgspec_struct(sample): if exclude_unset: def _dump(value: Any) -> dict[str, Any]: @@ -276,7 +305,7 @@ def _dump(value: Any) -> dict[str, Any]: return lambda value: {f: getattr(value, f, None) for f in value.__struct_fields__} - if _is_attrs_instance(sample): + if is_attrs_instance(sample): def _dump_attrs(value: Any) -> dict[str, Any]: return attrs_asdict(value, recurse=True) @@ -293,29 +322,31 @@ def _dump_dict_attr(value: Any) -> dict[str, Any]: return lambda value: dict(value) -def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> SchemaSerializer: +def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "SchemaSerializer": """Return cached serializer pipeline for the provided sample object.""" key = _make_serializer_key(sample, exclude_unset) with _SERIALIZER_LOCK: pipeline = _SCHEMA_SERIALIZERS.get(key) if pipeline is not None: + _SERIALIZER_METRICS.record_hit(len(_SCHEMA_SERIALIZERS)) return pipeline dump = _build_dump_function(sample, exclude_unset) pipeline = SchemaSerializer(key, dump) _SCHEMA_SERIALIZERS[key] = pipeline + _SERIALIZER_METRICS.record_miss(len(_SCHEMA_SERIALIZERS)) return pipeline -def serialize_collection(items: Iterable[Any], *, exclude_unset: bool = True) -> list[Any]: +def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) -> list[Any]: """Serialize a collection using cached pipelines keyed by item type.""" serialized: list[Any] = [] cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} for item in items: - if isinstance(item, (str, bytes, int, float, bool)) or item is None or isinstance(item, dict): + if isinstance(item, _PRIMITIVE_TYPES) or item is None or isinstance(item, dict): serialized.append(item) continue @@ -333,13 +364,16 @@ def reset_serializer_cache() -> None: with _SERIALIZER_LOCK: _SCHEMA_SERIALIZERS.clear() + _SERIALIZER_METRICS.reset() -def get_serializer_metrics() -> dict[str, Hashable]: - """Return metrics describing the serializer pipeline cache.""" +def get_serializer_metrics() -> dict[str, int]: + """Return cache metrics aligned with the core pipeline counters.""" with _SERIALIZER_LOCK: - return {"size": len(_SCHEMA_SERIALIZERS)} + metrics = _SERIALIZER_METRICS.snapshot() + metrics["size"] = len(_SCHEMA_SERIALIZERS) + return metrics def schema_dump(data: Any, *, exclude_unset: bool = True) -> dict[str, Any]: @@ -352,5 +386,11 @@ def schema_dump(data: Any, *, exclude_unset: bool = True) -> dict[str, Any]: Returns: A plain dictionary representation of the schema model. """ + if is_dict(data): + return data + + if isinstance(data, _PRIMITIVE_TYPES) or data is None: + return cast("dict[str, Any]", data) + serializer = get_collection_serializer(data, exclude_unset=exclude_unset) return serializer.dump_one(data) diff --git a/tests/unit/test_utils/test_type_guards.py b/tests/unit/test_utils/test_type_guards.py index a592058c..62b1bdc6 100644 --- a/tests/unit/test_utils/test_type_guards.py +++ b/tests/unit/test_utils/test_type_guards.py @@ -794,6 +794,13 @@ def test_schema_dump_with_dict() -> None: assert result is data +def test_schema_dump_with_primitives() -> None: + """Test schema_dump returns primitive payload unchanged.""" + payload = "primary" + result = schema_dump(payload) + assert result == payload + + def test_schema_dump_with_dataclass() -> None: """Test schema_dump converts dataclass to dict.""" instance = SampleDataclass(name="test", age=25) @@ -841,6 +848,20 @@ def test_serializer_pipeline_reuses_entry() -> None: assert pipeline is same_pipeline +def test_serializer_metrics_track_hits_and_misses(monkeypatch: pytest.MonkeyPatch) -> None: + reset_serializer_cache() + monkeypatch.setenv("SQLSPEC_DEBUG_PIPELINE_CACHE", "1") + + sample = _SerializerRecord(identifier=1, name="instrumented") + get_collection_serializer(sample) + metrics = get_serializer_metrics() + assert metrics["misses"] == 1 + + get_collection_serializer(sample) + metrics = get_serializer_metrics() + assert metrics["hits"] == 1 + + def test_serialize_collection_mixed_models() -> None: items = [_SerializerRecord(identifier=1, name="alpha"), {"identifier": 2, "name": "beta"}] serialized = serialize_collection(items) From 8ae61c08a63c4d18cd75bedadd8d7b29de946524 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 21:39:23 +0000 Subject: [PATCH 5/7] test: update asyncpg_batch_session fixture to truncate table within the script --- .../test_adapters/test_asyncpg/test_execute_many.py | 5 ++--- tests/unit/test_utils/test_type_guards.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_adapters/test_asyncpg/test_execute_many.py b/tests/integration/test_adapters/test_asyncpg/test_execute_many.py index 61ee67a8..45256d5d 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_execute_many.py +++ b/tests/integration/test_adapters/test_asyncpg/test_execute_many.py @@ -21,11 +21,10 @@ async def asyncpg_batch_session(asyncpg_async_driver: AsyncpgDriver) -> "AsyncGe name TEXT NOT NULL, value INTEGER DEFAULT 0, category TEXT - ) + ); + TRUNCATE TABLE test_batch RESTART IDENTITY """ ) - await asyncpg_async_driver.execute_script("TRUNCATE TABLE test_batch RESTART IDENTITY") - try: yield asyncpg_async_driver finally: diff --git a/tests/unit/test_utils/test_type_guards.py b/tests/unit/test_utils/test_type_guards.py index 62b1bdc6..13e37cbf 100644 --- a/tests/unit/test_utils/test_type_guards.py +++ b/tests/unit/test_utils/test_type_guards.py @@ -798,7 +798,7 @@ def test_schema_dump_with_primitives() -> None: """Test schema_dump returns primitive payload unchanged.""" payload = "primary" result = schema_dump(payload) - assert result == payload + assert result == payload # type: ignore[comparison-overlap] def test_schema_dump_with_dataclass() -> None: From 80e81faf309ebb0a10c36b4f7ef045bf74f80567 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 21:58:43 +0000 Subject: [PATCH 6/7] fix: update error handling in FSSpec S3 tests and correct numeric assertions in JSON deserialization --- .../test_storage/test_storage_integration.py | 3 ++- tests/unit/test_utils/test_serializers.py | 25 +++++++++++++------ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/integration/test_storage/test_storage_integration.py b/tests/integration/test_storage/test_storage_integration.py index 01745e87..e54bd6d5 100644 --- a/tests/integration/test_storage/test_storage_integration.py +++ b/tests/integration/test_storage/test_storage_integration.py @@ -561,6 +561,7 @@ def test_local_backend_error_handling(tmp_path: Path) -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") def test_fsspec_s3_error_handling(minio_service: "MinioService", minio_default_bucket_name: str) -> None: """Test FSSpec S3 backend error handling.""" + from sqlspec.exceptions import FileNotFoundInStorageError from sqlspec.storage.backends.fsspec import FSSpecBackend backend = FSSpecBackend.from_config({ @@ -574,7 +575,7 @@ def test_fsspec_s3_error_handling(minio_service: "MinioService", minio_default_b }) # Test reading nonexistent file - with pytest.raises(FileNotFoundError): + with pytest.raises(FileNotFoundInStorageError): backend.read_text("nonexistent.txt") diff --git a/tests/unit/test_utils/test_serializers.py b/tests/unit/test_utils/test_serializers.py index bee53a40..cafbc80d 100644 --- a/tests/unit/test_utils/test_serializers.py +++ b/tests/unit/test_utils/test_serializers.py @@ -128,7 +128,7 @@ def test_from_json_basic_types() -> None: assert from_json('"hello"') == "hello" assert from_json("42") == 42 - assert from_json("3.14") == 3.14 + assert from_json("3.14") == math.pi assert from_json("true") is True assert from_json("false") is False @@ -192,7 +192,7 @@ def test_from_json_numeric_edge_cases() -> None: assert from_json("9223372036854775807") == 9223372036854775807 assert from_json("-42") == -42 - assert from_json("-3.14") == -3.14 + assert from_json("-3.14") == -math.pi assert from_json("0") == 0 assert from_json("0.0") == 0.0 @@ -450,12 +450,21 @@ def test_module_all_exports() -> None: """Test that __all__ contains the expected exports.""" from sqlspec.utils.serializers import __all__ - assert "from_json" in __all__ - assert "to_json" in __all__ - assert "numpy_array_enc_hook" in __all__ - assert "numpy_array_dec_hook" in __all__ - assert "numpy_array_predicate" in __all__ - assert len(__all__) == 5 + expected = { + "SchemaSerializer", + "from_json", + "get_collection_serializer", + "get_serializer_metrics", + "numpy_array_dec_hook", + "numpy_array_enc_hook", + "numpy_array_predicate", + "reset_serializer_cache", + "schema_dump", + "serialize_collection", + "to_json", + } + + assert set(__all__) == expected def test_error_messages_are_helpful() -> None: From 580b4d805adf56bd30969e9573cc04d6a411780e Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 3 Nov 2025 22:22:30 +0000 Subject: [PATCH 7/7] fix: update numeric assertions in JSON deserialization tests to use pytest.approx for better precision --- tests/unit/test_utils/test_serializers.py | 4 ++-- uv.lock | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_utils/test_serializers.py b/tests/unit/test_utils/test_serializers.py index cafbc80d..53fdfa3c 100644 --- a/tests/unit/test_utils/test_serializers.py +++ b/tests/unit/test_utils/test_serializers.py @@ -128,7 +128,7 @@ def test_from_json_basic_types() -> None: assert from_json('"hello"') == "hello" assert from_json("42") == 42 - assert from_json("3.14") == math.pi + assert from_json("3.14") == pytest.approx(3.14) assert from_json("true") is True assert from_json("false") is False @@ -192,7 +192,7 @@ def test_from_json_numeric_edge_cases() -> None: assert from_json("9223372036854775807") == 9223372036854775807 assert from_json("-42") == -42 - assert from_json("-3.14") == -math.pi + assert from_json("-3.14") == pytest.approx(-3.14) assert from_json("0") == 0 assert from_json("0.0") == 0.0 diff --git a/uv.lock b/uv.lock index 6759942b..8c59ec3b 100644 --- a/uv.lock +++ b/uv.lock @@ -971,11 +971,11 @@ wheels = [ [[package]] name = "cloudpickle" -version = "3.1.1" +version = "3.1.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113, upload-time = "2025-01-14T17:02:05.085Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992, upload-time = "2025-01-14T17:02:02.417Z" }, + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, ] [[package]] @@ -1326,7 +1326,7 @@ wheels = [ [[package]] name = "fastapi" -version = "0.120.4" +version = "0.121.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, @@ -1334,9 +1334,9 @@ dependencies = [ { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3f/3a/0bf90d5189d7f62dc2bd0523899629ca59b58ff4290d631cd3bb5c8889d4/fastapi-0.120.4.tar.gz", hash = "sha256:2d856bc847893ca4d77896d4504ffdec0fb04312b705065fca9104428eca3868", size = 339716, upload-time = "2025-10-31T18:37:28.81Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/e3/77a2df0946703973b9905fd0cde6172c15e0781984320123b4f5079e7113/fastapi-0.121.0.tar.gz", hash = "sha256:06663356a0b1ee93e875bbf05a31fb22314f5bed455afaaad2b2dad7f26e98fa", size = 342412, upload-time = "2025-11-03T10:25:54.818Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/47/14a76b926edc3957c8a8258423db789d3fa925d2fed800102fce58959413/fastapi-0.120.4-py3-none-any.whl", hash = "sha256:9bdf192308676480d3593e10fd05094e56d6fdc7d9283db26053d8104d5f82a0", size = 108235, upload-time = "2025-10-31T18:37:27.038Z" }, + { url = "https://files.pythonhosted.org/packages/dd/2c/42277afc1ba1a18f8358561eee40785d27becab8f80a1f945c0a3051c6eb/fastapi-0.121.0-py3-none-any.whl", hash = "sha256:8bdf1b15a55f4e4b0d6201033da9109ea15632cb76cf156e7b8b4019f2172106", size = 109183, upload-time = "2025-11-03T10:25:53.27Z" }, ] [[package]] @@ -2007,7 +2007,7 @@ wheels = [ [[package]] name = "google-genai" -version = "1.47.0" +version = "1.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2019,9 +2019,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/97/784fba9bc6c41263ff90cb9063eadfdd755dde79cfa5a8d0e397b067dcf9/google_genai-1.47.0.tar.gz", hash = "sha256:ecece00d0a04e6739ea76cc8dad82ec9593d9380aaabef078990e60574e5bf59", size = 241471, upload-time = "2025-10-29T22:01:02.88Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/40/e8d4b60e45fb2c8f8e1cd5e52e29741d207ce844303e69b3546d06627ced/google_genai-1.48.0.tar.gz", hash = "sha256:d78fe33125a881461be5cb008564b1d73f309cd6b390d328c68fe706b142acea", size = 242952, upload-time = "2025-11-03T17:31:07.569Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/ef/e080e8d67c270ea320956bb911a9359664fc46d3b87d1f029decd33e5c4c/google_genai-1.47.0-py3-none-any.whl", hash = "sha256:e3851237556cbdec96007d8028b4b1f2425cdc5c099a8dc36b72a57e42821b60", size = 241506, upload-time = "2025-10-29T22:01:00.982Z" }, + { url = "https://files.pythonhosted.org/packages/25/7d/a5c02159099546ec01131059294d1f0174eee33a872fc888c6c99e8cd6d9/google_genai-1.48.0-py3-none-any.whl", hash = "sha256:919c1e96948a565e27b5b2a1d23f32865d9647e2236a8ffe1ca999a4922bf887", size = 242904, upload-time = "2025-11-03T17:31:05.465Z" }, ] [[package]]