From 054d03452adf348e8c535d8725c5ea02126fb44c Mon Sep 17 00:00:00 2001 From: Brendan Cooley Date: Wed, 7 Feb 2024 11:57:07 -0500 Subject: [PATCH] chore: remove database/duckdb support --- pyproject.toml | 7 +- src/patito/_pydantic/dtypes.py | 10 +- src/patito/database.py | 658 ------- src/patito/duckdb.py | 2793 ---------------------------- src/patito/pydantic.py | 7 +- tests/test_database.py | 568 ------ tests/test_dtypes.py | 1 - tests/test_duckdb/__init__.py | 0 tests/test_duckdb/test_database.py | 276 --- tests/test_duckdb/test_relation.py | 1063 ----------- tests/test_model.py | 2 +- 11 files changed, 10 insertions(+), 5375 deletions(-) delete mode 100644 src/patito/database.py delete mode 100644 src/patito/duckdb.py delete mode 100644 tests/test_database.py delete mode 100644 tests/test_duckdb/__init__.py delete mode 100644 tests/test_duckdb/test_database.py delete mode 100644 tests/test_duckdb/test_relation.py diff --git a/pyproject.toml b/pyproject.toml index 3716242..4465839 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,4 +136,9 @@ module = ["tests.test_validators"] warn_unused_ignores = false [tool.ruff.lint] -select = ["I"] \ No newline at end of file +select = ["E4", "E7", "E9", "F", "I"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] \ No newline at end of file diff --git a/src/patito/_pydantic/dtypes.py b/src/patito/_pydantic/dtypes.py index d806796..0a19d7d 100644 --- a/src/patito/_pydantic/dtypes.py +++ b/src/patito/_pydantic/dtypes.py @@ -1,16 +1,11 @@ -import itertools from enum import Enum from typing import ( Any, Dict, FrozenSet, List, - Literal, Optional, Sequence, - Tuple, - cast, - get_args, ) import polars as pl @@ -70,8 +65,7 @@ class PydanticStringFormat(Enum): def parse_composite_dtype(dtype: DataTypeClass | DataType) -> str: - """for serialization, converts polars dtype to string representation - """ + """for serialization, converts polars dtype to string representation""" if dtype in pl.NESTED_DTYPES: if dtype == pl.Struct or isinstance(dtype, pl.Struct): raise NotImplementedError("Structs not yet supported by patito") @@ -156,7 +150,7 @@ def valid_polars_dtypes_for_annotation( Args: annotation (type[Any] | None): python type annotation - + Returns: FrozenSet[DataTypeClass | DataType]: set of polars dtypes """ diff --git a/src/patito/database.py b/src/patito/database.py deleted file mode 100644 index 3477d79..0000000 --- a/src/patito/database.py +++ /dev/null @@ -1,658 +0,0 @@ -"""Module containing utilities for retrieving data from external databases.""" -import glob -import hashlib -import inspect -import re -from datetime import datetime, timedelta -from functools import wraps -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generic, - Optional, - Type, - TypeVar, - Union, - cast, - overload, -) - -import polars as pl -import pyarrow as pa # type: ignore[import] -import pyarrow.parquet as pq # type: ignore[import] -from typing_extensions import Literal, ParamSpec, Protocol - -from patito import xdg - -if TYPE_CHECKING: - from patito import Model - - -P = ParamSpec("P") -DF = TypeVar("DF", bound=Union[pl.DataFrame, pl.LazyFrame], covariant=True) - -# Increment this integer whenever you make backwards-incompatible changes to -# the parquet caching implemented in WrappedQueryFunc, then such caches -# are ejected the next time the wrapper tries to read from them. -CACHE_VERSION = 1 - - -class QueryConstructor(Protocol[P]): - """A function taking arbitrary arguments and returning an SQL query string.""" - - __name__: str - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> str: - """ - Return SQL query constructed from the given parameters. - - Args: - *args: Positional arguments used to build SQL query. - **kwargs: Keyword arguments used to build SQL query. - """ - ... # pragma: no cover - - -class DatabaseQuery(Generic[P, DF]): - """A class acting as a function that returns a polars.DataFrame when called.""" - - _cache: Union[bool, Path] - - def __init__( # noqa: C901 - self, - query_constructor: QueryConstructor[P], - cache_directory: Path, - query_handler: Callable[..., pa.Table], - ttl: timedelta, - lazy: bool = False, - cache: Union[str, Path, bool] = False, - model: Union[Type["Model"], None] = None, - query_handler_kwargs: Optional[Dict[Any, Any]] = None, - ) -> None: - """ - Convert SQL string query function to polars.DataFrame function. - - Args: - query_constructor: A function that takes arbitrary arguments and returns - an SQL query string. - cache_directory: Path to directory to store parquet cache files in. - query_handler: Function used to execute SQL queries and return arrow - tables. - ttl: See Database.query for documentation. - lazy: See Database.query for documentation. - cache: See Database.query for documentation. - model: See Database.query for documentation. - query_handler_kwargs: Arbitrary keyword arguments forwarded to the provided - query handler. - - Raises: - ValueError: If the given path does not have a '.parquet' file extension. - """ - if not isinstance(cache, bool) and Path(cache).suffix != ".parquet": - raise ValueError("Cache paths must have the '.parquet' file extension!") - - if isinstance(cache, (Path, str)): - self._cache = cache_directory.joinpath(cache) - else: - self._cache = cache - self._query_constructor = query_constructor - self.cache_directory = cache_directory - - self._query_handler_kwargs = query_handler_kwargs or {} - # Unless explicitly specified otherwise by the end-user, we retrieve query - # results as arrow tables with column types directly supported by polars. - # Otherwise the resulting parquet files that are written to disk can not be - # lazily read with polars.scan_parquet. - self._query_handler_kwargs.setdefault("cast_to_polars_equivalent_types", True) - - # We construct the new function with the same parameter signature as - # wrapped_function, but with polars.DataFrame as the return type. - @wraps(query_constructor) - def cached_func(*args: P.args, **kwargs: P.kwargs) -> DF: - query = query_constructor(*args, **kwargs) - cache_path = self.cache_path(*args, **kwargs) - if cache_path and cache_path.exists(): - metadata: Dict[bytes, bytes] = pq.read_schema(cache_path).metadata or {} - - # Check if the cache file was produced by an identical SQL query - is_same_query = metadata.get(b"query") == query.encode("utf-8") - - # Check if the cache is too old to be re-used - cache_created_time = datetime.fromisoformat( - metadata.get( - b"query_start_time", b"1900-01-01T00:00:00.000000" - ).decode("utf-8") - ) - is_fresh_cache = (datetime.now() - cache_created_time) < ttl - - # Check if the cache was produced by an incompatible version - cache_version = int.from_bytes( - metadata.get( - b"cache_version", - (0).to_bytes(length=16, byteorder="little", signed=False), - ), - byteorder="little", - signed=False, - ) - is_compatible_version = cache_version >= CACHE_VERSION - - if is_same_query and is_fresh_cache and is_compatible_version: - if lazy: - return pl.scan_parquet(cache_path) # type: ignore - else: - return pl.read_parquet(cache_path) # type: ignore - - arrow_table = query_handler(query, **self._query_handler_kwargs) - if cache_path: - cache_path.parent.mkdir(parents=True, exist_ok=True) - # We write the cache *before* any potential model validation since - # we don't want to lose the result of an expensive query just because - # the model specification is wrong. - # We also use pyarrow.parquet.write_table instead of - # polars.write_parquet since we want to write the arrow table's metadata - # to the parquet file, such as the executed query, time, etc.. - # This metadata is not preserved by polars. - metadata = arrow_table.schema.metadata - metadata[ - b"wrapped_function_name" - ] = self._query_constructor.__name__.encode("utf-8") - # Store the cache version as an 16-bit unsigned little-endian number - metadata[b"cache_version"] = CACHE_VERSION.to_bytes( - length=16, - byteorder="little", - signed=False, - ) - pq.write_table( - table=arrow_table.replace_schema_metadata(metadata), - where=cache_path, - # In order to support nanosecond-resolution timestamps, we must - # use parquet version >= 2.6. - version="2.6", - ) - - polars_df = cast(pl.DataFrame, pl.from_arrow(arrow_table)) - if model: - model.validate(polars_df) - - if lazy: - if cache_path: - # Delete in-memory representation of data and read from the new - # parquet file instead. That way we get consistent memory pressure - # the first and subsequent times this function is invoked. - del polars_df, arrow_table - return pl.scan_parquet(source=cache_path) # type: ignore - else: - return polars_df.lazy() # type: ignore - else: - return polars_df # type: ignore - - self._cached_func = cached_func - - def cache_path(self, *args: P.args, **kwargs: P.kwargs) -> Optional[Path]: - """ - Return the path to the parquet cache that would store the result of the query. - - Args: - args: The positional arguments passed to the wrapped function. - kwargs: The keyword arguments passed to the wrapped function. - - Returns: - A deterministic path to a parquet cache. None if caching is disabled. - """ - # We convert args+kwargs to kwargs-only and use it to format the string - function_signature = inspect.signature(self._query_constructor) - bound_arguments = function_signature.bind(*args, **kwargs) - - if isinstance(self._cache, Path): - # Interpret relative paths relative to the main query cache directory - return Path(str(self._cache).format(**bound_arguments.arguments)) - elif self._cache is True: - directory: Path = self.cache_directory / self._query_constructor.__name__ - directory.mkdir(exist_ok=True, parents=True) - sql_query = self.query_string(*args, **kwargs) - sql_query_hash = hashlib.sha1( # noqa: S324,S303 - sql_query.encode("utf-8") - ).hexdigest() - return directory / f"{sql_query_hash}.parquet" - else: - return None - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> DF: # noqa: D102 - return self._cached_func(*args, **kwargs) - - def query_string(self, *args: P.args, **kwargs: P.kwargs) -> str: - """ - Return the query to be executed for the given parameters. - - Args: - *args: Positional arguments used to construct the query string. - *kwargs: Keyword arguments used to construct the query string. - - Returns: - The query string produced for the given input parameters. - """ - return self._query_constructor(*args, **kwargs) - - def refresh_cache(self, *args: P.args, **kwargs: P.kwargs) -> DF: - """ - Force query execution by refreshing the cache. - - Args: - *args: Positional arguments used to construct the SQL query string. - *kwargs: Keyword arguments used to construct the SQL query string. - - Returns: - A DataFrame representing the result of the newly executed query. - """ - cache_path = self.cache_path(*args, **kwargs) - if cache_path and cache_path.exists(): - cache_path.unlink() - return self._cached_func(*args, **kwargs) - - def clear_caches(self) -> None: - """Delete all parquet cache files produced by this query wrapper.""" - if self._cache is False: - # Caching is not enabled, so this is simply a no-op - return - - if self._cache is True: - glob_pattern = str( - self.cache_directory / self._query_constructor.__name__ / "*.parquet" - ) - else: - # We replace all formatting specifiers of the form '{variable}' with - # recursive globs '**' (in case strings containing '/' are inserted) and - # search for all occurrences of such file paths. - # For example if cache="{a}/{b}.parquet" is specified, we search for - # all files matching the glob pattern '**/**.parquet'. - glob_pattern = re.sub( # noqa: PD005 - # We specify the reluctant qualifier (?) in order to get narrow matches - pattern=r"\{.+?\}", - repl="**", - string=str(self._cache), - ) - - for parquet_path in glob.iglob(glob_pattern): - try: - metadata: Dict[bytes, bytes] = ( - pq.read_schema(where=parquet_path).metadata or {} - ) - if metadata.get( - b"wrapped_function_name" - ) == self._query_constructor.__name__.encode("utf-8"): - Path(parquet_path).unlink() - except Exception: # noqa: S112 - # If we can't read the parquet metadata for some reason, - # it is probably not a cache anyway. - continue - - -class Database: - """ - Construct manager for executing SQL queries and caching the results. - - Args: - query_handler: The function that the Database object should use for executing - SQL queries. Its first argument should be the SQL query string to execute, - and it should return the query result as an arrow table, for instance - pyarrow.Table. - cache_directory: Path to the directory where caches should be stored as parquet - files. If not provided, the `XDG`_ Base Directory Specification will be - used to determine the suitable cache directory, by default - ``~/.cache/patito`` or ``${XDG_CACHE_HOME}/patito``. - default_ttl: The default Time To Live (TTL), or with other words, how long to - wait until caches are refreshed due to old age. The given default TTL can be - overwritten by specifying the ``ttl`` parameter in - :func:`Database.query`. The default is 52 weeks. - - Examples: - We start by importing the necessary modules: - - >>> from pathlib import Path - ... - >>> import patito as pt - >>> import pyarrow as pa - - In order to construct a ``Database``, we need to provide the constructor with - a function that can *execute* query strings. How to construct this function will - depend on what you actually want to run your queries against, for example a - local or remote database. For the purposes of demonstration we will use - SQLite since it is built into Python's standard library, but you can use - anything; for example Snowflake or PostgresQL. - - We will use Python's standard library - `documentation `_ - to create an in-memory SQLite database. - It will contain a single table named ``movies`` containing some dummy data. - The details do not really matter here, the only important part is that we - construct a database which we can run SQL queries against. - - >>> import sqlite3 - ... - >>> def dummy_database() -> sqlite3.Cursor: - ... connection = sqlite3.connect(":memory:") - ... cursor = connection.cursor() - ... cursor.execute("CREATE TABLE movies(title, year, score)") - ... data = [ - ... ("Monty Python Live at the Hollywood Bowl", 1982, 7.9), - ... ("Monty Python's The Meaning of Life", 1983, 7.5), - ... ("Monty Python's Life of Brian", 1979, 8.0), - ... ] - ... cursor.executemany("INSERT INTO movies VALUES(?, ?, ?)", data) - ... connection.commit() - ... return cursor - - Using this dummy database, we are now able to construct a function which accepts - SQL queries as its first parameter, executes the query, and returns the query - result in the form of an Arrow table. - - >>> def query_handler(query: str) -> pa.Table: - ... cursor = dummy_database() - ... cursor.execute(query) - ... columns = [description[0] for description in cursor.description] - ... data = [dict(zip(columns, row)) for row in cursor.fetchall()] - ... return pa.Table.from_pylist(data) - - We can now construct a ``Database`` object, providing ``query_handler`` - as the way to execute SQL queries. - - >>> db = pt.Database(query_handler=query_handler) - - The resulting object can now be used to execute SQL queries against the database - and return the result in the form of a polars ``DataFrame`` object. - - >>> db.query("select * from movies order by year limit 1") - shape: (1, 3) - ┌──────────────────────────────┬──────┬───────┐ - │ title ┆ year ┆ score │ - │ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ f64 │ - ╞══════════════════════════════╪══════╪═══════╡ - │ Monty Python's Life of Brian ┆ 1979 ┆ 8.0 │ - └──────────────────────────────┴──────┴───────┘ - - But the main way to use a ``Database`` object is to use the - ``@Database.as_query`` decarator to wrap functions which return SQL - query *strings*. - - >>> @db.as_query() - >>> def movies(newer_than_year: int): - ... return f"select * from movies where year > {newer_than_year}" - - This decorator will convert the function from producing query strings, to - actually executing the given query and return the query result in the form of - a polars ``DataFrame`` object. - - >>> movies(newer_than_year=1980) - shape: (2, 3) - ┌───────────────────────────────────┬──────┬───────┐ - │ title ┆ year ┆ score │ - │ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ f64 │ - ╞═══════════════════════════════════╪══════╪═══════╡ - │ Monty Python Live at the Hollywo… ┆ 1982 ┆ 7.9 │ - │ Monty Python's The Meaning of Li… ┆ 1983 ┆ 7.5 │ - └───────────────────────────────────┴──────┴───────┘ - - Caching is not enabled by default, but it can be enabled by specifying - ``cache=True`` to the ``@db.as_query(...)`` decorator. Other arguments are - also accepted, such as ``lazy=True`` if you want to retrieve the results in the - form of a ``LazyFrame`` instead of a ``DataFrame``, ``ttl`` if you want to - specify another TTL, and any additional keyword arguments are forwarded to - ``query_executor`` when the SQL query is executed. You can read more about these - parameters in the documentation of :func:`Database.query`. - - .. _XDG: https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html - """ - - Query = DatabaseQuery - - def __init__( # noqa: D107 - self, - query_handler: Callable[..., pa.Table], - cache_directory: Optional[Path] = None, - default_ttl: timedelta = timedelta(weeks=52), # noqa: B008 - ) -> None: - self.query_handler = query_handler - self.cache_directory = cache_directory or xdg.cache_home(application="patito") - self.default_ttl = default_ttl - - self.cache_directory.mkdir(exist_ok=True, parents=True) - - # With lazy = False a DataFrame-producing wrapper is returned - @overload - def as_query( - self, - *, - lazy: Literal[False] = False, - cache: Union[str, Path, bool] = False, - ttl: Optional[timedelta] = None, - model: Union[Type["Model"], None] = None, - **kwargs: Any, # noqa: ANN401 - ) -> Callable[[QueryConstructor[P]], DatabaseQuery[P, pl.DataFrame]]: - ... # pragma: no cover - - # With lazy = True a LazyFrame-producing wrapper is returned - @overload - def as_query( - self, - *, - lazy: Literal[True], - cache: Union[str, Path, bool] = False, - ttl: Optional[timedelta] = None, - model: Union[Type["Model"], None] = None, - **kwargs: Any, # noqa: ANN401 - ) -> Callable[[QueryConstructor[P]], DatabaseQuery[P, pl.LazyFrame]]: - ... # pragma: no cover - - def as_query( - self, - *, - lazy: bool = False, - cache: Union[str, Path, bool] = False, - ttl: Optional[timedelta] = None, - model: Union[Type["Model"], None] = None, - **kwargs: Any, # noqa: ANN401 - ) -> Callable[ - [QueryConstructor[P]], DatabaseQuery[P, Union[pl.DataFrame, pl.LazyFrame]] - ]: - """ - Execute the returned query string and return a polars dataframe. - - Args: - lazy: If the result should be returned as a LazyFrame rather than a - DataFrame. Allows more efficient reading from parquet caches if caching - is enabled. - cache: If queries should be cached in order to save time and costs. - The cache will only be used if the exact same SQL string has - been executed before. - If the parameter is specified as ``True``, a parquet file is - created for each unique query string, and is located at: - artifacts/query_cache//.parquet - If the a string or ``pathlib.Path`` object is provided, the given path - will be used, but it must have a '.parquet' file extension. - Relative paths are interpreted relative to artifacts/query_cache/ - in the workspace root. The given parquet path will be overwritten - if the query string changes, so only the latest query string value - will be cached. - ttl: The Time to Live (TTL) of the cache specified as a datetime.timedelta - object. When the cache becomes older than the specified TTL, the query - will be re-executed on the next invocation of the query function - and the cache will refreshed. - model: An optional Patito model used to validate the content of the - dataframe before return. - **kwargs: Connection parameters forwarded to sql_to_polars, for example - db_params. - - Returns: - A new function which returns a polars DataFrame based on the query - specified by the original function's return string. - """ - - def wrapper(query_constructor: QueryConstructor) -> DatabaseQuery: - return self.Query( - query_constructor=query_constructor, - lazy=lazy, - cache=cache, - ttl=ttl if ttl is not None else self.default_ttl, - cache_directory=self.cache_directory, - model=model, - query_handler=_with_query_metadata(self.query_handler), - query_handler_kwargs=kwargs, - ) - - return wrapper - - # With lazy=False, a DataFrame is returned - @overload - def query( - self, - query: str, - *, - lazy: Literal[False] = False, - cache: Union[str, Path, bool] = False, - ttl: Optional[timedelta] = None, - model: Union[Type["Model"], None] = None, - **kwargs: Any, # noqa: ANN401 - ) -> pl.DataFrame: - ... # pragma: no cover - - # With lazy=True, a LazyFrame is returned - @overload - def query( - self, - query: str, - *, - lazy: Literal[True], - cache: Union[str, Path, bool] = False, - ttl: Optional[timedelta] = None, - model: Union[Type["Model"], None] = None, - **kwargs: Any, # noqa: ANN401 - ) -> pl.LazyFrame: - ... # pragma: no cover - - def query( - self, - query: str, - *, - lazy: bool = False, - cache: Union[str, Path, bool] = False, - ttl: Optional[timedelta] = None, - model: Union[Type["Model"], None] = None, - **kwargs: Any, # noqa: ANN401 - ) -> Union[pl.DataFrame, pl.LazyFrame]: - """ - Execute the given query and return the query result as a DataFrame or LazyFrame. - - See :ref:`Database.as_query` for a more powerful way to build and execute - queries. - - Args: - query: The query string to be executed, for instance an SQL query. - lazy: If the query result should be returned in the form of a LazyFrame - instead of a DataFrame. - cache: If the query result should be saved and re-used the next time the - same query is executed. Can also be provided as a path. See - :func:`Database.as_query` for full documentation. - ttl: How long to use cached results until the query is re-executed anyway. - model: A :ref:`Model` to optionally validate the query result. - **kwargs: All additional keyword arguments are forwarded to the query - handler which executes the given query. - - Returns: - The result of the query in the form of a ``DataFrame`` if ``lazy=False``, or - a ``LazyFrame`` otherwise. - - Examples: - We will use DuckDB as our example database. - - >>> import duckdb - >>> import patito as pt - - We will construct a really simple query source from an in-memory database. - - >>> db = duckdb.connect(":memory:") - >>> query_handler = lambda query: db.cursor().query(query).arrow() - >>> query_source = pt.Database(query_handler=query_handler) - - We can now use :func:`Database.query` in order to execute queries against - the in-memory database. - - >>> query_source.query("select 1 as a, 2 as b, 3 as c") - shape: (1, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i32 ┆ i32 ┆ i32 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - └─────┴─────┴─────┘ - """ - - def __direct_query() -> str: - """ - A regular named function in order to store parquet files correctly. - - Returns: - The user-provided query string. - """ - return query - - return self.as_query( - lazy=lazy, # type: ignore - cache=cache, - ttl=ttl, - model=model, - **kwargs, - )(__direct_query)() - - -def _with_query_metadata(query_handler: Callable[P, pa.Table]) -> Callable[P, pa.Table]: - """ - Wrap SQL-query handler with additional logic. - - Args: - query_handler: Function accepting an SQL query as its first argument and - returning an Arrow table. - - Returns: - New function that returns Arrow table with additional metedata. Arrow types - which are not supported by polars directly have also been converted to - compatible ones where applicable. - """ - - @wraps(query_handler) - def wrapped_query_handler( - *args: P.args, - **kwargs: P.kwargs, - ) -> pa.Table: - cast_to_polars_equivalent_types = kwargs.pop( - "cast_to_polars_equivalent_types", True - ) - start_time = datetime.now() - arrow_table = query_handler(*args, **kwargs) - finish_time = datetime.now() - metadata: dict = arrow_table.schema.metadata or {} - if cast_to_polars_equivalent_types: - # We perform a round-trip to polars and back in order to get an arrow table - # with column types that are directly supported by polars. - arrow_table = pl.from_arrow(arrow_table).to_arrow() - - # Store additional metadata which is useful when the arrow table is written to a - # parquet file as a caching mechanism. - metadata.update( - { - "query": args[0], - "query_start_time": start_time.isoformat(), - "query_end_time": finish_time.isoformat(), - } - ) - return arrow_table.replace_schema_metadata(metadata) - - return wrapped_query_handler - - -__all__ = ["Database"] diff --git a/src/patito/duckdb.py b/src/patito/duckdb.py deleted file mode 100644 index 524916e..0000000 --- a/src/patito/duckdb.py +++ /dev/null @@ -1,2793 +0,0 @@ -""" -Module which wraps around the duckdb module in an opiniated manner. -""" -from __future__ import annotations - -import hashlib -from collections.abc import Collection, Iterable, Iterator -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generic, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) - -import numpy as np -import polars as pl -import pyarrow as pa # type: ignore[import] -from pydantic import create_model -from typing_extensions import Literal - -from patito import sql -from patito.exceptions import MultipleRowsReturned, RowDoesNotExist -from patito.polars import DataFrame -from patito.pydantic import Model, ModelType - -try: - import pandas as pd - - _PANDAS_AVAILABLE = True -except ImportError: - _PANDAS_AVAILABLE = False - -if TYPE_CHECKING: - import duckdb - - -# Types which can be used to instantiate a DuckDB Relation object -RelationSource = Union[ - DataFrame, - pl.DataFrame, - "pd.DataFrame", - Path, - str, - "duckdb.DuckDBPyRelation", - "Relation", -] - -# Used to refer to type(self) in Relation methods which preserve the type. -# Hard-coding Relation or Relation[ModelType] does not work for subclasses -# that return type(self) since that will refer to the parent class. -# See relevant SO answer: https://stackoverflow.com/a/63178532 -RelationType = TypeVar("RelationType", bound="Relation") - -# The SQL types supported by DuckDB -# See: https://duckdb.org/docs/sql/data_types/overview -# fmt: off -DuckDBSQLType = Literal[ - "BIGINT", "INT8", "LONG", - "BLOB", "BYTEA", "BINARY", "VARBINARY", - "BOOLEAN", "BOOL", "LOGICAL", - "DATE", - "DOUBLE", "FLOAT8", "NUMERIC", "DECIMAL", - "HUGEINT", - "INTEGER", "INT4", "INT", "SIGNED", - "INTERVAL", - "REAL", "FLOAT4", "FLOAT", - "SMALLINT", "INT2", "SHORT", - "TIME", - "TIMESTAMP", "DATETIME", - "TIMESTAMP WITH TIMEZONE", "TIMESTAMPTZ", - "TINYINT", "INT1", - "UBIGINT", - "UINTEGER", - "USMALLINT", - "UTINYINT", - "UUID", - "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING", -] -# fmt: on - -# Used for backward-compatible patches -POLARS_VERSION: Optional[Tuple[int, int, int]] -try: - POLARS_VERSION = cast( - Tuple[int, int, int], - tuple(map(int, pl.__version__.split("."))), - ) -except ValueError: # pragma: no cover - POLARS_VERSION = None - - -def create_pydantic_model(relation: "duckdb.DuckDBPyRelation") -> Type[Model]: - """Create pydantic model deserialization of the given relation.""" - pydantic_annotations = {column: (Any, ...) for column in relation.columns} - return create_model( # type: ignore - relation.alias, - __base__=Model, - **pydantic_annotations, # pyright: ignore - ) - - -def _enum_type_name(field_properties: dict) -> str: - """ - Return enum DuckDB SQL type name based on enum values. - - The same enum values, regardless of ordering, will always be given the same name. - """ - enum_values = ", ".join(repr(value) for value in sorted(field_properties["enum"])) - value_hash = hashlib.md5(enum_values.encode("utf-8")).hexdigest() # noqa: #S303 - return f"enum__{value_hash}" - - -def _is_missing_enum_type_exception(exception: BaseException) -> bool: - """ - Return True if the given exception might be caused by missing enum type definitions. - - Args: - exception: Exception raised by DuckDB. - - Returns: - True if the exception might be caused by a missing SQL enum type definition. - """ - description = str(exception) - # DuckDB version <= 0.3.4 - old_exception = description.startswith("Not implemented Error: DataType") - # DuckDB version >= 0.4.0 - new_exception = description.startswith("Catalog Error: Type with name enum_") - return old_exception or new_exception - - -class Relation(Generic[ModelType]): - # The database connection which the given relation belongs to - database: Database - - # The underlying DuckDB relation object which this class wraps around - _relation: duckdb.DuckDBPyRelation - - # Can be set by subclasses in order to specify the serialization class for rows. - # Must accept column names as keyword arguments. - model: Optional[Type[ModelType]] = None - - # The alias that can be used to refer to the relation in queries - alias: str - - def __init__( # noqa: C901 - self, - derived_from: RelationSource, - database: Optional[Database] = None, - model: Optional[Type[ModelType]] = None, - ) -> None: - """ - Create a new relation object containing data to be queried with DuckDB. - - Args: - derived_from: Data to be represented as a DuckDB relation object. - Can be one of the following types: - - - A pandas or polars DataFrame. - - An SQL query represented as a string. - - A ``Path`` object pointing to a CSV or a parquet file. - The path must point to an existing file with either a ``.csv`` - or ``.parquet`` file extension. - - A native DuckDB relation object (``duckdb.DuckDBPyRelation``). - - A ``patito.duckdb.Relation`` object. - - database: Which database to load the relation into. If not provided, - the default DuckDB database will be used. - - model: Sub-class of ``patito.Model`` which specifies how to deserialize rows - when fetched with methods such as - :ref:`Relation.get()` and ``__iter__()``. - - Will also be used to create a strict table schema if - :ref:`Relation.create_table()`. - schema should be constructed. - - If not provided, a dynamic model fitting the relation schema will be - created when required. - - Can also be set later dynamically by invoking - :ref:`Relation.set_model()`. - - Raises: - ValueError: If any one of the following cases are encountered: - - - If a provided ``Path`` object does not have a ``.csv`` or - ``.parquet`` file extension. - - If a database and relation object is provided, but the relation object - does not belong to the database. - - TypeError: If the type of ``derived_from`` is not supported. - - Examples: - Instantiated from a dataframe: - - >>> import patito as pt - >>> df = pt.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - >>> pt.duckdb.Relation(df).filter("a > 2").to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 3 ┆ 6 │ - └─────┴─────┘ - - Instantiated from an SQL query: - - >>> pt.duckdb.Relation("select 1 as a, 2 as b").to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 2 │ - └─────┴─────┘ - """ - import duckdb - - if isinstance(derived_from, Relation): - if ( - database is not None - and derived_from.database.connection is not database.connection - ): - raise ValueError( - "Relations can't be casted between database connections." - ) - self.database = derived_from.database - self._relation = derived_from._relation - self.model = derived_from.model - return - - if database is None: - self.database = Database.default() - else: - self.database = database - - if isinstance(derived_from, duckdb.DuckDBPyRelation): - relation = derived_from - elif isinstance(derived_from, str): - relation = self.database.connection.from_query(derived_from) - elif _PANDAS_AVAILABLE and isinstance(derived_from, pd.DataFrame): - # We must replace pd.NA with np.nan in order for it to be considered - # as null by DuckDB. Otherwise it will casted to the string - # or even segfault. - derived_from = derived_from.fillna(np.nan) - relation = self.database.connection.from_df(derived_from) - elif isinstance(derived_from, pl.DataFrame): - relation = self.database.connection.from_arrow(derived_from.to_arrow()) - elif isinstance(derived_from, Path): - if derived_from.suffix.lower() == ".parquet": - relation = self.database.connection.from_parquet(str(derived_from)) - elif derived_from.suffix.lower() == ".csv": - relation = self.database.connection.from_csv_auto(str(derived_from)) - else: - raise ValueError( - f"Unsupported file suffix {derived_from.suffix!r} for data import!" - ) - else: - raise TypeError # pragma: no cover - - self._relation = relation - if model is not None: - self.model = model # pyright: ignore - - def aggregate( - self, - *aggregations: str, - group_by: Union[str, Iterable[str]], - **named_aggregations: str, - ) -> Relation: - """ - Return relation formed by ``GROUP BY`` SQL aggregation(s). - - Args: - aggregations: Zero or more aggregation expressions such as - "sum(column_name)" and "count(distinct column_name)". - named_aggregations: Zero or more aggregated expressions where the keyword is - used to name the given aggregation. For example, - ``my_column="sum(column_name)"`` is inserted as - ``"sum(column_name) as my_column"`` in the executed SQL query. - group_by: A single column name or iterable collection of column names to - group by. - - Examples: - >>> import patito as pt - >>> df = pt.DataFrame({"a": [1, 2, 3], "b": ["X", "Y", "X"]}) - >>> relation = pt.duckdb.Relation(df) - >>> relation.aggregate( - ... "b", - ... "sum(a)", - ... "greatest(b)", - ... max_a="max(a)", - ... group_by="b", - ... ).to_df() - shape: (2, 4) - ┌─────┬────────┬─────────────┬───────┐ - │ b ┆ sum(a) ┆ greatest(b) ┆ max_a │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ f64 ┆ str ┆ i64 │ - ╞═════╪════════╪═════════════╪═══════╡ - │ X ┆ 4.0 ┆ X ┆ 3 │ - │ Y ┆ 2.0 ┆ Y ┆ 2 │ - └─────┴────────┴─────────────┴───────┘ - """ - expression = ", ".join( - aggregations - + tuple( - f"{expression} as {column_name}" - for column_name, expression in named_aggregations.items() - ) - ) - relation = self._relation.aggregate( - aggr_expr=expression, - group_expr=group_by if isinstance(group_by, str) else ", ".join(group_by), - ) - return self._wrap(relation=relation, schema_change=True) - - def add_suffix( - self, - suffix: str, - include: Optional[Collection[str]] = None, - exclude: Optional[Collection[str]] = None, - ) -> Relation: - """ - Add a suffix to all the columns of the relation. - - Args: - suffix: A string to append to add to all columns names. - include: If provided, only the given columns will be renamed. - exclude: If provided, the given columns will `not` be renamed. - - Raises: - TypeError: If both include `and` exclude are provided at the same time. - - Examples: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as column_1, 2 as column_2") - >>> relation.add_suffix("_renamed").to_df() - shape: (1, 2) - ┌──────────────────┬──────────────────┐ - │ column_1_renamed ┆ column_2_renamed │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞══════════════════╪══════════════════╡ - │ 1 ┆ 2 │ - └──────────────────┴──────────────────┘ - - >>> relation.add_suffix("_renamed", include=["column_1"]).to_df() - shape: (1, 2) - ┌──────────────────┬──────────┐ - │ column_1_renamed ┆ column_2 │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞══════════════════╪══════════╡ - │ 1 ┆ 2 │ - └──────────────────┴──────────┘ - - >>> relation.add_suffix("_renamed", exclude=["column_1"]).to_df() - shape: (1, 2) - ┌──────────┬──────────────────┐ - │ column_1 ┆ column_2_renamed │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞══════════╪══════════════════╡ - │ 1 ┆ 2 │ - └──────────┴──────────────────┘ - """ - if include is not None and exclude is not None: - raise TypeError("Both include and exclude provided at the same time!") - elif include is not None: - included = lambda column: column in include - elif exclude is not None: - included = lambda column: column not in exclude - else: - included = lambda _: True # noqa: E731 - - return self.select( - ", ".join( - f"{column} as {column}{suffix}" if included(column) else column - for column in self.columns - ) - ) - - def add_prefix( - self, - prefix: str, - include: Optional[Iterable[str]] = None, - exclude: Optional[Iterable[str]] = None, - ) -> Relation: - """ - Add a prefix to all the columns of the relation. - - Args: - prefix: A string to prepend to add to all the columns names. - include: If provided, only the given columns will be renamed. - exclude: If provided, the given columns will `not` be renamed. - - Raises: - TypeError: If both include `and` exclude are provided at the same time. - - Examples: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as column_1, 2 as column_2") - >>> relation.add_prefix("renamed_").to_df() - shape: (1, 2) - ┌──────────────────┬──────────────────┐ - │ renamed_column_1 ┆ renamed_column_2 │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞══════════════════╪══════════════════╡ - │ 1 ┆ 2 │ - └──────────────────┴──────────────────┘ - - >>> relation.add_prefix("renamed_", include=["column_1"]).to_df() - shape: (1, 2) - ┌──────────────────┬──────────┐ - │ renamed_column_1 ┆ column_2 │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞══════════════════╪══════════╡ - │ 1 ┆ 2 │ - └──────────────────┴──────────┘ - - >>> relation.add_prefix("renamed_", exclude=["column_1"]).to_df() - shape: (1, 2) - ┌──────────┬──────────────────┐ - │ column_1 ┆ renamed_column_2 │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞══════════╪══════════════════╡ - │ 1 ┆ 2 │ - └──────────┴──────────────────┘ - """ - if include is not None and exclude is not None: - raise TypeError("Both include and exclude provided at the same time!") - elif include is not None: - included = lambda column: column in include - elif exclude is not None: - included = lambda column: column not in exclude - else: - included = lambda _: True - - return self.select( - ", ".join( - f"{column} as {prefix}{column}" if included(column) else column - for column in self.columns - ) - ) - - def all(self, *filters: str, **equalities: Union[int, float, str]) -> bool: - """ - Return ``True`` if the given predicate(s) are true for all rows in the relation. - - See :func:`Relation.filter()` for additional information regarding the - parameters. - - Args: - filters: SQL predicates to satisfy. - equalities: SQL equality predicates to satisfy. - - Examples: - >>> import patito as pt - >>> df = pt.DataFrame( - ... { - ... "even_number": [2, 4, 6], - ... "odd_number": [1, 3, 5], - ... "zero": [0, 0, 0], - ... } - ... ) - >>> relation = pt.duckdb.Relation(df) - >>> relation.all(zero=0) - True - >>> relation.all( - ... "even_number % 2 = 0", - ... "odd_number % 2 = 1", - ... zero=0, - ... ) - True - >>> relation.all(zero=1) - False - >>> relation.all("odd_number % 2 = 0") - False - """ - return self.filter(*filters, **equalities).count() == self.count() - - def case( - self, - *, - from_column: str, - to_column: str, - mapping: Dict[sql.SQLLiteral, sql.SQLLiteral], - default: sql.SQLLiteral, - ) -> Relation: - """ - Map values of one column over to a new column. - - Args: - from_column: Name of column defining the domain of the mapping. - to_column: Name of column to insert the mapped values into. - mapping: Dictionary defining the mapping. The dictionary keys represent the - input values, while the dictionary values represent the output values. - Items are inserted into the SQL case statement by their repr() string - value. - default: Default output value for inputs which have no provided mapping. - - Examples: - The following case statement... - - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> relation = db.to_relation("select 1 as a union select 2 as a") - >>> relation.case( - ... from_column="a", - ... to_column="b", - ... mapping={1: "one", 2: "two"}, - ... default="three", - ... ).order(by="a").to_df() - shape: (2, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ str │ - ╞═════╪═════╡ - │ 1 ┆ one │ - │ 2 ┆ two │ - └─────┴─────┘ - - ... is equivalent with: - - >>> case_statement = pt.sql.Case( - ... on_column="a", - ... mapping={1: "one", 2: "two"}, - ... default="three", - ... as_column="b", - ... ) - >>> relation.select(f"*, {case_statement}").order(by="a").to_df() - shape: (2, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ str │ - ╞═════╪═════╡ - │ 1 ┆ one │ - │ 2 ┆ two │ - └─────┴─────┘ - """ - - case_statement = sql.Case( - on_column=from_column, - mapping=mapping, - default=default, - as_column=to_column, - ) - new_relation = self._relation.project(f"*, {case_statement}") - return self._wrap(relation=new_relation, schema_change=True) - - def cast( - self: RelationType, - model: Optional[ModelType] = None, - strict: bool = False, - include: Optional[Collection[str]] = None, - exclude: Optional[Collection[str]] = None, - ) -> RelationType: - """ - Cast the columns of the relation to types compatible with the associated model. - - The associated model must either be set by invoking - :ref:`Relation.set_model() ` or provided with the - ``model`` parameter. - - Any columns of the relation that are not part of the given model schema will be - left as-is. - - Args: - model: If :ref:`Relation.set_model() ` has not - been invoked or is intended to be overwritten. - strict: If set to ``False``, columns which are technically compliant with - the specified field type, will not be casted. For example, a column - annotated with ``int`` is technically compliant with ``SMALLINT``, even - if ``INTEGER`` is the default SQL type associated with ``int``-annotated - fields. If ``strict`` is set to ``True``, the resulting dtypes will - be forced to the default dtype associated with each python type. - include: If provided, only the given columns will be casted. - exclude: If provided, the given columns will `not` be casted. - - Returns: - New relation where the columns have been casted according to the model - schema. - - Examples: - >>> import patito as pt - >>> class Schema(pt.Model): - ... float_column: float - ... - >>> relation = pt.duckdb.Relation("select 1 as float_column") - >>> relation.types["float_column"] - INTEGER - >>> relation.cast(model=Schema).types["float_column"] - DOUBLE - - >>> relation = pt.duckdb.Relation("select 1::FLOAT as float_column") - >>> relation.cast(model=Schema).types["float_column"] - FLOAT - >>> relation.cast(model=Schema, strict=True).types["float_column"] - DOUBLE - - >>> class Schema(pt.Model): - ... column_1: float - ... column_2: float - ... - >>> relation = pt.duckdb.Relation( - ... "select 1 as column_1, 2 as column_2" - ... ).set_model(Schema) - >>> relation.types - {'column_1': INTEGER, 'column_2': INTEGER} - >>> relation.cast(include=["column_1"]).types - {'column_1': DOUBLE, 'column_2': INTEGER} - >>> relation.cast(exclude=["column_1"]).types - {'column_1': INTEGER, 'column_2': DOUBLE} - """ - if model is not None: - relation = self.set_model(model) - schema = model - elif self.model is not None: - relation = self - schema = cast(ModelType, self.model) - else: - class_name = self.__class__.__name__ - raise TypeError( - f"{class_name}.cast() invoked without " - f"{class_name}.model having been set! " - f"You should invoke {class_name}.set_model() first " - "or explicitly provide a model to .cast()." - ) - - if include is not None and exclude is not None: - raise ValueError( - "Both include and exclude provided to " - f"{self.__class__.__name__}.cast()!" - ) - elif include is not None: - include = set(include) - elif exclude is not None: - include = set(relation.columns) - set(exclude) - else: - include = set(relation.columns) - - new_columns = [] - for column, current_type in relation.types.items(): - if column not in schema.columns: - new_columns.append(column) - elif column in include and ( - strict or current_type not in schema.valid_sql_types[column] - ): - new_type = schema.sql_types[column] - new_columns.append(f"{column}::{new_type} as {column}") - else: - new_columns.append(column) - return cast(RelationType, self.select(*new_columns)) - - def coalesce( - self: RelationType, - **column_expressions: Union[str, int, float], - ) -> RelationType: - """ - Replace null-values in given columns with respective values. - - For example, ``coalesce(column_name=value)`` is compiled to: - ``f"coalesce({column_name}, {repr(value)}) as column_name"`` in the resulting - SQL. - - Args: - column_expressions: Keywords indicate which columns to coalesce, while the - string representation of the respective arguments are used as the - null-replacement. - - Return: - Relation: Relation where values have been filled in for nulls in the given - columns. - - Examples: - >>> import patito as pt - >>> df = pt.DataFrame( - ... { - ... "a": [1, None, 3], - ... "b": ["four", "five", None], - ... "c": [None, 8.0, 9.0], - ... } - ... ) - >>> relation = pt.duckdb.Relation(df) - >>> relation.coalesce(a=2, b="six").to_df() - shape: (3, 3) - ┌─────┬──────┬──────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ str ┆ f64 │ - ╞═════╪══════╪══════╡ - │ 1 ┆ four ┆ null │ - │ 2 ┆ five ┆ 8.0 │ - │ 3 ┆ six ┆ 9.0 │ - └─────┴──────┴──────┘ - """ - projections = [] - for column in self.columns: - if column in column_expressions: - expression = column_expressions[column] - projections.append(f"coalesce({column}, {expression!r}) as {column}") - else: - projections.append(column) - return cast(RelationType, self.select(*projections)) - - @property - def columns(self) -> List[str]: - """ - Return the columns of the relation as a list of strings. - - Examples: - >>> import patito as pt - >>> pt.duckdb.Relation("select 1 as a, 2 as b").columns - ['a', 'b'] - """ - # Under certain specific circumstances columns are suffixed with - # :1, which need to be removed from the column name. - return [column.partition(":")[0] for column in self._relation.columns] - - def count(self) -> int: - """ - Return the number of rows in the given relation. - - Returns: - Number of rows in the relation as an integer. - - Examples: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as a") - >>> relation.count() - 1 - >>> (relation + relation).count() - 2 - - The :ref:`Relation.__len__()` method invokes - ``Relation.count()`` under the hood, and is equivalent: - - >>> len(relation) - 1 - >>> len(relation + relation) - 2 - """ - return cast(Tuple[int], self._relation.aggregate("count(*)").fetchone())[0] - - def create_table(self: RelationType, name: str) -> RelationType: - """ - Create new database table based on relation. - - If ``self.model`` is set with - :ref:`Relation.set_model()`, then the model is used - to infer the table schema. Otherwise, a permissive table schema is created based - on the relation data. - - Returns: - Relation: A relation pointing to the newly created table. - - Examples: - >>> from typing import Literal - >>> import patito as pt - - >>> df = pt.DataFrame({"enum_column": ["A", "A", "B"]}) - >>> relation = pt.duckdb.Relation(df) - >>> relation.create_table("permissive_table").types - {'enum_column': VARCHAR} - - >>> class TableSchema(pt.Model): - ... enum_column: Literal["A", "B", "C"] - ... - >>> relation.set_model(TableSchema).create_table("strict_table").types - {'enum_column': enum__7ba49365cc1b0fd57e61088b3bc9aa25} - """ - if self.model is not None: - self.database.create_table(name=name, model=self.model) - self.insert_into(table=name) - else: - self._relation.create(table_name=name) - return cast(RelationType, self.database.table(name)) - - def create_view( - self: RelationType, - name: str, - replace: bool = False, - ) -> RelationType: - """ - Create new database view based on relation. - - Returns: - Relation: A relation pointing to the newly created view. - - Examples: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> df = pt.DataFrame({"column": ["A", "A", "B"]}) - >>> relation = db.to_relation(df) - >>> relation.create_view("my_view") - >>> db.query("select * from my_view").to_df() - shape: (3, 1) - ┌────────┐ - │ column │ - │ --- │ - │ str │ - ╞════════╡ - │ A │ - │ A │ - │ B │ - └────────┘ - """ - self._relation.create_view(view_name=name, replace=replace) - return cast(RelationType, self.database.view(name)) - - def drop(self, *columns: str) -> Relation: - """ - Remove specified column(s) from relation. - - Args: - columns (str): Any number of string column names to be dropped. - - Examples: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as a, 2 as b, 3 as c") - >>> relation.columns - ['a', 'b', 'c'] - >>> relation.drop("c").columns - ['a', 'b'] - >>> relation.drop("b", "c").columns - ['a'] - """ - new_columns = self.columns.copy() - for column in columns: - new_columns.remove(column) - return self[new_columns] - - def distinct(self: RelationType) -> RelationType: - """ - Drop all duplicate rows of the relation. - - Example: - >>> import patito as pt - >>> df = pt.DataFrame( - ... [[1, 2, 3], [1, 2, 3], [3, 2, 1]], - ... schema=["a", "b", "c"], - ... orient="row", - ... ) - >>> relation = pt.duckdb.Relation(df) - >>> relation.to_df() - shape: (3, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - │ 1 ┆ 2 ┆ 3 │ - │ 3 ┆ 2 ┆ 1 │ - └─────┴─────┴─────┘ - >>> relation.distinct().to_df() - shape: (2, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - │ 3 ┆ 2 ┆ 1 │ - └─────┴─────┴─────┘ - """ - return self._wrap(self._relation.distinct(), schema_change=False) - - def except_(self: RelationType, other: RelationSource) -> RelationType: - """ - Remove all rows that can be found in the other other relation. - - Args: - other: Another relation or something that can be casted to a relation. - - Returns: - New relation without the rows that can be found in the other relation. - - Example: - >>> import patito as pt - >>> relation_123 = pt.duckdb.Relation( - ... "select 1 union select 2 union select 3" - ... ) - >>> relation_123.order(by="1").to_df() - shape: (3, 1) - ┌─────┐ - │ 1 │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 2 │ - │ 3 │ - └─────┘ - >>> relation_2 = pt.duckdb.Relation("select 2") - >>> relation_2.to_df() - shape: (1, 1) - ┌─────┐ - │ 2 │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 2 │ - └─────┘ - >>> relation_123.except_(relation_2).order(by="1").to_df() - shape: (2, 1) - ┌─────┐ - │ 1 │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 3 │ - └─────┘ - """ - return self._wrap( - self._relation.except_(self.database.to_relation(other)._relation), - schema_change=False, - ) - - def execute(self) -> duckdb.DuckDBPyRelation: - """ - Execute built relation query and return result object. - - Returns: - A native ``duckdb.DuckDBPyResult`` object representing the executed query. - - Examples: - >>> import patito as pt - >>> relation = pt.duckdb.Relation( - ... "select 1 as a, 2 as b union select 3 as a, 4 as b" - ... ) - >>> result = relation.aggregate("sum(a)", group_by="").execute() - >>> result.description - [('sum(a)', 'NUMBER', None, None, None, None, None)] - >>> result.fetchall() - [(4,)] - """ - # A star-select is here performed in order to work around certain DuckDB bugs - return self._relation.project("*").execute() - - def get(self, *filters: str, **equalities: Union[str, int, float]) -> ModelType: - """ - Fetch the single row that matches the given filter(s). - - If you expect a relation to already return one row, you can use get() without - any arguments to return that row. - - Raises: - RuntimeError: RuntimeError is thrown if not exactly one single row matches - the given filter. - - Args: - filters (str): A conjunction of SQL where clauses. - equalities (Any): A conjunction of SQL equality clauses. The keyword name - is the column and the parameter is the value of the equality. - - Returns: - Model: A Patito model representing the given row. - - Examples: - >>> import patito as pt - >>> import polars as pl - >>> df = pt.DataFrame({"product_id": [1, 2, 3], "price": [10, 10, 20]}) - >>> relation = pt.duckdb.Relation(df).set_alias("my_relation") - - The ``.get()`` method will by default return a dynamically constructed - Patito model if no model has been associated with the given relation: - - >>> relation.get(product_id=1) - my_relation(product_id=1, price=10) - - If a Patito model has been associated with the relation, by the use of - :ref:`Relation.set_model()`, then the given model - will be used to represent the return type: - - >>> class Product(pt.Model): - ... product_id: int = pt.Field(unique=True) - ... price: float - ... - >>> relation.set_model(Product).get(product_id=1) - Product(product_id=1, price=10.0) - - You can invoke ``.get()`` without any arguments on relations containing - exactly one row: - - >>> relation.filter(product_id=1).get() - my_relation(product_id=1, price=10) - - If the given predicate matches multiple rows a ``MultipleRowsReturned`` - exception will be raised: - - >>> try: - ... relation.get(price=10) - ... except pt.exceptions.MultipleRowsReturned as e: - ... print(e) - ... - Relation.get(price=10) returned 2 rows! - - If the given predicate matches zero rows a ``RowDoesNotExist`` exception - will be raised: - - >>> try: - ... relation.get(price=0) - ... except pt.exceptions.RowDoesNotExist as e: - ... print(e) - ... - Relation.get(price=0) returned 0 rows! - """ - if filters or equalities: - relation = self.filter(*filters, **equalities) - else: - relation = self - result = relation.execute() - row = result.fetchone() - if row is None or result.fetchone() is not None: - args = [repr(f) for f in filters] - args.extend(f"{key}={value!r}" for key, value in equalities.items()) - args_string = ",".join(args) - - num_rows = relation.count() - if num_rows == 0: - raise RowDoesNotExist(f"Relation.get({args_string}) returned 0 rows!") - else: - raise MultipleRowsReturned( - f"Relation.get({args_string}) returned {num_rows} rows!" - ) - return self._to_model(row=row) - - def _to_model(self, row: tuple) -> ModelType: - """ - Cast row tuple to proper return type. - - If self.model is set, either by a class variable of a subclass or by the - invocation of Relation.set_model(), that type is used to construct the return - value. Otherwise, a pydantic model is dynamically created based on the column - schema of the relation. - """ - kwargs = {column: value for column, value in zip(self.columns, row)} - if self.model: - return self.model(**kwargs) - else: - RowModel = create_pydantic_model(relation=self._relation) - return cast( - ModelType, - RowModel(**kwargs), - ) - - def filter( - self: RelationType, - *filters: str, - **equalities: Union[str, int, float], - ) -> RelationType: - """ - Return subset of rows of relation that satisfy the given predicates. - - The method returns self if no filters are provided. - - Args: - filters: A conjunction of SQL ``WHERE`` clauses. - equalities: A conjunction of SQL equality clauses. The keyword name - is the column and the parameter is the value of the equality. - - Returns: - Relation: A new relation where all rows satisfy the given criteria. - - Examples: - >>> import patito as pt - >>> df = pt.DataFrame( - ... { - ... "number": [1, 2, 3, 4], - ... "string": ["A", "A", "B", "B"], - ... } - ... ) - >>> relation = pt.duckdb.Relation(df) - >>> relation.filter("number % 2 = 0").to_df() - shape: (2, 2) - ┌────────┬────────┐ - │ number ┆ string │ - │ --- ┆ --- │ - │ i64 ┆ str │ - ╞════════╪════════╡ - │ 2 ┆ A │ - │ 4 ┆ B │ - └────────┴────────┘ - - >>> relation.filter(number=1, string="A").to_df() - shape: (1, 2) - ┌────────┬────────┐ - │ number ┆ string │ - │ --- ┆ --- │ - │ i64 ┆ str │ - ╞════════╪════════╡ - │ 1 ┆ A │ - └────────┴────────┘ - """ - if not filters and not equalities: - return self - - clauses: List[str] = [] - if filters: - clauses.extend(filters) - if equalities: - clauses.extend(f"{key}={value!r}" for key, value in equalities.items()) - filter_string = " and ".join(f"({clause})" for clause in clauses) - return self._wrap(self._relation.filter(filter_string), schema_change=False) - - def join( - self: RelationType, - other: RelationSource, - *, - on: str, - how: Literal["inner", "left"] = "inner", - ) -> RelationType: - """ - Join relation with other relation source based on condition. - - See :ref:`duckdb.Relation.inner_join() ` and - :ref:`Relation.left_join() ` for alternative method - shortcuts instead of using ``how``. - - Args: - other: A source which can be casted to a ``Relation`` object, and be used - as the right table in the join. - on: Join condition following the ``INNER JOIN ... ON`` in the SQL query. - how: Either ``"left"`` or ``"inner"`` for what type of SQL join operation to - perform. - - Returns: - Relation: New relation based on the joined relations. - - Example: - >>> import patito as pt - >>> products_df = pt.DataFrame( - ... { - ... "product_name": ["apple", "banana", "oranges"], - ... "supplier_id": [2, 1, 3], - ... } - ... ) - >>> products = pt.duckdb.Relation(products_df) - >>> supplier_df = pt.DataFrame( - ... { - ... "id": [1, 2], - ... "supplier_name": ["Banana Republic", "Applies Inc."], - ... } - ... ) - >>> suppliers = pt.duckdb.Relation(supplier_df) - >>> products.set_alias("p").join( - ... suppliers.set_alias("s"), - ... on="p.supplier_id = s.id", - ... how="inner", - ... ).to_df() - shape: (2, 4) - ┌──────────────┬─────────────┬─────┬─────────────────┐ - │ product_name ┆ supplier_id ┆ id ┆ supplier_name │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 ┆ str │ - ╞══════════════╪═════════════╪═════╪═════════════════╡ - │ apple ┆ 2 ┆ 2 ┆ Applies Inc. │ - │ banana ┆ 1 ┆ 1 ┆ Banana Republic │ - └──────────────┴─────────────┴─────┴─────────────────┘ - - >>> products.set_alias("p").join( - ... suppliers.set_alias("s"), - ... on="p.supplier_id = s.id", - ... how="left", - ... ).to_df() - shape: (3, 4) - ┌──────────────┬─────────────┬──────┬─────────────────┐ - │ product_name ┆ supplier_id ┆ id ┆ supplier_name │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 ┆ str │ - ╞══════════════╪═════════════╪══════╪═════════════════╡ - │ apple ┆ 2 ┆ 2 ┆ Applies Inc. │ - │ banana ┆ 1 ┆ 1 ┆ Banana Republic │ - │ oranges ┆ 3 ┆ null ┆ null │ - └──────────────┴─────────────┴──────┴─────────────────┘ - """ - return self._wrap( - self._relation.join( - self.database.to_relation(other)._relation, condition=on, how=how - ), - schema_change=True, - ) - - def inner_join(self: RelationType, other: RelationSource, on: str) -> RelationType: - """ - Inner join relation with other relation source based on condition. - - Args: - other: A source which can be casted to a ``Relation`` object, and be used - as the right table in the join. - on: Join condition following the ``INNER JOIN ... ON`` in the SQL query. - - Returns: - Relation: New relation based on the joined relations. - - Example: - >>> import patito as pt - >>> products_df = pt.DataFrame( - ... { - ... "product_name": ["apple", "banana", "oranges"], - ... "supplier_id": [2, 1, 3], - ... } - ... ) - >>> products = pt.duckdb.Relation(products_df) - >>> supplier_df = pt.DataFrame( - ... { - ... "id": [1, 2], - ... "supplier_name": ["Banana Republic", "Applies Inc."], - ... } - ... ) - >>> suppliers = pt.duckdb.Relation(supplier_df) - >>> products.set_alias("p").inner_join( - ... suppliers.set_alias("s"), - ... on="p.supplier_id = s.id", - ... ).to_df() - shape: (2, 4) - ┌──────────────┬─────────────┬─────┬─────────────────┐ - │ product_name ┆ supplier_id ┆ id ┆ supplier_name │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 ┆ str │ - ╞══════════════╪═════════════╪═════╪═════════════════╡ - │ apple ┆ 2 ┆ 2 ┆ Applies Inc. │ - │ banana ┆ 1 ┆ 1 ┆ Banana Republic │ - └──────────────┴─────────────┴─────┴─────────────────┘ - """ - return self._wrap( - self._relation.join( - other_rel=self.database.to_relation(other)._relation, - condition=on, - how="inner", - ), - schema_change=True, - ) - - def left_join(self: RelationType, other: RelationSource, on: str) -> RelationType: - """ - Left join relation with other relation source based on condition. - - Args: - other: A source which can be casted to a Relation object, and be used as - the right table in the join. - on: Join condition following the ``LEFT JOIN ... ON`` in the SQL query. - - Returns: - Relation: New relation based on the joined tables. - - Example: - >>> import patito as pt - >>> products_df = pt.DataFrame( - ... { - ... "product_name": ["apple", "banana", "oranges"], - ... "supplier_id": [2, 1, 3], - ... } - ... ) - >>> products = pt.duckdb.Relation(products_df) - >>> supplier_df = pt.DataFrame( - ... { - ... "id": [1, 2], - ... "supplier_name": ["Banana Republic", "Applies Inc."], - ... } - ... ) - >>> suppliers = pt.duckdb.Relation(supplier_df) - >>> products.set_alias("p").left_join( - ... suppliers.set_alias("s"), - ... on="p.supplier_id = s.id", - ... ).to_df() - shape: (3, 4) - ┌──────────────┬─────────────┬──────┬─────────────────┐ - │ product_name ┆ supplier_id ┆ id ┆ supplier_name │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ i64 ┆ i64 ┆ str │ - ╞══════════════╪═════════════╪══════╪═════════════════╡ - │ apple ┆ 2 ┆ 2 ┆ Applies Inc. │ - │ banana ┆ 1 ┆ 1 ┆ Banana Republic │ - │ oranges ┆ 3 ┆ null ┆ null │ - └──────────────┴─────────────┴──────┴─────────────────┘ - """ - return self._wrap( - self._relation.join( - other_rel=self.database.to_relation(other)._relation, - condition=on, - how="left", - ), - schema_change=True, - ) - - def limit(self: RelationType, n: int, *, offset: int = 0) -> RelationType: - """ - Remove all but the first n rows. - - Args: - n: The number of rows to keep. - offset: Disregard the first ``offset`` rows before starting to count which - rows to keep. - - Returns: - New relation with only n rows. - - Example: - >>> import patito as pt - >>> relation = ( - ... pt.duckdb.Relation("select 1 as column") - ... + pt.duckdb.Relation("select 2 as column") - ... + pt.duckdb.Relation("select 3 as column") - ... + pt.duckdb.Relation("select 4 as column") - ... ) - >>> relation.limit(2).to_df() - shape: (2, 1) - ┌────────┐ - │ column │ - │ --- │ - │ i64 │ - ╞════════╡ - │ 1 │ - │ 2 │ - └────────┘ - >>> relation.limit(2, offset=2).to_df() - shape: (2, 1) - ┌────────┐ - │ column │ - │ --- │ - │ i64 │ - ╞════════╡ - │ 3 │ - │ 4 │ - └────────┘ - """ - return self._wrap(self._relation.limit(n=n, offset=offset), schema_change=False) - - def order(self: RelationType, by: Union[str, Iterable[str]]) -> RelationType: - """ - Change the order of the rows of the relation. - - Args: - by: An ``ORDER BY`` SQL expression such as ``"age DESC"`` or - ``("age DESC", "name ASC")``. - - Returns: - New relation where the rows have been ordered according to ``by``. - - Example: - >>> import patito as pt - >>> df = pt.DataFrame( - ... { - ... "name": ["Alice", "Bob", "Charles", "Diana"], - ... "age": [20, 20, 30, 35], - ... } - ... ) - >>> df - shape: (4, 2) - ┌─────────┬─────┐ - │ name ┆ age │ - │ --- ┆ --- │ - │ str ┆ i64 │ - ╞═════════╪═════╡ - │ Alice ┆ 20 │ - │ Bob ┆ 20 │ - │ Charles ┆ 30 │ - │ Diana ┆ 35 │ - └─────────┴─────┘ - >>> relation = pt.duckdb.Relation(df) - >>> relation.order(by="age desc").to_df() - shape: (4, 2) - ┌─────────┬─────┐ - │ name ┆ age │ - │ --- ┆ --- │ - │ str ┆ i64 │ - ╞═════════╪═════╡ - │ Diana ┆ 35 │ - │ Charles ┆ 30 │ - │ Alice ┆ 20 │ - │ Bob ┆ 20 │ - └─────────┴─────┘ - >>> relation.order(by=["age desc", "name desc"]).to_df() - shape: (4, 2) - ┌─────────┬─────┐ - │ name ┆ age │ - │ --- ┆ --- │ - │ str ┆ i64 │ - ╞═════════╪═════╡ - │ Diana ┆ 35 │ - │ Charles ┆ 30 │ - │ Bob ┆ 20 │ - │ Alice ┆ 20 │ - └─────────┴─────┘ - """ - order_expr = by if isinstance(by, str) else ", ".join(by) - return self._wrap( - self._relation.order(order_expr=order_expr), - schema_change=False, - ) - - def insert_into( - self: RelationType, - table: str, - ) -> RelationType: - """ - Insert all rows of the relation into a given table. - - The relation must contain all the columns present in the target table. - Extra columns are ignored and the column order is automatically matched - with the target table. - - Args: - table: Name of table for which to insert values into. - - Returns: - Relation: The original relation, i.e. ``self``. - - Examples: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> db.to_relation("select 1 as a").create_table("my_table") - >>> db.table("my_table").to_df() - shape: (1, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - └─────┘ - >>> db.to_relation("select 2 as a").insert_into("my_table") - >>> db.table("my_table").to_df() - shape: (2, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 2 │ - └─────┘ - """ - table_relation = self.database.table(table) - missing_columns = set(table_relation.columns) - set(self.columns) - if missing_columns: - raise TypeError( - f"Relation is missing column(s) {missing_columns} " - f"in order to be inserted into table '{table}'!", - ) - - reordered_relation = self[table_relation.columns] - reordered_relation._relation.insert_into(table_name=table) - return self - - def intersect(self: RelationType, other: RelationSource) -> RelationType: - """ - Return a new relation containing the rows that are present in both relations. - - This is a set operation which will remove duplicate rows as well. - - Args: - other: Another relation with the same column names. - - Returns: - Relation[Model]: A new relation with only those rows that are present in - both relations. - - Example: - >>> import patito as pt - >>> df1 = pt.DataFrame({"a": [1, 1, 2], "b": [1, 1, 2]}) - >>> df2 = pt.DataFrame({"a": [1, 1, 3], "b": [1, 1, 3]}) - >>> pt.duckdb.Relation(df1).intersect(pt.duckdb.Relation(df2)).to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 1 │ - └─────┴─────┘ - """ - other = self.database.to_relation(other) - return self._wrap( - self._relation.intersect(other._relation), - schema_change=False, - ) - - def select( - self, - *projections: Union[str, int, float], - **named_projections: Union[str, int, float], - ) -> Relation: - """ - Return relation based on one or more SQL ``SELECT`` projections. - - Keyword arguments are converted into ``{arg} as {keyword}`` in the executed SQL - query. - - Args: - *projections: One or more strings representing SQL statements to be - selected. For example ``"2"`` or ``"another_column"``. - **named_projections: One ore more keyword arguments where the keyword - specifies the name of the new column and the value is an SQL statement - defining the content of the new column. For example - ``new_column="2 * another_column"``. - - Examples: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> relation = db.to_relation(pt.DataFrame({"original_column": [1, 2, 3]})) - >>> relation.select("*").to_df() - shape: (3, 1) - ┌─────────────────┐ - │ original_column │ - │ --- │ - │ i64 │ - ╞═════════════════╡ - │ 1 │ - │ 2 │ - │ 3 │ - └─────────────────┘ - >>> relation.select("*", multiplied_column="2 * original_column").to_df() - shape: (3, 2) - ┌─────────────────┬───────────────────┐ - │ original_column ┆ multiplied_column │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════════════════╪═══════════════════╡ - │ 1 ┆ 2 │ - │ 2 ┆ 4 │ - │ 3 ┆ 6 │ - └─────────────────┴───────────────────┘ - """ - # We expand '*' to an explicit list of columns in order to support redefining - # columns within the star expressed columns. - expanded_projections: list = list(projections) - try: - star_index = projections.index("*") - if named_projections: - # Allow explicitly named projections to overwrite star-selected columns - expanded_projections[star_index : star_index + 1] = [ - column for column in self.columns if column not in named_projections - ] - else: - expanded_projections[star_index : star_index + 1] = self.columns - except ValueError: - pass - - projection = ", ".join( - expanded_projections - + list( # pyright: ignore - f"{expression} as {column_name}" - for column_name, expression in named_projections.items() - ) - ) - try: - relation = self._relation.project(projection) - except RuntimeError as exc: # pragma: no cover - # We might get a RunTime error if the enum type has not - # been created yet. If so, we create all enum types for - # this model. - if self.model is not None and _is_missing_enum_type_exception(exc): - self.database.create_enum_types(model=self.model) - relation = self._relation.project(projection) - else: - raise exc - return self._wrap(relation=relation, schema_change=True) - - def rename(self, **columns: str) -> Relation: - """ - Rename columns as specified. - - Args: - **columns: A set of keyword arguments where the keyword is the old column - name and the value is the new column name. - - Raises: - ValueError: If any of the given keywords do not exist as columns in the - relation. - - Examples: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as a, 2 as b") - >>> relation.rename(b="c").to_df().select(["a", "c"]) - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ c │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 2 │ - └─────┴─────┘ - """ - existing_columns = set(self.columns) - missing = set(columns.keys()) - set(existing_columns) - if missing: - raise ValueError( - f"Column '{missing.pop()}' can not be renamed as it does not exist. " - f"The columns of the relation are: {', '.join(existing_columns)}." - ) - # If we rename a column to overwrite another existing one, the column should - # be overwritten. - existing_columns = set(existing_columns) - set(columns.values()) - relation = self._relation.project( - ", ".join( - f"{column} as {columns.get(column, column)}" - for column in existing_columns - ) - ) - return self._wrap(relation=relation, schema_change=True) - - def set_alias(self: RelationType, name: str) -> RelationType: - """ - Set SQL alias for the given relation to be used in further queries. - - Args: - name: The new alias for the given relation. - - Returns: - Relation: A new relation containing the same query but addressable with the - new alias. - - Example: - >>> import patito as pt - >>> relation_1 = pt.duckdb.Relation("select 1 as a, 2 as b") - >>> relation_2 = pt.duckdb.Relation("select 1 as a, 3 as c") - >>> relation_1.set_alias("x").inner_join( - ... relation_2.set_alias("y"), - ... on="x.a = y.a", - ... ).select("x.a", "y.a", "b", "c").to_df() - shape: (1, 4) - ┌─────┬─────┬─────┬─────┐ - │ a ┆ a:1 ┆ b ┆ c │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╪═════╡ - │ 1 ┆ 1 ┆ 2 ┆ 3 │ - └─────┴─────┴─────┴─────┘ - """ - return self._wrap( - self._relation.set_alias(name), - schema_change=False, - ) - - def set_model(self, model): # type: ignore[no-untyped-def] # noqa: ANN - """ - Associate a give Patito model with the relation. - - The returned relation has an associated ``.model`` attribute which can in turn - be used by several methods such as :ref:`Relation.get()`, - :ref:`Relation.create_table()`, and - :ref:`Relation.__iter__`. - - Args: - model: A Patito Model class specifying the intended schema of the relation. - - Returns: - Relation[model]: A new relation with the associated model. - - Example: - >>> from typing import Literal - >>> import patito as pt - >>> class MySchema(pt.Model): - ... float_column: float - ... enum_column: Literal["A", "B", "C"] - ... - >>> relation = pt.duckdb.Relation( - ... "select 1 as float_column, 'A' as enum_column" - ... ) - >>> relation.get() - query_relation(float_column=1, enum_column='A') - >>> relation.set_model(MySchema).get() - MySchema(float_column=1.0, enum_column='A') - >>> relation.create_table("unmodeled_table").types - {'float_column': INTEGER, 'enum_column': VARCHAR} - >>> relation.set_model(MySchema).create_table("modeled_table").types - {'float_column': DOUBLE, - 'enum_column': enum__7ba49365cc1b0fd57e61088b3bc9aa25} - """ - # We are not able to annotate the generic instance of type(self)[type(model)] - # due to the lack of higher-kinded generics in python as of this writing. - # See: https://github.com/python/typing/issues/548 - # This cast() will be wrong for sub-classes of Relation... - return cast( - Relation[model], - type(self)( - derived_from=self._relation, - database=self.database, - model=model, - ), - ) - - @property - def types(self): # type: ignore[no-untyped-def] # noqa - """ - Return the SQL types of all the columns of the given relation. - - Returns: - dict[str, str]: A dictionary where the keys are the column names and the - values are SQL types as strings. - - Examples: - >>> import patito as pt - >>> pt.duckdb.Relation("select 1 as a, 'my_value' as b").types - {'a': INTEGER, 'b': VARCHAR} - """ - return dict(zip(self.columns, self._relation.types)) - - def to_pandas(self) -> "pd.DataFrame": - """ - Return a pandas DataFrame representation of relation object. - - Returns: A ``pandas.DataFrame`` object containing all the data of the relation. - - Example: - >>> import patito as pt - >>> pt.duckdb.Relation("select 1 as column union select 2 as column").order( - ... by="1" - ... ).to_pandas() - column - 0 1 - 1 2 - """ - return self._relation.to_df() - - def to_df(self) -> DataFrame: - """ - Return a polars DataFrame representation of relation object. - - Returns: A ``patito.DataFrame`` object which inherits from ``polars.DataFrame``. - - Example: - >>> import patito as pt - >>> pt.duckdb.Relation("select 1 as column union select 2 as column").order( - ... by="1" - ... ).to_df() - shape: (2, 1) - ┌────────┐ - │ column │ - │ --- │ - │ i64 │ - ╞════════╡ - │ 1 │ - │ 2 │ - └────────┘ - """ - # Here we do a star-select to work around certain weird issues with DuckDB - self._relation = self._relation.project("*") - arrow_table = cast(pa.lib.Table, self._relation.to_arrow_table()) - try: - # We cast `INTEGER`-typed columns to `pl.Int64` when converting to Polars - # because polars is much more eager to store integer Series as 64-bit - # integers. Otherwise there must be done a lot of manual casting whenever - # you cross the boundary between DuckDB and polars. - return DataFrame._from_arrow(arrow_table).with_columns( - pl.col(pl.Int32).cast(pl.Int64) - ) - except (pa.ArrowInvalid, pl.ArrowError): # pragma: no cover - # Empty relations with enum columns can sometimes produce errors. - # As a last-ditch effort, we convert such columns to VARCHAR. - casted_columns = [ - f"{field.name}::VARCHAR as {field.name}" - if isinstance(field.type, pa.DictionaryType) - else field.name - for field in arrow_table.schema - ] - non_enum_relation = self._relation.project(", ".join(casted_columns)) - arrow_table = non_enum_relation.to_arrow_table() - return DataFrame._from_arrow(arrow_table).with_columns( - pl.col(pl.Int32).cast(pl.Int64) - ) - - def to_series(self) -> pl.Series: - """ - Convert the given relation to a polars Series. - - Raises: - TypeError: If the given relation does not contain exactly one column. - - Returns: A ``polars.Series`` object containing the data of the relation. - - Example: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as a union select 2 as a") - >>> relation.order(by="a").to_series() - shape: (2,) - Series: 'a' [i32] - [ - 1 - 2 - ] - """ - if len(self._relation.columns) != 1: - raise TypeError( - f"{self.__class__.__name__}.to_series() was invoked on a relation with " - f"{len(self._relation.columns)} columns, while exactly 1 is required!" - ) - dataframe: DataFrame = DataFrame._from_arrow(self._relation.to_arrow_table()) - return dataframe.to_series(index=0).alias(name=self.columns[0]) - - def union(self: RelationType, other: RelationSource) -> RelationType: - """ - Produce a new relation that contains the rows of both relations. - - The ``+`` operator can also be used to union two relations. - - The two relations must have the same column names, but not necessarily in the - same order as reordering of columns is automatically performed, unlike regular - SQL. - - Duplicates are `not` dropped. - - Args: - other: A ``patito.duckdb.Relation`` object or something that can be - *casted* to ``patito.duckdb.Relation``. - See :ref:`Relation`. - - Returns: - New relation containing the rows of both ``self`` and ``other``. - - Raises: - TypeError: If the two relations do not contain the same columns. - - Examples: - >>> import patito as pt - >>> relation_1 = pt.duckdb.Relation("select 1 as a") - >>> relation_2 = pt.duckdb.Relation("select 2 as a") - >>> relation_1.union(relation_2).to_df() - shape: (2, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 2 │ - └─────┘ - - >>> (relation_1 + relation_2).to_df() - shape: (2, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 2 │ - └─────┘ - """ - other_relation = self.database.to_relation(other) - if set(self.columns) != set(other_relation.columns): - msg = "Union between relations with different column names is not allowed." - additional_left = set(self.columns) - set(other_relation.columns) - additional_right = set(other_relation.columns) - set(self.columns) - if additional_left: - msg += f" Additional columns in left relation: {additional_left}." - if additional_right: - msg += f" Additional columns in right relation: {additional_right}." - raise TypeError(msg) - if other_relation.columns != self.columns: - reordered_relation = other_relation[self.columns] - else: - reordered_relation = other_relation - unioned_relation = self._relation.union(reordered_relation._relation) - return self._wrap(relation=unioned_relation, schema_change=False) - - def with_columns( - self, - **named_projections: Union[str, int, float], - ) -> Relation: - """ - Return relations with additional columns. - - If the provided columns expressions already exists as a column on the relation, - the given column is overwritten. - - Args: - named_projections: A set of column expressions, where the keyword is used - as the column name, while the right-hand argument is a valid SQL - expression. - - Returns: - Relation with the given columns appended, or possibly overwritten. - - Examples: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> relation = db.to_relation("select 1 as a, 2 as b") - >>> relation.with_columns(c="a + b").to_df() - shape: (1, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - └─────┴─────┴─────┘ - """ - return self.select("*", **named_projections) - - def with_missing_defaultable_columns( - self: RelationType, - include: Optional[Iterable[str]] = None, - exclude: Optional[Iterable[str]] = None, - ) -> RelationType: - """ - Add missing defaultable columns filled with the default values of correct type. - - Make sure to invoke :ref:`Relation.set_model()` with - the correct model schema before executing - ``Relation.with_missing_default_columns()``. - - Args: - include: If provided, only fill in default values for missing columns part - of this collection of column names. - exclude: If provided, do `not` fill in default values for missing columns - part of this collection of column names. - - Returns: - Relation: New relation where missing columns with default values according - to the schema have been filled in. - - Example: - >>> import patito as pt - >>> class MyModel(pt.Model): - ... non_default_column: int - ... another_non_default_column: int - ... default_column: int = 42 - ... another_default_column: int = 42 - ... - >>> relation = pt.duckdb.Relation( - ... "select 1 as non_default_column, 2 as default_column" - ... ) - >>> relation.to_df() - shape: (1, 2) - ┌────────────────────┬────────────────┐ - │ non_default_column ┆ default_column │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞════════════════════╪════════════════╡ - │ 1 ┆ 2 │ - └────────────────────┴────────────────┘ - >>> relation.set_model(MyModel).with_missing_defaultable_columns().to_df() - shape: (1, 3) - ┌────────────────────┬────────────────┬────────────────────────┐ - │ non_default_column ┆ default_column ┆ another_default_column │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞════════════════════╪════════════════╪════════════════════════╡ - │ 1 ┆ 2 ┆ 42 │ - └────────────────────┴────────────────┴────────────────────────┘ - """ - if self.model is None: - class_name = self.__class__.__name__ - raise TypeError( - f"{class_name}.with_missing_default_columns() invoked without " - f"{class_name}.model having been set! " - f"You should invoke {class_name}.set_model() first!" - ) - elif include is not None and exclude is not None: - raise TypeError("Both include and exclude provided at the same time!") - - missing_columns = set(self.model.columns) - set(self.columns) - defaultable_columns = self.model.defaults.keys() - missing_defaultable_columns = missing_columns & defaultable_columns - - if exclude is not None: - missing_defaultable_columns -= set(exclude) - elif include is not None: - missing_defaultable_columns = missing_defaultable_columns & set(include) - - projection = "*" - for column_name in missing_defaultable_columns: - sql_type = self.model.sql_types[column_name] - default_value = self.model.defaults[column_name] - projection += f", {default_value!r}::{sql_type} as {column_name}" - - try: - relation = self._relation.project(projection) - except Exception as exc: # pragma: no cover - # We might get a RunTime error if the enum type has not - # been created yet. If so, we create all enum types for - # this model. - if _is_missing_enum_type_exception(exc): - self.database.create_enum_types(model=self.model) - relation = self._relation.project(projection) - else: - raise exc - return self._wrap(relation=relation, schema_change=False) - - def with_missing_nullable_columns( - self: RelationType, - include: Optional[Iterable[str]] = None, - exclude: Optional[Iterable[str]] = None, - ) -> RelationType: - """ - Add missing nullable columns filled with correctly typed nulls. - - Make sure to invoke :ref:`Relation.set_model()` with - the correct model schema before executing - ``Relation.with_missing_nullable_columns()``. - - Args: - include: If provided, only fill in null values for missing columns part of - this collection of column names. - exclude: If provided, do `not` fill in null values for missing columns - part of this collection of column names. - - Returns: - Relation: New relation where missing nullable columns have been filled in - with null values. - - Example: - >>> from typing import Optional - >>> import patito as pt - >>> class MyModel(pt.Model): - ... non_nullable_column: int - ... nullable_column: Optional[int] - ... another_nullable_column: Optional[int] - ... - >>> relation = pt.duckdb.Relation("select 1 as nullable_column") - >>> relation.to_df() - shape: (1, 1) - ┌─────────────────┐ - │ nullable_column │ - │ --- │ - │ i64 │ - ╞═════════════════╡ - │ 1 │ - └─────────────────┘ - >>> relation.set_model(MyModel).with_missing_nullable_columns().to_df() - shape: (1, 2) - ┌─────────────────┬─────────────────────────┐ - │ nullable_column ┆ another_nullable_column │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════════════════╪═════════════════════════╡ - │ 1 ┆ null │ - └─────────────────┴─────────────────────────┘ - """ - if self.model is None: - class_name = self.__class__.__name__ - raise TypeError( - f"{class_name}.with_missing_nullable_columns() invoked without " - f"{class_name}.model having been set! " - f"You should invoke {class_name}.set_model() first!" - ) - elif include is not None and exclude is not None: - raise TypeError("Both include and exclude provided at the same time!") - - missing_columns = set(self.model.columns) - set(self.columns) - missing_nullable_columns = self.model.nullable_columns & missing_columns - - if exclude is not None: - missing_nullable_columns -= set(exclude) - elif include is not None: - missing_nullable_columns = missing_nullable_columns & set(include) - - projection = "*" - for missing_nullable_column in missing_nullable_columns: - sql_type = self.model.sql_types[missing_nullable_column] - projection += f", null::{sql_type} as {missing_nullable_column}" - - try: - relation = self._relation.project(projection) - except Exception as exc: # pragma: no cover - # We might get a RunTime error if the enum type has not - # been created yet. If so, we create all enum types for - # this model. - if _is_missing_enum_type_exception(exc): - self.database.create_enum_types(model=self.model) - relation = self._relation.project(projection) - else: - raise exc - return self._wrap(relation=relation, schema_change=False) - - def __add__(self: RelationType, other: RelationSource) -> RelationType: - """ - Execute ``self.union(other)``. - - See :ref:`Relation.union()` for full documentation. - """ - return self.union(other) - - def __eq__(self, other: object) -> bool: - """Check if Relation is equal to a Relation-able data source.""" - other_relation = self.database.to_relation(other) # type: ignore - # Check if the number of rows are equal, and then check if each row is equal. - # Use zip(self, other_relation, strict=True) when we upgrade to Python 3.10. - return self.count() == other_relation.count() and all( - row == other_row for row, other_row in zip(self, other_relation) - ) - - def __getitem__(self, key: Union[str, Iterable[str]]) -> Relation: - """ - Return Relation with selected columns. - - Uses :ref:`Relation.select()` under-the-hood in order to - perform the selection. Can technically be used to rename columns, - define derived columns, and so on, but prefer the use of Relation.select() for - such use cases. - - Args: - key: Columns to select, either a single column represented as a string, or - an iterable of strings. - - Returns: - New relation only containing the column subset specified. - - Example: - >>> import patito as pt - >>> relation = pt.duckdb.Relation("select 1 as a, 2 as b, 3 as c") - >>> relation.to_df() - shape: (1, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - └─────┴─────┴─────┘ - >>> relation[["a", "b"]].to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 2 │ - └─────┴─────┘ - >>> relation["a"].to_df() - shape: (1, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - └─────┘ - """ - projection = key if isinstance(key, str) else ", ".join(key) - return self._wrap( - relation=self._relation.project(projection), - schema_change=True, - ) - - def __iter__(self) -> Iterator[ModelType]: - """ - Iterate over rows in relation. - - If :ref:`Relation.set_model()` has been invoked - first, the given model will be used to deserialize each row. Otherwise a Patito - model is dynamically constructed which fits the schema of the relation. - - Returns: - Iterator[Model]: An iterator of patito Model objects representing each row. - - Example: - >>> from typing import Literal - >>> import patito as pt - >>> df = pt.DataFrame({"float_column": [1, 2], "enum_column": ["A", "B"]}) - >>> relation = pt.duckdb.Relation(df).set_alias("my_relation") - >>> for row in relation: - ... print(row) - ... - float_column=1 enum_column='A' - float_column=2 enum_column='B' - >>> list(relation) - [my_relation(float_column=1, enum_column='A'), - my_relation(float_column=2, enum_column='B')] - - >>> class MySchema(pt.Model): - ... float_column: float - ... enum_column: Literal["A", "B", "C"] - ... - >>> relation = relation.set_model(MySchema) - >>> for row in relation: - ... print(row) - ... - float_column=1.0 enum_column='A' - float_column=2.0 enum_column='B' - >>> list(relation) - [MySchema(float_column=1.0, enum_column='A'), - MySchema(float_column=2.0, enum_column='B')] - """ - result = self._relation.execute() - while True: - row_tuple = result.fetchone() - if not row_tuple: - return - else: - yield self._to_model(row_tuple) - - def __len__(self) -> int: - """ - Return the number of rows in the relation. - - See :ref:`Relation.count()` for full documentation. - """ - return self.count() - - def __str__(self) -> str: - """ - Return string representation of Relation object. - - Includes an expression tree, the result columns, and a result preview. - - Example: - >>> import patito as pt - >>> products = pt.duckdb.Relation( - ... pt.DataFrame( - ... { - ... "product_name": ["apple", "red_apple", "banana", "oranges"], - ... "supplier_id": [2, 2, 1, 3], - ... } - ... ) - ... ).set_alias("products") - >>> print(str(products)) # xdoctest: +SKIP - --------------------- - --- Relation Tree --- - --------------------- - arrow_scan(94609350519648, 140317161740928, 140317161731168, 1000000)\ - - --------------------- - -- Result Columns -- - --------------------- - - product_name (VARCHAR) - - supplier_id (BIGINT)\ - - --------------------- - -- Result Preview -- - --------------------- - product_name supplier_id - VARCHAR BIGINT - [ Rows: 4] - apple 2 - red_apple 2 - banana 1 - oranges 3 - - >>> suppliers = pt.duckdb.Relation( - ... pt.DataFrame( - ... { - ... "id": [1, 2], - ... "supplier_name": ["Banana Republic", "Applies Inc."], - ... } - ... ) - ... ).set_alias("suppliers") - >>> relation = ( - ... products.set_alias("p") - ... .inner_join( - ... suppliers.set_alias("s"), - ... on="p.supplier_id = s.id", - ... ) - ... .aggregate( - ... "supplier_name", - ... num_products="count(product_name)", - ... group_by=["supplier_id", "supplier_name"], - ... ) - ... ) - >>> print(str(relation)) # xdoctest: +SKIP - --------------------- - --- Relation Tree --- - --------------------- - Aggregate [supplier_name, count(product_name)] - Join INNER p.supplier_id = s.id - arrow_scan(94609350519648, 140317161740928, 140317161731168, 1000000) - arrow_scan(94609436221024, 140317161740928, 140317161731168, 1000000)\ - - --------------------- - -- Result Columns -- - --------------------- - - supplier_name (VARCHAR) - - num_products (BIGINT)\ - - --------------------- - -- Result Preview -- - --------------------- - supplier_name num_products - VARCHAR BIGINT - [ Rows: 2] - Applies Inc. 2 - Banana Republic 1 - - """ - return str(self._relation) - - def _wrap( - self: RelationType, - relation: "duckdb.DuckDBPyRelation", - schema_change: bool = False, - ) -> RelationType: - """ - Wrap DuckDB Relation object in same Relation wrapper class as self. - - This will preserve the type of the relation, even for subclasses Relation. - It should therefore only be used for relations which can be considered schema- - compatible with the original relation. Otherwise set schema_change to True - in order to create a Relation base object instead. - """ - return type(self)( - derived_from=relation, - database=self.database, - model=self.model if not schema_change else None, - ) - - -class Database: - # Types created in order to represent enum strings - enum_types: Set[str] - - def __init__( - self, - path: Optional[Path] = None, - read_only: bool = False, - **kwargs: Any, # noqa: ANN401 - ) -> None: - """ - Instantiate a new DuckDB database, either persisted to disk or in-memory. - - Args: - path: Optional path to store all the data to. If ``None`` the data is - persisted in-memory only. - read_only: If the database connection should be a read-only connection. - **kwargs: Additional keywords forwarded to ``duckdb.connect()``. - - Examples: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> db.to_relation("select 1 as a, 2 as b").create_table("my_table") - >>> db.query("select * from my_table").to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 2 │ - └─────┴─────┘ - """ - import duckdb - - self.path = path - self.connection = duckdb.connect( - database=str(path) if path else ":memory:", - read_only=read_only, - **kwargs, - ) - self.enum_types: Set[str] = set() - - @classmethod - def default(cls) -> Database: - """ - Return the default DuckDB database. - - Returns: - A patito :ref:`Database` object wrapping around the given - connection. - - Example: - >>> import patito as pt - >>> db = pt.duckdb.Database.default() - >>> db.query("select 1 as a, 2 as b").to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 2 │ - └─────┴─────┘ - """ - import duckdb - - return cls.from_connection(duckdb.default_connection) - - @classmethod - def from_connection(cls, connection: "duckdb.DuckDBPyConnection") -> Database: - """ - Create database from native DuckDB connection object. - - Args: - connection: A native DuckDB connection object created with - ``duckdb.connect()``. - - Returns: - A :ref:`Database` object wrapping around the given - connection. - - Example: - >>> import duckdb - >>> import patito as pt - >>> connection = duckdb.connect() - >>> database = pt.duckdb.Database.from_connection(connection) - """ - obj = cls.__new__(cls) - obj.connection = connection - obj.enum_types = set() - return obj - - def to_relation( - self, - derived_from: RelationSource, - ) -> Relation: - """ - Create a new relation object based on data source. - - The given data will be represented as a relation associated with the database. - ``Database(x).to_relation(y)`` is equivalent to - ``Relation(y, database=Database(x))``. - - Args: - derived_from (RelationSource): One of either a polars or pandas - ``DataFrame``, a ``pathlib.Path`` to a parquet or CSV file, a SQL query - string, or an existing relation. - - Example: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> db.to_relation("select 1 as a, 2 as b").to_df() - shape: (1, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 2 │ - └─────┴─────┘ - >>> db.to_relation(pt.DataFrame({"c": [3, 4], "d": ["5", "6"]})).to_df() - shape: (2, 2) - ┌─────┬─────┐ - │ c ┆ d │ - │ --- ┆ --- │ - │ i64 ┆ str │ - ╞═════╪═════╡ - │ 3 ┆ 5 │ - │ 4 ┆ 6 │ - └─────┴─────┘ - """ - return Relation( - derived_from=derived_from, - database=self, - ) - - def execute( - self, - query: str, - *parameters: Collection[Union[str, int, float, bool]], - ) -> None: - """ - Execute SQL query in DuckDB database. - - Args: - query: A SQL statement to execute. Does `not` have to be terminated with - a semicolon (``;``). - parameters: One or more sets of parameters to insert into prepared - statements. The values are replaced in place of the question marks - (``?``) in the prepared query. - - Example: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> db.execute("create table my_table (x bigint);") - >>> db.execute("insert into my_table values (1), (2), (3)") - >>> db.table("my_table").to_df() - shape: (3, 1) - ┌─────┐ - │ x │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 2 │ - │ 3 │ - └─────┘ - - Parameters can be specified when executing prepared queries. - - >>> db.execute("delete from my_table where x = ?", (2,)) - >>> db.table("my_table").to_df() - shape: (2, 1) - ┌─────┐ - │ x │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - │ 3 │ - └─────┘ - - Multiple parameter sets can be specified when executing multiple prepared - queries. - - >>> db.execute( - ... "delete from my_table where x = ?", - ... (1,), - ... (3,), - ... ) - >>> db.table("my_table").to_df() - shape: (0, 1) - ┌─────┐ - │ x │ - │ --- │ - │ i64 │ - ╞═════╡ - └─────┘ - """ - duckdb_parameters: Union[ - Collection[Union[str, int, float, bool]], - Collection[Collection[Union[str, int, float, bool]]], - None, - ] - if parameters is None or len(parameters) == 0: - duckdb_parameters = [] - multiple_parameter_sets = False - elif len(parameters) == 1: - duckdb_parameters = parameters[0] - multiple_parameter_sets = False - else: - duckdb_parameters = parameters - multiple_parameter_sets = True - - self.connection.execute( - query=query, - parameters=duckdb_parameters, - multiple_parameter_sets=multiple_parameter_sets, - ) - - def query(self, query: str, alias: str = "query_relation") -> Relation: - """ - Execute arbitrary SQL select query and return the relation. - - Args: - query: Arbitrary SQL select query. - alias: The alias to assign to the resulting relation, to be used in further - queries. - - Returns: A relation representing the data produced by the given query. - - Example: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> relation = db.query("select 1 as a, 2 as b, 3 as c") - >>> relation.to_df() - shape: (1, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - └─────┴─────┴─────┘ - - >>> relation = db.query("select 1 as a, 2 as b, 3 as c", alias="my_alias") - >>> relation.select("my_alias.a").to_df() - shape: (1, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 1 │ - └─────┘ - """ - return Relation( - self.connection.query(query=query, alias=alias), - database=self, - ) - - def empty_relation(self, schema: Type[ModelType]) -> Relation[ModelType]: - """ - Create relation with zero rows, but correct schema that matches the given model. - - Args: - schema: A patito model which specifies the column names and types of the - given relation. - - Example: - >>> import patito as pt - >>> class Schema(pt.Model): - ... string_column: str - ... bool_column: bool - ... - >>> db = pt.duckdb.Database() - >>> empty_relation = db.empty_relation(Schema) - >>> empty_relation.to_df() - shape: (0, 2) - ┌───────────────┬─────────────┐ - │ string_column ┆ bool_column │ - │ --- ┆ --- │ - │ str ┆ bool │ - ╞═══════════════╪═════════════╡ - └───────────────┴─────────────┘ - >>> non_empty_relation = db.query( - ... "select 'dummy' as string_column, true as bool_column" - ... ) - >>> non_empty_relation.union(empty_relation).to_df() - shape: (1, 2) - ┌───────────────┬─────────────┐ - │ string_column ┆ bool_column │ - │ --- ┆ --- │ - │ str ┆ bool │ - ╞═══════════════╪═════════════╡ - │ dummy ┆ true │ - └───────────────┴─────────────┘ - """ - return self.to_relation(schema.examples()).limit(0) - - def table(self, name: str) -> Relation: - """ - Return relation representing all the data in the given table. - - Args: - name: The name of the table. - - Example: - >>> import patito as pt - >>> df = pt.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> db = pt.duckdb.Database() - >>> relation = db.to_relation(df) - >>> relation.create_table(name="my_table") - >>> db.table("my_table").to_df() - shape: (2, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 3 │ - │ 2 ┆ 4 │ - └─────┴─────┘ - """ - return Relation( - self.connection.table(name), - database=self.from_connection(self.connection), - ) - - def view(self, name: str) -> Relation: - """ - Return relation representing all the data in the given view. - - Args: - name: The name of the view. - - Example: - >>> import patito as pt - >>> df = pt.DataFrame({"a": [1, 2], "b": [3, 4]}) - >>> db = pt.duckdb.Database() - >>> relation = db.to_relation(df) - >>> relation.create_view(name="my_view") - >>> db.view("my_view").to_df() - shape: (2, 2) - ┌─────┬─────┐ - │ a ┆ b │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪═════╡ - │ 1 ┆ 3 │ - │ 2 ┆ 4 │ - └─────┴─────┘ - """ - return Relation( - self.connection.view(name), - database=self.from_connection(self.connection), - ) - - def create_table( - self, - name: str, - model: Type[ModelType], - ) -> Relation[ModelType]: - """ - Create table with schema matching the provided Patito model. - - See :ref:`Relation.insert_into()` for how to insert - data into the table after creation. - The :ref:`Relation.create_table()` method can also - be used to create a table from a given relation `and` insert the data at the - same time. - - Args: - name: Name of new database table. - model (Type[Model]): Patito model indicating names and types of table - columns. - Returns: - Relation[ModelType]: Relation pointing to the new table. - - Example: - >>> from typing import Optional - >>> import patito as pt - >>> class MyModel(pt.Model): - ... str_column: str - ... nullable_string_column: Optional[str] - ... - >>> db = pt.duckdb.Database() - >>> db.create_table(name="my_table", model=MyModel) - >>> db.table("my_table").types - {'str_column': VARCHAR, 'nullable_string_column': VARCHAR} - """ - self.create_enum_types(model=model) - schema = model.schema() - non_nullable = schema.get("required", []) - columns = [] - for column_name, sql_type in model.sql_types.items(): - column = f"{column_name} {sql_type}" - if column_name in non_nullable: - column += " not null" - columns.append(column) - self.connection.execute(f"create table {name} ({','.join(columns)})") - # TODO: Fix typing - return self.table(name).set_model(model) # pyright: ignore - - def create_enum_types(self, model: Type[ModelType]) -> None: - """ - Define SQL enum types in DuckDB database. - - Args: - model: Model for which all Literal-annotated or enum-annotated string fields - will get respective DuckDB enum types. - - Example: - >>> import patito as pt - >>> class EnumModel(pt.Model): - ... enum_column: Literal["A", "B", "C"] - ... - >>> db = pt.duckdb.Database() - >>> db.create_enum_types(EnumModel) - >>> db.enum_types - {'enum__7ba49365cc1b0fd57e61088b3bc9aa25'} - """ - import duckdb - - for props in model._schema_properties().values(): - if "enum" not in props or props["type"] != "string": - # DuckDB enums only support string values - continue - - enum_type_name = _enum_type_name(field_properties=props) - if enum_type_name in self.enum_types: - # This enum type has already been created - continue - - enum_values = ", ".join(repr(value) for value in sorted(props["enum"])) - try: - self.connection.execute( - f"create type {enum_type_name} as enum ({enum_values})" - ) - except duckdb.CatalogException as e: - if "already exists" not in str(e): - raise e # pragma: no cover - self.enum_types.add(enum_type_name) - - def create_view( - self, - name: str, - data: RelationSource, - ) -> Relation: - """Create a view based on the given data source.""" - return self.to_relation(derived_from=data).create_view(name) - - def __contains__(self, table: str) -> bool: - """ - Return ``True`` if the database contains a table with the given name. - - Args: - table: The name of the table to be checked for. - - Examples: - >>> import patito as pt - >>> db = pt.duckdb.Database() - >>> "my_table" in db - False - >>> db.to_relation("select 1 as a, 2 as b").create_table(name="my_table") - >>> "my_table" in db - True - """ - try: - self.connection.table(table_name=table) - return True - except Exception: - return False diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 44120c7..2da71d4 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -5,7 +5,6 @@ import json from collections.abc import Iterable from datetime import date, datetime -from functools import cached_property from typing import ( TYPE_CHECKING, Any, @@ -25,14 +24,12 @@ ) import polars as pl -from polars.datatypes import DataType, DataTypeClass, PolarsDataType, convert +from polars.datatypes import DataType, DataTypeClass from pydantic import ( # noqa: F401 BaseModel, - ConfigDict, create_model, field_serializer, fields, - JsonDict, ) from pydantic._internal._model_construction import ( ModelMetaclass as PydanticModelMetaclass, @@ -40,13 +37,11 @@ from patito._pydantic.dtypes import ( default_polars_dtype_for_annotation, - dtype_from_string, parse_composite_dtype, valid_polars_dtypes_for_annotation, validate_annotation, validate_polars_dtype, ) -from patito._pydantic.repr import display_as_type from patito.polars import DataFrame, LazyFrame from patito.validators import validate diff --git a/tests/test_database.py b/tests/test_database.py deleted file mode 100644 index 2ac320c..0000000 --- a/tests/test_database.py +++ /dev/null @@ -1,568 +0,0 @@ -import os -import sqlite3 -from datetime import datetime, timedelta -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, List, Optional - -import patito as pt -import polars as pl -import pytest - -if TYPE_CHECKING: - import pyarrow as pa # type: ignore -else: - # Python 3.7 does not support pyarrow - pa = pytest.importorskip("pyarrow") - - -class LoggingQuerySource(pt.Database): - """A dummy query source with an associated query execution log.""" - - executed_queries: List[str] - - -@pytest.fixture() -def query_cache(tmp_path) -> LoggingQuerySource: - """ - Return dummy query cache with query execution logger. - - Args: - tmp_path: Test-specific temporary directory provided by pytest. - - Returns: - A cacher which also keeps track of the executed queries. - """ - # Keep track of the executed queries in a mutable list - executed_queries = [] - - # Unless other is specified, some dummy data is always returned - def query_handler(query, mock_data: Optional[dict] = None) -> pa.Table: - executed_queries.append(query) - data = {"column": [1, 2, 3]} if mock_data is None else mock_data - return pa.Table.from_pydict(data) - - query_cache = LoggingQuerySource( - query_handler=query_handler, - cache_directory=tmp_path, - default_ttl=timedelta(weeks=52), - ) - - # Attach the query execution log as an attribute of the query source - query_cache.executed_queries = executed_queries - return query_cache - - -@pytest.fixture -def query_source(tmpdir) -> LoggingQuerySource: - """ - A QuerySource connected to an in-memory SQLite3 database with dummy data. - - Args: - tmpdir: Test-specific temporary directory provided by pytest. - - Returns: - A query source which also keeps track of the executed queries. - """ - # Keep track of the executed queries in a mutable list - executed_queries = [] - - def dummy_database() -> sqlite3.Cursor: - connection = sqlite3.connect(":memory:") - cursor = connection.cursor() - cursor.execute("CREATE TABLE movies(title, year, score)") - data = [ - ("Monty Python Live at the Hollywood Bowl", 1982, 7.9), - ("Monty Python's The Meaning of Life", 1983, 7.5), - ("Monty Python's Life of Brian", 1979, 8.0), - ] - cursor.executemany("INSERT INTO movies VALUES(?, ?, ?)", data) - connection.commit() - return cursor - - def query_handler(query: str) -> pa.Table: - cursor = dummy_database() - cursor.execute(query) - executed_queries.append(query) - columns = [description[0] for description in cursor.description] - data = [dict(zip(columns, row)) for row in cursor.fetchall()] - return pa.Table.from_pylist(data) - - # Attach the query execution log as an attribute of the query source - tmp_dir = Path(tmpdir) - query_cache = LoggingQuerySource( - query_handler=query_handler, - cache_directory=tmp_dir, - ) - query_cache.executed_queries = executed_queries - return query_cache - - -def test_uncached_query(query_cache: LoggingQuerySource): - """It should not cache queries by default.""" - - @query_cache.as_query() - def products(): - return "query" - - # First time it is called we should execute the query - products() - assert query_cache.executed_queries == ["query"] - # And no cache file is created - assert not any(query_cache.cache_directory.iterdir()) - - # The next time the query is executed again - products() - assert query_cache.executed_queries == ["query", "query"] - # And still no cache file - assert not any(query_cache.cache_directory.iterdir()) - - -def test_cached_query(query_cache: LoggingQuerySource): - """It should cache queries if so parametrized.""" - - # We enable cache for the given query - @query_cache.as_query(cache=True) - def products(version: int): - return f"query {version}" - - # The cache is stored in the "products" sub-folder - cache_dir = query_cache.cache_directory / "products" - - # First time the query is executed - products(version=1) - assert query_cache.executed_queries == ["query 1"] - # And the result is stored in a cache file - assert len(list(cache_dir.iterdir())) == 1 - - # The next time the query is *not* executed - products(version=1) - assert query_cache.executed_queries == ["query 1"] - # And the cache file persists - assert len(list(cache_dir.iterdir())) == 1 - - # But if we change the query itself, it is executed - products(version=2) - assert query_cache.executed_queries == ["query 1", "query 2"] - # And it is cached in a separate file - assert len(list(cache_dir.iterdir())) == 2 - - # If we delete the cache file, the query is re-executed - for cache_file in cache_dir.iterdir(): - cache_file.unlink() - products(version=1) - assert query_cache.executed_queries == ["query 1", "query 2", "query 1"] - # And the cache file is rewritten - assert len(list(cache_dir.iterdir())) == 1 - - # We clear the cache with .clear_cache() - products.refresh_cache(version=1) - assert query_cache.executed_queries == ["query 1", "query 2", "query 1", "query 1"] - # We can also clear caches that have never existed - products.refresh_cache(version=3) - assert query_cache.executed_queries[-1] == "query 3" - - -def test_cached_query_with_explicit_path( - query_cache: LoggingQuerySource, - tmpdir: Path, -) -> None: - """It should cache queries in the provided path.""" - cache_path = Path(tmpdir / "name.parquet") - - # This time we specify an explicit path - @query_cache.as_query(cache=cache_path) - def products(version): - return f"query {version}" - - # At first the path does not exist - assert not cache_path.exists() - - # We then execute and cache the query - products(version=1) - assert cache_path.exists() - assert query_cache.executed_queries == ["query 1"] - - # And the next time it is reused - products(version=1) - assert query_cache.executed_queries == ["query 1"] - assert cache_path.exists() - - # If the query changes, it is re-executed - products(version=2) - assert query_cache.executed_queries == ["query 1", "query 2"] - - # If a non-parquet file is specified, it will raise - with pytest.raises( - ValueError, - match=r"Cache paths must have the '\.parquet' file extension\!", - ): - - @query_cache.as_query(cache=tmpdir / "name.csv") - def products(version): - return f"query {version}" - - -def test_cached_query_with_relative_path(query_cache: LoggingQuerySource) -> None: - """Relative paths should be interpreted relative to the cache directory.""" - relative_path = Path("foo/bar.parquet") - - @query_cache.as_query(cache=relative_path) - def products(): - return "query" - - products() - assert (query_cache.cache_directory / "foo" / "bar.parquet").exists() - - -def test_cached_query_with_format_string(query_cache: LoggingQuerySource) -> None: - """Strings with placeholders should be interpolated.""" - - @query_cache.as_query(cache="version-{version}.parquet") - def products(version: int): - return f"query {version}" - - # It should work for both positional arguments... - products(1) - assert (query_cache.cache_directory / "version-1.parquet").exists() - # ... and keywords - products(version=2) - assert (query_cache.cache_directory / "version-2.parquet").exists() - - -def test_cached_query_with_format_path(query_cache: LoggingQuerySource) -> None: - """Paths with placeholders should be interpolated.""" - - @query_cache.as_query( - cache=query_cache.cache_directory / "version-{version}.parquet" - ) - def products(version: int): - return f"query {version}" - - # It should work for both positional arguments... - products(1) - assert (query_cache.cache_directory / "version-1.parquet").exists() - # ... and keywords - products(version=2) - assert (query_cache.cache_directory / "version-2.parquet").exists() - - -def test_cache_ttl(query_cache: LoggingQuerySource, monkeypatch): - """It should automatically refresh the cache according to the TTL.""" - - # We freeze the time during the execution of this test - class FrozenDatetime: - def __init__(self, year: int, month: int, day: int) -> None: - self.frozen_time = datetime(year=year, month=month, day=day) - monkeypatch.setattr(pt.database, "datetime", self) # pyright: ignore - - def now(self): - return self.frozen_time - - @staticmethod - def fromisoformat(*args, **kwargs): - return datetime.fromisoformat(*args, **kwargs) - - # The cache should be cleared every week - @query_cache.as_query(cache=True, ttl=timedelta(weeks=1)) - def users(): - return "query" - - # The first time the query should be executed - FrozenDatetime(year=2000, month=1, day=1) - users() - assert query_cache.executed_queries == ["query"] - - # The next time it should not be executed - users() - assert query_cache.executed_queries == ["query"] - - # Even if we advance the time by one day, - # the cache should still be used. - FrozenDatetime(year=2000, month=1, day=2) - users() - assert query_cache.executed_queries == ["query"] - - # Then we let one week pass, and the cache should be cleared - FrozenDatetime(year=2000, month=1, day=8) - users() - assert query_cache.executed_queries == ["query", "query"] - - # But then it will be reused for another week - users() - assert query_cache.executed_queries == ["query", "query"] - - -@pytest.mark.parametrize("cache", [True, False]) -def test_lazy_query(query_cache: LoggingQuerySource, cache: bool): - """It should return a LazyFrame when specified with lazy=True.""" - - @query_cache.as_query(lazy=True, cache=cache) - def lazy(): - return "query" - - @query_cache.as_query(lazy=False, cache=cache) - def eager(): - return "query" - - # We invoke it twice, first not hitting the cache, and then hitting it - assert lazy().collect().frame_equal(eager()) - assert lazy().collect().frame_equal(eager()) - - -def test_model_query_model_validation(query_cache: LoggingQuerySource): - """It should validate the data model.""" - - class CorrectModel(pt.Model): - column: int - - @query_cache.as_query(model=CorrectModel) - def correct_data(): - return "" - - assert isinstance(correct_data(), pl.DataFrame) - - class IncorrectModel(pt.Model): - column: str - - @query_cache.as_query(model=IncorrectModel) - def incorrect_data(): - return "" - - with pytest.raises(pt.exceptions.ValidationError): - incorrect_data() - - -def test_custom_forwarding_of_parameters_to_query_function( - query_cache: LoggingQuerySource, -): - """It should forward all additional parameters to the sql_to_arrow function.""" - - # The dummy cacher accepts a "data" parameter, specifying the data to be returned - data = {"actual_data": [10, 20, 30]} - - @query_cache.as_query(mock_data=data) - def custom_data(): - return "select 1, 2, 3 as dummy_column" - - assert custom_data().frame_equal(pl.DataFrame(data)) - - # It should also work without type normalization - @query_cache.as_query(mock_data=data, cast_to_polars_equivalent_types=False) - def non_normalized_custom_data(): - return "select 1, 2, 3 as dummy_column" - - assert non_normalized_custom_data().frame_equal(pl.DataFrame(data)) - - -def test_clear_caches(query_cache: LoggingQuerySource): - """It should clear all cache files with .clear_all_caches().""" - - @query_cache.as_query(cache=True) - def products(version: int): - return f"query {version}" - - # The cache is stored in the "products" sub-directory - products_cache_dir = query_cache.cache_directory / "products" - - # We produce two cache files - products(version=1) - products(version=2) - assert query_cache.executed_queries == ["query 1", "query 2"] - assert len(list(products_cache_dir.iterdir())) == 2 - - # We also insert another parquet file that should *not* be deleted - dummy_parquet_path = products_cache_dir / "dummy.parquet" - pl.DataFrame().write_parquet(dummy_parquet_path) - - # And an invalid parquet file - invalid_parquet_path = products_cache_dir / "invalid.parquet" - invalid_parquet_path.write_bytes(b"invalid content") - - # We delete all caches, but not the dummy parquet file - products.clear_caches() - assert len(list(products_cache_dir.iterdir())) == 2 - assert dummy_parquet_path.exists() - assert invalid_parquet_path.exists() - - # The next time both queries need to be re-executed - products(version=1) - products(version=2) - assert query_cache.executed_queries == ["query 1", "query 2"] * 2 - assert len(list(products_cache_dir.iterdir())) == 4 - - # If caching is not enabled, clear_caches should be a NO-OP - @query_cache.as_query(cache=False) - def uncached_products(version: int): - return f"query {version}" - - uncached_products.clear_caches() - - -def test_clear_caches_with_formatted_paths(query_cache: LoggingQuerySource): - """Formatted paths should also be properly cleared.""" - # We specify another temporary cache directory to see if caches can be cleared - # irregardless of the cache directory's location. - tmp_dir = TemporaryDirectory() - cache_dir = Path(tmp_dir.name) - - @query_cache.as_query(cache=cache_dir / "{a}" / "{b}.parquet") - def users(a: int, b: int): - return f"query {a}.{b}" - - users(1, 1) - users(1, 2) - users(2, 1) - - assert query_cache.executed_queries == ["query 1.1", "query 1.2", "query 2.1"] - - assert {str(path.relative_to(cache_dir)) for path in cache_dir.rglob("*")} == { - # Both directories have been created - "1", - "2", - # Two cache files for a=1 - "1/1.parquet", - "1/2.parquet", - # One cache file for a=2 - "2/1.parquet", - } - - # We insert another parquet file that should *not* be cleared - pl.DataFrame().write_parquet(cache_dir / "1" / "3.parquet") - - # Only directories and non-cached files should be kept - users.clear_caches() - assert {str(path.relative_to(cache_dir)) for path in cache_dir.rglob("*")} == { - "1", - "2", - "1/3.parquet", - } - tmp_dir.cleanup() - - -def test_ejection_of_incompatible_caches(query_cache: LoggingQuerySource): - """It should clear old, incompatible caches.""" - - cache_path = query_cache.cache_directory / "my_cache.parquet" - - @query_cache.as_query(cache=cache_path) - def my_query(): - return "my query" - - # Write a parquet file without any metadata - pl.DataFrame().write_parquet(cache_path) - - # The existing parquet file without metadata should be overwritten - df = my_query() - assert not df.is_empty() - assert query_cache.executed_queries == ["my query"] - - # Now we decrement the version number of the cache in order to overwrite it - arrow_table = pa.parquet.read_table(cache_path) # noqa - metadata = arrow_table.schema.metadata - assert ( - int.from_bytes(metadata[b"cache_version"], "little") - == pt.database.CACHE_VERSION # pyright: ignore - ) - metadata[b"cache_version"] = ( - pt.database.CACHE_VERSION - 1 # pyright: ignore - ).to_bytes( - length=16, - byteorder="little", - signed=False, - ) - pa.parquet.write_table( - arrow_table.replace_schema_metadata(metadata), - where=cache_path, - ) - - # The query should now be re-executed - my_query() - assert query_cache.executed_queries == ["my query"] * 2 - - # Deleting the cache_version alltogether should also retrigger the query - del metadata[b"cache_version"] - pa.parquet.write_table( - arrow_table.replace_schema_metadata(metadata), - where=cache_path, - ) - my_query() - assert query_cache.executed_queries == ["my query"] * 3 - - -def test_adherence_to_xdg_directory_standard(monkeypatch, tmpdir): - """It should use XDG Cache Home when no cache directory is specified.""" - xdg_cache_home = tmpdir / ".cache" - os.environ["XDG_CACHE_HOME"] = str(xdg_cache_home) - query_source = pt.Database(query_handler=lambda query: pa.Table()) - assert query_source.cache_directory == xdg_cache_home / "patito" - - del os.environ["XDG_CACHE_HOME"] - query_source = pt.Database(query_handler=lambda query: pa.Table()) - assert query_source.cache_directory == Path("~/.cache/patito").resolve() - - -def test_invoking_query_source_directly_with_query_string( - query_source: LoggingQuerySource, -): - """It should accept SQL queries directly, not ony query constructors.""" - sql = "select * from movies" - movies = query_source.query(sql) - assert query_source.executed_queries == [sql] - assert len(list(query_source.cache_directory.iterdir())) == 0 - assert movies.height == 3 - - for _ in range(2): - query_source.query(sql, cache=True) - assert query_source.executed_queries == [sql] * 2 - assert ( - len(list((query_source.cache_directory / "__direct_query").iterdir())) == 1 - ) - - assert query_source.query(sql, lazy=True).collect().frame_equal(movies) - - -@pytest.mark.skip(reason="TODO: Future feature to implement") -def test_custom_kwarg_hashing(tmp_path): - """You should be able to hash the keyword arguments passed to the query handler.""" - - executed_queries = [] - - def query_handler(query: str, prod=False) -> pa.Table: - executed_queries.append(query) - return pa.Table.from_pydict({"column": [1, 2, 3]}) - - def query_handler_hasher(query: str, prod: bool) -> bytes: - return bytes(prod) - - dummy_source = pt.Database( - query_handler=query_handler, - query_handler_hasher=query_handler_hasher, # pyright: ignore - cache_directory=tmp_path, - ) - - # The first time the query should be executed - sql_query = "select * from my_table" - dummy_source.query(sql_query, cache=True) - assert executed_queries == [sql_query] - assert len(list(dummy_source.cache_directory.rglob("*.parquet"))) == 1 - - # The second time the dev query has been cached - dummy_source.query(sql_query, cache=True) - assert executed_queries == [sql_query] - assert len(list(dummy_source.cache_directory.rglob("*.parquet"))) == 1 - - # The production query has never executed, so a new query is executed - dummy_source.query(sql_query, cache=True, prod=True) - assert executed_queries == [sql_query] * 2 - assert len(list(dummy_source.cache_directory.rglob("*.parquet"))) == 2 - - # Then the production query cache is used - dummy_source.query(sql_query, cache=True, prod=True) - assert executed_queries == [sql_query] * 2 - assert len(list(dummy_source.cache_directory.rglob("*.parquet"))) == 2 - - # And the dev query cache still remains - dummy_source.query(sql_query, cache=True) - assert executed_queries == [sql_query] * 2 - assert len(list(dummy_source.cache_directory.rglob("*.parquet"))) == 2 diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 52a8260..839b392 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -10,7 +10,6 @@ DURATION_DTYPES, FLOAT_DTYPES, INTEGER_DTYPES, - PT_BASE_SUPPORTED_DTYPES, STRING_DTYPES, TIME_DTYPES, DataTypeGroup, diff --git a/tests/test_duckdb/__init__.py b/tests/test_duckdb/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_duckdb/test_database.py b/tests/test_duckdb/test_database.py deleted file mode 100644 index 88e670b..0000000 --- a/tests/test_duckdb/test_database.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Tests for patito.Database.""" -import enum -from typing import Optional - -import patito as pt -import polars as pl -import pytest -from typing_extensions import Literal - -# Skip test module if DuckDB is not installed -if not pt._DUCKDB_AVAILABLE: - pytest.skip("DuckDB not installed", allow_module_level=True) - - -def test_database(tmp_path): - """Test functionality of Database class.""" - # Create a new in-memory database - db = pt.duckdb.Database() - - # Insert a simple dataframe as a new table - table_df = pl.DataFrame( - { - "column_1": [1, 2, 3], - "column_2": ["a", "b", "c"], - } - ) - db.to_relation(table_df).create_table(name="table_name_1") - - # Check that a round-trip to and from the database preserves the data - db_table = db.table("table_name_1").to_df() - assert db_table is not table_df - assert table_df.frame_equal(db_table) - - # Check that new database objects are isolated from previous ones - another_db = pt.duckdb.Database() - with pytest.raises( - Exception, - match=r"Catalog Error\: Table 'table_name_1' does not exist!", - ): - db_table = another_db.table("table_name_1") - - # Check the parquet reading functionality - parquet_path = tmp_path / "tmp.parquet" - table_df.write_parquet(str(parquet_path), compression="snappy") - new_relation = another_db.to_relation(parquet_path) - new_relation.create_table(name="parquet_table") - assert another_db.table("parquet_table").count() == 3 - - -def test_file_database(tmp_path): - """Check if the Database can be persisted to a file.""" - # Insert some data into a file-backed database - db_path = tmp_path / "tmp.db" - file_db = pt.duckdb.Database(path=db_path) - file_db.to_relation("select 1 as a, 2 as b").create_table(name="table") - before_df = file_db.table("table").to_df() - - # Delete the database - del file_db - - # And restore tha data from the file - restored_db = pt.duckdb.Database(path=db_path) - after_df = restored_db.table("table").to_df() - - # The data should still be the same - assert before_df.frame_equal(after_df) - - -def test_database_create_table(): - """Tests for patito.Database.create_table().""" - - # A pydantic basemodel is used to specify the table schema - # We inherit here in order to make sure that inheritance works as intended - class BaseModel(pt.Model): - int_column: int - optional_int_column: Optional[int] - str_column: str - - class Model(BaseModel): - optional_str_column: Optional[str] - bool_column: bool - optional_bool_column: Optional[bool] - enum_column: Literal["a", "b", "c"] - - # We crate the table schema - db = pt.duckdb.Database() - table = db.create_table(name="test_table", model=Model) - - # We insert some dummy data into the new table - dummy_relation = db.to_relation(Model.examples({"optional_int_column": [1, None]})) - dummy_relation.insert_into(table="test_table") - - # But we should not be able to insert null data in non-optional columns - null_relation = dummy_relation.drop("int_column").select("null as int_column, *") - with pytest.raises( - Exception, - match=("Constraint Error: NOT NULL constraint failed: test_table.int_column"), - ): - null_relation.insert_into(table="test_table") - - # Check if the correct columns and types have been set - assert table.columns == [ - "int_column", - "optional_int_column", - "str_column", - "optional_str_column", - "bool_column", - "optional_bool_column", - "enum_column", - ] - assert list(table.types.values()) == [ - "INTEGER", - "INTEGER", - "VARCHAR", - "VARCHAR", - "BOOLEAN", - "BOOLEAN", - pt.duckdb._enum_type_name( # pyright: ignore - field_properties=Model.model_json_schema()["properties"]["enum_column"] - ), - ] - - -def test_create_view(): - """It should be able to create a view from a relation source.""" - db = pt.duckdb.Database() - df = pt.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) - db.create_view(name="my_view", data=df) - assert db.view("my_view").to_df().frame_equal(df) - - -def test_validate_non_nullable_enum_columns(): - """Enum columns should be null-validated.""" - - class EnumModel(pt.Model): - non_nullable_enum_column: Literal["a", "b", "c"] - nullable_enum_column: Optional[Literal["a", "b", "c"]] - - db = pt.duckdb.Database() - db.create_table(name="enum_table", model=EnumModel) - - # We allow null values in nullable_enum_column - valid_relation = db.to_relation( - "select 'a' as non_nullable_enum_column, null as nullable_enum_column" - ) - valid_relation.insert_into("enum_table") - - # But we do not allow it in non_nullable_enum_column - invalid_relation = db.to_relation( - "select null as non_nullable_enum_column, 'a' as nullable_enum_column" - ) - with pytest.raises( - Exception, - match=( - "Constraint Error: " - "NOT NULL constraint failed: " - "enum_table.non_nullable_enum_column" - ), - ): - invalid_relation.insert_into(table="enum_table") - - # The non-nullable enum column should do enum value validation - invalid_relation = db.to_relation( - "select 'd' as non_nullable_enum_column, 'a' as nullable_enum_column" - ) - with pytest.raises( - Exception, - match="Conversion Error: Could not convert string 'd' to UINT8", - ): - invalid_relation.insert_into(table="enum_table") - - # And the nullable enum column should do enum value validation - invalid_relation = db.to_relation( - "select 'a' as non_nullable_enum_column, 'd' as nullable_enum_column" - ) - with pytest.raises( - Exception, - match="Conversion Error: Could not convert string 'd' to UINT8", - ): - invalid_relation.insert_into(table="enum_table") - - -def test_table_existence_check(): - """You should be able to check for the existence of a table.""" - - class Model(pt.Model): - column_1: str - column_2: int - - # At first there is no table named "test_table" - db = pt.duckdb.Database() - assert "test_table" not in db - - # We create the table - db.create_table(name="test_table", model=Model) - - # And now the table should exist - assert "test_table" in db - - -def test_creating_enums_several_tiems(): - """Enums should be able to be defined several times.""" - - class EnumModel(pt.Model): - enum_column: Literal["a", "b", "c"] - - db = pt.duckdb.Database() - db.create_enum_types(EnumModel) - db.enum_types = set() - db.create_enum_types(EnumModel) - - -def test_use_of_same_enum_types_from_literal_annotation(): - """Identical literals should get the same DuckDB SQL enum type.""" - - class Table1(pt.Model): - column_1: Literal["a", "b"] - - class Table2(pt.Model): - column_2: Optional[Literal["b", "a"]] - - db = pt.duckdb.Database() - db.create_table(name="table_1", model=Table1) - db.create_table(name="table_2", model=Table2) - - assert ( - db.table("table_1").types["column_1"] == db.table("table_2").types["column_2"] - ) - - -def test_use_of_same_enum_types_from_enum_annotation(): - """Identical enums should get the same DuckDB SQL enum type.""" - - class ABEnum(enum.Enum): - ONE = "a" - TWO = "b" - - class BAEnum(enum.Enum): - TWO = "b" - ONE = "a" - - class Table1(pt.Model): - column_1: ABEnum - - class Table2(pt.Model): - column_2: Optional[BAEnum] - - db = pt.duckdb.Database() - db.create_table(name="table_1", model=Table1) - db.create_table(name="table_2", model=Table2) - - assert ( - db.table("table_1").types["column_1"] == db.table("table_2").types["column_2"] - ) - - -def test_execute(): - """It should be able to execute prepared statements.""" - db = pt.duckdb.Database() - db.execute("create table my_table (a int, b int, c int)") - db.execute("insert into my_table select ? as a, ? as b, ? as c", [2, 3, 4]) - assert ( - db.table("my_table") - .to_df() - .frame_equal(pt.DataFrame({"a": [2], "b": [3], "c": [4]})) - ) - db.execute( - "insert into my_table select ? as a, ? as b, ? as c", - [5, 6, 7], - [8, 9, 10], - ) - assert ( - db.table("my_table") - .to_df() - .frame_equal(pt.DataFrame({"a": [2, 5, 8], "b": [3, 6, 9], "c": [4, 7, 10]})) - ) diff --git a/tests/test_duckdb/test_relation.py b/tests/test_duckdb/test_relation.py deleted file mode 100644 index 60d8acb..0000000 --- a/tests/test_duckdb/test_relation.py +++ /dev/null @@ -1,1063 +0,0 @@ -import re -from datetime import date, timedelta -from pathlib import Path -from typing import Optional -from unittest.mock import MagicMock - -import patito as pt -import polars as pl -import pytest -from typing_extensions import Literal - -# Skip test module if DuckDB is not installed -if not pt._DUCKDB_AVAILABLE: - pytest.skip("DuckDB not installed", allow_module_level=True) - - -def test_relation(): - """Test functionality of Relation class.""" - # Create a new in-memory database with dummy data - db = pt.duckdb.Database() - table_df = pl.DataFrame( - { - "column_1": [1, 2, 3], - "column_2": ["a", "b", "c"], - } - ) - db.to_relation(table_df).create_table(name="table_name") - table_relation = db.table("table_name") - - # A projection can be done in several different ways - assert table_relation.select("column_1", "column_2") == table_relation.select( - "column_1, column_2" - ) - assert ( - table_relation.select("column_1, column_2") - == table_relation[["column_1, column_2"]] - ) - assert table_relation[["column_1, column_2"]] == table_relation - assert table_relation.select("column_1") != table_relation.select("column_2") - - # We can also use kewyrod arguments to rename columns - assert tuple(table_relation.select(column_3="column_1::varchar || column_2")) == ( - {"column_3": "1a"}, - {"column_3": "2b"}, - {"column_3": "3c"}, - ) - - # The .get() method should only work if the filter matches a single row - assert table_relation.get(column_1=1).column_2 == "a" - - # But raise if not exactly one matching row is found - with pytest.raises(RuntimeError, match="Relation.get(.*) returned 0 rows!"): - assert table_relation.get("column_1 = 4") - with pytest.raises(RuntimeError, match="Relation.get(.*) returned 2 rows!"): - assert table_relation.get("column_1 > 1") - - # The .get() should also accept a positional string - assert table_relation.get("column_1 < 2").column_2 == "a" - - # And several positional strings - assert table_relation.get("column_1 > 1", "column_1 < 3").column_2 == "b" - - # And a mix of positional and keyword arguments - assert table_relation.get("column_1 < 2", column_2="a").column_2 == "a" - - # Order by statements shoud be respected when iterating over the relation - assert tuple(table_relation.order("column_1 desc")) == ( - {"column_1": 3, "column_2": "c"}, - {"column_1": 2, "column_2": "b"}, - {"column_1": 1, "column_2": "a"}, - ) - - # The plus operator acts as a union all - assert ( - db.to_relation(table_df[:1]) - + db.to_relation(table_df[1:2]) - + db.to_relation(table_df[2:]) - ) == db.to_relation(table_df) - - # The union all must *not* remove duplicates - assert db.to_relation(table_df) + db.to_relation(table_df) != db.to_relation( - table_df - ) - assert db.to_relation(table_df) + db.to_relation(table_df) == db.to_relation( - pl.concat([table_df, table_df]) - ) - - # You should be able to subscript columns - assert table_relation["column_1"] == table_relation.select("column_1") - assert table_relation[["column_1", "column_2"]] == table_relation - - # The relation's columns can be retrieved - assert table_relation.columns == ["column_1", "column_2"] - - # You should be able to prefix and suffix all columns of a relation - assert table_relation.add_prefix("prefix_").columns == [ - "prefix_column_1", - "prefix_column_2", - ] - assert table_relation.add_suffix("_suffix").columns == [ - "column_1_suffix", - "column_2_suffix", - ] - - # You can drop one or more columns - assert table_relation.drop("column_1").columns == ["column_2"] - assert table_relation.select("*, 1 as column_3").drop( - "column_1", "column_2" - ).columns == ["column_3"] - - # You can rename columns - assert set(table_relation.rename(column_1="new_name").columns) == { - "new_name", - "column_2", - } - - # A value error must be raised if the source column does not exist - with pytest.raises( - ValueError, - match=( - "Column 'a' can not be renamed as it does not exist. " - "The columns of the relation are: column_[12], column_[12]" - ), - ): - table_relation.rename(a="new_name") - - # Null values should be correctly handled - none_df = pl.DataFrame({"column_1": [1, None]}) - none_relation = db.to_relation(none_df) - assert none_relation.filter("column_1 is null") == none_df.filter( - pl.col("column_1").is_null() - ) - - # The .inner_join() method should work as INNER JOIN, not LEFT or OUTER JOIN - left_relation = db.to_relation( - pl.DataFrame( - { - "left_primary_key": [1, 2], - "left_foreign_key": [10, 20], - } - ) - ) - right_relation = db.to_relation( - pl.DataFrame( - { - "right_primary_key": [10], - } - ) - ) - joined_table = pl.DataFrame( - { - "left_primary_key": [1], - "left_foreign_key": [10], - "right_primary_key": [10], - } - ) - assert ( - left_relation.set_alias("l").inner_join( - right_relation.set_alias("r"), - on="l.left_foreign_key = r.right_primary_key", - ) - == joined_table - ) - - # But the .left_join() method performs a LEFT JOIN - left_joined_table = pl.DataFrame( - { - "left_primary_key": [1, 2], - "left_foreign_key": [10, 20], - "right_primary_key": [10, None], - } - ) - assert ( - left_relation.set_alias("l").left_join( - right_relation.set_alias("r"), - on="l.left_foreign_key = r.right_primary_key", - ) - == left_joined_table - ) - - -def test_star_select(): - """It should select all columns with star.""" - df = pt.DataFrame({"a": [1, 2], "b": [3, 4]}) - relation = pt.duckdb.Relation(df) - assert relation.select("*") == relation - - -def test_casting_relations_between_database_connections(): - """It should raise when you try to mix databases.""" - db_1 = pt.duckdb.Database() - relation_1 = db_1.query("select 1 as a") - db_2 = pt.duckdb.Database() - relation_2 = db_2.query("select 1 as a") - with pytest.raises( - ValueError, - match="Relations can't be casted between database connections.", - ): - relation_1 + relation_2 # pyright: ignore - - -def test_creating_relation_from_pandas_df(): - """It should be able to create a relation from a pandas dataframe.""" - pd = pytest.importorskip("pandas") - pandas_df = pd.DataFrame({"a": [1, 2]}) - relation = pt.duckdb.Relation(pandas_df) - pd.testing.assert_frame_equal(relation.to_pandas(), pandas_df) - - -def test_creating_relation_from_a_csv_file(tmp_path): - """It should be able to create a relation from a CSV path.""" - df = pl.DataFrame({"a": [1, 2]}) - csv_path = tmp_path / "test.csv" - df.write_csv(csv_path) - relation = pt.duckdb.Relation(csv_path) - assert relation.to_df().frame_equal(df) - - -def test_creating_relation_from_a_parquet_file(tmp_path): - """It should be able to create a relation from a parquet path.""" - df = pl.DataFrame({"a": [1, 2]}) - parquet_path = tmp_path / "test.parquet" - df.write_parquet(parquet_path, compression="uncompressed") - relation = pt.duckdb.Relation(parquet_path) - assert relation.to_df().frame_equal(df) - - -def test_creating_relation_from_a_unknown_file_format(tmp_path): - """It should raise when you try to create relation from unknown path.""" - with pytest.raises( - ValueError, - match="Unsupported file suffix '.unknown' for data import!", - ): - pt.duckdb.Relation(Path("test.unknown")) - - with pytest.raises( - ValueError, - match="Unsupported file suffix '' for data import!", - ): - pt.duckdb.Relation(Path("test")) - - -def test_relation_with_default_database(): - """It should be constructable with the default DuckDB cursor.""" - import duckdb - - relation_a = pt.duckdb.Relation("select 1 as a") - assert relation_a.database.connection is duckdb.default_connection - - relation_a.create_view("table_a") - del relation_a - - relation_b = pt.duckdb.Relation("select 1 as b") - relation_b.create_view("table_b") - del relation_b - - default_database = pt.duckdb.Database.default() - joined_relation = default_database.query( - """ - select * - from table_a - inner join table_b - on a = b - """ - ) - assert joined_relation.to_df().frame_equal(pl.DataFrame({"a": [1], "b": [1]})) - - -def test_with_columns(): - """It should be able to crate new additional columns.""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as a, 2 as b") - - # We can define a new column - extended_relation = relation.with_columns(c="a + b") - correct_extended = pl.DataFrame({"a": [1], "b": [2], "c": [3]}) - assert extended_relation.to_df().frame_equal(correct_extended) - - # Or even overwrite an existing column - overwritten_relation = relation.with_columns(a="a + b") - correct_overwritten = db.to_relation("select 2 as b, 3 as a").to_df() - assert overwritten_relation.to_df().frame_equal(correct_overwritten) - - -def test_rename_to_existing_column(): - """Renaming a column to overwrite another should work.""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as a, 2 as b") - renamed_relation = relation.rename(b="a") - assert renamed_relation.columns == ["a"] - assert renamed_relation.get().a == 2 - - -def test_add_suffix(): - """It should be able to add suffixes to all column names.""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as a, 2 as b") - assert relation.add_suffix("x").columns == ["ax", "bx"] - assert relation.add_suffix("x", exclude=["a"]).columns == ["a", "bx"] - assert relation.add_suffix("x", include=["a"]).columns == ["ax", "b"] - - with pytest.raises( - TypeError, - match="Both include and exclude provided at the same time!", - ): - relation.add_suffix("x", exclude=["a"], include=["b"]) - - -def test_add_prefix(): - """It should be able to add prefixes to all column names.""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as a, 2 as b") - assert relation.add_prefix("x").columns == ["xa", "xb"] - assert relation.add_prefix("x", exclude=["a"]).columns == ["a", "xb"] - assert relation.add_prefix("x", include=["a"]).columns == ["xa", "b"] - - with pytest.raises( - TypeError, - match="Both include and exclude provided at the same time!", - ): - relation.add_prefix("x", exclude=["a"], include=["b"]) - - -def test_relation_aggregate_method(): - """Test for Relation.aggregate().""" - db = pt.duckdb.Database() - relation = db.to_relation( - pl.DataFrame( - { - "a": [1, 1, 2], - "b": [10, 100, 1000], - "c": [1, 2, 1], - } - ) - ) - aggregated_relation = relation.aggregate( - "a", - b_sum="sum(b)", - group_by="a", - ) - assert tuple(aggregated_relation) == ( - {"a": 1, "b_sum": 110}, - {"a": 2, "b_sum": 1000}, - ) - - aggregated_relation_with_multiple_group_by = relation.aggregate( - "a", - "c", - b_sum="sum(b)", - group_by=["a", "c"], - ) - assert tuple(aggregated_relation_with_multiple_group_by) == ( - {"a": 1, "c": 1, "b_sum": 10}, - {"a": 1, "c": 2, "b_sum": 100}, - {"a": 2, "c": 1, "b_sum": 1000}, - ) - - -def test_relation_all_method(): - """Test for Relation.all().""" - db = pt.duckdb.Database() - relation = db.to_relation( - pl.DataFrame( - { - "a": [1, 2, 3], - "b": [100, 100, 100], - } - ) - ) - - assert not relation.all(a=100) - assert relation.all(b=100) - assert relation.all("a < 4", b=100) - - -def test_relation_case_method(): - db = pt.duckdb.Database() - - df = pl.DataFrame( - { - "shelf_classification": ["A", "B", "A", "C", "D"], - "weight": [1, 2, 3, 4, 5], - } - ) - - correct_df = df.with_columns( - pl.Series([10, 20, 10, 0, None], dtype=pl.Int32).alias("max_weight") - ) - correct_mapped_actions = db.to_relation(correct_df) - - mapped_actions = db.to_relation(df).case( - from_column="shelf_classification", - to_column="max_weight", - mapping={"A": 10, "B": 20, "D": None}, - default=0, - ) - assert mapped_actions == correct_mapped_actions - - # We can also use the Case class - case_statement = pt.sql.Case( - on_column="shelf_classification", - mapping={"A": 10, "B": 20, "D": None}, - default=0, - ) - alt_mapped_actions = db.to_relation(df).select(f"*, {case_statement} as max_weight") - assert alt_mapped_actions == correct_mapped_actions - - -def test_relation_coalesce_method(): - """Test for Relation.coalesce().""" - db = pt.duckdb.Database() - df = pl.DataFrame( - {"column_1": [1.0, None], "column_2": [None, "2"], "column_3": [3.0, None]} - ) - relation = db.to_relation(df) - coalesce_result = relation.coalesce(column_1=10, column_2="20").to_df() - correct_coalesce_result = pl.DataFrame( - { - "column_1": [1.0, 10.0], - "column_2": ["20", "2"], - "column_3": [3.0, None], - } - ) - assert coalesce_result.frame_equal(correct_coalesce_result) - - -def test_relation_union_method(): - """Test for Relation.union and Relation.__add__.""" - db = pt.duckdb.Database() - left = db.to_relation("select 1 as a, 2 as b") - right = db.to_relation("select 200 as b, 100 as a") - correct_union = pl.DataFrame( - { - "a": [1, 100], - "b": [2, 200], - } - ) - assert left + right == correct_union - assert right + left == correct_union[["b", "a"]][::-1] - - assert left.union(right) == correct_union - assert right.union(left) == correct_union[["b", "a"]][::-1] - - incompatible = db.to_relation("select 1 as a") - with pytest.raises( - TypeError, - match="Union between relations with different column names is not allowed.", - ): - incompatible + right # pyright: ignore - with pytest.raises( - TypeError, - match="Union between relations with different column names is not allowed.", - ): - left + incompatible # pyright: ignore - - -def test_relation_model_functionality(): - """The end-user should be able to specify the constructor for row values.""" - db = pt.duckdb.Database() - - # We have two rows in our relation - first_row_relation = db.to_relation("select 1 as a, 2 as b") - second_row_relation = db.to_relation("select 3 as a, 4 as b") - relation = first_row_relation + second_row_relation - - # Iterating over the relation should yield the same as .get() - iterator_value = tuple(relation)[0] - get_value = relation.get("a = 1") - assert iterator_value == get_value - assert iterator_value.a == 1 - assert get_value.a == 1 - assert iterator_value.b == 2 - assert get_value.b == 2 - - # The end-user should be able to specify a custom row constructor - model_mock = MagicMock(return_value="mock_return") - new_relation = relation.set_model(model_mock) - assert new_relation.get("a = 1") == "mock_return" - model_mock.assert_called_with(a=1, b=2) - - # We create a custom model - class MyModel(pt.Model): - a: int - b: str - - # Some dummy data - dummy_df = MyModel.examples({"a": [1, 2], "b": ["one", "two"]}) - dummy_relation = db.to_relation(dummy_df) - - # Initially the relation has no custom model and it is dynamically constructed - assert dummy_relation.model is None - assert not isinstance( - dummy_relation.limit(1).get(), - MyModel, - ) - - # MyRow can be specified as the deserialization class with Relation.set_model() - assert isinstance( - dummy_relation.set_model(MyModel).limit(1).get(), - MyModel, - ) - - # A custom relation class which specifies this as the default model - class MyRelation(pt.duckdb.Relation): - model = MyModel - - assert isinstance( - MyRelation(dummy_relation._relation, database=db).limit(1).get(), - MyModel, - ) - - # But the model is "lost" when we use schema-changing methods - assert not isinstance( - dummy_relation.set_model(MyModel).limit(1).select("a").get(), - MyModel, - ) - - -def test_row_sql_type_functionality(): - """Tests for mapping pydantic types to DuckDB SQL types.""" - - # Two nullable and two non-nullable columns - class OptionalRow(pt.Model): - a: str - b: float - c: Optional[str] - d: Optional[bool] - - assert OptionalRow.non_nullable_columns == {"a", "b"} - assert OptionalRow.nullable_columns == {"c", "d"} - - # All different types of SQL types - class TypeModel(pt.Model): - a: str - b: int - c: float - d: Optional[bool] - - assert TypeModel.sql_types == { - "a": "VARCHAR", - "b": "INTEGER", - "c": "DOUBLE", - "d": "BOOLEAN", - } - - -def test_fill_missing_columns(): - """Tests for Relation.with_missing_{nullable,defaultable}_columns.""" - - class MyRow(pt.Model): - # This can't be filled - a: str - # This can be filled with default value - b: Optional[str] = "default_value" - # This can be filled with null - c: Optional[str] - # This can be filled with null, but will be set - d: Optional[float] - # This can befilled with null, but with a different type - e: Optional[bool] - - # We check if defaults are easily retrievable from the model - assert MyRow.defaults == {"b": "default_value"} - - db = pt.duckdb.Database() - df = pl.DataFrame({"a": ["mandatory"], "d": [10.5]}) - relation = db.to_relation(df).set_model(MyRow) - - # Missing nullable columns b, c, and e are filled in with nulls - filled_nullables = relation.with_missing_nullable_columns() - assert filled_nullables.set_model(None).get() == { - "a": "mandatory", - "b": None, - "c": None, - "d": 10.5, - "e": None, - } - # And these nulls are properly typed - assert filled_nullables.types == { - "a": "VARCHAR", - "b": "VARCHAR", - "c": "VARCHAR", - "d": "DOUBLE", - "e": "BOOLEAN", - } - - # Now we fill in the b column with "default_value" - filled_defaults = relation.with_missing_defaultable_columns() - assert filled_defaults.set_model(None).get().dict() == { - "a": "mandatory", - "b": "default_value", - "d": 10.5, - } - assert filled_defaults.types == { - "a": "VARCHAR", - "b": "VARCHAR", - "d": "DOUBLE", - } - - # We now exclude the b column from being filled with default values - excluded_default = relation.with_missing_defaultable_columns(exclude=["b"]) - assert excluded_default.set_model(None).get().dict() == { - "a": "mandatory", - "d": 10.5, - } - - # We can also specify that we only want to fill a subset - included_defualts = relation.with_missing_defaultable_columns(include=["b"]) - assert included_defualts.set_model(None).get().dict() == { - "a": "mandatory", - "b": "default_value", - "d": 10.5, - } - - # We now exclude column b and c from being filled with null values - excluded_nulls = relation.with_missing_nullable_columns(exclude=["b", "c"]) - assert excluded_nulls.set_model(None).get().dict() == { - "a": "mandatory", - "d": 10.5, - "e": None, - } - - # Only specify that we want to fill column e with nulls - included_nulls = relation.with_missing_nullable_columns(include=["e"]) - assert included_nulls.set_model(None).get().dict() == { - "a": "mandatory", - "d": 10.5, - "e": None, - } - - # We should raise if both include and exclude is specified - with pytest.raises( - TypeError, match="Both include and exclude provided at the same time!" - ): - relation.with_missing_nullable_columns(include={"x"}, exclude={"y"}) - - with pytest.raises( - TypeError, match="Both include and exclude provided at the same time!" - ): - relation.with_missing_defaultable_columns(include={"x"}, exclude={"y"}) - - -def test_with_missing_nullable_enum_columns(): - """It should produce enums with null values correctly.""" - - class EnumModel(pt.Model): - enum_column: Optional[Literal["a", "b", "c"]] - other_column: int - - db = pt.duckdb.Database() - - # We insert data into a properly typed table in order to get the correct enum type - db.create_table(name="enum_table", model=EnumModel) - db.to_relation("select 'a' as enum_column, 1 as other_column").insert_into( - table="enum_table" - ) - table_relation = db.table("enum_table") - assert str(table_relation.types["enum_column"]).startswith("enum__") - - # We generate another dynamic relation where we expect the correct enum type - null_relation = ( - db.to_relation("select 2 as other_column") - .set_model(EnumModel) - .with_missing_nullable_columns() - ) - assert null_relation.types["enum_column"] == table_relation.types["enum_column"] - - # These two relations should now be unionable - union_relation = (null_relation + table_relation).order("other_column asc") - assert union_relation.types["enum_column"] == table_relation.types["enum_column"] - - with pl.StringCache(): - correct_union_df = pl.DataFrame( - { - "other_column": [1, 2], - "enum_column": pl.Series(["a", None]).cast(pl.Categorical), - } - ) - assert union_relation.to_df().frame_equal(correct_union_df) - - -def test_with_missing_nullable_enum_columns_without_table(): - """It should produce enums with null values correctly without a table.""" - - class EnumModel(pt.Model): - enum_column_1: Optional[Literal["a", "b", "c"]] - enum_column_2: Optional[Literal["a", "b", "c"]] - other_column: int - - # We should be able to create the correct type without a table - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as other_column") - with pytest.raises( - TypeError, match=r".*You should invoke Relation.set_model\(\) first!" - ): - relation.with_missing_nullable_columns() - - model_relation = relation.set_model(EnumModel).with_missing_nullable_columns() - assert str(model_relation.types["enum_column_1"]).startswith("enum__") - assert ( - model_relation.types["enum_column_2"] == model_relation.types["enum_column_1"] - ) - - # And now we should be able to insert it into a new table - model_relation.create_table(name="enum_table") - table_relation = db.table("enum_table") - assert ( - table_relation.types["enum_column_1"] == model_relation.types["enum_column_1"] - ) - assert ( - table_relation.types["enum_column_2"] == model_relation.types["enum_column_1"] - ) - - -def test_with_missing_defualtable_enum_columns(): - """It should produce enums with default values correctly typed.""" - - class EnumModel(pt.Model): - enum_column: Optional[Literal["a", "b", "c"]] = "a" - other_column: int - - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as other_column") - with pytest.raises( - TypeError, - match=r".*You should invoke Relation.set_model\(\) first!", - ): - relation.with_missing_defaultable_columns() - - model_relation = relation.set_model(EnumModel).with_missing_defaultable_columns() - assert str(model_relation.types["enum_column"]).startswith("enum__") - - -def test_relation_insert_into(): - """Relation.insert_into() should automatically order columnns correctly.""" - db = pt.duckdb.Database() - db.execute( - """ - create table foo ( - a integer, - b integer - ) - """ - ) - db.to_relation("select 2 as b, 1 as a").insert_into(table="foo") - row = db.table("foo").get() - assert row.a == 1 - assert row.b == 2 - - with pytest.raises( - TypeError, - match=re.escape( - "Relation is missing column(s) {'a'} " - "in order to be inserted into table 'foo'!" - ), - ): - db.to_relation("select 2 as b, 1 as c").insert_into(table="foo") - - -def test_polars_support(): - # Test converting a polars DataFrame to patito relation - df = pl.DataFrame(data={"column_1": ["a", "b", None], "column_2": [1, 2, None]}) - correct_dtypes = [pl.Utf8, pl.Int64] - assert df.dtypes == correct_dtypes - db = pt.duckdb.Database() - relation = db.to_relation(df) - assert relation.get(column_1="a").column_2 == 1 - - # Test converting back again the other way - roundtrip_df = relation.to_df() - assert roundtrip_df.frame_equal(df) - assert roundtrip_df.dtypes == correct_dtypes - - # Assert that .to_df() always returns a DataFrame. - assert isinstance(relation["column_1"].to_df(), pl.DataFrame) - - # Assert that .to_df() returns an empty DataFrame when the table has no rows - empty_dataframe = relation.filter(column_1="missing-column").to_df() - # assert empty_dataframe == pl.DataFrame(columns=["column_1", "column_2"]) - # assert empty_dataframe.frame_equal(pl.DataFrame(columns=["column_1", "column_2"])) - - # The datatype should be preserved - assert empty_dataframe.dtypes == correct_dtypes - - # A model should be able to be instantiated with a polars row - class MyModel(pt.Model): - a: int - b: str - - my_model_df = pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) - with pytest.raises( - ValueError, - match=r"MyModel._from_polars\(\) can only be invoked with exactly 1 row.*", - ): - MyModel.from_row(my_model_df) - - my_model = MyModel.from_row(my_model_df.head(1)) - assert my_model.a == 1 - assert my_model.b == "x" - - # Anything besides a polars dataframe should raise TypeError - with pytest.raises(TypeError): - MyModel.from_row(None) # pyright: ignore - - # But we can also skip validation if we want - unvalidated_model = MyModel.from_row( - pl.DataFrame().with_columns( - [ - pl.lit("string").alias("a"), - pl.lit(2).alias("b"), - ] - ), - validate=False, - ) - assert unvalidated_model.a == "string" - assert unvalidated_model.b == 2 - - -def test_series_vs_dataframe_behavior(): - """Test Relation.to_series().""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as column_1, 2 as column_2") - - # Selecting multiple columns should yield a DataFrame - assert isinstance(relation[["column_1", "column_2"]].to_df(), pl.DataFrame) - - # Selecting a single column, but as an item in a list, should yield a DataFrame - assert isinstance(relation[["column_1"]].to_df(), pl.DataFrame) - - # Selecting a single column as a string should also yield a DataFrame - assert isinstance(relation["column_1"].to_df(), pl.DataFrame) - - # But .to_series() should yield a series - series = relation["column_1"].to_series() - assert isinstance(series, pl.Series) - - # The name should also be set correctly - assert series.name == "column_1" - - # And the content should be correct - correct_series = pl.Series([1], dtype=pl.Int32).alias("column_1") - assert series.series_equal(correct_series) - - # To series will raise a type error if invoked with anything other than 1 column - with pytest.raises(TypeError, match=r".*2 columns, while exactly 1 is required.*"): - relation.to_series() - - -def test_converting_enum_column_to_polars(): - """Enum types should be convertible to polars categoricals.""" - - class EnumModel(pt.Model): - enum_column: Literal["a", "b", "c"] - - db = pt.duckdb.Database() - db.create_table(name="enum_table", model=EnumModel) - db.execute( - """ - insert into enum_table - (enum_column) - values - ('a'), - ('a'), - ('b'); - """ - ) - enum_df = db.table("enum_table").to_df() - assert enum_df.frame_equal(pl.DataFrame({"enum_column": ["a", "a", "b"]})) - assert enum_df.dtypes == [pl.Categorical] - - -def test_non_string_enum(): - """It should handle other types than just string enums.""" - - class EnumModel(pt.Model): - enum_column: Literal[10, 11, 12] - - db = pt.duckdb.Database() - db.create_table(name="enum_table", model=EnumModel) - - db.execute( - """ - insert into enum_table - (enum_column) - values - (10), - (11), - (12); - """ - ) - enum_df = db.table("enum_table").to_df() - assert enum_df.frame_equal(pl.DataFrame({"enum_column": [10, 11, 12]})) - assert enum_df.dtypes == [pl.Int64] - - -def test_multiple_filters(): - """The filter method should AND multiple filters properly.""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as a, 2 as b") - # The logical or should not make the filter valid for our row - assert relation.filter("(1 = 2) or b = 2", a=0).count() == 0 - assert relation.filter("a=0", "(1 = 2) or b = 2").count() == 0 - - -def test_no_filter(): - """No filters should return all rows.""" - db = pt.duckdb.Database() - relation = db.to_relation("select 1 as a, 2 as b") - # The logical or should not make the filter valid for our row - assert relation.filter().count() - - -def test_string_representation_of_relation(): - """It should have a string representation.""" - relation = pt.duckdb.Relation("select 1 as my_column") - relation_str = str(relation) - assert "my_column" in relation_str - - -def test_cast(): - """It should be able to cast to the correct SQL types based on model.""" - - class Schema(pt.Model): - float_column: float - - relation = pt.duckdb.Relation("select 1 as float_column, 2 as other_column") - with pytest.raises( - TypeError, - match=( - r"Relation\.cast\(\) invoked without Relation.model having been set\! " - r"You should invoke Relation\.set_model\(\) first or explicitly provide " - r"a model to \.cast\(\)." - ), - ): - relation.cast() - - # Originally the type of both columns are integers - modeled_relation = relation.set_model(Schema) - assert modeled_relation.types["float_column"] == "INTEGER" - assert modeled_relation.types["other_column"] == "INTEGER" - - # The casted variant has converted the float column to double - casted_relation = relation.set_model(Schema).cast() - assert casted_relation.types["float_column"] == "DOUBLE" - # But kept the other as-is - assert casted_relation.types["other_column"] == "INTEGER" - - # You can either set the model with .set_model() or provide it to cast - assert ( - relation.set_model(Schema) - .cast() - .to_df() - .frame_equal(relation.cast(Schema).to_df()) - ) - - # Other types that should be considered compatible should be kept as-is - compatible_relation = pt.duckdb.Relation("select 1::FLOAT as float_column") - assert compatible_relation.cast(Schema).types["float_column"] == "FLOAT" - - # Unless the strict parameter is specified - assert ( - compatible_relation.cast(Schema, strict=True).types["float_column"] == "DOUBLE" - ) - - # We can also specify a specific SQL type - class SpecificSQLTypeSchema(pt.Model): - float_column: float = pt.Field(sql_type="BIGINT") - - specific_cast_relation = relation.set_model(SpecificSQLTypeSchema).cast() - assert specific_cast_relation.types["float_column"] == "BIGINT" - - # Unknown types raise - class ObjectModel(pt.Model): - object_column: object - - with pytest.raises( - NotImplementedError, - match=r"No valid sql_type mapping found for column 'object_column'\.", - ): - pt.duckdb.Relation("select 1 as object_column").set_model(ObjectModel).cast() - - # Check for more specific type annotations - class TotalModel(pt.Model): - timedelta_column: timedelta - date_column: date - null_column: None - - df = pt.DataFrame( - { - "date_column": [date(2022, 9, 4)], - "null_column": [None], - } - ) - casted_relation = pt.duckdb.Relation(df, model=TotalModel).cast() - assert casted_relation.types == { - "date_column": "DATE", - "null_column": "INTEGER", - } - assert casted_relation.to_df().frame_equal(df) - - # It is possible to only cast a subset - class MyModel(pt.Model): - column_1: float - column_2: float - - relation = pt.duckdb.Relation("select 1 as column_1, 2 as column_2").set_model( - MyModel - ) - assert relation.cast(include=[]).types == { - "column_1": "INTEGER", - "column_2": "INTEGER", - } - assert relation.cast(include=["column_1"]).types == { - "column_1": "DOUBLE", - "column_2": "INTEGER", - } - assert relation.cast(include=["column_1", "column_2"]).types == { - "column_1": "DOUBLE", - "column_2": "DOUBLE", - } - - assert relation.cast(exclude=[]).types == { - "column_1": "DOUBLE", - "column_2": "DOUBLE", - } - assert relation.cast(exclude=["column_1"]).types == { - "column_1": "INTEGER", - "column_2": "DOUBLE", - } - assert relation.cast(exclude=["column_1", "column_2"]).types == { - "column_1": "INTEGER", - "column_2": "INTEGER", - } - - # Providing both include and exclude should raise a value error - with pytest.raises( - ValueError, - match=r"Both include and exclude provided to Relation.cast\(\)\!", - ): - relation.cast(include=["column_1"], exclude=["column_2"]) - - -@pytest.mark.xfail(strict=True) -def test_casting_timedelta_column_back_and_forth(): - class TotalModel(pt.Model): - timedelta_column: timedelta - date_column: date - null_column: None - - df = pt.DataFrame( - { - "timedelta_column": [timedelta(seconds=90)], - "date_column": [date(2022, 9, 4)], - "null_column": [None], - } - ) - casted_relation = pt.duckdb.Relation(df, model=TotalModel).cast() - assert casted_relation.types == { - "timedelta_column": "INTERVAL", - "date_column": "DATE", - "null_column": "INTEGER", - } - assert casted_relation.to_df().frame_equal(df) diff --git a/tests/test_model.py b/tests/test_model.py index 3c609bd..22e7ead 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import enum import re from datetime import date, datetime, timedelta -from typing import List, Literal, Optional, Type +from typing import Literal, Optional, Type import patito as pt import polars as pl