34 changes: 15 additions & 19 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import BaseBackend
from ibis.backends.base import BaseBackend, CanCreateDatabase
from ibis.backends.clickhouse.compiler import translate
from ibis.backends.clickhouse.datatypes import parse, serialize
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -58,7 +57,7 @@ def insert(self, obj, settings: Mapping[str, Any] | None = None, **kwargs):
return self._client.con.insert_df(self.name, obj, settings=settings, **kwargs)


class Backend(BaseBackend):
class Backend(BaseBackend, CanCreateDatabase):
name = "clickhouse"

# ClickHouse itself does, but the client driver does not
Expand Down Expand Up @@ -194,7 +193,9 @@ def version(self) -> str:

@property
def current_database(self) -> str:
return self.con.database
with closing(self.raw_sql("SELECT currentDatabase()")) as result:
[(db,)] = result.result_rows
return db

def list_databases(self, like: str | None = None) -> list[str]:
with closing(self.raw_sql("SELECT name FROM system.databases")) as result:
Expand Down Expand Up @@ -282,14 +283,9 @@ def to_pyarrow(
external_tables=external_tables,
**kwargs,
) as reader:
t = reader.read_all()
table = reader.read_all()

if isinstance(expr, ir.Scalar):
return t[0][0]
elif isinstance(expr, ir.Column):
return t[0]
else:
return t
return expr.__pyarrow_result__(table)

def to_pyarrow_batches(
self,
Expand Down Expand Up @@ -392,13 +388,11 @@ def execute(
if df.empty:
df = pd.DataFrame(columns=schema.names)

result = PandasData.convert_table(df, schema)
if isinstance(expr, ir.Scalar):
return result.iat[0, 0]
elif isinstance(expr, ir.Column):
return result.iloc[:, 0]
else:
return result
# TODO: remove the extra conversion
#
# the extra __pandas_result__ call is to work around slight differences
# in single column conversion and whole table conversion
return expr.__pandas_result__(table.__pandas_result__(df))

def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any):
table_expr = expr.as_table()
Expand Down Expand Up @@ -473,6 +467,8 @@ def raw_sql(
def fetch_from_cursor(self, cursor, schema):
import pandas as pd

from ibis.formats.pandas import PandasData

df = pd.DataFrame.from_records(iter(cursor), columns=schema.names)
return PandasData.convert_table(df, schema)

Expand Down Expand Up @@ -520,7 +516,7 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
def has_operation(cls, operation: type[ops.Value]) -> bool:
from ibis.backends.clickhouse.compiler.values import translate_val

return operation in translate_val.registry
return translate_val.dispatch(operation) is not translate_val.dispatch(object)

def create_database(
self, name: str, *, force: bool = False, engine: str = "Atomic"
Expand Down
39 changes: 33 additions & 6 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Literal, Mapping

import sqlglot as sg
from sqlglot.dialects.dialect import rename_func
from toolz import flip

import ibis
Expand All @@ -23,6 +24,16 @@
# TODO: Find a way to remove all the dialect="clickhouse" kwargs


# TODO: This is a hack to get around the fact that sqlglot 17.8.6 is broken for
# ClickHouse's isNaN
sg.dialects.clickhouse.ClickHouse.Generator.TRANSFORMS.update(
{
sg.exp.IsNan: rename_func("isNaN"),
sg.exp.StartsWith: rename_func("startsWith"),
}
)


@functools.singledispatch
def translate_val(op, **_):
"""Translate a value expression into sqlglot."""
Expand Down Expand Up @@ -254,7 +265,9 @@ def _xor(op, **kw):
left = _parenthesize(op.left, raw_left)
raw_right = translate_val(op.right, **kw)
right = _parenthesize(op.right, raw_right)
return f"xor({left}, {right})"
# clickhouse has xor but sqlglot's compilation of this function is broken
# in 17.6.0
return f"(({left} or {right}) and not ({left} and {right}))"


@translate_val.register(ops.Arbitrary)
Expand Down Expand Up @@ -386,7 +399,7 @@ def _log(op, **kw):

# base is translated at this point
if has_base:
if base != "2" and base != "10":
if base not in ("2", "10"):
raise ValueError(f"Base {base} for logarithm not supported!")
else:
func += base
Expand All @@ -395,9 +408,9 @@ def _log(op, **kw):


@translate_val.register(tuple)
def _node_list(op, punct="()", **kw):
def _node_list(op, **kw):
values = ", ".join(map(_sql, map(partial(translate_val, **kw), op)))
return f"{punct[0]}{values}{punct[1]}"
return f"({values})"


def _interval_format(op):
Expand Down Expand Up @@ -922,6 +935,12 @@ def _map_get(op, **kw):
return f"if(mapContains({arg}, {key}), {arg}[{key}], {default})"


@translate_val.register(ops.ArrayConcat)
def _array_concat(op, **kw):
args = ", ".join(map(_sql, map(partial(translate_val, **kw), op.arg)))
return f"arrayConcat({args})"


def _binary_infix(symbol: str):
def formatter(op, **kw):
left = translate_val(op_left := op.left, **kw)
Expand Down Expand Up @@ -1054,7 +1073,6 @@ def formatter(op, **kw):
# because clickhouse"s greatest and least doesn"t support varargs
ops.Where: "if",
ops.ArrayLength: "length",
ops.ArrayConcat: "arrayConcat",
ops.Unnest: "arrayJoin",
ops.Degrees: "degrees",
ops.Radians: "radians",
Expand Down Expand Up @@ -1393,7 +1411,7 @@ def _array_remove(op, **kw):

@translate_val.register(ops.ArrayUnion)
def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw)
return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw)


@translate_val.register(ops.ArrayZip)
Expand All @@ -1405,3 +1423,12 @@ def _array_zip(op: ops.ArrayZip, **kw: Any) -> str:
sql_arg = sql_arg.sql(dialect="clickhouse")
arglist.append(sql_arg)
return f"arrayZip({', '.join(arglist)})"


@translate_val.register(ops.CountDistinctStar)
def _count_distinct_star(op: ops.CountDistinctStar, **kw: Any) -> str:
column_list = ", ".join(map(_sql, map(sg.column, op.arg.schema.names)))
if op.where is not None:
return f"countDistinctIf(({column_list}), {translate_val(op.where, **kw)})"
else:
return f"countDistinct(({column_list}))"
76 changes: 29 additions & 47 deletions ibis/backends/clickhouse/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import contextlib
import os
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable, Iterable

import pytest

import ibis
import ibis.expr.types as ir
from ibis import util
from ibis.backends.tests.base import (
RoundHalfToEven,
ServiceBackendTest,
ServiceSpec,
UnorderedComparator,
)

Expand All @@ -32,28 +32,22 @@ class TestConf(UnorderedComparator, ServiceBackendTest, RoundHalfToEven):
supported_to_timestamp_units = {'s'}
supports_floating_modulus = False
supports_json = False
data_volume = "/var/lib/clickhouse/user_files/ibis"
service_name = "clickhouse"
deps = ("clickhouse_connect",)

@property
def native_bool(self) -> bool:
[(value,)] = self.connection.con.query("SELECT true").result_set
return isinstance(value, bool)

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name=cls.name(),
data_volume="/var/lib/clickhouse/user_files/ibis",
files=data_dir.joinpath("parquet").glob("*.parquet"),
)
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("parquet").glob("*.parquet")

