Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion sqlspec/builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@
MathExpression,
StringExpression,
)
from sqlspec.builder._factory import SQLFactory, sql
from sqlspec.builder._factory import (
SQLFactory,
build_copy_from_statement,
build_copy_statement,
build_copy_to_statement,
sql,
)
from sqlspec.builder._insert import Insert
from sqlspec.builder._join import JoinBuilder
from sqlspec.builder._merge import Merge
Expand Down Expand Up @@ -127,6 +133,9 @@
"UpdateTableClauseMixin",
"WhereClauseMixin",
"WindowFunctionBuilder",
"build_copy_from_statement",
"build_copy_statement",
"build_copy_to_statement",
"extract_expression",
"parse_column_expression",
"parse_condition_expression",
Expand Down
148 changes: 147 additions & 1 deletion sqlspec/builder/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

import logging
from typing import TYPE_CHECKING, Any, Union
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, cast

import sqlglot
from sqlglot import exp
Expand Down Expand Up @@ -46,6 +47,8 @@
from sqlspec.exceptions import SQLBuilderError

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

from sqlspec.builder._expression_wrappers import ExpressionWrapper


Expand Down Expand Up @@ -73,6 +76,9 @@
"Truncate",
"Update",
"WindowFunctionBuilder",
"build_copy_from_statement",
"build_copy_statement",
"build_copy_to_statement",
"sql",
)

Expand Down Expand Up @@ -108,6 +114,96 @@
}


def _normalize_copy_dialect(dialect: DialectType | None) -> str:
if dialect is None:
return "postgres"
if isinstance(dialect, str):
return dialect
return str(dialect)


def _to_copy_schema(table: str, columns: "Sequence[str] | None") -> exp.Expression:
base = exp.table_(table)
if not columns:
return base
column_nodes = [exp.column(column_name) for column_name in columns]
return exp.Schema(this=base, expressions=column_nodes)


def _build_copy_expression(
*, direction: str, table: str, location: str, columns: "Sequence[str] | None", options: "Mapping[str, Any] | None"
) -> exp.Copy:
copy_args: dict[str, Any] = {"this": _to_copy_schema(table, columns), "files": [exp.Literal.string(location)]}

if direction == "from":
copy_args["kind"] = True
elif direction == "to":
copy_args["kind"] = False

if options:
params: list[exp.CopyParameter] = []
for key, value in options.items():
identifier = exp.Var(this=str(key).upper())
value_expression: exp.Expression
if isinstance(value, bool):
value_expression = exp.Boolean(this=value)
elif value is None:
value_expression = exp.null()
elif isinstance(value, (int, float)):
value_expression = exp.Literal.number(value)
elif isinstance(value, (list, tuple)):
elements = [exp.Literal.string(str(item)) for item in value]
value_expression = exp.Array(expressions=elements)
else:
value_expression = exp.Literal.string(str(value))
params.append(exp.CopyParameter(this=identifier, expression=value_expression))
copy_args["params"] = params

return exp.Copy(**copy_args)


def build_copy_statement(
*,
direction: str,
table: str,
location: str,
columns: "Sequence[str] | None" = None,
options: "Mapping[str, Any] | None" = None,
dialect: DialectType | None = None,
) -> SQL:
expression = _build_copy_expression(
direction=direction, table=table, location=location, columns=columns, options=options
)
rendered = expression.sql(dialect=_normalize_copy_dialect(dialect))
return SQL(rendered)


def build_copy_from_statement(
table: str,
source: str,
*,
columns: "Sequence[str] | None" = None,
options: "Mapping[str, Any] | None" = None,
dialect: DialectType | None = None,
) -> SQL:
return build_copy_statement(
direction="from", table=table, location=source, columns=columns, options=options, dialect=dialect
)


def build_copy_to_statement(
table: str,
target: str,
*,
columns: "Sequence[str] | None" = None,
options: "Mapping[str, Any] | None" = None,
dialect: DialectType | None = None,
) -> SQL:
return build_copy_statement(
direction="to", table=table, location=target, columns=columns, options=options, dialect=dialect
)


class SQLFactory:
"""Factory for creating SQL builders and column expressions."""

