diff --git a/sqlspec/builder/__init__.py b/sqlspec/builder/__init__.py index 866158a2..8293087b 100644 --- a/sqlspec/builder/__init__.py +++ b/sqlspec/builder/__init__.py @@ -40,7 +40,13 @@ MathExpression, StringExpression, ) -from sqlspec.builder._factory import SQLFactory, sql +from sqlspec.builder._factory import ( + SQLFactory, + build_copy_from_statement, + build_copy_statement, + build_copy_to_statement, + sql, +) from sqlspec.builder._insert import Insert from sqlspec.builder._join import JoinBuilder from sqlspec.builder._merge import Merge @@ -127,6 +133,9 @@ "UpdateTableClauseMixin", "WhereClauseMixin", "WindowFunctionBuilder", + "build_copy_from_statement", + "build_copy_statement", + "build_copy_to_statement", "extract_expression", "parse_column_expression", "parse_condition_expression", diff --git a/sqlspec/builder/_factory.py b/sqlspec/builder/_factory.py index d33b4748..50cebedf 100644 --- a/sqlspec/builder/_factory.py +++ b/sqlspec/builder/_factory.py @@ -4,7 +4,8 @@ """ import logging -from typing import TYPE_CHECKING, Any, Union +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Union, cast import sqlglot from sqlglot import exp @@ -46,6 +47,8 @@ from sqlspec.exceptions import SQLBuilderError if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from sqlspec.builder._expression_wrappers import ExpressionWrapper @@ -73,6 +76,9 @@ "Truncate", "Update", "WindowFunctionBuilder", + "build_copy_from_statement", + "build_copy_statement", + "build_copy_to_statement", "sql", ) @@ -108,6 +114,96 @@ } +def _normalize_copy_dialect(dialect: DialectType | None) -> str: + if dialect is None: + return "postgres" + if isinstance(dialect, str): + return dialect + return str(dialect) + + +def _to_copy_schema(table: str, columns: "Sequence[str] | None") -> exp.Expression: + base = exp.table_(table) + if not columns: + return base + column_nodes = [exp.column(column_name) for column_name in columns] + return exp.Schema(this=base, expressions=column_nodes) + + +def _build_copy_expression( + *, direction: str, table: str, location: str, columns: "Sequence[str] | None", options: "Mapping[str, Any] | None" +) -> exp.Copy: + copy_args: dict[str, Any] = {"this": _to_copy_schema(table, columns), "files": [exp.Literal.string(location)]} + + if direction == "from": + copy_args["kind"] = True + elif direction == "to": + copy_args["kind"] = False + + if options: + params: list[exp.CopyParameter] = [] + for key, value in options.items(): + identifier = exp.Var(this=str(key).upper()) + value_expression: exp.Expression + if isinstance(value, bool): + value_expression = exp.Boolean(this=value) + elif value is None: + value_expression = exp.null() + elif isinstance(value, (int, float)): + value_expression = exp.Literal.number(value) + elif isinstance(value, (list, tuple)): + elements = [exp.Literal.string(str(item)) for item in value] + value_expression = exp.Array(expressions=elements) + else: + value_expression = exp.Literal.string(str(value)) + params.append(exp.CopyParameter(this=identifier, expression=value_expression)) + copy_args["params"] = params + + return exp.Copy(**copy_args) + + +def build_copy_statement( + *, + direction: str, + table: str, + location: str, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, +) -> SQL: + expression = _build_copy_expression( + direction=direction, table=table, location=location, columns=columns, options=options + ) + rendered = expression.sql(dialect=_normalize_copy_dialect(dialect)) + return SQL(rendered) + + +def build_copy_from_statement( + table: str, + source: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, +) -> SQL: + return build_copy_statement( + direction="from", table=table, location=source, columns=columns, options=options, dialect=dialect + ) + + +def build_copy_to_statement( + table: str, + target: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, +) -> SQL: + return build_copy_statement( + direction="to", table=table, location=target, columns=columns, options=options, dialect=dialect + ) + + class SQLFactory: """Factory for creating SQL builders and column expressions.""" @@ -479,6 +575,56 @@ def comment_on(self, dialect: DialectType = None) -> "CommentOn": """ return CommentOn(dialect=dialect or self.dialect) + def copy_from( + self, + table: str, + source: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, + ) -> SQL: + """Build a COPY ... FROM statement.""" + + effective_dialect = dialect or self.dialect + return build_copy_from_statement(table, source, columns=columns, options=options, dialect=effective_dialect) + + def copy_to( + self, + table: str, + target: str, + *, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, + ) -> SQL: + """Build a COPY ... TO statement.""" + + effective_dialect = dialect or self.dialect + return build_copy_to_statement(table, target, columns=columns, options=options, dialect=effective_dialect) + + def copy( + self, + table: str, + *, + source: str | None = None, + target: str | None = None, + columns: "Sequence[str] | None" = None, + options: "Mapping[str, Any] | None" = None, + dialect: DialectType | None = None, + ) -> SQL: + """Build a COPY statement, inferring direction from provided arguments.""" + + if (source is None and target is None) or (source is not None and target is not None): + msg = "Provide either 'source' or 'target' (but not both) to sql.copy()." + raise SQLBuilderError(msg) + + if source is not None: + return self.copy_from(table, source, columns=columns, options=options, dialect=dialect) + + target_value = cast("str", target) + return self.copy_to(table, target_value, columns=columns, options=options, dialect=dialect) + @staticmethod def _looks_like_sql(candidate: str, expected_type: str | None = None) -> bool: """Determine if a string looks like SQL. diff --git a/sqlspec/loader.py b/sqlspec/loader.py index 8a1fee96..2d679f2a 100644 --- a/sqlspec/loader.py +++ b/sqlspec/loader.py @@ -13,7 +13,12 @@ from urllib.parse import unquote, urlparse from sqlspec.core import SQL, get_cache, get_cache_config -from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError +from sqlspec.exceptions import ( + FileNotFoundInStorageError, + SQLFileNotFoundError, + SQLFileParseError, + StorageOperationFailedError, +) from sqlspec.storage.registry import storage_registry as default_storage_registry from sqlspec.utils.correlation import CorrelationContext from sqlspec.utils.logging import get_logger @@ -259,9 +264,11 @@ def _read_file_content(self, path: str | Path) -> str: return backend.read_text(path_str, encoding=self.encoding) except KeyError as e: raise SQLFileNotFoundError(path_str) from e + except FileNotFoundInStorageError as e: + raise SQLFileNotFoundError(path_str) from e + except FileNotFoundError as e: + raise SQLFileNotFoundError(path_str) from e except StorageOperationFailedError as e: - if "not found" in str(e).lower() or "no such file" in str(e).lower(): - raise SQLFileNotFoundError(path_str) from e raise SQLFileParseError(path_str, path_str, e) from e except Exception as e: raise SQLFileParseError(path_str, path_str, e) from e diff --git a/sqlspec/storage/_utils.py b/sqlspec/storage/_utils.py index 02359fea..a98905a5 100644 --- a/sqlspec/storage/_utils.py +++ b/sqlspec/storage/_utils.py @@ -1,12 +1,44 @@ """Shared utilities for storage backends.""" from pathlib import Path +from typing import Any, Final -__all__ = ("resolve_storage_path",) +from sqlspec.utils.module_loader import ensure_pyarrow + +FILE_PROTOCOL: Final[str] = "file" +FILE_SCHEME_PREFIX: Final[str] = "file://" + +__all__ = ("FILE_PROTOCOL", "FILE_SCHEME_PREFIX", "import_pyarrow", "import_pyarrow_parquet", "resolve_storage_path") + + +def import_pyarrow() -> "Any": + """Import PyArrow with optional dependency guard. + + Returns: + PyArrow module. + """ + + ensure_pyarrow() + import pyarrow as pa + + return pa + + +def import_pyarrow_parquet() -> "Any": + """Import PyArrow parquet module with optional dependency guard. + + Returns: + PyArrow parquet module. + """ + + ensure_pyarrow() + import pyarrow.parquet as pq + + return pq def resolve_storage_path( - path: "str | Path", base_path: str = "", protocol: str = "file", strip_file_scheme: bool = True + path: "str | Path", base_path: str = "", protocol: str = FILE_PROTOCOL, strip_file_scheme: bool = True ) -> str: """Resolve path relative to base_path with protocol-specific handling. @@ -43,10 +75,10 @@ def resolve_storage_path( path_str = str(path) - if strip_file_scheme and path_str.startswith("file://"): - path_str = path_str.removeprefix("file://") + if strip_file_scheme and path_str.startswith(FILE_SCHEME_PREFIX): + path_str = path_str.removeprefix(FILE_SCHEME_PREFIX) - if protocol == "file": + if protocol == FILE_PROTOCOL: path_obj = Path(path_str) if path_obj.is_absolute(): diff --git a/sqlspec/storage/backends/fsspec.py b/sqlspec/storage/backends/fsspec.py index 21df5cee..3f0b0fae 100644 --- a/sqlspec/storage/backends/fsspec.py +++ b/sqlspec/storage/backends/fsspec.py @@ -1,12 +1,13 @@ # pyright: reportPrivateUsage=false import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from mypy_extensions import mypyc_attr -from sqlspec.storage._utils import resolve_storage_path -from sqlspec.utils.module_loader import ensure_fsspec, ensure_pyarrow +from sqlspec.storage._utils import import_pyarrow_parquet, resolve_storage_path +from sqlspec.storage.errors import execute_sync_storage_operation +from sqlspec.utils.module_loader import ensure_fsspec from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: @@ -158,7 +159,15 @@ def base_uri(self) -> str: def read_bytes(self, path: str | Path, **kwargs: Any) -> bytes: """Read bytes from an object.""" resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - return self.fs.cat(resolved_path, **kwargs) # type: ignore[no-any-return] # pyright: ignore + return cast( + "bytes", + execute_sync_storage_operation( + lambda: self.fs.cat(resolved_path, **kwargs), + backend=self.backend_type, + operation="read_bytes", + path=resolved_path, + ), + ) def write_bytes(self, path: str | Path, data: bytes, **kwargs: Any) -> None: """Write bytes to an object.""" @@ -169,8 +178,11 @@ def write_bytes(self, path: str | Path, data: bytes, **kwargs: Any) -> None: if parent_dir and not self.fs.exists(parent_dir): self.fs.makedirs(parent_dir, exist_ok=True) - with self.fs.open(resolved_path, mode="wb", **kwargs) as f: - f.write(data) # pyright: ignore + def _action() -> None: + with self.fs.open(resolved_path, mode="wb", **kwargs) as file_obj: + file_obj.write(data) # pyright: ignore + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_bytes", path=resolved_path) def read_text(self, path: str | Path, encoding: str = "utf-8", **kwargs: Any) -> str: """Read text from an object.""" @@ -189,37 +201,65 @@ def exists(self, path: str | Path, **kwargs: Any) -> bool: def delete(self, path: str | Path, **kwargs: Any) -> None: """Delete an object.""" resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - self.fs.rm(resolved_path, **kwargs) + execute_sync_storage_operation( + lambda: self.fs.rm(resolved_path, **kwargs), + backend=self.backend_type, + operation="delete", + path=resolved_path, + ) def copy(self, source: str | Path, destination: str | Path, **kwargs: Any) -> None: """Copy an object.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=False) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=False) - self.fs.copy(source_path, dest_path, **kwargs) + execute_sync_storage_operation( + lambda: self.fs.copy(source_path, dest_path, **kwargs), + backend=self.backend_type, + operation="copy", + path=f"{source_path}->{dest_path}", + ) def move(self, source: str | Path, destination: str | Path, **kwargs: Any) -> None: """Move an object.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=False) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=False) - self.fs.mv(source_path, dest_path, **kwargs) + execute_sync_storage_operation( + lambda: self.fs.mv(source_path, dest_path, **kwargs), + backend=self.backend_type, + operation="move", + path=f"{source_path}->{dest_path}", + ) def read_arrow(self, path: str | Path, **kwargs: Any) -> "ArrowTable": """Read an Arrow table from storage.""" - ensure_pyarrow() - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - with self.fs.open(resolved_path, mode="rb", **kwargs) as f: - return pq.read_table(f) + return cast( + "ArrowTable", + execute_sync_storage_operation( + lambda: self._read_arrow(resolved_path, pq, kwargs), + backend=self.backend_type, + operation="read_arrow", + path=resolved_path, + ), + ) def write_arrow(self, path: str | Path, table: "ArrowTable", **kwargs: Any) -> None: """Write an Arrow table to storage.""" - ensure_pyarrow() - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=False) - with self.fs.open(resolved_path, mode="wb") as f: - pq.write_table(table, f, **kwargs) # pyright: ignore + + def _action() -> None: + with self.fs.open(resolved_path, mode="wb") as file_obj: + pq.write_table(table, file_obj, **kwargs) # pyright: ignore + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_arrow", path=resolved_path) + + def _read_arrow(self, resolved_path: str, pq: Any, options: "dict[str, Any]") -> Any: + with self.fs.open(resolved_path, mode="rb", **options) as file_obj: + return pq.read_table(file_obj) def list_objects(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> list[str]: """List objects with optional prefix.""" @@ -273,15 +313,23 @@ def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> s return f"{self._fs_uri}{resolved_path}" def _stream_file_batches(self, obj_path: str | Path) -> "Iterator[ArrowRecordBatch]": - import pyarrow.parquet as pq - - with self.fs.open(obj_path, mode="rb") as f: - parquet_file = pq.ParquetFile(f) # pyright: ignore[reportArgumentType] + pq = import_pyarrow_parquet() + + file_handle = execute_sync_storage_operation( + lambda: self.fs.open(obj_path, mode="rb"), + backend=self.backend_type, + operation="stream_open", + path=str(obj_path), + ) + + with file_handle as stream: + parquet_file = execute_sync_storage_operation( + lambda: pq.ParquetFile(stream), backend=self.backend_type, operation="stream_arrow", path=str(obj_path) + ) yield from parquet_file.iter_batches() def stream_arrow(self, pattern: str, **kwargs: Any) -> "Iterator[ArrowRecordBatch]": - ensure_pyarrow() - + import_pyarrow_parquet() for obj_path in self.glob(pattern, **kwargs): yield from self._stream_file_batches(obj_path) @@ -303,8 +351,6 @@ def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[Arro Returns: AsyncIterator of Arrow record batches """ - ensure_pyarrow() - return _ArrowStreamer(self, pattern, **kwargs) async def read_text_async(self, path: str | Path, encoding: str = "utf-8", **kwargs: Any) -> str: diff --git a/sqlspec/storage/backends/local.py b/sqlspec/storage/backends/local.py index 6ffc0411..88567e40 100644 --- a/sqlspec/storage/backends/local.py +++ b/sqlspec/storage/backends/local.py @@ -6,13 +6,16 @@ import shutil from collections.abc import AsyncIterator, Iterator +from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import unquote, urlparse from mypy_extensions import mypyc_attr -from sqlspec.utils.module_loader import ensure_pyarrow +from sqlspec.exceptions import FileNotFoundInStorageError +from sqlspec.storage._utils import import_pyarrow_parquet +from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: @@ -107,23 +110,32 @@ def _resolve_path(self, path: "str | Path") -> Path: def read_bytes(self, path: "str | Path", **kwargs: Any) -> bytes: """Read bytes from file.""" resolved = self._resolve_path(path) - return resolved.read_bytes() + try: + return execute_sync_storage_operation( + lambda: resolved.read_bytes(), backend=self.backend_type, operation="read_bytes", path=str(resolved) + ) + except FileNotFoundInStorageError as error: + raise FileNotFoundError(str(resolved)) from error def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: """Write bytes to file.""" resolved = self._resolve_path(path) - resolved.parent.mkdir(parents=True, exist_ok=True) - resolved.write_bytes(data) + + def _action() -> None: + resolved.parent.mkdir(parents=True, exist_ok=True) + resolved.write_bytes(data) + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_bytes", path=str(resolved)) def read_text(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str: """Read text from file.""" - return self._resolve_path(path).read_text(encoding=encoding) + data = self.read_bytes(path, **kwargs) + return data.decode(encoding) def write_text(self, path: "str | Path", data: str, encoding: str = "utf-8", **kwargs: Any) -> None: """Write text to file.""" - resolved = self._resolve_path(path) - resolved.parent.mkdir(parents=True, exist_ok=True) - resolved.write_text(data, encoding=encoding) + encoded = data.encode(encoding) + self.write_bytes(path, encoded, **kwargs) def list_objects(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> list[str]: """List objects in directory.""" @@ -163,10 +175,14 @@ def exists(self, path: "str | Path", **kwargs: Any) -> bool: def delete(self, path: "str | Path", **kwargs: Any) -> None: """Delete file or directory.""" resolved = self._resolve_path(path) - if resolved.is_dir(): - shutil.rmtree(resolved) - elif resolved.exists(): - resolved.unlink() + + def _action() -> None: + if resolved.is_dir(): + shutil.rmtree(resolved) + elif resolved.exists(): + resolved.unlink() + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="delete", path=str(resolved)) def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: """Copy file or directory.""" @@ -174,17 +190,22 @@ def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) - dst = self._resolve_path(destination) dst.parent.mkdir(parents=True, exist_ok=True) - if src.is_dir(): - shutil.copytree(src, dst, dirs_exist_ok=True) - else: - shutil.copy2(src, dst) + def _action() -> None: + if src.is_dir(): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy2(src, dst) + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="copy", path=f"{src}->{dst}") def move(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: """Move file or directory.""" src = self._resolve_path(source) dst = self._resolve_path(destination) dst.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(src), str(dst)) + execute_sync_storage_operation( + lambda: shutil.move(str(src), str(dst)), backend=self.backend_type, operation="move", path=f"{src}->{dst}" + ) def glob(self, pattern: str, **kwargs: Any) -> list[str]: """Find files matching pattern.""" @@ -210,6 +231,14 @@ def glob(self, pattern: str, **kwargs: Any) -> list[str]: def get_metadata(self, path: "str | Path", **kwargs: Any) -> dict[str, Any]: """Get file metadata.""" resolved = self._resolve_path(path) + return execute_sync_storage_operation( + lambda: self._collect_metadata(resolved), + backend=self.backend_type, + operation="get_metadata", + path=str(resolved), + ) + + def _collect_metadata(self, resolved: "Path") -> dict[str, Any]: if not resolved.exists(): return {} @@ -233,19 +262,28 @@ def is_path(self, path: "str | Path") -> bool: def read_arrow(self, path: "str | Path", **kwargs: Any) -> "ArrowTable": """Read Arrow table from file.""" - ensure_pyarrow() - import pyarrow.parquet as pq - - return pq.read_table(str(self._resolve_path(path))) # pyright: ignore + pq = import_pyarrow_parquet() + resolved = self._resolve_path(path) + return cast( + "ArrowTable", + execute_sync_storage_operation( + lambda: pq.read_table(str(resolved), **kwargs), + backend=self.backend_type, + operation="read_arrow", + path=str(resolved), + ), + ) def write_arrow(self, path: "str | Path", table: "ArrowTable", **kwargs: Any) -> None: """Write Arrow table to file.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved = self._resolve_path(path) - resolved.parent.mkdir(parents=True, exist_ok=True) - pq.write_table(table, str(resolved)) # pyright: ignore + + def _action() -> None: + resolved.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(table, str(resolved), **kwargs) # pyright: ignore + + execute_sync_storage_operation(_action, backend=self.backend_type, operation="write_arrow", path=str(resolved)) def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatch"]: """Stream Arrow record batches from files matching pattern. @@ -253,19 +291,21 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator["ArrowRecordBatc Yields: Arrow record batches from matching files. """ - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() files = self.glob(pattern) for file_path in files: resolved = self._resolve_path(file_path) - parquet_file = pq.ParquetFile(str(resolved)) + resolved_str = str(resolved) + parquet_file = execute_sync_storage_operation( + partial(pq.ParquetFile, resolved_str), + backend=self.backend_type, + operation="stream_arrow", + path=resolved_str, + ) yield from parquet_file.iter_batches() # pyright: ignore def sign(self, path: "str | Path", expires_in: int = 3600, for_upload: bool = False) -> str: """Generate a signed URL (returns file:// URI for local files).""" - # For local files, just return a file:// URI - # No actual signing needed for local files return self._resolve_path(path).as_uri() # Async methods using sync_tools.async_ diff --git a/sqlspec/storage/backends/obstore.py b/sqlspec/storage/backends/obstore.py index 375010e2..0e825a79 100644 --- a/sqlspec/storage/backends/obstore.py +++ b/sqlspec/storage/backends/obstore.py @@ -7,7 +7,9 @@ import fnmatch import io import logging +import re from collections.abc import AsyncIterator, Iterator +from functools import partial from pathlib import Path, PurePosixPath from typing import Any, Final, cast from urllib.parse import urlparse @@ -15,9 +17,10 @@ from mypy_extensions import mypyc_attr from sqlspec.exceptions import StorageOperationFailedError -from sqlspec.storage._utils import resolve_storage_path +from sqlspec.storage._utils import import_pyarrow, import_pyarrow_parquet, resolve_storage_path +from sqlspec.storage.errors import execute_sync_storage_operation from sqlspec.typing import ArrowRecordBatch, ArrowTable -from sqlspec.utils.module_loader import ensure_obstore, ensure_pyarrow +from sqlspec.utils.module_loader import ensure_obstore from sqlspec.utils.sync_tools import async_ __all__ = ("ObStoreBackend",) @@ -46,7 +49,7 @@ def __aiter__(self) -> "_AsyncArrowIterator": return self async def __anext__(self) -> ArrowRecordBatch: - import pyarrow.parquet as pq + pq = import_pyarrow_parquet() if self._files_iterator is None: files = self.backend.glob(self.pattern, **self.kwargs) @@ -196,8 +199,13 @@ def read_bytes(self, path: "str | Path", **kwargs: Any) -> bytes: # pyright: ig else: resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - result = self.store.get(resolved_path) - return cast("bytes", result.bytes().to_bytes()) + def _action() -> bytes: + result = self.store.get(resolved_path) + return cast("bytes", result.bytes().to_bytes()) + + return execute_sync_storage_operation( + _action, backend=self.backend_type, operation="read_bytes", path=resolved_path + ) def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Write bytes using obstore.""" @@ -205,7 +213,13 @@ def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None: resolved_path = self._resolve_path_for_local_store(path) else: resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - self.store.put(resolved_path, data) + + execute_sync_storage_operation( + lambda: self.store.put(resolved_path, data), + backend=self.backend_type, + operation="write_bytes", + path=resolved_path, + ) def read_text(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str: """Read text using obstore.""" @@ -240,19 +254,31 @@ def exists(self, path: "str | Path", **kwargs: Any) -> bool: # pyright: ignore[ def delete(self, path: "str | Path", **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Delete object using obstore.""" resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) - self.store.delete(resolved_path) + execute_sync_storage_operation( + lambda: self.store.delete(resolved_path), backend=self.backend_type, operation="delete", path=resolved_path + ) def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Copy object using obstore.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=True) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=True) - self.store.copy(source_path, dest_path) + execute_sync_storage_operation( + lambda: self.store.copy(source_path, dest_path), + backend=self.backend_type, + operation="copy", + path=f"{source_path}->{dest_path}", + ) def move(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None: # pyright: ignore[reportUnusedParameter] """Move object using obstore.""" source_path = resolve_storage_path(source, self.base_path, self.protocol, strip_file_scheme=True) dest_path = resolve_storage_path(destination, self.base_path, self.protocol, strip_file_scheme=True) - self.store.rename(source_path, dest_path) + execute_sync_storage_operation( + lambda: self.store.rename(source_path, dest_path), + backend=self.backend_type, + operation="move", + path=f"{source_path}->{dest_path}", + ) def glob(self, pattern: str, **kwargs: Any) -> list[str]: """Find objects matching pattern. @@ -338,19 +364,23 @@ def is_path(self, path: "str | Path") -> bool: def read_arrow(self, path: "str | Path", **kwargs: Any) -> ArrowTable: """Read Arrow table using obstore.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = self.read_bytes(resolved_path) - return pq.read_table(io.BytesIO(data), **kwargs) + return cast( + "ArrowTable", + execute_sync_storage_operation( + lambda: pq.read_table(io.BytesIO(data), **kwargs), + backend=self.backend_type, + operation="read_arrow", + path=resolved_path, + ), + ) def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table using obstore.""" - ensure_pyarrow() - import pyarrow as pa - import pyarrow.parquet as pq - + pa = import_pyarrow() + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) schema = table.schema @@ -358,8 +388,6 @@ def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> N new_fields = [] for field in schema: if str(field.type).startswith("decimal64"): - import re - match = re.match(r"decimal64\((\d+),\s*(\d+)\)", str(field.type)) if match: precision, scale = int(match.group(1)), int(match.group(2)) @@ -371,7 +399,12 @@ def write_arrow(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> N table = table.cast(pa.schema(new_fields)) buffer = io.BytesIO() - pq.write_table(table, buffer, **kwargs) + execute_sync_storage_operation( + lambda: pq.write_table(table, buffer, **kwargs), + backend=self.backend_type, + operation="write_arrow", + path=resolved_path, + ) buffer.seek(0) self.write_bytes(resolved_path, buffer.read()) @@ -381,16 +414,21 @@ def stream_arrow(self, pattern: str, **kwargs: Any) -> Iterator[ArrowRecordBatch Yields: Iterator of Arrow record batches from matching objects. """ - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() for obj_path in self.glob(pattern, **kwargs): resolved_path = resolve_storage_path(obj_path, self.base_path, self.protocol, strip_file_scheme=True) - result = self.store.get(resolved_path) + result = execute_sync_storage_operation( + partial(self.store.get, resolved_path), + backend=self.backend_type, + operation="stream_read", + path=resolved_path, + ) bytes_obj = result.bytes() data = bytes_obj.to_bytes() buffer = io.BytesIO(data) - parquet_file = pq.ParquetFile(buffer) + parquet_file = execute_sync_storage_operation( + partial(pq.ParquetFile, buffer), backend=self.backend_type, operation="stream_arrow", path=resolved_path + ) yield from parquet_file.iter_batches() def sign(self, path: str, expires_in: int = 3600, for_upload: bool = False) -> str: @@ -518,18 +556,14 @@ async def get_metadata_async(self, path: "str | Path", **kwargs: Any) -> dict[st async def read_arrow_async(self, path: "str | Path", **kwargs: Any) -> ArrowTable: """Read Arrow table from storage asynchronously.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) data = await self.read_bytes_async(resolved_path) - return pq.read_table(io.BytesIO(data), **kwargs) + return cast("ArrowTable", pq.read_table(io.BytesIO(data), **kwargs)) async def write_arrow_async(self, path: "str | Path", table: ArrowTable, **kwargs: Any) -> None: """Write Arrow table to storage asynchronously.""" - ensure_pyarrow() - import pyarrow.parquet as pq - + pq = import_pyarrow_parquet() resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True) buffer = io.BytesIO() pq.write_table(table, buffer, **kwargs) diff --git a/sqlspec/storage/errors.py b/sqlspec/storage/errors.py new file mode 100644 index 00000000..c1b81bd2 --- /dev/null +++ b/sqlspec/storage/errors.py @@ -0,0 +1,104 @@ +"""Storage error normalization helpers.""" + +import errno +import logging +from typing import TYPE_CHECKING, Any, TypeVar + +from sqlspec.exceptions import FileNotFoundInStorageError, StorageOperationFailedError + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Mapping + +__all__ = ("StorageError", "execute_async_storage_operation", "execute_sync_storage_operation", "raise_storage_error") + + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +_NOT_FOUND_NAMES = {"NotFoundError", "ObjectNotFound", "NoSuchKey", "NoSuchBucket", "NoSuchFile"} + + +class StorageError: + """Normalized view of a storage backend exception.""" + + __slots__ = ("backend", "message", "operation", "original", "path", "retryable") + + def __init__( + self, message: str, backend: str, operation: str, path: str | None, retryable: bool, original: Exception + ) -> None: + self.message = message + self.backend = backend + self.operation = operation + self.path = path + self.retryable = retryable + self.original = original + + +def _is_missing_error(error: Exception) -> bool: + if isinstance(error, FileNotFoundError): + return True + + number = getattr(error, "errno", None) + if number in {errno.ENOENT, errno.ENOTDIR}: + return True + + name = error.__class__.__name__ + return name in _NOT_FOUND_NAMES + + +def _is_retryable(error: Exception) -> bool: + if isinstance(error, (ConnectionError, TimeoutError)): + return True + + name = error.__class__.__name__ + return bool("Timeout" in name or "Temporary" in name) + + +def _normalize_storage_error(error: Exception, *, backend: str, operation: str, path: str | None) -> "StorageError": + message = f"{backend} {operation} failed" + if path: + message = f"{message} for {path}" + + return StorageError( + message=message, backend=backend, operation=operation, path=path, retryable=_is_retryable(error), original=error + ) + + +def raise_storage_error(error: Exception, *, backend: str, operation: str, path: str | None) -> None: + is_missing = _is_missing_error(error) + normalized = _normalize_storage_error(error, backend=backend, operation=operation, path=path) + + log_extra: Mapping[str, Any] = { + "storage_backend": backend, + "storage_operation": operation, + "storage_path": path, + "exception_type": error.__class__.__name__, + "retryable": normalized.retryable, + } + + if is_missing: + logger.debug("Storage object missing during %s", operation, extra=log_extra) + raise FileNotFoundInStorageError(normalized.message) from error + + logger.warning("Storage operation %s failed", operation, extra=log_extra, exc_info=error) + raise StorageOperationFailedError(normalized.message) from error + + +def execute_sync_storage_operation(func: "Callable[[], T]", *, backend: str, operation: str, path: str | None) -> T: + try: + return func() + except Exception as error: + raise_storage_error(error, backend=backend, operation=operation, path=path) + raise + + +async def execute_async_storage_operation( + func: "Callable[[], Awaitable[T]]", *, backend: str, operation: str, path: str | None +) -> T: + try: + return await func() + except Exception as error: + raise_storage_error(error, backend=backend, operation=operation, path=path) + raise diff --git a/sqlspec/utils/fixtures.py b/sqlspec/utils/fixtures.py index 628303af..60d54446 100644 --- a/sqlspec/utils/fixtures.py +++ b/sqlspec/utils/fixtures.py @@ -11,9 +11,9 @@ from sqlspec.storage import storage_registry from sqlspec.utils.serializers import from_json as decode_json +from sqlspec.utils.serializers import schema_dump from sqlspec.utils.serializers import to_json as encode_json from sqlspec.utils.sync_tools import async_ -from sqlspec.utils.type_guards import schema_dump if TYPE_CHECKING: from sqlspec.typing import SupportedSchemaModel @@ -146,14 +146,17 @@ def _serialize_data(data: Any) -> str: """ if isinstance(data, (list, tuple)): serialized_items: list[Any] = [] + for item in data: if isinstance(item, (str, int, float, bool, type(None))): serialized_items.append(item) else: serialized_items.append(schema_dump(item)) + return encode_json(serialized_items) if isinstance(data, (str, int, float, bool, type(None))): return encode_json(data) + return encode_json(schema_dump(data)) diff --git a/sqlspec/utils/serializers.py b/sqlspec/utils/serializers.py index 35fac632..7cb13826 100644 --- a/sqlspec/utils/serializers.py +++ b/sqlspec/utils/serializers.py @@ -1,16 +1,94 @@ -"""JSON serialization utilities for SQLSpec. +"""Serialization utilities for SQLSpec. -Re-exports common JSON encoding and decoding functions from the core -serialization module for convenient access. - -Provides NumPy array serialization hooks for framework integrations -that support custom type encoders and decoders (e.g., Litestar). +Provides JSON helpers, serializer pipelines, optional dependency hooks, +and cache instrumentation aligned with the core pipeline counters. """ -from typing import Any, Literal, overload +import os +from threading import RLock +from typing import TYPE_CHECKING, Any, Final, Literal, cast, overload from sqlspec._serialization import decode_json, encode_json -from sqlspec.typing import NUMPY_INSTALLED +from sqlspec.typing import NUMPY_INSTALLED, UNSET, ArrowReturnFormat, attrs_asdict +from sqlspec.utils.arrow_helpers import convert_dict_to_arrow +from sqlspec.utils.type_guards import ( + dataclass_to_dict, + has_dict_attribute, + is_attrs_instance, + is_dataclass_instance, + is_dict, + is_msgspec_struct, + is_pydantic_model, +) + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + +__all__ = ( + "SchemaSerializer", + "from_json", + "get_collection_serializer", + "get_serializer_metrics", + "numpy_array_dec_hook", + "numpy_array_enc_hook", + "numpy_array_predicate", + "reset_serializer_cache", + "schema_dump", + "serialize_collection", + "to_json", +) + +DEBUG_ENV_FLAG: Final[str] = "SQLSPEC_DEBUG_PIPELINE_CACHE" +_PRIMITIVE_TYPES: Final[tuple[type[Any], ...]] = (str, bytes, int, float, bool) + + +def _is_truthy(value: "str | None") -> bool: + if value is None: + return False + normalized = value.strip().lower() + return normalized in {"1", "true", "yes", "on"} + + +def _metrics_enabled() -> bool: + return _is_truthy(os.getenv(DEBUG_ENV_FLAG)) + + +class _SerializerCacheMetrics: + __slots__ = ("hits", "max_size", "misses", "size") + + def __init__(self) -> None: + self.hits = 0 + self.misses = 0 + self.size = 0 + self.max_size = 0 + + def record_hit(self, cache_size: int) -> None: + if not _metrics_enabled(): + return + self.hits += 1 + self.size = cache_size + self.max_size = max(self.max_size, cache_size) + + def record_miss(self, cache_size: int) -> None: + if not _metrics_enabled(): + return + self.misses += 1 + self.size = cache_size + self.max_size = max(self.max_size, cache_size) + + def reset(self) -> None: + self.hits = 0 + self.misses = 0 + self.size = 0 + self.max_size = 0 + + def snapshot(self) -> "dict[str, int]": + return { + "hits": self.hits if _metrics_enabled() else 0, + "misses": self.misses if _metrics_enabled() else 0, + "max_size": self.max_size if _metrics_enabled() else 0, + "size": self.size if _metrics_enabled() else 0, + } @overload @@ -93,7 +171,7 @@ def numpy_array_enc_hook(value: Any) -> Any: return value -def numpy_array_dec_hook(value: Any) -> "Any": +def numpy_array_dec_hook(value: Any) -> Any: """Decode list to NumPy array. Converts Python lists to NumPy arrays when appropriate. @@ -164,4 +242,155 @@ def numpy_array_predicate(value: Any) -> bool: return isinstance(value, np.ndarray) -__all__ = ("from_json", "numpy_array_dec_hook", "numpy_array_enc_hook", "numpy_array_predicate", "to_json") +class SchemaSerializer: + """Serializer pipeline that caches conversions for repeated schema dumps.""" + + __slots__ = ("_dump", "_key") + + def __init__(self, key: "tuple[type[Any] | None, bool]", dump: "Callable[[Any], dict[str, Any]]") -> None: + self._key = key + self._dump = dump + + @property + def key(self) -> "tuple[type[Any] | None, bool]": + return self._key + + def dump_one(self, item: Any) -> dict[str, Any]: + return self._dump(item) + + def dump_many(self, items: "Iterable[Any]") -> list[dict[str, Any]]: + return [self._dump(item) for item in items] + + def to_arrow( + self, items: "Iterable[Any]", *, return_format: "ArrowReturnFormat" = "table", batch_size: int | None = None + ) -> Any: + payload = self.dump_many(items) + return convert_dict_to_arrow(payload, return_format=return_format, batch_size=batch_size) + + +_SERIALIZER_LOCK: RLock = RLock() +_SCHEMA_SERIALIZERS: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} +_SERIALIZER_METRICS = _SerializerCacheMetrics() + + +def _make_serializer_key(sample: Any, exclude_unset: bool) -> tuple[type[Any] | None, bool]: + if sample is None or isinstance(sample, dict): + return (None, exclude_unset) + return (type(sample), exclude_unset) + + +def _build_dump_function(sample: Any, exclude_unset: bool) -> "Callable[[Any], dict[str, Any]]": + if sample is None or isinstance(sample, dict): + return lambda value: cast("dict[str, Any]", value) + + if is_dataclass_instance(sample): + + def _dump_dataclass(value: Any) -> dict[str, Any]: + return dataclass_to_dict(value, exclude_empty=exclude_unset) + + return _dump_dataclass + if is_pydantic_model(sample): + + def _dump_pydantic(value: Any) -> dict[str, Any]: + return cast("dict[str, Any]", value.model_dump(exclude_unset=exclude_unset)) + + return _dump_pydantic + if is_msgspec_struct(sample): + if exclude_unset: + + def _dump(value: Any) -> dict[str, Any]: + return {f: val for f in value.__struct_fields__ if (val := getattr(value, f, None)) != UNSET} + + return _dump + + return lambda value: {f: getattr(value, f, None) for f in value.__struct_fields__} + + if is_attrs_instance(sample): + + def _dump_attrs(value: Any) -> dict[str, Any]: + return attrs_asdict(value, recurse=True) + + return _dump_attrs + + if has_dict_attribute(sample): + + def _dump_dict_attr(value: Any) -> dict[str, Any]: + return dict(value.__dict__) + + return _dump_dict_attr + + return lambda value: dict(value) + + +def get_collection_serializer(sample: Any, *, exclude_unset: bool = True) -> "SchemaSerializer": + """Return cached serializer pipeline for the provided sample object.""" + + key = _make_serializer_key(sample, exclude_unset) + with _SERIALIZER_LOCK: + pipeline = _SCHEMA_SERIALIZERS.get(key) + if pipeline is not None: + _SERIALIZER_METRICS.record_hit(len(_SCHEMA_SERIALIZERS)) + return pipeline + + dump = _build_dump_function(sample, exclude_unset) + pipeline = SchemaSerializer(key, dump) + _SCHEMA_SERIALIZERS[key] = pipeline + _SERIALIZER_METRICS.record_miss(len(_SCHEMA_SERIALIZERS)) + return pipeline + + +def serialize_collection(items: "Iterable[Any]", *, exclude_unset: bool = True) -> list[Any]: + """Serialize a collection using cached pipelines keyed by item type.""" + + serialized: list[Any] = [] + cache: dict[tuple[type[Any] | None, bool], SchemaSerializer] = {} + + for item in items: + if isinstance(item, _PRIMITIVE_TYPES) or item is None or isinstance(item, dict): + serialized.append(item) + continue + + key = _make_serializer_key(item, exclude_unset) + pipeline = cache.get(key) + if pipeline is None: + pipeline = get_collection_serializer(item, exclude_unset=exclude_unset) + cache[key] = pipeline + serialized.append(pipeline.dump_one(item)) + return serialized + + +def reset_serializer_cache() -> None: + """Clear cached serializer pipelines.""" + + with _SERIALIZER_LOCK: + _SCHEMA_SERIALIZERS.clear() + _SERIALIZER_METRICS.reset() + + +def get_serializer_metrics() -> dict[str, int]: + """Return cache metrics aligned with the core pipeline counters.""" + + with _SERIALIZER_LOCK: + metrics = _SERIALIZER_METRICS.snapshot() + metrics["size"] = len(_SCHEMA_SERIALIZERS) + return metrics + + +def schema_dump(data: Any, *, exclude_unset: bool = True) -> dict[str, Any]: + """Dump a schema model or dict to a plain dictionary. + + Args: + data: Schema model instance or dictionary to dump. + exclude_unset: Whether to exclude unset fields (for models that support it). + + Returns: + A plain dictionary representation of the schema model. + """ + if is_dict(data): + return data + + if isinstance(data, _PRIMITIVE_TYPES) or data is None: + return cast("dict[str, Any]", data) + + serializer = get_collection_serializer(data, exclude_unset=exclude_unset) + return serializer.dump_one(data) diff --git a/sqlspec/utils/type_guards.py b/sqlspec/utils/type_guards.py index 81c4094e..eaaf9627 100644 --- a/sqlspec/utils/type_guards.py +++ b/sqlspec/utils/type_guards.py @@ -20,7 +20,6 @@ DataclassProtocol, DTOData, Struct, - attrs_asdict, attrs_has, ) @@ -123,7 +122,6 @@ "is_string_literal", "is_typed_dict", "is_typed_parameter", - "schema_dump", "supports_arrow_native", "supports_arrow_results", "supports_limit", @@ -828,36 +826,6 @@ def dataclass_to_dict( return cast("dict[str, Any]", ret) -def schema_dump(data: Any, exclude_unset: bool = True) -> "dict[str, Any]": - """Dump a data object to a dictionary. - - Args: - data: :type:`dict[str, Any]` | :class:`DataclassProtocol` | :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`AttrsInstance` - exclude_unset: :type:`bool` Whether to exclude unset values. - - Returns: - :type:`dict[str, Any]` - """ - from sqlspec._typing import UNSET - - if is_dict(data): - return data - if is_dataclass(data): - return dataclass_to_dict(data, exclude_empty=exclude_unset) - if is_pydantic_model(data): - return data.model_dump(exclude_unset=exclude_unset) # type: ignore[no-any-return] - if is_msgspec_struct(data): - if exclude_unset: - return {f: val for f in data.__struct_fields__ if (val := getattr(data, f, None)) != UNSET} - return {f: getattr(data, f, None) for f in data.__struct_fields__} - if is_attrs_instance(data): - return attrs_asdict(data) - - if has_dict_attribute(data): - return data.__dict__ - return cast("dict[str, Any]", data) - - def can_extract_parameters(obj: Any) -> "TypeGuard[FilterParameterProtocol]": """Check if an object can extract parameters.""" from sqlspec.protocols import FilterParameterProtocol diff --git a/tests/integration/test_adapters/test_asyncpg/test_execute_many.py b/tests/integration/test_adapters/test_asyncpg/test_execute_many.py index 61ee67a8..45256d5d 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_execute_many.py +++ b/tests/integration/test_adapters/test_asyncpg/test_execute_many.py @@ -21,11 +21,10 @@ async def asyncpg_batch_session(asyncpg_async_driver: AsyncpgDriver) -> "AsyncGe name TEXT NOT NULL, value INTEGER DEFAULT 0, category TEXT - ) + ); + TRUNCATE TABLE test_batch RESTART IDENTITY """ ) - await asyncpg_async_driver.execute_script("TRUNCATE TABLE test_batch RESTART IDENTITY") - try: yield asyncpg_async_driver finally: diff --git a/tests/integration/test_storage/test_storage_integration.py b/tests/integration/test_storage/test_storage_integration.py index 01745e87..e54bd6d5 100644 --- a/tests/integration/test_storage/test_storage_integration.py +++ b/tests/integration/test_storage/test_storage_integration.py @@ -561,6 +561,7 @@ def test_local_backend_error_handling(tmp_path: Path) -> None: @pytest.mark.skipif(not FSSPEC_INSTALLED, reason="fsspec not installed") def test_fsspec_s3_error_handling(minio_service: "MinioService", minio_default_bucket_name: str) -> None: """Test FSSpec S3 backend error handling.""" + from sqlspec.exceptions import FileNotFoundInStorageError from sqlspec.storage.backends.fsspec import FSSpecBackend backend = FSSpecBackend.from_config({ @@ -574,7 +575,7 @@ def test_fsspec_s3_error_handling(minio_service: "MinioService", minio_default_b }) # Test reading nonexistent file - with pytest.raises(FileNotFoundError): + with pytest.raises(FileNotFoundInStorageError): backend.read_text("nonexistent.txt") diff --git a/tests/unit/test_builder/test_copy_helpers.py b/tests/unit/test_builder/test_copy_helpers.py new file mode 100644 index 00000000..11d0edd2 --- /dev/null +++ b/tests/unit/test_builder/test_copy_helpers.py @@ -0,0 +1,43 @@ +from sqlglot import parse_one + +from sqlspec.builder import build_copy_from_statement, build_copy_to_statement, sql +from sqlspec.core import SQL + + +def test_build_copy_from_statement_generates_expected_sql() -> None: + statement = build_copy_from_statement( + "public.users", "s3://bucket/data.parquet", columns=("id", "name"), options={"format": "parquet"} + ) + + assert isinstance(statement, SQL) + rendered = statement.sql + assert rendered == "COPY \"public.users\" (id, name) FROM 's3://bucket/data.parquet' WITH (FORMAT 'parquet')" + + expression = parse_one(rendered, read="postgres") + assert expression.args["kind"] is True + + +def test_build_copy_to_statement_generates_expected_sql() -> None: + statement = build_copy_to_statement( + "public.users", "s3://bucket/output.parquet", options={"format": "parquet", "compression": "gzip"} + ) + + assert isinstance(statement, SQL) + rendered = statement.sql + assert rendered == ( + "COPY \"public.users\" TO 's3://bucket/output.parquet' WITH (FORMAT 'parquet', COMPRESSION 'gzip')" + ) + + expression = parse_one(rendered, read="postgres") + assert expression.args["kind"] is False + + +def test_sql_factory_copy_helpers() -> None: + statement = sql.copy_from("users", "s3://bucket/in.csv", columns=("id", "name"), options={"format": "csv"}) + assert isinstance(statement, SQL) + assert statement.sql.startswith("COPY users") + + to_statement = sql.copy("users", target="s3://bucket/out.csv", options={"format": "csv", "header": True}) + assert isinstance(to_statement, SQL) + parsed = parse_one(to_statement.sql, read="postgres") + assert parsed.args["kind"] is False diff --git a/tests/unit/test_storage/test_errors.py b/tests/unit/test_storage/test_errors.py new file mode 100644 index 00000000..2fcc5ec9 --- /dev/null +++ b/tests/unit/test_storage/test_errors.py @@ -0,0 +1,32 @@ +import pytest + +from sqlspec.exceptions import FileNotFoundInStorageError, StorageOperationFailedError +from sqlspec.storage.errors import ( + _normalize_storage_error, # pyright: ignore + execute_sync_storage_operation, + raise_storage_error, +) + + +def test_raise_normalized_storage_error_for_missing_file() -> None: + with pytest.raises(FileNotFoundInStorageError) as excinfo: + raise_storage_error(FileNotFoundError("missing"), backend="local", operation="read_bytes", path="file.txt") + + assert "local read_bytes failed" in str(excinfo.value) + + +def test_normalize_storage_error_marks_retryable() -> None: + normalized = _normalize_storage_error( + TimeoutError("timed out"), backend="fsspec", operation="write_bytes", path="s3://bucket/key" + ) + assert normalized.retryable is True + + +def test_execute_with_normalized_errors_wraps_generic_failure() -> None: + def _boom() -> None: + raise RuntimeError("boom") + + with pytest.raises(StorageOperationFailedError) as excinfo: + execute_sync_storage_operation(_boom, backend="obstore", operation="write_bytes", path="object") + + assert "obstore write_bytes failed" in str(excinfo.value) diff --git a/tests/unit/test_utils/test_serializers.py b/tests/unit/test_utils/test_serializers.py index bee53a40..53fdfa3c 100644 --- a/tests/unit/test_utils/test_serializers.py +++ b/tests/unit/test_utils/test_serializers.py @@ -128,7 +128,7 @@ def test_from_json_basic_types() -> None: assert from_json('"hello"') == "hello" assert from_json("42") == 42 - assert from_json("3.14") == 3.14 + assert from_json("3.14") == pytest.approx(3.14) assert from_json("true") is True assert from_json("false") is False @@ -192,7 +192,7 @@ def test_from_json_numeric_edge_cases() -> None: assert from_json("9223372036854775807") == 9223372036854775807 assert from_json("-42") == -42 - assert from_json("-3.14") == -3.14 + assert from_json("-3.14") == pytest.approx(-3.14) assert from_json("0") == 0 assert from_json("0.0") == 0.0 @@ -450,12 +450,21 @@ def test_module_all_exports() -> None: """Test that __all__ contains the expected exports.""" from sqlspec.utils.serializers import __all__ - assert "from_json" in __all__ - assert "to_json" in __all__ - assert "numpy_array_enc_hook" in __all__ - assert "numpy_array_dec_hook" in __all__ - assert "numpy_array_predicate" in __all__ - assert len(__all__) == 5 + expected = { + "SchemaSerializer", + "from_json", + "get_collection_serializer", + "get_serializer_metrics", + "numpy_array_dec_hook", + "numpy_array_enc_hook", + "numpy_array_predicate", + "reset_serializer_cache", + "schema_dump", + "serialize_collection", + "to_json", + } + + assert set(__all__) == expected def test_error_messages_are_helpful() -> None: diff --git a/tests/unit/test_utils/test_type_guards.py b/tests/unit/test_utils/test_type_guards.py index 6ad9ee97..13e37cbf 100644 --- a/tests/unit/test_utils/test_type_guards.py +++ b/tests/unit/test_utils/test_type_guards.py @@ -12,6 +12,14 @@ from sqlglot import exp from typing_extensions import TypedDict +from sqlspec.typing import PYARROW_INSTALLED +from sqlspec.utils.serializers import ( + get_collection_serializer, + get_serializer_metrics, + reset_serializer_cache, + schema_dump, + serialize_collection, +) from sqlspec.utils.type_guards import ( dataclass_to_dict, expression_has_limit, @@ -59,7 +67,6 @@ is_schema_without_field, is_string_literal, is_typed_dict, - schema_dump, ) pytestmark = pytest.mark.xdist_group("utils") @@ -84,6 +91,12 @@ class SampleTypedDict(TypedDict): optional_field: "str | None" +@dataclass +class _SerializerRecord: + identifier: int + name: str + + class MockSQLGlotExpression: """Mock SQLGlot expression for testing type guard functions. @@ -781,6 +794,13 @@ def test_schema_dump_with_dict() -> None: assert result is data +def test_schema_dump_with_primitives() -> None: + """Test schema_dump returns primitive payload unchanged.""" + payload = "primary" + result = schema_dump(payload) + assert result == payload # type: ignore[comparison-overlap] + + def test_schema_dump_with_dataclass() -> None: """Test schema_dump converts dataclass to dict.""" instance = SampleDataclass(name="test", age=25) @@ -814,6 +834,49 @@ def __init__(self) -> None: assert result == expected +def test_serializer_pipeline_reuses_entry() -> None: + reset_serializer_cache() + metrics = get_serializer_metrics() + assert metrics["size"] == 0 + + sample = _SerializerRecord(identifier=1, name="first") + pipeline = get_collection_serializer(sample) + metrics = get_serializer_metrics() + assert metrics["size"] == 1 + + same_pipeline = get_collection_serializer(_SerializerRecord(identifier=2, name="second")) + assert pipeline is same_pipeline + + +def test_serializer_metrics_track_hits_and_misses(monkeypatch: pytest.MonkeyPatch) -> None: + reset_serializer_cache() + monkeypatch.setenv("SQLSPEC_DEBUG_PIPELINE_CACHE", "1") + + sample = _SerializerRecord(identifier=1, name="instrumented") + get_collection_serializer(sample) + metrics = get_serializer_metrics() + assert metrics["misses"] == 1 + + get_collection_serializer(sample) + metrics = get_serializer_metrics() + assert metrics["hits"] == 1 + + +def test_serialize_collection_mixed_models() -> None: + items = [_SerializerRecord(identifier=1, name="alpha"), {"identifier": 2, "name": "beta"}] + serialized = serialize_collection(items) + assert serialized == [{"identifier": 1, "name": "alpha"}, {"identifier": 2, "name": "beta"}] + + +@pytest.mark.skipif(not PYARROW_INSTALLED, reason="PyArrow not installed") +def test_serializer_pipeline_arrow_conversion() -> None: + sample = _SerializerRecord(identifier=1, name="alpha") + pipeline = get_collection_serializer(sample) + table = pipeline.to_arrow([sample, _SerializerRecord(identifier=2, name="beta")]) + assert table.num_rows == 2 + assert table.column(0).to_pylist() == [1, 2] + + @pytest.mark.parametrize( "guard_func,test_obj,expected", [ diff --git a/uv.lock b/uv.lock index 6759942b..8c59ec3b 100644 --- a/uv.lock +++ b/uv.lock @@ -971,11 +971,11 @@ wheels = [ [[package]] name = "cloudpickle" -version = "3.1.1" +version = "3.1.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113, upload-time = "2025-01-14T17:02:05.085Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992, upload-time = "2025-01-14T17:02:02.417Z" }, + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, ] [[package]] @@ -1326,7 +1326,7 @@ wheels = [ [[package]] name = "fastapi" -version = "0.120.4" +version = "0.121.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, @@ -1334,9 +1334,9 @@ dependencies = [ { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3f/3a/0bf90d5189d7f62dc2bd0523899629ca59b58ff4290d631cd3bb5c8889d4/fastapi-0.120.4.tar.gz", hash = "sha256:2d856bc847893ca4d77896d4504ffdec0fb04312b705065fca9104428eca3868", size = 339716, upload-time = "2025-10-31T18:37:28.81Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/e3/77a2df0946703973b9905fd0cde6172c15e0781984320123b4f5079e7113/fastapi-0.121.0.tar.gz", hash = "sha256:06663356a0b1ee93e875bbf05a31fb22314f5bed455afaaad2b2dad7f26e98fa", size = 342412, upload-time = "2025-11-03T10:25:54.818Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/47/14a76b926edc3957c8a8258423db789d3fa925d2fed800102fce58959413/fastapi-0.120.4-py3-none-any.whl", hash = "sha256:9bdf192308676480d3593e10fd05094e56d6fdc7d9283db26053d8104d5f82a0", size = 108235, upload-time = "2025-10-31T18:37:27.038Z" }, + { url = "https://files.pythonhosted.org/packages/dd/2c/42277afc1ba1a18f8358561eee40785d27becab8f80a1f945c0a3051c6eb/fastapi-0.121.0-py3-none-any.whl", hash = "sha256:8bdf1b15a55f4e4b0d6201033da9109ea15632cb76cf156e7b8b4019f2172106", size = 109183, upload-time = "2025-11-03T10:25:53.27Z" }, ] [[package]] @@ -2007,7 +2007,7 @@ wheels = [ [[package]] name = "google-genai" -version = "1.47.0" +version = "1.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2019,9 +2019,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/97/784fba9bc6c41263ff90cb9063eadfdd755dde79cfa5a8d0e397b067dcf9/google_genai-1.47.0.tar.gz", hash = "sha256:ecece00d0a04e6739ea76cc8dad82ec9593d9380aaabef078990e60574e5bf59", size = 241471, upload-time = "2025-10-29T22:01:02.88Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/40/e8d4b60e45fb2c8f8e1cd5e52e29741d207ce844303e69b3546d06627ced/google_genai-1.48.0.tar.gz", hash = "sha256:d78fe33125a881461be5cb008564b1d73f309cd6b390d328c68fe706b142acea", size = 242952, upload-time = "2025-11-03T17:31:07.569Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/ef/e080e8d67c270ea320956bb911a9359664fc46d3b87d1f029decd33e5c4c/google_genai-1.47.0-py3-none-any.whl", hash = "sha256:e3851237556cbdec96007d8028b4b1f2425cdc5c099a8dc36b72a57e42821b60", size = 241506, upload-time = "2025-10-29T22:01:00.982Z" }, + { url = "https://files.pythonhosted.org/packages/25/7d/a5c02159099546ec01131059294d1f0174eee33a872fc888c6c99e8cd6d9/google_genai-1.48.0-py3-none-any.whl", hash = "sha256:919c1e96948a565e27b5b2a1d23f32865d9647e2236a8ffe1ca999a4922bf887", size = 242904, upload-time = "2025-11-03T17:31:05.465Z" }, ] [[package]]