@staticmethod
def _load_data(
data_dir: Path,
script_dir: Path,
host: str = CLICKHOUSE_HOST,
port: int = CLICKHOUSE_PORT,
user: str = CLICKHOUSE_USER,
password: str = CLICKHOUSE_PASS,
self,
*,
database: str = IBIS_TEST_CLICKHOUSE_DB,
**_,
) -> None:
Expand All @@ -66,35 +60,28 @@ def _load_data(
script_dir
Location of scripts defining schemas
"""
cc = pytest.importorskip("clickhouse_connect")

client = cc.get_client(
host=host,
port=port,
user=user,
password=password,
settings={
"allow_experimental_object_type": 1,
"output_format_json_named_tuples_as_objects": 1,
},
)
import clickhouse_connect as cc

con = self.connection
client = con.con

with contextlib.suppress(cc.driver.exceptions.DatabaseError):
client.command(f"CREATE DATABASE {database} ENGINE = Atomic")

with open(script_dir / 'schema' / 'clickhouse.sql') as schema:
for stmt in filter(None, map(str.strip, schema.read().split(";"))):
client.command(stmt)
util.consume(map(client.command, self.ddl_script))

def postload(self, **kw: Any):
# reconnect to set the database to the test database
self.connection = self.connect(database=IBIS_TEST_CLICKHOUSE_DB, **kw)

@staticmethod
def connect(data_directory: Path):
pytest.importorskip("clickhouse_connect")
def connect(*, tmpdir, worker_id, **kw: Any):
return ibis.clickhouse.connect(
host=CLICKHOUSE_HOST,
port=CLICKHOUSE_PORT,
password=CLICKHOUSE_PASS,
database=IBIS_TEST_CLICKHOUSE_DB,
user=CLICKHOUSE_USER,
**kw,
)

@staticmethod
Expand All @@ -114,27 +101,22 @@ def least(f: Callable[..., ir.Value], *args: ir.Value) -> ir.Value:
return f(*args)


@pytest.fixture(scope='module')
def con(tmp_path_factory, data_directory, script_directory, worker_id):
return TestConf.load_data(
data_directory,
script_directory,
tmp_path_factory,
worker_id,
).connect(data_directory)
@pytest.fixture(scope='session')
def con(tmp_path_factory, data_dir, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='module')
@pytest.fixture(scope='session')
def db(con):
return con.database()


@pytest.fixture(scope='module')
def alltypes(db):
return db.functional_alltypes
@pytest.fixture(scope='session')
def alltypes(con):
return con.tables.functional_alltypes


@pytest.fixture(scope='module')
@pytest.fixture(scope='session')
def df(alltypes):
return alltypes.execute()

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/tests/test_aggregations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from operator import methodcaller

import numpy as np
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pandas as pd
import pandas.testing as tm
import pyarrow as pa
Expand Down Expand Up @@ -72,8 +74,9 @@ def logger(x):

expected = 'DESCRIBE ibis_testing.functional_alltypes'

assert len(queries) == 1
assert queries[0] == expected
# might be other queries in there, we only check that a describe table
# query was logged
assert expected in queries


def test_sql_query_limits(alltypes):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
from datetime import date, datetime
from operator import methodcaller
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/tests/test_literals.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest
from pandas import Timestamp
from pytest import param
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import operator
from datetime import date, datetime

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pandas as pd
import pandas.testing as tm
import pytest
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/clickhouse/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest
from pytest import param

Expand Down
204 changes: 62 additions & 142 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import importlib
import importlib.metadata
import itertools
import os
import platform
import sys
from functools import lru_cache
from pathlib import Path
from typing import Any, TextIO
from typing import Any, Iterable

import _pytest
import numpy as np
Expand All @@ -20,17 +18,10 @@
from packaging.version import parse as vparse

import ibis
import ibis.common.exceptions as com
from ibis import util
from ibis.backends.base import _get_backend_names

SANDBOXED = (
any(key.startswith("NIX_") for key in os.environ)
and os.environ.get("IN_NIX_SHELL") != "impure"
)
LINUX = platform.system() == "Linux"
MACOS = platform.system() == "Darwin"
WINDOWS = platform.system() == "Windows"
CI = os.environ.get("CI") is not None
from ibis.backends.base import CanCreateDatabase, CanCreateSchema, _get_backend_names
from ibis.conftest import WINDOWS

TEST_TABLES = {
"functional_alltypes": ibis.schema(
Expand Down Expand Up @@ -126,19 +117,7 @@


@pytest.fixture(scope='session')
def script_directory() -> Path:
"""Return the test script directory.
Returns
-------
Path
Test script directory
"""
return Path(__file__).absolute().parents[2] / "ci"


@pytest.fixture(scope='session')
def data_directory() -> Path:
def data_dir() -> Path:
"""Return the test data directory.
Returns
Expand Down Expand Up @@ -178,7 +157,7 @@ def recreate_database(
def init_database(
url: sa.engine.url.URL,
database: str,
schema: TextIO | None = None,
schema: Iterable[str] | None = None,
recreate: bool = True,
isolation_level: str | None = "AUTOCOMMIT",
**kwargs: Any,
Expand Down Expand Up @@ -220,11 +199,7 @@ def init_database(

if schema:
with engine.begin() as conn:
for stmt in filter(
None,
map(str.strip, schema.read().split(';')),
):
conn.exec_driver_sql(stmt)
util.consume(map(conn.exec_driver_sql, schema))

return engine

Expand Down Expand Up @@ -398,56 +373,6 @@ def pytest_runtest_call(item):

backend = next(iter(backend))

for marker in item.iter_markers(name="min_server_version"):
kwargs = marker.kwargs
if backend not in kwargs:
continue

funcargs = item.funcargs
con = funcargs.get(
"con",
getattr(
funcargs.get("backend"),
"connection",
None,
),
)

if con is None:
continue

min_server_version = kwargs.pop(backend)
server_version = con.version
condition = vparse(server_version) < vparse(min_server_version)
item.add_marker(
pytest.mark.xfail(
condition,
reason=(
f"unsupported functionality for server version {server_version}"
),
**kwargs,
)
)

for marker in item.iter_markers(name="min_version"):
kwargs = marker.kwargs
if backend not in kwargs:
continue

min_version = kwargs.pop(backend)
reason = kwargs.pop("reason", None)
version = getattr(importlib.import_module(backend), "__version__", None)
if condition := version is None: # pragma: no cover
if reason is None:
reason = f"{backend} backend module has no __version__ attribute"
else:
condition = vparse(version) < vparse(min_version)
if reason is None:
reason = f"test requires {backend}>={version}; got version {version}"
else:
reason = f"{backend}@{version} (<{min_version}): {reason}"
item.add_marker(pytest.mark.xfail(condition, reason=reason, **kwargs))

# Ibis hasn't exposed existing functionality
# This xfails so that you know when it starts to pass
for marker in item.iter_markers(name="notimpl"):
Expand All @@ -457,13 +382,9 @@ def pytest_runtest_call(item):
and "raises" not in marker.kwargs.keys()
):
raise ValueError("notimpl requires a raises")
reason = marker.kwargs.get("reason")
item.add_marker(
pytest.mark.xfail(
reason=reason or f'Feature not yet exposed in {backend}',
**{k: v for k, v in marker.kwargs.items() if k != "reason"},
)
)
kwargs = marker.kwargs.copy()
kwargs.setdefault("reason", f"Feature not yet exposed in {backend}")
item.add_marker(pytest.mark.xfail(**kwargs))

# Functionality is unavailable upstream (but could be)
# This xfails so that you know when it starts to pass
Expand All @@ -474,23 +395,16 @@ def pytest_runtest_call(item):
and "raises" not in marker.kwargs.keys()
):
raise ValueError("notyet requires a raises")
reason = marker.kwargs.get("reason")
item.add_marker(
pytest.mark.xfail(
reason=reason or f'Feature not available upstream for {backend}',
**{k: v for k, v in marker.kwargs.items() if k != "reason"},
)
)

kwargs = marker.kwargs.copy()
kwargs.setdefault("reason", f"Feature not available upstream for {backend}")
item.add_marker(pytest.mark.xfail(**kwargs))

for marker in item.iter_markers(name="never"):
if backend in marker.args[0]:
if "reason" not in marker.kwargs.keys():
raise ValueError("never requires a reason")
item.add_marker(
pytest.mark.xfail(
**marker.kwargs,
)
)
item.add_marker(pytest.mark.xfail(**marker.kwargs))

# Something has been exposed as broken by a new test and it shouldn't be
# imperative for a contributor to fix it just because they happened to
Expand All @@ -502,16 +416,13 @@ def pytest_runtest_call(item):
and "raises" not in marker.kwargs.keys()
):
raise ValueError("broken requires a raises")
reason = marker.kwargs.get("reason")
item.add_marker(
pytest.mark.xfail(
reason=reason or f"Feature is failing on {backend}",
**{k: v for k, v in marker.kwargs.items() if k != "reason"},
)
)

kwargs = marker.kwargs.copy()
kwargs.setdefault("reason", f"Feature is failing on {backend}")
item.add_marker(pytest.mark.xfail(**kwargs))

for marker in item.iter_markers(name="xfail_version"):
kwargs = marker.kwargs
kwargs = marker.kwargs.copy()
if backend not in kwargs:
continue

Expand All @@ -530,11 +441,11 @@ def pytest_runtest_call(item):


@pytest.fixture(params=_get_backends_to_test(), scope='session')
def backend(request, data_directory, script_directory, tmp_path_factory, worker_id):
def backend(request, data_dir, tmp_path_factory, worker_id):
"""Return an instance of BackendTest, loaded with data."""

cls = _get_backend_conf(request.param)
return cls.load_data(data_directory, script_directory, tmp_path_factory, worker_id)
return cls.load_data(data_dir, tmp_path_factory, worker_id)


@pytest.fixture(scope="session")
Expand All @@ -543,31 +454,48 @@ def con(backend):
return backend.connection


def _setup_backend(
request, data_directory, script_directory, tmp_path_factory, worker_id
):
if (backend := request.param) == "duckdb" and platform.system() == "Windows":
@pytest.fixture(scope='session')
def con_create_database(con):
if isinstance(con, CanCreateDatabase):
return con
else:
pytest.skip(f"{con.name} backend cannot create databases")


@pytest.fixture(scope='session')
def con_create_schema(con):
if isinstance(con, CanCreateSchema):
return con
else:
pytest.skip(f"{con.name} backend cannot create schemas")


@pytest.fixture(scope='session')
def con_create_database_schema(con):
if isinstance(con, CanCreateDatabase) and isinstance(con, CanCreateSchema):
return con
else:
pytest.skip(f"{con.name} backend cannot create both database and schemas")


def _setup_backend(request, data_dir, tmp_path_factory, worker_id):
if (backend := request.param) == "duckdb" and WINDOWS:
pytest.xfail(
"windows prevents two connections to the same duckdb file "
"even in the same process"
)
return None
else:
cls = _get_backend_conf(backend)
return cls.load_data(
data_directory, script_directory, tmp_path_factory, worker_id
)
return cls.load_data(data_dir, tmp_path_factory, worker_id)


@pytest.fixture(
params=_get_backends_to_test(discard=("dask", "pandas")),
scope='session',
)
def ddl_backend(request, data_directory, script_directory, tmp_path_factory, worker_id):
def ddl_backend(request, data_dir, tmp_path_factory, worker_id):
"""Set up the backends that are SQL-based."""
return _setup_backend(
request, data_directory, script_directory, tmp_path_factory, worker_id
)
return _setup_backend(request, data_dir, tmp_path_factory, worker_id)


@pytest.fixture(scope='session')
Expand All @@ -591,13 +519,9 @@ def ddl_con(ddl_backend):
),
scope='session',
)
def alchemy_backend(
request, data_directory, script_directory, tmp_path_factory, worker_id
):
def alchemy_backend(request, data_dir, tmp_path_factory, worker_id):
"""Set up the SQLAlchemy-based backends."""
return _setup_backend(
request, data_directory, script_directory, tmp_path_factory, worker_id
)
return _setup_backend(request, data_dir, tmp_path_factory, worker_id)


@pytest.fixture(scope='session')
Expand All @@ -610,10 +534,10 @@ def alchemy_con(alchemy_backend):
params=_get_backends_to_test(keep=("dask", "pandas", "pyspark")),
scope='session',
)
def udf_backend(request, data_directory, script_directory, tmp_path_factory, worker_id):
def udf_backend(request, data_dir, tmp_path_factory, worker_id):
"""Runs the UDF-supporting backends."""
cls = _get_backend_conf(request.param)
return cls.load_data(data_directory, script_directory, tmp_path_factory, worker_id)
return cls.load_data(data_dir, tmp_path_factory, worker_id)