Expand Down Expand Up @@ -479,6 +575,56 @@ def comment_on(self, dialect: DialectType = None) -> "CommentOn":
"""
return CommentOn(dialect=dialect or self.dialect)

def copy_from(
self,
table: str,
source: str,
*,
columns: "Sequence[str] | None" = None,
options: "Mapping[str, Any] | None" = None,
dialect: DialectType | None = None,
) -> SQL:
"""Build a COPY ... FROM statement."""

effective_dialect = dialect or self.dialect
return build_copy_from_statement(table, source, columns=columns, options=options, dialect=effective_dialect)

def copy_to(
self,
table: str,
target: str,
*,
columns: "Sequence[str] | None" = None,
options: "Mapping[str, Any] | None" = None,
dialect: DialectType | None = None,
) -> SQL:
"""Build a COPY ... TO statement."""

effective_dialect = dialect or self.dialect
return build_copy_to_statement(table, target, columns=columns, options=options, dialect=effective_dialect)

def copy(
self,
table: str,
*,
source: str | None = None,
target: str | None = None,
columns: "Sequence[str] | None" = None,
options: "Mapping[str, Any] | None" = None,
dialect: DialectType | None = None,
) -> SQL:
"""Build a COPY statement, inferring direction from provided arguments."""

if (source is None and target is None) or (source is not None and target is not None):
msg = "Provide either 'source' or 'target' (but not both) to sql.copy()."
raise SQLBuilderError(msg)

if source is not None:
return self.copy_from(table, source, columns=columns, options=options, dialect=dialect)

target_value = cast("str", target)
return self.copy_to(table, target_value, columns=columns, options=options, dialect=dialect)

@staticmethod
def _looks_like_sql(candidate: str, expected_type: str | None = None) -> bool:
"""Determine if a string looks like SQL.
Expand Down
13 changes: 10 additions & 3 deletions sqlspec/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from urllib.parse import unquote, urlparse

from sqlspec.core import SQL, get_cache, get_cache_config
from sqlspec.exceptions import SQLFileNotFoundError, SQLFileParseError, StorageOperationFailedError
from sqlspec.exceptions import (
FileNotFoundInStorageError,
SQLFileNotFoundError,
SQLFileParseError,
StorageOperationFailedError,
)
from sqlspec.storage.registry import storage_registry as default_storage_registry
from sqlspec.utils.correlation import CorrelationContext
from sqlspec.utils.logging import get_logger
Expand Down Expand Up @@ -259,9 +264,11 @@ def _read_file_content(self, path: str | Path) -> str:
return backend.read_text(path_str, encoding=self.encoding)
except KeyError as e:
raise SQLFileNotFoundError(path_str) from e
except FileNotFoundInStorageError as e:
raise SQLFileNotFoundError(path_str) from e
except FileNotFoundError as e:
raise SQLFileNotFoundError(path_str) from e
except StorageOperationFailedError as e:
if "not found" in str(e).lower() or "no such file" in str(e).lower():
raise SQLFileNotFoundError(path_str) from e
raise SQLFileParseError(path_str, path_str, e) from e
except Exception as e:
raise SQLFileParseError(path_str, path_str, e) from e
Expand Down
42 changes: 37 additions & 5 deletions sqlspec/storage/_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,44 @@
"""Shared utilities for storage backends."""

from pathlib import Path
from typing import Any, Final

__all__ = ("resolve_storage_path",)
from sqlspec.utils.module_loader import ensure_pyarrow

FILE_PROTOCOL: Final[str] = "file"
FILE_SCHEME_PREFIX: Final[str] = "file://"

__all__ = ("FILE_PROTOCOL", "FILE_SCHEME_PREFIX", "import_pyarrow", "import_pyarrow_parquet", "resolve_storage_path")


def import_pyarrow() -> "Any":
"""Import PyArrow with optional dependency guard.

Returns:
PyArrow module.
"""

ensure_pyarrow()
import pyarrow as pa

return pa


def import_pyarrow_parquet() -> "Any":
"""Import PyArrow parquet module with optional dependency guard.

Returns:
PyArrow parquet module.
"""

ensure_pyarrow()
import pyarrow.parquet as pq

return pq


def resolve_storage_path(
path: "str | Path", base_path: str = "", protocol: str = "file", strip_file_scheme: bool = True
path: "str | Path", base_path: str = "", protocol: str = FILE_PROTOCOL, strip_file_scheme: bool = True
) -> str:
"""Resolve path relative to base_path with protocol-specific handling.

Expand Down Expand Up @@ -43,10 +75,10 @@ def resolve_storage_path(

path_str = str(path)

if strip_file_scheme and path_str.startswith("file://"):
path_str = path_str.removeprefix("file://")
if strip_file_scheme and path_str.startswith(FILE_SCHEME_PREFIX):
path_str = path_str.removeprefix(FILE_SCHEME_PREFIX)

if protocol == "file":
if protocol == FILE_PROTOCOL:
path_obj = Path(path_str)

if path_obj.is_absolute():
Expand Down
Loading
Loading