@pytest.fixture(scope='session')
Expand Down Expand Up @@ -772,12 +696,6 @@ def temp_view(ddl_con) -> str:
ddl_con.drop_view(name, force=True)


@pytest.fixture(scope='session')
def current_data_db(ddl_con) -> str:
"""Return current database name."""
return ddl_con.current_database


@pytest.fixture
def alternate_current_database(ddl_con, ddl_backend) -> str:
"""Create a temporary database and yield its name. Drops the created
Expand All @@ -786,18 +704,20 @@ def alternate_current_database(ddl_con, ddl_backend) -> str:
Parameters
----------
ddl_con : ibis.backends.base.Client
current_data_db : str
Yields
-------
------
str
"""
name = util.gen_name('database')
try:
ddl_con.create_database(name)
except NotImplementedError:
pytest.skip(f"{ddl_backend.name()} doesn't have create_database method.")
except AttributeError:
pytest.skip(f"{ddl_backend.name()} doesn't have a `create_database` method.")
yield name
ddl_con.drop_database(name, force=True)

with contextlib.suppress(com.UnsupportedOperationError):
ddl_con.drop_database(name, force=True)


@pytest.fixture
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/dask/aggcontext.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import operator
from typing import Any, Callable, Dict, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, Union

import dask.dataframe as dd
from dask.dataframe.groupby import SeriesGroupBy

import ibis
from ibis.backends.pandas.aggcontext import (
Expand All @@ -15,6 +16,9 @@
)
from ibis.backends.pandas.aggcontext import Transform as PandasTransform

if TYPE_CHECKING:
from dask.dataframe.groupby import SeriesGroupBy

# TODO Consolidate this logic with the pandas aggcontext.
# This file is almost a direct port of the pandas aggcontext.
# https://github.com/ibis-project/ibis/issues/5911
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def execute_and_reset(
kwargs : Dict[str, object]
Additional arguments that can potentially be used by individual node
execution
Returns
-------
result : Union[
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/dask/execution/arrays.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
from functools import partial

import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
Expand Down Expand Up @@ -51,3 +52,8 @@ def execute_array_collect(op, data, where, aggcontext=None, **kwargs):
@execute_node.register(ops.ArrayCollect, ddgb.SeriesGroupBy, type(None))
def execute_array_collect_grouped_series(op, data, where, **kwargs):
return data.agg(collect_list)


@execute_node.register(ops.ArrayConcat, tuple)
def execute_array_concat(op, args, **kwargs):
return execute_node(op, *map(partial(execute, **kwargs), args), **kwargs)
10 changes: 10 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
execute_between,
execute_cast_series_array,
execute_cast_series_generic,
execute_count_distinct_star_frame,
execute_count_distinct_star_frame_filter,
execute_count_star_frame,
execute_count_star_frame_filter,
execute_count_star_frame_groupby,
Expand Down Expand Up @@ -107,6 +109,14 @@
((dd.DataFrame, type(None)), execute_count_star_frame),
((dd.DataFrame, dd.Series), execute_count_star_frame_filter),
],
ops.CountDistinctStar: [
(
(ddgb.DataFrameGroupBy, type(None)),
execute_count_star_frame_groupby,
),
((dd.DataFrame, type(None)), execute_count_distinct_star_frame),
((dd.DataFrame, dd.Series), execute_count_distinct_star_frame_filter),
],
ops.NullIfZero: [((dd.Series,), execute_null_if_zero_series)],
ops.Between: [
(
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/execution/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compute_projection(
return data

assert isinstance(parent.table, ops.Join)
assert node == parent.table.left or node == parent.table.right
assert node in (parent.table.left, parent.table.right)

mapping = remap_overlapping_column_names(
parent.table,
Expand Down
28 changes: 17 additions & 11 deletions ibis/backends/dask/execution/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,25 @@ def compute_sort_key(
`execute` the expression and sort by the new derived column.
"""
name = ibis.util.guid()
if key.name in data:
return name, data[key.name]
if isinstance(key, str):
return key, None
if key.output_shape.is_columnar():
if key.name in data:
return name, data[key.name]
if isinstance(key, str):
return key, None
else:
if scope is None:
scope = Scope()
scope = scope.merge_scopes(
Scope({t: data}, timecontext)
for t in an.find_immediate_parent_tables(key)
)
new_column = execute(key, scope=scope, **kwargs)
new_column.name = name
return name, new_column
else:
if scope is None:
scope = Scope()
scope = scope.merge_scopes(
Scope({t: data}, timecontext) for t in an.find_immediate_parent_tables(key)
raise NotImplementedError(
"Scalar sort keys are not yet supported in the dask backend"
)
new_column = execute(key, scope=scope, **kwargs)
new_column.name = name
return name, new_column


def compute_sorted_frame(
Expand Down
86 changes: 35 additions & 51 deletions ibis/backends/dask/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import Any

import dask
import pandas as pd
import pandas.testing as tm
import pytest

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.pandas.tests.conftest import TestConf as PandasTest
from ibis.backends.tests.data import array_types, win

if TYPE_CHECKING:
from pathlib import Path

dd = pytest.importorskip("dask.dataframe")
from ibis.backends.tests.data import array_types, json_types, win

# FIXME Dask issue with non deterministic groupby results, relates to the
# shuffle method on a local cluster. Manually setting the shuffle method
# avoids the issue https://github.com/dask/dask/issues/10034.
dask.config.set({"dataframe.shuffle.method": "tasks"})

# TODO: support pyarrow string column types across ibis
dask.config.set({"dataframe.convert-string": False})

# It's necessary that NPARTITIONS > 1 in order to test cross partitioning bugs.
NPARTITIONS = 2

Expand All @@ -32,52 +31,35 @@ def npartitions():

class TestConf(PandasTest):
supports_structs = False
deps = ("dask.dataframe",)

@staticmethod
def connect(data_directory: Path):
# Note - we use `dd.from_pandas(pd.read_csv(...))` instead of
# `dd.read_csv` due to https://github.com/dask/dask/issues/6970

return ibis.dask.connect(
{
"functional_alltypes": dd.from_pandas(
pd.read_parquet(
data_directory / "parquet" / "functional_alltypes.parquet"
),
npartitions=NPARTITIONS,
),
"batting": dd.from_pandas(
pd.read_parquet(data_directory / "parquet" / "batting.parquet"),
npartitions=NPARTITIONS,
),
"awards_players": dd.from_pandas(
pd.read_parquet(
data_directory / "parquet" / "awards_players.parquet"
),
npartitions=NPARTITIONS,
),
'diamonds': dd.from_pandas(
pd.read_parquet(data_directory / "parquet" / "diamonds.parquet"),
npartitions=NPARTITIONS,
),
'json_t': dd.from_pandas(
pd.DataFrame(
{
"js": [
'{"a": [1,2,3,4], "b": 1}',
'{"a":null,"b":2}',
'{"a":"foo", "c":null}',
"null",
"[42,47,55]",
"[]",
]
}
),
npartitions=NPARTITIONS,
),
"win": dd.from_pandas(win, npartitions=NPARTITIONS),
"array_types": dd.from_pandas(array_types, npartitions=NPARTITIONS),
}
def connect(*, tmpdir, worker_id, **kw):
return ibis.dask.connect(**kw)

def _load_data(self, **_: Any) -> None:
import dask.dataframe as dd

con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.create_table(
table_name,
dd.from_pandas(pd.read_parquet(path), npartitions=NPARTITIONS),
)

con.create_table(
"array_types",
dd.from_pandas(array_types, npartitions=NPARTITIONS),
overwrite=True,
)
con.create_table(
"win", dd.from_pandas(win, npartitions=NPARTITIONS), overwrite=True
)
con.create_table(
"json_t",
dd.from_pandas(json_types, npartitions=NPARTITIONS),
overwrite=True,
)

@classmethod
Expand All @@ -93,6 +75,8 @@ def assert_series_equal(

@pytest.fixture
def dataframe(npartitions):
dd = pytest.importorskip("dask.dataframe")

return dd.from_pandas(
pd.DataFrame(
{
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/dask/tests/execution/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def df(npartitions):


@pytest.fixture(scope='module')
def batting_df(data_directory):
df = dd.read_parquet(data_directory / 'parquet' / 'batting.parquet')
def batting_df(data_dir):
df = dd.read_parquet(data_dir / 'parquet' / 'batting.parquet')
# Dask dataframe thinks the columns are of type int64,
# but when computed they are all float64.
non_float_cols = ['playerID', 'yearID', 'stint', 'teamID', 'lgID', 'G']
Expand All @@ -73,8 +73,8 @@ def batting_df(data_directory):


@pytest.fixture(scope='module')
def awards_players_df(data_directory):
return dd.read_parquet(data_directory / 'parquet' / 'awards_players.parquet')
def awards_players_df(data_dir):
return dd.read_parquet(data_dir / 'parquet' / 'awards_players.parquet')


@pytest.fixture(scope='module')
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_arrays.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import operator

import numpy as np
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_cast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import decimal

import pytest
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import decimal
import functools
import math
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_join.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pandas as pd
import pytest
from pandas import Timedelta, date_range
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_maps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
import pandas as pd
import pytest
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import operator
from operator import methodcaller

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_strings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from warnings import catch_warnings

import pytest
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_structs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections import OrderedDict

import pandas as pd
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_temporal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
from operator import methodcaller

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest

from ibis.backends.dask.execution.util import assert_identical_grouping_keys
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/execution/test_window.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import io
from datetime import date
from operator import methodcaller
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import re

import dask.dataframe as dd
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest
from dask.dataframe.utils import tm

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/test_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest
from multipledispatch import Dispatcher

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections

import numpy as np
Expand Down
108 changes: 64 additions & 44 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Mapping

import datafusion
import pyarrow as pa

import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend
from ibis.backends.base import BaseBackend, CanCreateDatabase, CanCreateSchema
from ibis.backends.datafusion.compiler import translate
from ibis.util import gen_name, normalize_filename

Expand All @@ -21,14 +21,17 @@
except ImportError:
from datafusion import SessionContext

import datafusion
try:
from datafusion import SessionConfig
except ImportError:
SessionConfig = None

if TYPE_CHECKING:
import pandas as pd


class Backend(BaseBackend):
name = 'datafusion'
class Backend(BaseBackend, CanCreateDatabase, CanCreateSchema):
name = "datafusion"
builder = None
supports_in_memory_tables = False

Expand All @@ -39,8 +42,7 @@ def version(self):
return importlib.metadata.version("datafusion")

def do_connect(
self,
config: Mapping[str, str | Path] | SessionContext | None = None,
self, config: Mapping[str, str | Path] | SessionContext | None = None
) -> None:
"""Create a Datafusion backend for use with Ibis.
Expand All @@ -58,18 +60,65 @@ def do_connect(
if isinstance(config, SessionContext):
self._context = config
else:
self._context = SessionContext()
if SessionConfig is not None:
df_config = SessionConfig().with_information_schema(True)
else:
df_config = None
self._context = SessionContext(df_config)

config = config or {}
if not config:
config = {}

for name, path in config.items():
self.register(path, table_name=name)

@property
def current_database(self) -> str:
raise NotImplementedError()

@property
def current_schema(self) -> str:
return NotImplementedError()

def list_databases(self, like: str | None = None) -> list[str]:
raise NotImplementedError()
code = "SELECT DISTINCT table_catalog FROM information_schema.tables"
if like:
code += f" WHERE table_catalog LIKE {like!r}"
result = self._context.sql(code).to_pydict()
return result["table_catalog"]

def create_database(self, name: str, force: bool = False) -> None:
code = "CREATE DATABASE"
if force:
code += " IF NOT EXISTS"
code += f" {name}"
self._context.sql(code)

def drop_database(self, name: str, force: bool = False) -> None:
raise com.UnsupportedOperationError(
"DataFusion does not support dropping databases"
)

def list_schemas(self, like: str | None = None) -> list[str]:
return self._filter_with_like(self._context.catalog().names(), like=like)

def create_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
create_stmt = "CREATE SCHEMA"
if force:
create_stmt += " IF NOT EXISTS"

create_stmt += " "
create_stmt += ".".join(filter(None, [database, name]))
self._context.sql(create_stmt)

def drop_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
raise com.UnsupportedOperationError(
"DataFusion does not support dropping schemas"
)

def list_tables(
self,
Expand Down Expand Up @@ -99,7 +148,7 @@ def table(self, name: str, schema: sch.Schema | None = None) -> ir.Table:
A table expression
"""
catalog = self._context.catalog()
database = catalog.database('public')
database = catalog.database()
table = database.table(name)
schema = sch.schema(table.schema)
return ops.DatabaseTable(name, schema, self).to_expr()
Expand Down Expand Up @@ -292,27 +341,7 @@ def _get_frame(
limit: int | str | None = None,
**kwargs: Any,
) -> datafusion.DataFrame:
if isinstance(expr, ir.Table):
return self.compile(expr, params, **kwargs)
elif isinstance(expr, ir.Column):
# expression must be named for the projection
expr = expr.as_table()
return self.compile(expr, params, **kwargs)
elif isinstance(expr, ir.Scalar):
if an.find_immediate_parent_tables(expr.op()):
# there are associated datafusion tables so convert the expr
# to a selection which we can directly convert to a datafusion
# plan
expr = expr.as_table()
frame = self.compile(expr, params, **kwargs)
else:
# doesn't have any tables associated so create a plan from a
# dummy datafusion table
compiled = self.compile(expr, params, **kwargs)
frame = self._context.empty_table().select(compiled)
return frame
else:
raise com.IbisError(f"Cannot execute expression of type: {type(expr)}")
return self.compile(expr.as_table(), params, **kwargs)

def to_pyarrow_batches(
self,
Expand All @@ -334,25 +363,16 @@ def execute(
limit: int | str | None = "default",
**kwargs: Any,
):
output = self.to_pyarrow(expr, params=params, limit=limit, **kwargs)
if isinstance(expr, ir.Table):
return output.to_pandas()
elif isinstance(expr, ir.Column):
series = output.to_pandas()
series.name = expr.get_name()
return series
elif isinstance(expr, ir.Scalar):
return output.as_py()
else:
raise com.IbisError(f"Cannot execute expression of type: {type(expr)}")
output = self.to_pyarrow(expr.as_table(), params=params, limit=limit, **kwargs)
return expr.__pandas_result__(output.to_pandas(timestamp_as_object=True))

def compile(
self,
expr: ir.Expr,
params: Mapping[ir.Expr, object] | None = None,
**kwargs: Any,
):
return translate(expr.op())
return translate(expr.op(), ctx=self._context)

@classmethod
@lru_cache
Expand Down
602 changes: 424 additions & 178 deletions ibis/backends/datafusion/compiler.py

Large diffs are not rendered by default.

75 changes: 16 additions & 59 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,43 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import pytest

import ibis
import ibis.expr.types as ir
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero

if TYPE_CHECKING:
from pathlib import Path

pa = pytest.importorskip("pyarrow")


class TestConf(BackendTest, RoundAwayFromZero):
# check_names = False
# additional_skipped_operations = frozenset({ops.StringSQLLike})
# supports_divide_by_zero = True
# returned_timestamp_unit = 'ns'
native_bool = False
supports_structs = False
supports_json = False
supports_arrays = False
stateful = False
deps = ("datafusion",)

@staticmethod
def connect(data_directory: Path):
# can be various types:
# pyarrow.RecordBatch
# parquet file path
# csv file path
client = ibis.datafusion.connect({})
client.register(
data_directory / "csv" / 'functional_alltypes.csv',
table_name='functional_alltypes',
schema=pa.schema(
[
('id', 'int64'),
('bool_col', 'int8'),
('tinyint_col', 'int8'),
('smallint_col', 'int16'),
('int_col', 'int32'),
('bigint_col', 'int64'),
('float_col', 'float32'),
('double_col', 'float64'),
('date_string_col', 'string'),
('string_col', 'string'),
('timestamp_col', 'string'),
('year', 'int64'),
('month', 'int64'),
]
),
)
client.register(
data_directory / "parquet" / 'batting.parquet', table_name='batting'
)
client.register(
data_directory / "parquet" / 'awards_players.parquet',
table_name='awards_players',
)
client.register(
data_directory / "parquet" / 'diamonds.parquet', table_name='diamonds'
)
return client
def _load_data(self, **_: Any) -> None:
con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.register(path, table_name=table_name)

@property
def functional_alltypes(self) -> ir.Table:
t = self.connection.table('functional_alltypes')
return t.mutate(
bool_col=t.bool_col == 1,
timestamp_col=t.timestamp_col.cast('timestamp'),
)
@staticmethod
def connect(*, tmpdir, worker_id, **kw):
return ibis.datafusion.connect(**kw)


@pytest.fixture(scope='session')
def client(data_directory):
return TestConf.connect(data_directory)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope='session')
def alltypes(client):
return client.table("functional_alltypes")
def alltypes(con):
return con.table("functional_alltypes")


@pytest.fixture(scope='session')
Expand Down
10 changes: 6 additions & 4 deletions ibis/backends/datafusion/tests/test_register.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
Expand All @@ -11,13 +13,13 @@ def conn():
return ibis.datafusion.connect()


def test_read_csv(conn, data_directory):
t = conn.read_csv(data_directory / "csv" / "functional_alltypes.csv")
def test_read_csv(conn, data_dir):
t = conn.read_csv(data_dir / "csv" / "functional_alltypes.csv")
assert t.count().execute()


def test_read_parquet(conn, data_directory):
t = conn.read_parquet(data_directory / "parquet" / "functional_alltypes.parquet")
def test_read_parquet(conn, data_dir):
t = conn.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet")
assert t.count().execute()


Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/datafusion/tests/test_select.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest

from ibis.backends.datafusion.tests.conftest import BackendTest
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/datafusion/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pandas.testing as tm
import pytest

Expand Down
49 changes: 32 additions & 17 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Iterable

import sqlalchemy as sa
from pydruid.db.sqlalchemy import DruidDialect

import ibis.backends.druid.datatypes as ddt
import ibis.expr.datatypes as dt
Expand All @@ -21,6 +20,11 @@ class Backend(BaseAlchemyBackend):
compiler = DruidCompiler
supports_create_or_replace = False

@property
def current_database(self) -> str:
# https://druid.apache.org/docs/latest/querying/sql-metadata-tables.html#schemata-table
return "druid"

def do_connect(
self,
host: str = "localhost",
Expand Down Expand Up @@ -50,27 +54,38 @@ def do_connect(
# workaround a broken pydruid `has_table` implementation
engine.dialect.has_table = self._has_table

@staticmethod
def _new_sa_metadata():
meta = sa.MetaData()

@sa.event.listens_for(meta, "column_reflect")
def column_reflect(inspector, table, column_info):
if isinstance(typ := column_info["type"], sa.DateTime):
column_info["type"] = ddt.DruidDateTime()
elif isinstance(typ, (sa.LargeBinary, sa.BINARY, sa.VARBINARY)):
column_info["type"] = ddt.DruidBinary()
elif isinstance(typ, sa.String):
column_info["type"] = ddt.DruidString()

return meta

@contextlib.contextmanager
def _safe_raw_sql(self, query, *args, **kwargs):
if not isinstance(query, str):
query = str(
query.compile(
dialect=DruidDialect(), compile_kwargs=dict(literal_binds=True)
)
)
query = query.compile(
dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True)
)

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Dialect druid:rest will not make use of SQL compilation caching",
category=sa.exc.SAWarning,
)
with self.begin() as con:
yield con.exec_driver_sql(query, *args, **kwargs)
yield con.execute(query, *args, **kwargs)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
query = f"EXPLAIN PLAN FOR {query}"
with self.begin() as con:
result = con.exec_driver_sql(query).scalar()
result = self._scalar_query(f"EXPLAIN PLAN FOR {query}")

(plan,) = json.loads(result)
for column in plan["signature"]:
Expand All @@ -87,12 +102,12 @@ def _get_temp_view_definition(
raise NotImplementedError()

def _has_table(self, connection, table_name: str, schema) -> bool:
query = sa.text(
"""\
SELECT COUNT(*) > 0 as c
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = :table_name"""
).bindparams(table_name=table_name)
t = sa.table(
"TABLES", sa.column("TABLE_NAME", sa.TEXT), schema="INFORMATION_SCHEMA"
)
query = sa.select(
sa.func.sum(sa.cast(t.c.TABLE_NAME == table_name, sa.INTEGER))
).compile(dialect=self.con.dialect)

return bool(connection.execute(query).scalar())

Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations

import contextlib

import sqlalchemy as sa

import ibis.backends.druid.datatypes as ddt
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.druid.registry import operation_registry

Expand All @@ -9,6 +14,14 @@ class DruidExprTranslator(AlchemyExprTranslator):
_rewrites = AlchemyExprTranslator._rewrites.copy()
_dialect_name = "druid"

type_mapper = ddt.DruidType

def translate(self, op):
result = super().translate(op)
with contextlib.suppress(AttributeError):
result = result.scalar_subquery()
return sa.type_coerce(result, self.type_mapper.from_ibis(op.output_dtype))


rewrites = DruidExprTranslator.rewrites

Expand Down
55 changes: 55 additions & 0 deletions ibis/backends/druid/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,45 @@

import parsy
import sqlalchemy as sa
import sqlalchemy.types as sat
from dateutil.parser import parse as timestamp_parse
from sqlalchemy.ext.compiler import compiles

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.common.parsing import (
LANGLE,
RANGLE,
spaceless_string,
)


class DruidDateTime(sat.TypeDecorator):
impl = sa.TIMESTAMP

cache_ok = True

def process_result_value(self, value, dialect):
return None if value is None else timestamp_parse(value)


class DruidBinary(sa.LargeBinary):
def result_processor(self, dialect, coltype):
def process(value):
return None if value is None else value.encode("utf-8")

return process


class DruidString(sat.TypeDecorator):
impl = sa.String

cache_ok = True

def process_result_value(self, value, dialect):
return value


@compiles(sa.BIGINT, "druid")
@compiles(sa.BigInteger, "druid")
def _bigint(element, compiler, **kw):
Expand Down Expand Up @@ -47,3 +76,29 @@ def parse(text: str) -> dt.DataType:

ty.become(primitive | array | json)
return ty.parse(text)


class DruidType(AlchemyType):
dialect = "hive"

@classmethod
def to_ibis(cls, typ, nullable=True):
if isinstance(typ, DruidDateTime):
return dt.Timestamp(nullable=nullable)
elif isinstance(typ, DruidBinary):
return dt.Binary(nullable=nullable)
elif isinstance(typ, DruidString):
return dt.String(nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_ibis(cls, dtype):
if dtype.is_timestamp():
return DruidDateTime()
elif dtype.is_binary():
return DruidBinary()
elif dtype.is_string():
return DruidString()
else:
return super().from_ibis(dtype)
42 changes: 13 additions & 29 deletions ibis/backends/druid/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import chain, repeat
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import pytest
from requests import Session

import ibis
from ibis.backends.tests.base import (
RoundHalfToEven,
ServiceBackendTest,
ServiceSpec,
)
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -102,17 +98,14 @@ class TestConf(ServiceBackendTest, RoundHalfToEven):
native_bool = True
supports_structs = False
supports_json = False # it does, but we haven't implemented it
service_name = "druid-middlemanager"
deps = ("pydruid.db.sqlalchemy",)

@classmethod
def service_spec(cls, data_dir: Path) -> ServiceSpec:
return ServiceSpec(
name="druid-middlemanager",
data_volume="/data",
files=data_dir.joinpath("parquet").glob("*.parquet"),
)
@property
def test_files(self) -> Iterable[Path]:
return self.data_dir.joinpath("parquet").glob("*.parquet")

@staticmethod
def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
def _load_data(self, **_: Any) -> None:
"""Load test data into a druid backend instance.
Parameters
Expand All @@ -122,28 +115,19 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
script_dir
Location of scripts defining schemas
"""
# copy data into the volume mount
queries = filter(
None,
map(
str.strip,
(script_dir / "schema" / "druid.sql").read_text().split(";"),
),
)

# run queries concurrently using threads; lots of time is spent on IO
# making requests to check whether data loading is complete
with Session() as session, ThreadPoolExecutor() as executor:
for fut in as_completed(
executor.submit(run_query, session, query) for query in queries
executor.submit(run_query, session, query) for query in self.ddl_script
):
fut.result()

@staticmethod
def connect(_: Path):
return ibis.connect(DRUID_URL)
def connect(*, tmpdir, worker_id, **kw):
return ibis.connect(DRUID_URL, **kw)


@pytest.fixture(scope='session')
def con():
return ibis.connect(DRUID_URL)
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection
99 changes: 73 additions & 26 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import CanCreateSchema
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.duckdb.compiler import DuckDBSQLCompiler
from ibis.backends.duckdb.datatypes import DuckDBType, parse
Expand All @@ -53,9 +54,10 @@ def _format_kwargs(kwargs: Mapping[str, Any]):
bindparams, pieces = [], []
for name, value in kwargs.items():
bindparam = sa.bindparam(name, value)
if not isinstance(
bindparam.type, sa.sql.sqltypes.NullType
): # the parameter type is not null
if isinstance(paramtype := bindparam.type, sa.String):
# special case strings to avoid double escaping backslashes
pieces.append(f"{name} = '{value!s}'")
elif not isinstance(paramtype, sa.types.NullType):
bindparams.append(bindparam)
pieces.append(f"{name} = :{name}")
else: # fallback to string strategy
Expand All @@ -70,17 +72,51 @@ def _format_kwargs(kwargs: Mapping[str, Any]):
}


class Backend(BaseAlchemyBackend):
class Backend(BaseAlchemyBackend, CanCreateSchema):
name = "duckdb"
compiler = DuckDBSQLCompiler
supports_create_or_replace = True

@property
def current_database(self) -> str:
return "main"
return self._scalar_query(sa.select(sa.func.current_database()))

def list_databases(self, like: str | None = None) -> list[str]:
s = sa.table(
"schemata",
sa.column("catalog_name", sa.TEXT()),
schema="information_schema",
)

query = sa.select(sa.distinct(s.c.catalog_name)).order_by(s.c.catalog_name)
with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.func.current_schema()))

def list_schemas(self, like: str | None = None) -> list[str]:
s = sa.table(
"schemata",
sa.column("catalog_name", sa.TEXT()),
sa.column("schema_name", sa.TEXT()),
schema="information_schema",
)

query = (
sa.select(s.c.schema_name)
.where(s.c.catalog_name == sa.func.current_database())
.order_by(s.c.schema_name)
)
with self.begin() as con:
results = list(con.execute(query).scalars())
return self._filter_with_like(results, like=like)

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
read_only = kwargs.pop("read_only", "False").capitalize()
read_only = str(kwargs.pop("read_only", "False")).capitalize()
try:
kwargs["read_only"] = ast.literal_eval(read_only)
except ValueError as e:
Expand Down Expand Up @@ -118,7 +154,7 @@ def column_reflect(inspector, table, column_info):
engine = inspector.engine
colname = column_info["name"]
if (coltype := complex_type_info_cache.get(colname)) is None:
quote = engine.dialect.identifier_preparer.quote_identifier
quote = engine.dialect.identifier_preparer.quote
quoted_colname = quote(colname)
quoted_tablename = quote(table.name)
with engine.connect() as con:
Expand Down Expand Up @@ -223,6 +259,30 @@ def _load_extensions(self, extensions):
c.install_extension(extension)
c.load_extension(extension)

def create_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
if database is not None:
raise exc.UnsupportedOperationError(
"DuckDB cannot create a schema in another database."
)
name = self._quote(name)
if_not_exists = "IF NOT EXISTS " * force
with self.begin() as con:
con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}")

def drop_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
if database is not None:
raise exc.UnsupportedOperationError(
"DuckDB cannot drop a schema in another database."
)
name = self._quote(name)
if_exists = "IF EXISTS " * force
with self.begin() as con:
con.exec_driver_sql(f"DROP SCHEMA {if_exists}{name}")

def register(
self,
source: str | Path | Any,
Expand Down Expand Up @@ -735,19 +795,7 @@ def to_pyarrow(
cursor = con.execute(sql)
table = cursor.cursor.fetch_arrow_table()

if isinstance(expr, ir.Table):
return table
elif isinstance(expr, ir.Column):
# Column will be a ChunkedArray, `combine_chunks` will
# flatten it
if len(table.columns[0]):
return table.columns[0].combine_chunks()
else:
return pa.array(table.columns[0])
elif isinstance(expr, ir.Scalar):
return table.columns[0][0]
else:
raise ValueError
return expr.__pyarrow_result__(table)

@util.experimental
def to_torch(
Expand Down Expand Up @@ -825,6 +873,7 @@ def to_parquet(
>>> # partition on multiple columns
>>> con.to_parquet(penguins, "penguins_hive_dir", partition_by=("year", "island")) # doctest: +SKIP
"""
self._run_pre_execute_hooks(expr)
query = self._to_sql(expr, params=params)
args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items())]
copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})"
Expand Down Expand Up @@ -859,6 +908,7 @@ def to_csv(
**kwargs
DuckDB CSV writer arguments. https://duckdb.org/docs/data/csv.html#parameters
"""
self._run_pre_execute_hooks(expr)
query = self._to_sql(expr, params=params)
args = [
"FORMAT 'csv'",
Expand Down Expand Up @@ -971,9 +1021,7 @@ def _register_udfs(self, expr: ir.Expr) -> None:
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
with contextlib.suppress(duckdb.InvalidInputException):
con.connection.driver_connection.remove_function(
udf_node.__class__.__name__
)
con.connection.remove_function(udf_node.__class__.__name__)

registration_func = compile_func(udf_node)
registration_func(con)
Expand All @@ -985,7 +1033,7 @@ def _compile_udf(self, udf_node: ops.ScalarUDF) -> None:
output_type = DuckDBType.to_string(udf_node.output_dtype)

def register_udf(con):
return con.connection.driver_connection.create_function(
return con.connection.create_function(
name,
func,
input_types,
Expand Down Expand Up @@ -1013,8 +1061,7 @@ def _insert_dataframe(
columns = list(df.columns)
t = sa.table(table_name, *map(sa.column, columns))

quote = self.con.dialect.identifier_preparer.quote
table_name = quote(table_name)
table_name = self._quote(table_name)

# the table name df here matters, and *must* match the input variable's
# name because duckdb will look up this name in the outer scope of the
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles

import ibis.backends.base.sql.alchemy.datatypes as sat
Expand All @@ -13,7 +14,9 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
_registry = operation_registry
_rewrites = AlchemyExprTranslator._rewrites.copy()
_has_reduction_filter_syntax = True
_supports_tuple_syntax = True
_dialect_name = "duckdb"

type_mapper = DuckDBType


Expand All @@ -39,7 +42,12 @@ def compile_uint(element, compiler, **kw):

@compiles(sat.ArrayType, "duckdb")
def compile_array(element, compiler, **kw):
return f"{compiler.process(element.value_type, **kw)}[]"
if isinstance(value_type := element.value_type, sa.types.NullType):
# duckdb infers empty arrays with no other context as array<int32>
typ = "INTEGER"
else:
typ = compiler.process(value_type, **kw)
return f"{typ}[]"


rewrites = DuckDBSQLExprTranslator.rewrites
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
"""Parse a DuckDB type into an ibis data type."""
primitive = (
spaceless_string("interval").result(dt.Interval('us'))
| spaceless_string("hugeint", "int128").result(dt.Decimal(38, 0))
| spaceless_string("bigint", "int8", "long").result(dt.int64)
| spaceless_string("boolean", "bool", "logical").result(dt.boolean)
| spaceless_string("blob", "bytea", "binary", "varbinary").result(dt.binary)
Expand Down Expand Up @@ -101,9 +102,14 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
ducktypes.SmallInteger: dt.Int16,
ducktypes.Integer: dt.Int32,
ducktypes.BigInteger: dt.Int64,
ducktypes.HugeInteger: dt.Decimal(38, 0),
ducktypes.UInt8: dt.UInt8,
ducktypes.UTinyInteger: dt.UInt8,
ducktypes.UInt16: dt.UInt16,
ducktypes.USmallInteger: dt.UInt16,
ducktypes.UInt32: dt.UInt32,
ducktypes.UInteger: dt.UInt32,
ducktypes.UInt64: dt.UInt64,
ducktypes.UBigInteger: dt.UInt64,
}

Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ def _try_cast(t, op):
)
),
ops.TryCast: _try_cast,
ops.ArrayConcat: fixed_arity(sa.func.array_concat, 2),
ops.ArrayRepeat: fixed_arity(
lambda arg, times: sa.func.flatten(
sa.func.array(
Expand Down Expand Up @@ -413,7 +412,7 @@ def _try_cast(t, op):
ops.RegexReplace: fixed_arity(
lambda *args: sa.func.regexp_replace(*args, sa.text("'g'")), 3
),
ops.RegexSearch: fixed_arity(lambda x, y: x.op("SIMILAR TO")(y), 2),
ops.RegexSearch: fixed_arity(sa.func.regexp_matches, 2),
ops.StringContains: fixed_arity(sa.func.contains, 2),
ops.ApproxMedian: reduction(
# without inline text, duckdb fails with
Expand Down
61 changes: 25 additions & 36 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Iterator

import pytest

import ibis
from ibis.backends.conftest import SANDBOXED, TEST_TABLES
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.conftest import SANDBOXED

if TYPE_CHECKING:
from pathlib import Path

from ibis.backends.base import BaseBackend


class TestConf(BackendTest, RoundAwayFromZero):
supports_map = True
deps = "duckdb", "duckdb_engine"
stateful = False

def __init__(self, data_directory: Path, **kwargs: Any) -> None:
self.connection = self.connect(data_directory, **kwargs)

def preload(self):
if not SANDBOXED:
self.connection._load_extensions(
["httpfs", "postgres_scanner", "sqlite_scanner"]
)

script_dir = data_directory.parent
schema = script_dir.joinpath("schema", "duckdb.sql").read_text()

with self.connection.begin() as con:
for stmt in filter(None, map(str.strip, schema.split(";"))):
con.exec_driver_sql(stmt)

for table in TEST_TABLES:
src = data_directory / "csv" / f"{table}.csv"
con.exec_driver_sql(
f"COPY {table} FROM {str(src)!r} (DELIMITER ',', HEADER)"
)

@staticmethod
def _load_data(data_dir, script_dir, **_: Any) -> None:
"""Load test data into a DuckDB backend instance.
Parameters
----------
data_dir
Location of test data
"""
return TestConf(data_directory=data_dir)
@property
def ddl_script(self) -> Iterator[str]:
parquet_dir = self.data_dir / "parquet"
for table in TEST_TABLES:
yield (
f"""
CREATE OR REPLACE TABLE {table} AS
SELECT * FROM read_parquet('{parquet_dir / f'{table}.parquet'}')
"""
)
yield from super().ddl_script

@staticmethod
def connect(data_directory: Path, **kwargs: Any) -> BaseBackend:
pytest.importorskip("duckdb")
return ibis.duckdb.connect(**kwargs) # type: ignore
def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
# extension directory per test worker to prevent simultaneous downloads
return ibis.duckdb.connect(
extension_directory=str(tmpdir.mktemp(f"{worker_id}_exts")), **kw
)


@pytest.fixture
def con(data_directory, tmp_path: Path):
return TestConf(data_directory, extension_directory=str(tmp_path)).connection
@pytest.fixture(scope="session")
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection
31 changes: 31 additions & 0 deletions ibis/backends/duckdb/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import duckdb_engine
import pytest
import sqlalchemy as sa
Expand Down Expand Up @@ -78,6 +80,8 @@
("TIMESTAMP_US", dt.Timestamp("UTC", scale=6)),
("TIMESTAMP_NS", dt.Timestamp("UTC", scale=9)),
("JSON", dt.json),
("HUGEINT", dt.Decimal(38, 0)),
("INT128", dt.Decimal(38, 0)),
]
],
)
Expand Down Expand Up @@ -131,3 +135,30 @@ def test_generate_quoted_struct():
result = typ.compile(dialect=duckdb_engine.Dialect())
expected = 'STRUCT("in come" TEXT, "my count" BIGINT, thing INTEGER)'
assert result == expected


@pytest.mark.xfail(
condition=vparse(duckdb_engine.__version__) < vparse("0.9.2"),
raises=AssertionError,
reason="mapping from UINTEGER query metadata fixed in 0.9.2",
)
def test_read_uint8_from_parquet(tmp_path):
import numpy as np

import ibis

con = ibis.duckdb.connect()

# There is an incorrect mapping in duckdb-engine from UInteger -> UInt8
# In order to get something that reads as a UInt8, we cast to UInt32 (UInteger)
t = ibis.memtable({"a": np.array([1, 2, 3, 4], dtype="uint32")})
assert t.a.type() == dt.uint32

parqpath = tmp_path / "uint.parquet"

con.to_parquet(t, parqpath)

# If this doesn't fail, then things are working
t2 = con.read_parquet(parqpath)

assert t2.schema() == t.schema()
32 changes: 21 additions & 11 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import sqlite3
import tempfile
Expand All @@ -13,17 +15,17 @@
import ibis
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis.backends.conftest import LINUX, SANDBOXED
from ibis.conftest import LINUX, SANDBOXED


def test_read_csv(data_directory):
t = ibis.read_csv(data_directory / "csv" / "functional_alltypes.csv")
def test_read_csv(data_dir):
t = ibis.read_csv(data_dir / "csv" / "functional_alltypes.csv")
assert t.count().execute()


def test_read_csv_with_columns(data_directory):
def test_read_csv_with_columns(data_dir):
t = ibis.read_csv(
data_directory / "csv" / "awards_players.csv",
data_dir / "csv" / "awards_players.csv",
header=True,
columns={
'playerID': 'VARCHAR',
Expand All @@ -38,16 +40,16 @@ def test_read_csv_with_columns(data_directory):
assert t.count().execute()


def test_read_parquet(data_directory):
t = ibis.read_parquet(data_directory / "parquet" / "functional_alltypes.parquet")
def test_read_parquet(data_dir):
t = ibis.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet")
assert t.count().execute()


@pytest.mark.xfail_version(
duckdb=["duckdb<0.7.0"], reason="read_json_auto doesn't exist", raises=exc.IbisError
)
def test_read_json(data_directory, tmp_path):
pqt = ibis.read_parquet(data_directory / "parquet" / "functional_alltypes.parquet")
def test_read_json(data_dir, tmp_path):
pqt = ibis.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet")

path = tmp_path.joinpath("ft.json")
path.write_text(pqt.execute().to_json(orient="records", lines=True))
Expand Down Expand Up @@ -159,13 +161,13 @@ def test_register_sqlite(con, tmp_path):
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
raises=duckdb.IOException,
)
def test_attach_sqlite(data_directory, tmp_path):
def test_attach_sqlite(data_dir, tmp_path):
import sqlite3

test_db_path = tmp_path / "test.db"
with sqlite3.connect(test_db_path) as scon:
for line in (
Path(data_directory.parent / "schema" / "sqlite.sql").read_text().split(";")
Path(data_dir.parent / "schema" / "sqlite.sql").read_text().split(";")
):
scon.execute(line)

Expand Down Expand Up @@ -331,3 +333,11 @@ def test_register_recordbatchreader_warns(con):
t = con.read_in_memory(reader, table_name=t.get_name())
res = t.execute()
tm.assert_frame_equal(res, sol)


def test_csv_with_slash_n_null(con, tmp_path):
data_path = tmp_path / "data.csv"
data_path.write_text("a\n1\n3\n\\N\n")
t = con.read_csv(data_path, nullstr="\\N")
col = t.a.execute()
assert pd.isna(col.iat[-1])
24 changes: 15 additions & 9 deletions ibis/backends/flink/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.compiler import DDL, DML, Compiler, Select, SelectBuilder
from ibis.backends.base.sql.compiler import Compiler, Select, SelectBuilder
from ibis.backends.flink.translator import FlinkExprTranslator
from ibis.backends.flink.utils import translate_literal


class FlinkSelectBuilder(SelectBuilder):
Expand Down Expand Up @@ -45,16 +46,21 @@ class FlinkCompiler(Compiler):


def translate(op: ops.TableNode) -> str:
# TODO(chloeh13q): support translation of non-select exprs (e.g. literals)
ast = FlinkCompiler.to_ast(op)
return translate_language(ast.queries[0])
return translate_op(op)


@functools.singledispatch
def translate_language(language: DML | DDL) -> str:
raise com.OperationNotDefinedError(f'No translation rule for {type(language)}')
def translate_op(op: ops.TableNode) -> str:
raise com.OperationNotDefinedError(f'No translation rule for {type(op)}')


@translate_language.register(Select)
def _select(language: Select):
return language.compile()
@translate_op.register(ops.Literal)
def _literal(op: ops.Literal) -> str:
return translate_literal(op)


@translate_op.register(ops.Selection)
@translate_op.register(ops.Aggregation)
@translate_op.register(ops.Limit)
def _(op: ops.Selection | ops.Aggregation | ops.Limit) -> str:
return FlinkCompiler.to_sql(op) # to_sql uses to_ast, which builds a select tree
170 changes: 167 additions & 3 deletions ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.compiler import ExprTranslator
from ibis.backends.base.sql.registry import helpers
from ibis.backends.base.sql.registry import helpers, window
from ibis.backends.base.sql.registry import (
operation_registry as base_operation_registry,
)
from ibis.backends.flink.utils import translate_literal
from ibis.common.temporal import TimestampUnit

if TYPE_CHECKING:
from ibis.backends.base.sql.compiler import ExprTranslator

operation_registry = base_operation_registry.copy()


def _count_star(translator: ExprTranslator, op: ops.Node) -> str:
return "count(*)"


def _timestamp_from_unix(translator, op):
def _timestamp_from_unix(translator: ExprTranslator, op: ops.Node) -> str:
arg, unit = op.args

if unit == TimestampUnit.MILLISECOND:
Expand All @@ -23,9 +31,165 @@ def _timestamp_from_unix(translator, op):
raise ValueError(f"{unit!r} unit is not supported!")


def _extract_field(sql_attr: str) -> str:
def extract_field_formatter(translator: ExprTranslator, op: ops.Node) -> str:
arg = translator.translate(op.args[0])
if sql_attr == "epochseconds":
return f"UNIX_SECONDS({arg})"
else:
return f"EXTRACT({sql_attr} from {arg})"

return extract_field_formatter


def _filter(translator: ExprTranslator, op: ops.Node) -> str:
bool_expr = translator.translate(op.bool_expr)
true_expr = translator.translate(op.true_expr)
false_null_expr = translator.translate(op.false_null_expr)

# [TODO](chloeh13q): It's preferable to use the FILTER syntax instead of CASE WHEN
# to let the planner do more optimizations to reduce the state size; besides, FILTER
# is more compliant with the SQL standard.
# For example,
# ```
# COUNT(DISTINCT CASE WHEN flag = 'app' THEN user_id ELSE NULL END) AS app_uv
# ```
# is equivalent to
# ```
# COUNT(DISTINCT) FILTER (WHERE flag = 'app') AS app_uv
# ```
return f"CASE WHEN {bool_expr} THEN {true_expr} ELSE {false_null_expr} END"


def _literal(translator: ExprTranslator, op: ops.Literal) -> str:
return translate_literal(op)


def _format_window_start(translator: ExprTranslator, boundary):
if boundary is None:
return 'UNBOUNDED PRECEDING'

if isinstance(boundary.value, ops.Literal) and boundary.value.value == 0:
return "CURRENT ROW"

value = translator.translate(boundary.value)
return f'{value} PRECEDING'


def _format_window_end(translator: ExprTranslator, boundary):
if boundary is None:
raise com.UnsupportedOperationError(
"OVER RANGE FOLLOWING windows are not supported in Flink yet"
)

value = boundary.value
if isinstance(value, ops.Cast):
value = boundary.value.arg
if isinstance(value, ops.Literal):
if value.value != 0:
raise com.UnsupportedOperationError(
"OVER RANGE FOLLOWING windows are not supported in Flink yet"
)

return "CURRENT ROW"


def _format_window_frame(translator: ExprTranslator, func, frame):
components = []

if frame.group_by:
partition_args = ', '.join(map(translator.translate, frame.group_by))
components.append(f'PARTITION BY {partition_args}')

(order_by,) = frame.order_by
if order_by.descending is True:
raise com.UnsupportedOperationError(
"Flink only supports windows ordered in ASCENDING mode"
)
components.append(f'ORDER BY {translator.translate(order_by)}')

if frame.start is None and frame.end is None:
# no-op, default is full sample
pass
elif not isinstance(func, translator._forbids_frame_clause):
# [NOTE] Flink allows
# "ROWS BETWEEN INTERVAL [...] PRECEDING AND CURRENT ROW"
# but not
# "RANGE BETWEEN [...] PRECEDING AND CURRENT ROW",
# but `.over(rows=(-ibis.interval(...), 0)` is not allowed in Ibis
if isinstance(frame, ops.RangeWindowFrame):
if not frame.start.value.output_dtype.is_interval():
# [TODO] need to expand support for range-based interval windowing on expr
# side, for now only ibis intervals can be used
raise com.UnsupportedOperationError(
"Data Type mismatch between ORDER BY and RANGE clause"
)

start = _format_window_start(translator, frame.start)
end = _format_window_end(translator, frame.end)

frame = f'{frame.how.upper()} BETWEEN {start} AND {end}'
components.append(frame)

return 'OVER ({})'.format(' '.join(components))


def _window(translator: ExprTranslator, op: ops.Node) -> str:
frame = op.frame
if not frame.order_by:
raise com.UnsupportedOperationError(
"Flink engine does not support generic window clause with no order by"
)
if len(frame.order_by) > 1:
raise com.UnsupportedOperationError(
"Windows in Flink can only be ordered by a single time column"
)

_unsupported_reductions = translator._unsupported_reductions

func = op.func.__window_op__

if isinstance(func, _unsupported_reductions):
raise com.UnsupportedOperationError(
f'{type(func)} is not supported in window functions'
)

if isinstance(func, ops.CumulativeOp):
arg = window.cumulative_to_window(translator, func, op.frame)
return translator.translate(arg)

if isinstance(frame, ops.RowsWindowFrame):
if frame.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented for SQL-based backends.'
)

window_formatted = _format_window_frame(translator, func, frame)

arg_formatted = translator.translate(func.__window_op__)
result = f'{arg_formatted} {window_formatted}'

if isinstance(func, ops.RankBase):
return f'({result} - 1)'
else:
return result


operation_registry.update(
{
ops.CountStar: _count_star,
ops.ExtractYear: _extract_field("year"), # equivalent to YEAR(date)
ops.ExtractQuarter: _extract_field("quarter"), # equivalent to QUARTER(date)
ops.ExtractMonth: _extract_field("month"), # equivalent to MONTH(date)
ops.ExtractWeekOfYear: _extract_field("week"), # equivalent to WEEK(date)
ops.ExtractDayOfYear: _extract_field("doy"), # equivalent to DAYOFYEAR(date)
ops.ExtractDay: _extract_field("day"), # equivalent to DAYOFMONTH(date)
ops.ExtractHour: _extract_field("hour"), # equivalent to HOUR(timestamp)
ops.ExtractMinute: _extract_field("minute"), # equivalent to MINUTE(timestamp)
ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp)
ops.Literal: _literal,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.Where: _filter,
ops.Window: _window,
}
)
24 changes: 24 additions & 0 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.types as ir
from ibis.backends.conftest import TEST_TABLES


@pytest.fixture
def simple_schema():
return [
('a', 'int8'),
('b', 'int16'),
('c', 'int32'),
('d', 'int64'),
('e', 'float32'),
('f', 'float64'),
('g', 'string'),
('h', 'boolean'),
('i', 'timestamp'),
('j', 'date'),
('k', 'time'),
]


@pytest.fixture
def simple_table(simple_schema):
return ibis.table(simple_schema, name='table')


@pytest.fixture
def batting() -> ir.Table:
return ibis.table(schema=TEST_TABLES["batting"], name="batting")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TIMESTAMP '2017-01-01 04:55:59'
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TIMESTAMP '2017-01-01 04:55:59.001122'
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TIME '04:55:59'
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TIMESTAMP '2017-01-01 04:55:59'
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TIME '04:55:59'
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TIMESTAMP '2017-01-01 04:55:59'
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT t0.`b`, count(*) AS `total`, avg(t0.`a`) AS `avg_a`,
avg(CASE WHEN t0.`g` = 'A' THEN t0.`a` ELSE NULL END) AS `avg_a_A`,
avg(CASE WHEN t0.`g` = 'B' THEN t0.`a` ELSE NULL END) AS `avg_a_B`
FROM table t0
GROUP BY t0.`b`
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT EXTRACT(year from t0.`i`) AS `year`,
EXTRACT(month from t0.`i`) AS `month`, count(*) AS `total`,
count(DISTINCT t0.`b`) AS `b_unique`
FROM table t0
GROUP BY EXTRACT(year from t0.`i`), EXTRACT(month from t0.`i`)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT t0.`a`, avg(abs(t0.`the_sum`)) AS `mad`
FROM (
SELECT t1.`a`, t1.`c`, sum(t1.`b`) AS `the_sum`
FROM table t1
GROUP BY t1.`a`, t1.`c`
) t0
GROUP BY t0.`a`
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(day from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(doy from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(hour from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(minute from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(month from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(quarter from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(second from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(week from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT EXTRACT(year from t0.`i`) AS `tmp`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT t0.*
FROM table t0
WHERE ((t0.`c` > 0) OR (t0.`c` < 0)) AND
(t0.`g` IN ('A', 'B'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT t0.`g`, sum(t0.`b`) AS `b_sum`
FROM table t0
GROUP BY t0.`g`
HAVING count(*) >= 1000
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT count(DISTINCT CASE WHEN t0.`g` = 'A' THEN t0.`b` ELSE NULL END) AS `CountDistinct(b, Equals(g, 'A'))`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT t0.`ExtractYear(i)`, count(*) AS `ExtractYear(i)_count`
FROM (
SELECT EXTRACT(year from t1.`i`) AS `ExtractYear(i)`
FROM table t1
) t0
GROUP BY t0.`ExtractYear(i)`
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC RANGE BETWEEN INTERVAL '00 08:20:00.000000' DAY TO SECOND PRECEDING AND CURRENT ROW) AS `Sum(f)`
FROM table t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN 1000 PRECEDING AND CURRENT ROW) AS `Sum(f)`
FROM table t0
2 changes: 2 additions & 0 deletions ibis/backends/flink/tests/test_join.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest

from ibis.backends.flink.compiler.core import translate
Expand Down
80 changes: 80 additions & 0 deletions ibis/backends/flink/tests/test_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import datetime

import pandas as pd
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis.backends.flink.compiler.core import translate


@pytest.mark.parametrize(
"value,expected",
[
param(5, "5", id="int"),
param(1.5, "1.5", id="float"),
param(True, "TRUE", id="true"),
param(False, "FALSE", id="false"),
],
)
def test_simple_literals(value, expected):
expr = ibis.literal(value)
result = translate(expr.op())
assert result == expected


@pytest.mark.parametrize(
"value,expected",
[
param("simple", "'simple'", id="simple"),
param("I can't", "'I can''t'", id="nested_quote"),
param('An "escape"', """'An "escape"'""", id="nested_token"),
],
)
def test_string_literals(value, expected):
expr = ibis.literal(value)
result = translate(expr.op())
assert result == expected


@pytest.mark.parametrize(
"value,expected",
[
param(
datetime.timedelta(seconds=70),
"INTERVAL '00 00:01:10.000000' DAY TO SECOND",
id="70seconds",
),
param(
ibis.interval(months=50), "INTERVAL '04-02' YEAR TO MONTH", id="50months"
),
],
)
def test_translate_interval_literal(value, expected):
expr = ibis.literal(value)
result = translate(expr.op())
assert result == expected


@pytest.mark.parametrize(
("case", "dtype"),
[
param(datetime.datetime(2017, 1, 1, 4, 55, 59), dt.timestamp, id="datetime"),
param(
datetime.datetime(2017, 1, 1, 4, 55, 59, 1122),
dt.timestamp,
id="datetime_with_microseconds",
),
param("2017-01-01 04:55:59", dt.timestamp, id="string_timestamp"),
param(pd.Timestamp("2017-01-01 04:55:59"), dt.timestamp, id="timestamp"),
param(datetime.time(4, 55, 59), dt.time, id="time"),
param("04:55:59", dt.time, id="string_time"),
],
)
def test_literal_timestamp_or_time(snapshot, case, dtype):
expr = ibis.literal(case, type=dtype)
result = translate(expr.op())
snapshot.assert_match(result, "out.sql")
119 changes: 90 additions & 29 deletions ibis/backends/flink/tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,19 @@
from __future__ import annotations

import pytest
from pytest import param

import ibis
from ibis.backends.flink.compiler.core import translate


@pytest.fixture
def schema():
return [
('a', 'int8'),
('b', 'int16'),
('c', 'int32'),
('d', 'int64'),
('e', 'float32'),
('f', 'float64'),
('g', 'string'),
('h', 'boolean'),
('i', 'timestamp'),
('j', 'date'),
('k', 'time'),
]


@pytest.fixture
def table(schema):
return ibis.table(schema, name='table')


def test_translate_sum(snapshot, table):
expr = table.a.sum()
def test_translate_sum(snapshot, simple_table):
expr = simple_table.a.sum()
result = translate(expr.as_table().op())
snapshot.assert_match(str(result), "out.sql")


def test_translate_count_star(snapshot, table):
expr = table.group_by(table.i).size()
def test_translate_count_star(snapshot, simple_table):
expr = simple_table.group_by(simple_table.i).size()
result = translate(expr.as_table().op())
snapshot.assert_match(str(result), "out.sql")

Expand All @@ -46,7 +25,89 @@ def test_translate_count_star(snapshot, table):
param("s", id="timestamp_s"),
],
)
def test_translate_timestamp_from_unix(snapshot, table, unit):
expr = table.d.to_timestamp(unit=unit)
def test_translate_timestamp_from_unix(snapshot, simple_table, unit):
expr = simple_table.d.to_timestamp(unit=unit)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_complex_projections(snapshot, simple_table):
expr = (
simple_table.group_by(['a', 'c'])
.aggregate(the_sum=simple_table.b.sum())
.group_by('a')
.aggregate(mad=lambda x: x.the_sum.abs().mean())
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_filter(snapshot, simple_table):
expr = simple_table[
((simple_table.c > 0) | (simple_table.c < 0)) & simple_table.g.isin(['A', 'B'])
]
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


@pytest.mark.parametrize(
"kind",
[
"year",
"quarter",
"month",
"week_of_year",
"day_of_year",
"day",
"hour",
"minute",
"second",
],
)
def test_translate_extract_fields(snapshot, simple_table, kind):
expr = getattr(simple_table.i, kind)().name("tmp")
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_complex_groupby_aggregation(snapshot, simple_table):
keys = [simple_table.i.year().name('year'), simple_table.i.month().name('month')]
b_unique = simple_table.b.nunique()
expr = simple_table.group_by(keys).aggregate(
total=simple_table.count(), b_unique=b_unique
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_simple_filtered_agg(snapshot, simple_table):
expr = simple_table.b.nunique(where=simple_table.g == 'A')
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_complex_filtered_agg(snapshot, simple_table):
expr = simple_table.group_by('b').aggregate(
total=simple_table.count(),
avg_a=simple_table.a.mean(),
avg_a_A=simple_table.a.mean(where=simple_table.g == 'A'),
avg_a_B=simple_table.a.mean(where=simple_table.g == 'B'),
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_value_counts(snapshot, simple_table):
expr = simple_table.i.year().value_counts()
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_translate_having(snapshot, simple_table):
expr = (
simple_table.group_by('g')
.having(simple_table.count() >= 1000)
.aggregate(simple_table.b.sum().name('b_sum'))
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")
88 changes: 88 additions & 0 deletions ibis/backends/flink/tests/test_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import pytest
from pytest import param

import ibis
from ibis.backends.flink.compiler.core import translate
from ibis.common.exceptions import UnsupportedOperationError


def test_window_requires_order_by(simple_table):
expr = simple_table.mutate(simple_table.c - simple_table.c.mean())
with pytest.raises(
UnsupportedOperationError,
match="Flink engine does not support generic window clause with no order by",
):
translate(expr.as_table().op())


def test_window_does_not_support_multiple_order_by(simple_table):
expr = simple_table.f.sum().over(
rows=(-1, 1),
group_by=[simple_table.g, simple_table.a],
order_by=[simple_table.f, simple_table.d],
)
with pytest.raises(
UnsupportedOperationError,
match="Windows in Flink can only be ordered by a single time column",
):
translate(expr.as_table().op())


def test_window_does_not_support_desc_order(simple_table):
expr = simple_table.f.sum().over(
rows=(-1, 1),
group_by=[simple_table.g, simple_table.a],
order_by=[simple_table.f.desc()],
)
with pytest.raises(
UnsupportedOperationError,
match="Flink only supports windows ordered in ASCENDING mode",
):
translate(expr.as_table().op())


@pytest.mark.parametrize(
("window", "err"),
[
param(
{"rows": (-1, 1)},
"OVER RANGE FOLLOWING windows are not supported in Flink yet",
id="bounded_rows_following",
),
param(
{"rows": (-1, None)},
"OVER RANGE FOLLOWING windows are not supported in Flink yet",
id="unbounded_rows_following",
),
param(
{"rows": (-500, 1)},
"OVER RANGE FOLLOWING windows are not supported in Flink yet",
id="casted_bounded_rows_following",
),
param(
{"range": (-1000, 0)},
"Data Type mismatch between ORDER BY and RANGE clause",
id="int_range",
),
],
)
def test_window_invalid_start_end(simple_table, window, err):
expr = simple_table.f.sum().over(**window, order_by=simple_table.f)
with pytest.raises(UnsupportedOperationError, match=err):
translate(expr.as_table().op())


def test_range_window(snapshot, simple_table):
expr = simple_table.f.sum().over(
range=(-ibis.interval(minutes=500), 0), order_by=simple_table.f
)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")


def test_rows_window(snapshot, simple_table):
expr = simple_table.f.sum().over(rows=(-1000, 0), order_by=simple_table.f)
result = translate(expr.as_table().op())
snapshot.assert_match(result, "out.sql")
2 changes: 2 additions & 0 deletions ibis/backends/flink/translator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from ibis.backends.base.sql.compiler import ExprTranslator
from ibis.backends.flink.registry import operation_registry

Expand Down
Loading