Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
MOD(EXTRACT(dayofweek FROM DATE(2017, 1, 1)) + 5, 7) AS `DayOfWeekIndex_datetime_date_2017_ 1_ 1`
MOD(EXTRACT(dayofweek FROM DATE(2017, 1, 1)) + 5, 7) AS `DayOfWeekIndex_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
INITCAP(CAST(DATE(2017, 1, 1) AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_date_2017_ 1_ 1`
INITCAP(CAST(DATE(2017, 1, 1) AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
MOD(EXTRACT(dayofweek FROM datetime('2017-01-01T04:55:59')) + 5, 7) AS `DayOfWeekIndex_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
MOD(EXTRACT(dayofweek FROM datetime('2017-01-01T04:55:59')) + 5, 7) AS `DayOfWeekIndex_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
INITCAP(CAST(datetime('2017-01-01T04:55:59') AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
INITCAP(CAST(datetime('2017-01-01T04:55:59') AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
MOD(EXTRACT(dayofweek FROM datetime('2017-01-01T04:55:59')) + 5, 7) AS `DayOfWeekIndex_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
MOD(EXTRACT(dayofweek FROM datetime('2017-01-01T04:55:59')) + 5, 7) AS `DayOfWeekIndex_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
INITCAP(CAST(datetime('2017-01-01T04:55:59') AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
INITCAP(CAST(datetime('2017-01-01T04:55:59') AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
MOD(EXTRACT(dayofweek FROM DATE(2017, 1, 1)) + 5, 7) AS `DayOfWeekIndex_datetime_date_2017_ 1_ 1`
MOD(EXTRACT(dayofweek FROM DATE(2017, 1, 1)) + 5, 7) AS `DayOfWeekIndex_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
INITCAP(CAST(DATE(2017, 1, 1) AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_date_2017_ 1_ 1`
INITCAP(CAST(DATE(2017, 1, 1) AS STRING FORMAT 'DAY')) AS `DayOfWeekName_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
CAST(FLOOR(ieee_divide(`t0`.`double_col`, 0)) AS INT64) AS `FloorDivide_double_col_ 0`
CAST(FLOOR(ieee_divide(`t0`.`double_col`, 0)) AS INT64) AS `FloorDivide_double_col_0`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
ieee_divide(`t0`.`double_col`, 0) AS `Divide_double_col_ 0`
ieee_divide(`t0`.`double_col`, 0) AS `Divide_double_col_0`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
SHA(CAST('74657374' AS BYTES FORMAT 'HEX')) AS `tmp`
SHA1(CAST('74657374' AS BYTES FORMAT 'HEX')) AS `tmp`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
SHA('test') AS `tmp`
SHA1('test') AS `tmp`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1`
EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1`
EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_ 1_ 1_ 4_ 55_ 59`
EXTRACT(year FROM datetime('2017-01-01T04:55:59')) AS `ExtractYear_datetime_datetime_2017_1_1_4_55_59`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_ 1_ 1`
EXTRACT(year FROM DATE(2017, 1, 1)) AS `ExtractYear_datetime_date_2017_1_1`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
parse_timestamp('%F', `t0`.`date_string_col`, 'UTC') AS `StringToTimestamp_date_string_col_ '%F'`
parse_timestamp('%F', `t0`.`date_string_col`, 'UTC') AS `StringToTimestamp_date_string_col_'%F'`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
parse_timestamp('%F %Z', CONCAT(`t0`.`date_string_col`, ' America/New_York'), 'UTC') AS `StringToTimestamp_StringConcat_ '%F %Z'`
parse_timestamp('%F %Z', CONCAT(`t0`.`date_string_col`, ' America/New_York'), 'UTC') AS `StringToTimestamp_StringConcat_'%F %Z'`
FROM `functional_alltypes` AS `t0`
27 changes: 26 additions & 1 deletion ibis/backends/bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import pytest
import sqlglot as sg

from ibis.backends.bigquery import client
from ibis.backends.bigquery import _force_quote_table, client


@pytest.mark.parametrize(
Expand Down Expand Up @@ -30,3 +31,27 @@ def test_parse_project_and_dataset_raises_error():
expected_message = "data-project.my_dataset.table is not a BigQuery dataset"
with pytest.raises(ValueError, match=expected_message):
client.parse_project_and_dataset("my-project", "data-project.my_dataset.table")


@pytest.mark.parametrize(
"bq_path_str, expected",
[
("ibis-gbq.ibis_gbq_testing.argle", "`ibis-gbq`.`ibis_gbq_testing`.`argle`"),
(
"ibis-gbq.ibis_gbq_testing.28argle",
"`ibis-gbq`.`ibis_gbq_testing`.`28argle`",
),
("mytable-287a", "`mytable-287a`"),
("myproject.mydataset.my-table", "`myproject`.`mydataset`.`my-table`"),
("my-dataset.mytable", "`my-dataset`.`mytable`"),
(
"a-7b0a.dev_test_dataset.test_ibis5",
"`a-7b0a`.`dev_test_dataset`.`test_ibis5`",
),
],
)
def test_force_quoting(bq_path_str, expected):
table = sg.parse_one(bq_path_str, into=sg.exp.Table, read="bigquery")
table = _force_quote_table(table)

assert table.sql("bigquery") == expected
13 changes: 13 additions & 0 deletions ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import _
from ibis.backends.bigquery.compiler import BigQueryCompiler
from ibis.common.annotations import ValidationError

to_sql = ibis.bigquery.compile
Expand Down Expand Up @@ -633,3 +634,15 @@ def test_unnest(snapshot):
).select(level_two=lambda t: t.level_one.unnest())
)
snapshot.assert_match(result, "out_two_unnests.sql")


@pytest.mark.parametrize(
"fieldname, expected",
[
("TryCast(b, Float64)", "TryCast_b_Float64"),
("Cast(b, Int64)", "Cast_b_Int64"),
("that, is, a, lot, of, spaces", "that_is_a_lot_of_spaces"),
],
)
def test_field_names_strip_whitespace(fieldname, expected):
assert BigQueryCompiler._gen_valid_name(fieldname) == expected
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
farm_fingerprint(CAST('48656c6c6f2c20576f726c6421' AS BYTES FORMAT 'HEX')) AS `farm_fingerprint_0_b'Hello_ World_'`
farm_fingerprint(CAST('48656c6c6f2c20576f726c6421' AS BYTES FORMAT 'HEX')) AS `farm_fingerprint_0_b'Hello_World_'`
5 changes: 4 additions & 1 deletion ibis/backends/bigquery/udf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import inspect
import textwrap
from collections import ChainMap
from typing import Callable
from typing import TYPE_CHECKING

from ibis.backends.bigquery.udf.find import find_names
from ibis.backends.bigquery.udf.rewrite import rewrite

if TYPE_CHECKING:
from collections.abc import Callable


class SymbolTable(ChainMap):
"""ChainMap subclass implementing scope for the translator.
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/bigquery/udf/rewrite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import ast
from typing import Callable
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable


def matches(value: ast.AST, pattern: ast.AST) -> bool:
Expand Down
11 changes: 8 additions & 3 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pathlib import Path

import pandas as pd
import polars as pl


def _to_memtable(v):
Expand Down Expand Up @@ -422,8 +423,7 @@ def insert(
elif not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

query = sge.insert(self.compile(obj), into=name, dialect=self.name)

query = self._build_insert_query(target=name, source=obj)
external_tables = self._collect_in_memory_tables(obj, {})
external_data = self._normalize_external_tables(external_tables)
return self.con.command(query.sql(self.name), external_data=external_data)
Expand Down Expand Up @@ -586,7 +586,12 @@ def read_csv(
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand Down
51 changes: 24 additions & 27 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,38 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import ClickHouseType
from ibis.backends.sql.dialects import ClickHouse
from ibis.backends.sql.rewrites import rewrite_sample_as_filter
from ibis.expr.rewrites import rewrite_stringslice


class ClickhouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
# Clickhouse aggregate functions all have filtering variants with a
# `If` suffix (e.g. `SumIf` instead of `Sum`).
if where is not None:
name += "If"
args += (where,)
return compiler.f[name](*args, dialect=compiler.dialect)


class ClickHouseCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = ClickHouse
type_mapper = ClickHouseType
rewrites = (
rewrite_sample_as_filter,
rewrite_stringslice,
*SQLGlotCompiler.rewrites,
)

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.RowID,
ops.CumeDist,
ops.PercentRank,
ops.Time,
ops.TimeDelta,
ops.StringToTimestamp,
ops.StringToDate,
ops.Levenshtein,
)
agg = ClickhouseAggGen()

UNSUPPORTED_OPS = (
ops.RowID,
ops.CumeDist,
ops.PercentRank,
ops.Time,
ops.TimeDelta,
ops.StringToTimestamp,
ops.StringToDate,
ops.Levenshtein,
)

SIMPLE_OPS = {
Expand Down Expand Up @@ -81,6 +84,7 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ExtractSecond: "toSecond",
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.ExtractIsoYear: "toISOYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
Expand Down Expand Up @@ -112,13 +116,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.Unnest: "arrayJoin",
}

def _aggregate(self, funcname: str, *args, where):
has_filter = where is not None
func = self.f[funcname + "If" * has_filter]
args += (where,) * has_filter

return func(*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down Expand Up @@ -243,7 +240,7 @@ def visit_Sign(self, op, *, arg):
return self.f.intDivOrZero(arg, self.f.abs(arg))

def visit_Hash(self, op, *, arg):
return self.f.sipHash64(arg)
return self.f.reinterpretAsInt64(self.f.sipHash64(arg))

def visit_HashBytes(self, op, *, arg, how):
supported_algorithms = {
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import contextlib
import os
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

import pytest

Expand All @@ -12,7 +12,7 @@
from ibis.backends.tests.base import ServiceBackendTest

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from pathlib import Path

CLICKHOUSE_HOST = os.environ.get("IBIS_TEST_CLICKHOUSE_HOST", "localhost")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
sipHash64("t0"."string_col") AS "Hash(string_col)"
reinterpretAsInt64(sipHash64("t0"."string_col")) AS "Hash(string_col)"
FROM "functional_alltypes" AS "t0"
17 changes: 17 additions & 0 deletions ibis/backends/clickhouse/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import hypothesis as h
import hypothesis.strategies as st
import pytest
import sqlglot as sg
import sqlglot.expressions as sge
from packaging.version import parse as vparse
from pytest import param

import ibis
Expand Down Expand Up @@ -192,6 +194,11 @@ def test_array_discovery_clickhouse(con):
),
nullable=False,
),
marks=pytest.mark.xfail(
vparse(sg.__version__) == vparse("24.0.0"),
reason="struct parsing for clickhouse broken in sqlglot 24",
raises=sg.ParseError,
),
id="named_tuple",
),
param(
Expand All @@ -203,6 +210,11 @@ def test_array_discovery_clickhouse(con):
),
nullable=False,
),
marks=pytest.mark.xfail(
vparse("24.0.0") <= vparse(sg.__version__) <= vparse("24.0.1"),
reason="struct parsing for clickhouse broken in sqlglot 24",
raises=sg.ParseError,
),
id="unnamed_tuple",
),
param(
Expand All @@ -214,6 +226,11 @@ def test_array_discovery_clickhouse(con):
),
nullable=False,
),
marks=pytest.mark.xfail(
vparse("24.0.0") <= vparse(sg.__version__) <= vparse("24.0.1"),
reason="struct parsing for clickhouse broken in sqlglot 24",
raises=sg.ParseError,
),
id="partially_named",
),
param(
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/clickhouse/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_isnull_notnull(con, expr, expected):
("expr", "expected"),
[
(ibis.coalesce(5, None, 4), 5),
(ibis.coalesce(ibis.NA, 4, ibis.NA), 4),
(ibis.coalesce(ibis.NA, ibis.NA, 3.14), 3.14),
(ibis.coalesce(ibis.null(), 4, ibis.null()), 4),
(ibis.coalesce(ibis.null(), ibis.null(), 3.14), 3.14),
],
)
def test_coalesce(con, expr, expected):
Expand All @@ -127,13 +127,13 @@ def test_coalesce(con, expr, expected):
@pytest.mark.parametrize(
("expr", "expected"),
[
(ibis.NA.fillna(5), 5),
(L(5).fillna(10), 5),
(ibis.null().fill_null(5), 5),
(L(5).fill_null(10), 5),
(L(5).nullif(5), None),
(L(10).nullif(5), 10),
],
)
def test_fillna_nullif(con, expr, expected):
def test_fill_null_nullif(con, expr, expected):
result = con.execute(expr)
if expected is None:
assert pd.isnull(result)
Expand All @@ -150,7 +150,7 @@ def test_fillna_nullif(con, expr, expected):
(L(datetime(2015, 9, 1, hour=14, minute=48, second=5)), "DateTime"),
(L(date(2015, 9, 1)), "Date"),
param(
ibis.NA,
ibis.null(),
"Null",
marks=pytest.mark.xfail(
raises=AssertionError,
Expand Down Expand Up @@ -418,7 +418,7 @@ def test_numeric_builtins_work(alltypes, df):
def test_null_column(alltypes):
t = alltypes
nrows = t.count().execute()
expr = t.mutate(na_column=ibis.NA).na_column
expr = t.mutate(na_column=ibis.null()).na_column
result = expr.execute()
expected = pd.Series([None] * nrows, name="na_column")
tm.assert_series_equal(result, expected)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_count_name(assert_sql):
t = ibis.table(dict(a="string", b="bool"), name="t")

expr = t.group_by(t.a).agg(
A=t.count(where=~t.b).fillna(0), B=t.count(where=t.b).fillna(0)
A=t.count(where=~t.b).fill_null(0), B=t.count(where=t.b).fill_null(0)
)
assert_sql(expr)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _get_backend_from_parts(parts: tuple[str, ...]) -> str | None:
return parts[index + 1]


def pytest_ignore_collect(collection_path, path, config):
def pytest_ignore_collect(collection_path, config):
# get the backend path part
backend = _get_backend_from_parts(collection_path.parts)
if backend is None or backend not in _get_backend_names():
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ def visit(cls, op: ops.Array, exprs):
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
)

@classmethod
def visit(cls, op: ops.StructColumn, names, values):
return cls.rowwise(
lambda row: dict(zip(names, row)), values, name=op.name, dtype=object
)

@classmethod
def visit(cls, op: ops.ArrayConcat, arg):
dtype = PandasType.from_ibis(op.dtype)
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import TYPE_CHECKING

import dask.array as da
import dask.dataframe as dd
Expand All @@ -9,6 +9,9 @@

from ibis.backends.pandas.helpers import PandasUtils

if TYPE_CHECKING:
from collections.abc import Callable


class DaskUtils(PandasUtils):
@classmethod
Expand Down
14 changes: 11 additions & 3 deletions ibis/backends/dask/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,25 @@
ops.DateAdd: lambda row: row["left"] + row["right"],
}


def maybe_pandas_reduction(func):
def inner(df):
return df.reduction(func) if isinstance(df, dd.Series) else func(df)

return inner


reductions = {
**pandas_kernels.reductions,
ops.Mode: lambda x: x.mode().loc[0],
ops.ApproxMedian: lambda x: x.median_approximate(),
ops.BitAnd: lambda x: x.reduction(np.bitwise_and.reduce),
ops.BitOr: lambda x: x.reduction(np.bitwise_or.reduce),
ops.BitXor: lambda x: x.reduction(np.bitwise_xor.reduce),
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.first),
# Window functions are calculated locally using pandas
ops.Last: lambda x: x.compute().iloc[-1] if isinstance(x, dd.Series) else x.iat[-1],
ops.First: lambda x: x.loc[0] if isinstance(x, dd.Series) else x.iat[0],
ops.Arbitrary: lambda x: x.reduction(pandas_kernels.arbitrary),
ops.Last: maybe_pandas_reduction(pandas_kernels.last),
ops.First: maybe_pandas_reduction(pandas_kernels.first),
}

serieswise = {
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/dask/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def sort_kind():
return "mergesort"


default = pytest.mark.parametrize("default", [ibis.NA, ibis.literal("a")])
default = pytest.mark.parametrize("default", [ibis.null(), ibis.literal("a")])
row_offset = pytest.mark.parametrize("row_offset", list(map(ibis.literal, [-1, 1, 0])))
range_offset = pytest.mark.parametrize(
"range_offset",
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_lead(con, t, df, row_offset, default, row_window):
expr = t.dup_strings.lead(row_offset, default=default).over(row_window)
result = expr.execute()
expected = df.dup_strings.shift(con.execute(-row_offset)).compute()
if default is not ibis.NA:
if default is not ibis.null():
expected = expected.fillna(con.execute(default))
tm.assert_series_equal(result, expected, check_names=False)

Expand All @@ -59,7 +59,7 @@ def test_lag(con, t, df, row_offset, default, row_window):
expr = t.dup_strings.lag(row_offset, default=default).over(row_window)
result = expr.execute()
expected = df.dup_strings.shift(con.execute(row_offset)).compute()
if default is not ibis.NA:
if default is not ibis.null():
expected = expected.fillna(con.execute(default))
tm.assert_series_equal(result, expected, check_names=False)

Expand All @@ -78,7 +78,7 @@ def test_lead_delta(con, t, pandas_df, range_offset, default, range_window):
.reindex(pandas_df.plain_datetimes_naive)
.reset_index(drop=True)
)
if default is not ibis.NA:
if default is not ibis.null():
expected = expected.fillna(con.execute(default))
tm.assert_series_equal(result, expected, check_names=False)

Expand All @@ -98,7 +98,7 @@ def test_lag_delta(t, con, pandas_df, range_offset, default, range_window):
.reindex(pandas_df.plain_datetimes_naive)
.reset_index(drop=True)
)
if default is not ibis.NA:
if default is not ibis.null():
expected = expected.fillna(con.execute(default))
tm.assert_series_equal(result, expected, check_names=False)

Expand Down
182 changes: 148 additions & 34 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
Expand All @@ -24,9 +23,10 @@
from ibis.backends.datafusion.compiler import DataFusionCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowType
from ibis.util import gen_name, normalize_filename
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
from ibis.util import deprecated, gen_name, normalize_filename

try:
from datafusion import ExecutionContext as SessionContext
Expand All @@ -40,6 +40,26 @@

if TYPE_CHECKING:
import pandas as pd
import polars as pl


def as_nullable(dtype: dt.DataType) -> dt.DataType:
"""Recursively convert a possibly non-nullable datatype to a nullable one."""
if dtype.is_struct():
return dtype.copy(
fields={name: as_nullable(typ) for name, typ in dtype.items()},
nullable=True,
)
elif dtype.is_array():
return dtype.copy(value_type=as_nullable(dtype.value_type), nullable=True)
elif dtype.is_map():
return dtype.copy(
key_type=as_nullable(dtype.key_type),
value_type=as_nullable(dtype.value_type),
nullable=True,
)
else:
return dtype.copy(nullable=True)


class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl):
Expand Down Expand Up @@ -113,23 +133,11 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
pass

try:
result = (
self.raw_sql(f"DESCRIBE {table.sql(self.name)}")
.to_arrow_table()
.to_pydict()
)
df = self.con.table(name)
finally:
self.drop_view(name)
return sch.Schema(
{
name: self.compiler.type_mapper.from_string(
type_string, nullable=is_nullable == "YES"
)
for name, type_string, is_nullable in zip(
result["column_name"], result["data_type"], result["is_nullable"]
)
}
)

return PyArrowSchema.to_ibis(df.schema())

def _register_builtin_udfs(self):
from ibis.backends.datafusion import udfs
Expand Down Expand Up @@ -272,7 +280,13 @@ def list_tables(
list[str]
The list of the table names that match the pattern `like`.
"""
return self._filter_with_like(self.con.tables(), like)
database = database or "public"
query = (
sg.select("table_name")
.from_("information_schema.tables")
.where(sg.column("table_schema").eq(sge.convert(database)))
)
return self.raw_sql(query).to_pydict()["table_name"]

def get_schema(
self,
Expand All @@ -294,6 +308,10 @@ def get_schema(
table = database.table(table_name)
return sch.schema(table.schema)

@deprecated(
as_of="9.1",
instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.",
)
def register(
self,
source: str | Path | pa.Table | pa.RecordBatch | pa.Dataset | pd.DataFrame,
Expand Down Expand Up @@ -492,8 +510,8 @@ def read_delta(
)

delta_table = DeltaTable(source_table, **kwargs)

return self.register(delta_table.to_pyarrow_dataset(), table_name=table_name)
self.con.register_dataset(table_name, delta_table.to_pyarrow_dataset())
return self.table(table_name)

def to_pyarrow_batches(
self,
Expand All @@ -512,7 +530,9 @@ def to_pyarrow_batches(

frame = self.con.sql(raw_sql)

schema = table_expr.schema()
schema = sch.Schema(
{name: as_nullable(typ) for name, typ in table_expr.schema().items()}
)
names = schema.names

struct_schema = schema.as_struct().to_pyarrow()
Expand All @@ -526,15 +546,12 @@ def make_gen():
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema)
.cast(struct_schema, safe=False)
)
for batch in frame.collect()
)

return pa.ipc.RecordBatchReader.from_batches(
schema.to_pyarrow(),
make_gen(),
)
return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen())

def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
Expand All @@ -550,7 +567,14 @@ def execute(self, expr: ir.Expr, **kwargs: Any):
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pa.RecordBatchReader
| pa.RecordBatch
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -589,12 +613,10 @@ def create_table(

quoted = self.compiler.quoted

if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
else:
table = obj
if isinstance(obj, ir.Expr):
table = obj

# If it's a memtable, it will get registered in the pre-execute hooks
self._run_pre_execute_hooks(table)

relname = "_"
Expand All @@ -610,10 +632,13 @@ def create_table(
sg.to_identifier(relname, quoted=quoted)
)
)
elif obj is not None:
_read_in_memory(obj, name, self, overwrite=overwrite)
return self.table(name, database=database)
else:
query = None

table_ident = sg.to_identifier(name, quoted=quoted)
table_ident = sg.table(name, db=database, quoted=quoted)

if query is None:
column_defs = [
Expand Down Expand Up @@ -670,3 +695,92 @@ def truncate_table(
ident = sg.table(name, db=db, catalog=catalog).sql(self.name)
with self._safe_raw_sql(sge.delete(ident)):
pass


@contextlib.contextmanager
def _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
"""Workaround inability to overwrite tables in dataframe API.
Datafusion has helper methods for loading in-memory data, but these methods
don't allow overwriting tables.
The SQL interface allows creating tables from existing tables, so we register
the data as a table using the dataframe API, then run a
CREATE [OR REPLACE] TABLE table_name AS SELECT * FROM in_memory_thing
and that allows us to toggle the overwrite flag.
"""
src = sge.Create(
this=table_name,
kind="TABLE",
expression=sg.select("*").from_(tmp_name),
replace=overwrite,
)

yield

_conn.raw_sql(src)
_conn.drop_table(tmp_name)


@lazy_singledispatch
def _read_in_memory(
source: Any, table_name: str, _conn: Backend, overwrite: bool = False
):
raise NotImplementedError("No support for source or imports missing")


@_read_in_memory.register(dict)
def _pydict(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pydict")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pydict(source, name=tmp_name)


@_read_in_memory.register("polars.DataFrame")
def _polars(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("polars")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_polars(source, name=tmp_name)


@_read_in_memory.register("polars.LazyFrame")
def _polars(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("polars")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_polars(source.collect(), name=tmp_name)


@_read_in_memory.register("pyarrow.Table")
def _pyarrow_table(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_arrow_table(source, name=tmp_name)


@_read_in_memory.register("pyarrow.RecordBatchReader")
def _pyarrow_rbr(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_arrow_table(source.read_all(), name=tmp_name)


@_read_in_memory.register("pyarrow.RecordBatch")
def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.register_record_batches(tmp_name, [[source]])


@_read_in_memory.register("pyarrow.dataset.Dataset")
def _pyarrow_rb(source, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pyarrow")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.register_dataset(tmp_name, source)


@_read_in_memory.register("pandas.DataFrame")
def _pandas(source: pd.DataFrame, table_name, _conn, overwrite: bool = False):
tmp_name = gen_name("pandas")
with _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
_conn.con.from_pandas(source, name=tmp_name)
91 changes: 46 additions & 45 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import rewrite_sample_as_filter
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import rewrite_stringslice
from ibis.formats.pyarrow import PyArrowType


Expand All @@ -26,42 +25,42 @@ class DataFusionCompiler(SQLGlotCompiler):

dialect = DataFusion
type_mapper = DataFusionType

rewrites = (
rewrite_sample_as_filter,
rewrite_stringslice,
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArrayZip,
ops.BitwiseNot,
ops.Clip,
ops.CountDistinctStar,
ops.DateDelta,
ops.Greatest,
ops.GroupConcat,
ops.IntervalFromInteger,
ops.Least,
ops.MultiQuantile,
ops.Quantile,
ops.RowID,
ops.Strftime,
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.TimestampNow,
ops.TypeOf,
ops.Unnest,
ops.StringToDate,
ops.StringToTimestamp,
)
agg = AggGen(supports_filter=True)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArrayZip,
ops.BitwiseNot,
ops.Clip,
ops.CountDistinctStar,
ops.DateDelta,
ops.Greatest,
ops.GroupConcat,
ops.IntervalFromInteger,
ops.Least,
ops.MultiQuantile,
ops.Quantile,
ops.RowID,
ops.Strftime,
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.TimestampNow,
ops.TypeOf,
ops.Unnest,
ops.StringToDate,
ops.StringToTimestamp,
)

SIMPLE_OPS = {
Expand All @@ -72,8 +71,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.BitXor: "bit_xor",
ops.Cot: "cot",
ops.ExtractMicrosecond: "extract_microsecond",
ops.First: "first_value",
ops.Last: "last_value",
ops.Median: "median",
ops.StringLength: "character_length",
ops.RegexSplit: "regex_split",
Expand All @@ -82,12 +79,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.ArrayUnion: "array_union",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where))
return expr

def _to_timestamp(self, value, target_dtype, literal=False):
tz = (
f'Some("{timezone}")'
Expand Down Expand Up @@ -149,7 +140,7 @@ def visit_Cast(self, op, *, arg, to):
return self.cast(arg, to)

def visit_Arbitrary(self, op, *, arg, where):
cond = ~arg.is_(None)
cond = ~arg.is_(NULL)
if where is not None:
cond &= where
return self.agg.first_value(arg, where=cond)
Expand All @@ -173,7 +164,7 @@ def visit_StandardDev(self, op, *, arg, how, where):
def visit_ScalarUDF(self, op, **kw):
input_type = op.__input_type__
if input_type in (InputType.PYARROW, InputType.BUILTIN):
return self.f[op.__func_name__](*kw.values())
return self.f.anon[op.__func_name__](*kw.values())
else:
raise NotImplementedError(
f"DataFusion only supports PyArrow UDFs: got a {input_type.name.lower()} UDF"
Expand Down Expand Up @@ -431,6 +422,16 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where)

def visit_Last(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where)

def visit_Aggregate(self, op, *, parent, groups, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
quoted = self.quoted
Expand Down
22 changes: 18 additions & 4 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,36 @@ class TestConf(BackendTest):
supports_structs = False
supports_json = False
supports_arrays = True
supports_tpch = True
stateful = False
deps = ("datafusion",)
# Query 1 seems to require a bit more room here
tpch_absolute_tolerance = 0.11

def _load_data(self, **_: Any) -> None:
con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.register(path, table_name=table_name)
con.register(array_types, table_name="array_types")
con.register(win, table_name="win")
con.register(topk, table_name="topk")
with pytest.warns(FutureWarning, match="v9.1"):
con.register(path, table_name=table_name)
# TODO: remove warnings and replace register when implementing 8858
with pytest.warns(FutureWarning, match="v9.1"):
con.register(array_types, table_name="array_types")
con.register(win, table_name="win")
con.register(topk, table_name="topk")

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
return ibis.datafusion.connect(**kw)

def load_tpch(self) -> None:
con = self.connection
for path in self.data_dir.joinpath("tpch", "sf=0.17", "parquet").glob(
"*.parquet"
):
table_name = path.with_suffix("").name
con.read_parquet(path, table_name=table_name)


@pytest.fixture(scope="session")
def con(data_dir, tmp_path_factory, worker_id):
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/datafusion/tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ def test_none_config():

def test_str_config(name_to_path):
config = {name: str(path) for name, path in name_to_path.items()}
conn = ibis.datafusion.connect(config)
# if path.endswith((".parquet", ".csv", ".csv.gz")) connect triggers register
with pytest.warns(FutureWarning, match="v9.1"):
conn = ibis.datafusion.connect(config)
assert sorted(conn.list_tables()) == sorted(name_to_path)


def test_path_config(name_to_path):
config = name_to_path
conn = ibis.datafusion.connect(config)
# if path.endswith((".parquet", ".csv", ".csv.gz")) connect triggers register
with pytest.warns(FutureWarning, match="v9.1"):
conn = ibis.datafusion.connect(config)
assert sorted(conn.list_tables()) == sorted(name_to_path)


Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/datafusion/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import hypothesis as h

import ibis.tests.strategies as its
from ibis.backends.datafusion import as_nullable


def is_nullable(dtype):
if dtype.is_struct():
return all(map(is_nullable, dtype.values()))
elif dtype.is_array():
return is_nullable(dtype.value_type)
elif dtype.is_map():
return is_nullable(dtype.key_type) and is_nullable(dtype.value_type)
else:
return dtype.nullable is True


@h.given(its.all_dtypes())
def test_as_nullable(dtype):
nullable_dtype = as_nullable(dtype)
assert nullable_dtype.nullable is True
assert is_nullable(nullable_dtype)
18 changes: 11 additions & 7 deletions ibis/backends/datafusion/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,28 @@ def test_read_parquet(conn, data_dir):

def test_register_table(conn):
tab = pa.table({"x": [1, 2, 3]})
conn.register(tab, "my_table")
with pytest.warns(FutureWarning, match="v9.1"):
conn.register(tab, "my_table")
assert conn.table("my_table").x.sum().execute() == 6


def test_register_pandas(conn):
df = pd.DataFrame({"x": [1, 2, 3]})
conn.register(df, "my_table")
assert conn.table("my_table").x.sum().execute() == 6
with pytest.warns(FutureWarning, match="v9.1"):
conn.register(df, "my_table")
assert conn.table("my_table").x.sum().execute() == 6


def test_register_batches(conn):
batch = pa.record_batch([pa.array([1, 2, 3])], names=["x"])
conn.register(batch, "my_table")
assert conn.table("my_table").x.sum().execute() == 6
with pytest.warns(FutureWarning, match="v9.1"):
conn.register(batch, "my_table")
assert conn.table("my_table").x.sum().execute() == 6


def test_register_dataset(conn):
tab = pa.table({"x": [1, 2, 3]})
dataset = ds.InMemoryDataset(tab)
conn.register(dataset, "my_table")
assert conn.table("my_table").x.sum().execute() == 6
with pytest.warns(FutureWarning, match="v9.1"):
conn.register(dataset, "my_table")
assert conn.table("my_table").x.sum().execute() == 6
7 changes: 6 additions & 1 deletion ibis/backends/datafusion/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import pandas.testing as tm
import pytest
from packaging.version import parse as vparse

import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis import udf
from ibis.legacy.udf.vectorized import elementwise, reduction

pytest.importorskip("datafusion")
datafusion = pytest.importorskip("datafusion")
pc = pytest.importorskip("pyarrow.compute")

with pytest.warns(FutureWarning, match="v9.0"):
Expand Down Expand Up @@ -68,6 +69,10 @@ def median(a: float) -> float:
assert result == con.tables.batting.G.execute().median()


@pytest.mark.xfail(
condition=vparse(datafusion.__version__) == vparse("38.0.1"),
reason="internal error about MEDIAN(G) naming",
)
def test_builtin_agg_udf_filtered(con):
@udf.agg.builtin
def median(a: float, where: bool = True) -> float:
Expand Down
116 changes: 49 additions & 67 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,77 +6,65 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DruidType
from ibis.backends.sql.dialects import Druid
from ibis.backends.sql.rewrites import (
rewrite_capitalize,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_stringslice


class DruidCompiler(SQLGlotCompiler):
__slots__ = ()

dialect = Druid
type_mapper = DruidType
rewrites = (
rewrite_sample_as_filter,
rewrite_stringslice,
*(
rewrite
for rewrite in SQLGlotCompiler.rewrites
if rewrite is not rewrite_capitalize
),
)

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayIntersect,
ops.ArrayMap,
ops.ArraySort,
ops.ArrayUnion,
ops.ArrayZip,
ops.CountDistinctStar,
ops.Covariance,
ops.DateDelta,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.First,
ops.IntervalFromInteger,
ops.IsNan,
ops.IsInf,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.RegexReplace,
ops.RegexSplit,
ops.RowID,
ops.StandardDev,
ops.Strftime,
ops.StringAscii,
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.Translate,
ops.TypeOf,
ops.Unnest,
ops.Variance,
)
agg = AggGen(supports_filter=True)

LOWERED_OPS = {ops.Capitalize: None}

UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayIntersect,
ops.ArrayMap,
ops.ArraySort,
ops.ArrayUnion,
ops.ArrayZip,
ops.CountDistinctStar,
ops.Covariance,
ops.DateDelta,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.First,
ops.IntervalFromInteger,
ops.IsNan,
ops.IsInf,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.RegexReplace,
ops.RegexSplit,
ops.RowID,
ops.StandardDev,
ops.Strftime,
ops.StringAscii,
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.Translate,
ops.TypeOf,
ops.Unnest,
ops.Variance,
)

SIMPLE_OPS = {
Expand All @@ -94,12 +82,6 @@ class DruidCompiler(SQLGlotCompiler):
ops.StringContains: "contains_string",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sg.exp.Filter(this=expr, expression=sg.exp.Where(this=where))
return expr

def visit_Modulus(self, op, *, left, right):
return self.f.anon.mod(left, right)

Expand Down
183 changes: 124 additions & 59 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,20 @@
from ibis.backends.duckdb.compiler import DuckDBCompiler
from ibis.backends.duckdb.converter import DuckDBPandasData
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import STAR, C, F
from ibis.backends.sql.compiler import STAR, C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.util import deprecated

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, MutableMapping, Sequence

import pandas as pd
import polars as pl
import torch
from fsspec import AbstractFileSystem


def normalize_filenames(source_list):
# Promote to list
source_list = util.promote_list(source_list)

return list(map(util.normalize_filename, source_list))


_UDF_INPUT_TYPE_MAPPING = {
InputType.PYARROW: duckdb.functional.ARROW,
InputType.PYTHON: duckdb.functional.NATIVE,
Expand All @@ -57,21 +53,17 @@ def __init__(self, con: duckdb.DuckDBPyConnection) -> None:

def __getitem__(self, key: str) -> Any:
maybe_value = self.con.execute(
f"select value from duckdb_settings() where name = '{key}'"
"select value from duckdb_settings() where name = $1", [key]
).fetchone()
if maybe_value is not None:
return maybe_value[0]
raise KeyError(key)

def __setitem__(self, key, value):
self.con.execute(f"SET {key} = '{value}'")
self.con.execute(f"SET {key} = {str(value)!r}")

def __repr__(self):
((kv,),) = self.con.execute(
"select map(array_agg(name), array_agg(value)) from duckdb_settings()"
).fetch()

return repr(dict(zip(kv["key"], kv["value"])))
return repr(self.con.sql("from duckdb_settings()"))


class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema, UrlFromPath):
Expand Down Expand Up @@ -132,7 +124,12 @@ def _to_sqlglot(
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
Expand All @@ -154,24 +151,46 @@ def create_table(
database
The name of the database in which to create the table; if not
passed, the current database is used.
For multi-level table hierarchies, you can pass in a dotted string
path like `"catalog.database"` or a tuple of strings like
`("catalog", "database")`.
temp
Create a temporary table
overwrite
If `True`, replace the table if it already exists, otherwise fail
if the table exists
"""
table_loc = self._to_sqlglot_table(database)

if getattr(table_loc, "catalog", False) and temp:
raise exc.UnsupportedArgumentError(
"DuckDB can only create temporary tables in the `temp` catalog. "
"Don't specify a catalog to enable temp table creation."
)

catalog = self.current_catalog
database = self.current_database
if table_loc is not None:
catalog = table_loc.catalog or catalog
database = table_loc.db or database

if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")

properties = []

if temp:
properties.append(sge.TemporaryProperty())
catalog = "temp"

temp_memtable_view = None

if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand Down Expand Up @@ -199,8 +218,10 @@ def create_table(
else:
temp_name = name

initial_table = sg.table(
temp_name, catalog=database, quoted=self.compiler.quoted
initial_table = sge.Table(
this=sg.to_identifier(temp_name, quoted=self.compiler.quoted),
catalog=catalog,
db=database,
)
target = sge.Schema(this=initial_table, expressions=column_defs)

Expand All @@ -211,7 +232,11 @@ def create_table(
)

# This is the same table as initial_table unless overwrite == True
final_table = sg.table(name, catalog=database, quoted=self.compiler.quoted)
final_table = sge.Table(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
catalog=catalog,
db=database,
)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.insert(query, into=initial_table).sql(self.name)
Expand Down Expand Up @@ -248,7 +273,10 @@ def create_table(
).sql(self.name)
)

return self.table(name, database=database)
if temp_memtable_view is not None:
self.con.unregister(temp_memtable_view)

return self.table(name, database=(catalog, database))

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)
Expand Down Expand Up @@ -434,10 +462,8 @@ def do_connect(
<ibis.backends.duckdb.Backend object at ...>
"""
if (
not isinstance(database, Path)
and database != ":memory:"
and not database.startswith(("md:", "motherduck:"))
if not isinstance(database, Path) and not database.startswith(
("md:", "motherduck:", ":memory:")
):
database = Path(database).absolute()

Expand All @@ -457,9 +483,8 @@ def do_connect(
if extensions is not None:
self._load_extensions(extensions)

# Default timezone
with self._safe_raw_sql("SET TimeZone = 'UTC'"):
pass
# Default timezone, can't be set with `config`
self.settings["timezone"] = "UTC"

self._record_batch_readers_consumed = {}

Expand All @@ -468,7 +493,7 @@ def _load_extensions(
) -> None:
f = self.compiler.f
query = (
sg.select(f.unnest(f.list_append(C.aliases, C.extension_name)))
sg.select(f.anon.unnest(f.list_append(C.aliases, C.extension_name)))
.from_(f.duckdb_extensions())
.where(sg.and_(C.installed, C.loaded))
)
Expand Down Expand Up @@ -517,6 +542,10 @@ def drop_database(
with self._safe_raw_sql(sge.Drop(this=name, kind="SCHEMA", replace=force)):
pass

@deprecated(
as_of="9.1",
instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.",
)
def register(
self,
source: str | Path | Any,
Expand Down Expand Up @@ -623,7 +652,7 @@ def read_json(
table_name,
sg.select(STAR).from_(
self.compiler.f.read_json_auto(
normalize_filenames(source_list), *options
util.normalize_filenames(source_list), *options
)
),
)
Expand Down Expand Up @@ -656,7 +685,7 @@ def read_csv(
The just-registered table
"""
source_list = normalize_filenames(source_list)
source_list = util.normalize_filenames(source_list)

if not table_name:
table_name = util.gen_name("read_csv")
Expand Down Expand Up @@ -781,7 +810,7 @@ def read_parquet(
The just-registered table
"""
source_list = normalize_filenames(source_list)
source_list = util.normalize_filenames(source_list)

table_name = table_name or util.gen_name("read_parquet")

Expand Down Expand Up @@ -829,11 +858,19 @@ def _read_parquet_pyarrow_dataset(
# explicitly.

def read_in_memory(
# TODO: deprecate this in favor of `create_table`
self,
source: pd.DataFrame | pa.Table | pa.ipc.RecordBatchReader,
source: pd.DataFrame
| pa.Table
| pa.RecordBatchReader
| pl.DataFrame
| pl.LazyFrame,
table_name: str | None = None,
) -> ir.Table:
"""Register a Pandas DataFrame or pyarrow object as a table in the current database.
"""Register an in-memory table object in the current database.
Supported objects include pandas DataFrame, a Polars
DataFrame/LazyFrame, or a PyArrow Table or RecordBatchReader.
Parameters
----------
Expand All @@ -850,13 +887,7 @@ def read_in_memory(
"""
table_name = table_name or util.gen_name("read_in_memory")
self.con.register(table_name, source)

if isinstance(source, pa.ipc.RecordBatchReader):
# Ensure the reader isn't marked as started, in case the name is
# being overwritten.
self._record_batch_readers_consumed[table_name] = False

_read_in_memory(source, table_name, self)
return self.table(table_name)

def read_delta(
Expand Down Expand Up @@ -884,7 +915,7 @@ def read_delta(
The just-registered table.
"""
source_table = normalize_filenames(source_table)[0]
source_table = util.normalize_filenames(source_table)[0]

table_name = table_name or util.gen_name("read_delta")

Expand All @@ -911,6 +942,17 @@ def list_tables(
) -> list[str]:
"""List tables and views.
::: {.callout-note}
## Ibis does not use the word `schema` to refer to database hierarchy.
A collection of tables is referred to as a `database`.
A collection of `database` is referred to as a `catalog`.
These terms are mapped onto the corresponding features in each
backend (where available), regardless of whether the backend itself
uses the same terminology.
:::
Parameters
----------
like
Expand All @@ -924,17 +966,6 @@ def list_tables(
To specify a table in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.
::: {.callout-note}
## Ibis does not use the word `schema` to refer to database hierarchy.
A collection of tables is referred to as a `database`.
A collection of `database` is referred to as a `catalog`.
These terms are mapped onto the corresponding features in each
backend (where available), regardless of whether the backend itself
uses the same terminology.
:::
schema
[deprecated] Schema name. If not passed, uses the current schema.
Expand Down Expand Up @@ -965,8 +996,8 @@ def list_tables(
"""
table_loc = self._warn_and_create_table_loc(database, schema)

catalog = F.current_database()
database = F.current_schema()
catalog = self.current_catalog
database = self.current_database
if table_loc is not None:
catalog = table_loc.catalog or catalog
database = table_loc.db or database
Expand All @@ -977,12 +1008,10 @@ def list_tables(
.from_(sg.table("tables", db="information_schema"))
.distinct()
.where(
C.table_catalog.eq(catalog).or_(
C.table_catalog.eq(sge.convert("temp"))
),
C.table_schema.eq(database),
C.table_catalog.isin(sge.convert(catalog), sge.convert("temp")),
C.table_schema.eq(sge.convert(database)),
)
.sql(self.name, pretty=True)
.sql(self.dialect)
)
out = self.con.execute(sql).fetch_arrow_table()

Expand All @@ -993,6 +1022,17 @@ def read_postgres(
) -> ir.Table:
"""Register a table from a postgres instance into a DuckDB table.
::: {.callout-note}
## Ibis does not use the word `schema` to refer to database hierarchy.
A collection of `table` is referred to as a `database`.
A collection of `database` is referred to as a `catalog`.
These terms are mapped onto the corresponding features in each
backend (where available), regardless of whether the backend itself
uses the same terminology.
:::
Parameters
----------
uri
Expand Down Expand Up @@ -1570,3 +1610,28 @@ def _get_temp_view_definition(self, name: str, definition: str) -> str:
def _create_temp_view(self, table_name, source):
with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)):
pass


@lazy_singledispatch
def _read_in_memory(source: Any, table_name: str, _conn: Backend, **kwargs: Any):
raise NotImplementedError(
f"The `{_conn.name}` backend currently does not support "
f"reading data of {type(source)!r}"
)


@_read_in_memory.register("polars.DataFrame")
@_read_in_memory.register("polars.LazyFrame")
@_read_in_memory.register("pyarrow.Table")
@_read_in_memory.register("pandas.DataFrame")
@_read_in_memory.register("pyarrow.dataset.Dataset")
def _default(source, table_name, _conn, **kwargs: Any):
_conn.con.register(table_name, source)


@_read_in_memory.register("pyarrow.RecordBatchReader")
def _pyarrow_rbr(source, table_name, _conn, **kwargs: Any):
_conn.con.register(table_name, source)
# Ensure the reader isn't marked as started, in case the name is
# being overwritten.
_conn._record_batch_readers_consumed[table_name] = False
65 changes: 50 additions & 15 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect

_INTERVAL_SUFFIXES = {
"ms": "milliseconds",
Expand All @@ -33,20 +34,31 @@ class DuckDBCompiler(SQLGlotCompiler):
dialect = DuckDB
type_mapper = DuckDBType

agg = AggGen(supports_filter=True)

rewrites = (
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)

LOWERED_OPS = {
ops.Sample: None,
ops.StringSlice: None,
}

SIMPLE_OPS = {
ops.Arbitrary: "any_value",
ops.ArrayPosition: "list_indexof",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
ops.EndsWith: "suffix",
ops.Hash: "hash",
ops.ExtractIsoYear: "isoyear",
ops.IntegerRange: "range",
ops.TimestampRange: "range",
ops.MapLength: "cardinality",
ops.Mode: "mode",
ops.TimeFromHMS: "make_time",
ops.TypeOf: "typeof",
ops.GeoPoint: "st_point",
ops.GeoAsText: "st_astext",
ops.GeoArea: "st_area",
Expand Down Expand Up @@ -80,12 +92,6 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.GeoY: "st_y",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[
Expand Down Expand Up @@ -238,40 +244,48 @@ def visit_MapMerge(self, op, *, left, right):

def visit_ToJSONMap(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("OBJECT"),
self.f.json_type(arg).eq(sge.convert("OBJECT")),
self.cast(self.cast(arg, dt.json), op.dtype),
NULL,
)

def visit_ToJSONArray(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("ARRAY"),
self.f.json_type(arg).eq(sge.convert("ARRAY")),
self.cast(self.cast(arg, dt.json), op.dtype),
NULL,
)

def visit_UnwrapJSONString(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("VARCHAR"),
self.f.json_type(arg).eq(sge.convert("VARCHAR")),
self.f.json_extract_string(arg, "$"),
NULL,
)

def visit_UnwrapJSONInt64(self, op, *, arg):
arg_type = self.f.json_type(arg)
return self.if_(
arg_type.isin("UBIGINT", "BIGINT"), self.cast(arg, op.dtype), NULL
arg_type.isin(sge.convert("UBIGINT"), sge.convert("BIGINT")),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONFloat64(self, op, *, arg):
arg_type = self.f.json_type(arg)
return self.if_(
arg_type.isin("UBIGINT", "BIGINT", "DOUBLE"), self.cast(arg, op.dtype), NULL
arg_type.isin(
sge.convert("UBIGINT"), sge.convert("BIGINT"), sge.convert("DOUBLE")
),
self.cast(arg, op.dtype),
NULL,
)

def visit_UnwrapJSONBoolean(self, op, *, arg):
return self.if_(
self.f.json_type(arg).eq("BOOLEAN"), self.cast(arg, op.dtype), NULL
self.f.json_type(arg).eq(sge.convert("BOOLEAN")),
self.cast(arg, op.dtype),
NULL,
)

def visit_ArrayConcat(self, op, *, arg):
Expand Down Expand Up @@ -447,6 +461,16 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement):
arg, pattern, replacement, "g", dialect=self.dialect
)

def visit_First(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first(arg, where=where)

def visit_Last(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where)

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
funcname = f"percentile_{suffix}"
Expand All @@ -461,6 +485,14 @@ def visit_HexDigest(self, op, *, arg, how):
else:
raise NotImplementedError(f"No available hashing function for {how}")

def visit_Hash(self, op, *, arg):
# duckdb's hash() returns a uint64, but ops.Hash is supposed to be int64
# So do HASH(x)::BITSTRING::BIGINT
raw = self.f.hash(arg)
bitstring = sg.cast(sge.convert(raw), to=sge.DataType.Type.BIT, copy=False)
int64 = sg.cast(bitstring, to=sge.DataType.Type.BIGINT, copy=False)
return int64

def visit_StringConcat(self, op, *, arg):
return reduce(lambda x, y: sge.DPipe(this=x, expression=y), arg)

Expand All @@ -486,3 +518,6 @@ def visit_RandomScalar(self, op, **kwargs):

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

def visit_TypeOf(self, op, *, arg):
return self.f.coalesce(self.f.nullif(self.f.typeof(arg), '"NULL"'), "NULL")
11 changes: 8 additions & 3 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,14 @@ def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
return ibis.duckdb.connect(**kw)

def load_tpch(self) -> None:
"""Load the TPC-H dataset."""
with self.connection._safe_raw_sql("CALL dbgen(sf=0.17)"):
pass
"""Load TPC-H data."""
con = self.connection
for path in self.data_dir.joinpath("tpch", "sf=0.17", "parquet").glob(
"*.parquet"
):
table_name = path.with_suffix("").name
# duckdb automatically infers the sf=0.17 as a hive partition
con.read_parquet(path, table_name=table_name, hive_partitioning=False)


@pytest.fixture(scope="session")
Expand Down
73 changes: 73 additions & 0 deletions ibis/backends/duckdb/tests/test_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import pandas as pd
import pandas.testing as tm
import pytest

import ibis
import ibis.common.exceptions as exc


@pytest.fixture(scope="session")
def external_duckdb_file(tmpdir_factory): # pragma: no cover
ddb_path = str(tmpdir_factory.mktemp("data") / "starwars.ddb")
con = ibis.duckdb.connect(ddb_path)

starwars_df = pd.DataFrame(
{
"name": ["Luke Skywalker", "C-3PO", "R2-D2"],
"height": [172, 167, 96],
"mass": [77.0, 75.0, 32.0],
}
)
con.create_table("starwars", obj=starwars_df)
con.disconnect()

return ddb_path, starwars_df


def test_read_write_external_catalog(con, external_duckdb_file, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)

ddb_path, starwars_df = external_duckdb_file
con.attach(ddb_path, name="ext")

# Read from catalog
assert "ext" in con.list_catalogs()
assert "main" in con.list_databases(catalog="ext")

assert "starwars" in con.list_tables(database="ext.main")
assert "starwars" not in con.list_tables()

starwars = con.table("starwars", database="ext.main")
tm.assert_frame_equal(starwars.to_pandas(), starwars_df)

# Write to catalog
t = ibis.memtable([{"a": 1, "b": "foo"}, {"a": 2, "b": "baz"}])

_ = con.create_table("t2", obj=t, database="ext.main")

assert "t2" in con.list_tables(database="ext.main")
assert "t2" not in con.list_tables()

table = con.table("t2", database="ext.main")

tm.assert_frame_equal(t.to_pandas(), table.to_pandas())

# Overwrite table in catalog

t_overwrite = ibis.memtable([{"a": 8, "b": "bing"}, {"a": 9, "b": "bong"}])

_ = con.create_table("t2", obj=t_overwrite, database="ext.main", overwrite=True)

assert "t2" in con.list_tables(database="ext.main")
assert "t2" not in con.list_tables()

table = con.table("t2", database="ext.main")

tm.assert_frame_equal(t_overwrite.to_pandas(), table.to_pandas())


def test_raise_if_catalog_and_temp(con):
with pytest.raises(exc.UnsupportedArgumentError):
con.create_table("some_table", obj="hi", temp=True, database="ext.main")
27 changes: 26 additions & 1 deletion ibis/backends/duckdb/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def test_connect_duckdb(url, tmp_path):
)
def test_connect_local_file(out_method, extension, test_employee_data_1, tmp_path):
getattr(test_employee_data_1, out_method)(tmp_path / f"out.{extension}")
con = ibis.connect(tmp_path / f"out.{extension}")
with pytest.warns(FutureWarning, match="v9.1"):
# ibis.connect uses con.register
con = ibis.connect(tmp_path / f"out.{extension}")
t = next(iter(con.tables.values()))
assert not t.head().execute().empty

Expand Down Expand Up @@ -297,3 +299,26 @@ def test_list_tables_schema_warning_refactor(con):

assert con.list_tables(database="shops") == icecream_table
assert con.list_tables(database=("shops",)) == icecream_table


def test_settings_repr():
con = ibis.duckdb.connect()
view = repr(con.settings)
assert "name" in view
assert "value" in view


def test_connect_named_in_memory_db():
con_named_db = ibis.duckdb.connect(":memory:mydb")

con_named_db.create_table("ork", schema=ibis.schema(dict(bork="int32")))
assert "ork" in con_named_db.list_tables()

con_named_db_2 = ibis.duckdb.connect(":memory:mydb")
assert "ork" in con_named_db_2.list_tables()

unnamed_memory_db = ibis.duckdb.connect(":memory:")
assert "ork" not in unnamed_memory_db.list_tables()

default_memory_db = ibis.duckdb.connect()
assert "ork" not in default_memory_db.list_tables()
38 changes: 12 additions & 26 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import ibis
import ibis.expr.datatypes as dt
from ibis.conftest import LINUX, SANDBOXED
from ibis.conftest import ARM64, LINUX, MACOS, SANDBOXED


def test_read_csv(con, data_dir):
Expand Down Expand Up @@ -257,10 +257,10 @@ def test_read_sqlite_no_table_name(con, tmp_path):
)
def test_register_sqlite(con, tmp_path):
path = tmp_path / "test.db"

sqlite_con = sqlite3.connect(str(path))
sqlite_con.execute("CREATE TABLE t AS SELECT 1 a UNION SELECT 2 UNION SELECT 3")
ft = con.register(f"sqlite://{path}", "t")
with pytest.warns(FutureWarning, match="v9.1"):
ft = con.register(f"sqlite://{path}", "t")
assert ft.count().execute()


Expand Down Expand Up @@ -311,16 +311,6 @@ def test_attach_sqlite(data_dir, tmp_path):
assert dt.String(nullable=True) in set(types)


def test_read_in_memory(con):
df_arrow = pa.table({"a": ["a"], "b": [1]})
df_pandas = pd.DataFrame({"a": ["a"], "b": [1]})
con.read_in_memory(df_arrow, table_name="df_arrow")
con.read_in_memory(df_pandas, table_name="df_pandas")

assert "df_arrow" in con.list_tables()
assert "df_pandas" in con.list_tables()


def test_re_read_in_memory_overwrite(con):
df_pandas_1 = pd.DataFrame({"a": ["a"], "b": [1], "d": ["hi"]})
df_pandas_2 = pd.DataFrame({"a": [1], "c": [1.4]})
Expand Down Expand Up @@ -390,37 +380,33 @@ def test_set_temp_dir(tmp_path):


@pytest.mark.xfail(
LINUX and SANDBOXED,
SANDBOXED and LINUX,
reason=(
"nix on linux cannot download duckdb extensions or data due to sandboxing; "
"duckdb will try to automatically install and load read_parquet"
),
raises=(duckdb.Error, duckdb.IOException),
)
@pytest.mark.skipif(
SANDBOXED and MACOS and ARM64, reason="raises a RuntimeError on nix macos arm64"
)
def test_s3_403_fallback(con, httpserver, monkeypatch):
# monkeypatch to avoid downloading extensions in tests
monkeypatch.setattr(con, "_load_extensions", lambda _: True)

# Throw a 403 to trigger fallback to pyarrow.dataset
httpserver.expect_request("/myfile").respond_with_data(
"Forbidden", status=403, content_type="text/plain"
path = "/invalid.parquet"
httpserver.expect_request(path).respond_with_data(
status=403, content_type="application/vnd.apache.parquet"
)

# Since the URI is nonsense to pyarrow, expect an error, but raises from
# pyarrow, which indicates the fallback worked
url = httpserver.url_for(path)
with pytest.raises(pa.lib.ArrowInvalid):
con.read_parquet(httpserver.url_for("/myfile"))
con.read_parquet(url)


@pytest.mark.xfail_version(
duckdb=["duckdb<=0.7.1"],
reason="""
the fix for this (issue #5879) caused a serious performance regression in the repr.
added this xfail in #5959, which also reverted the bugfix that caused the regression.
the issue was fixed upstream in duckdb in https://github.com/duckdb/duckdb/pull/6978
""",
)
def test_register_numpy_str(con):
data = pd.DataFrame({"a": [np.str_("xyz"), None]})
result = con.read_in_memory(data)
Expand Down
18 changes: 15 additions & 3 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from collections.abc import Iterable, Mapping

import pandas as pd
import polars as pl
import pyarrow as pa

from ibis.backends import BaseBackend
Expand All @@ -49,7 +50,7 @@ def version(self) -> str:
query = (
sg.select("param_value")
.from_(sg.table("EXA_METADATA", catalog="SYS"))
.where(C.param_name.eq("databaseProductVersion"))
.where(C.param_name.eq(sge.convert("databaseProductVersion")))
)
with self._safe_raw_sql(query) as result:
[(version,)] = result.fetchall()
Expand Down Expand Up @@ -279,14 +280,19 @@ def process_item(item: Any):

def _clean_up_tmp_table(self, ident: sge.Identifier) -> None:
with self._safe_raw_sql(
sge.Drop(kind="TABLE", this=ident, force=True, cascade=True)
sge.Drop(kind="TABLE", this=ident, exists=True, cascade=True)
):
pass

def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -331,9 +337,11 @@ def create_table(

quoted = self.compiler.quoted

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand Down Expand Up @@ -383,6 +391,10 @@ def create_table(
)

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)

# preserve the input schema if it was provided
Expand Down
123 changes: 57 additions & 66 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_stringslice


class ExasolCompiler(SQLGlotCompiler):
Expand All @@ -25,77 +23,73 @@ class ExasolCompiler(SQLGlotCompiler):
dialect = Exasol
type_mapper = ExasolType
rewrites = (
rewrite_sample_as_filter,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_empty_order_by_window,
rewrite_stringslice,
*SQLGlotCompiler.rewrites,
)

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayIntersect,
ops.ArrayMap,
ops.ArraySort,
ops.ArrayStringJoin,
ops.ArrayUnion,
ops.ArrayZip,
ops.BitwiseNot,
ops.Covariance,
ops.CumeDist,
ops.DateAdd,
ops.DateSub,
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.ElementWiseVectorizedUDF,
ops.First,
ops.IntervalFromInteger,
ops.IsInf,
ops.IsNan,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.ReductionVectorizedUDF,
ops.RegexExtract,
ops.RegexReplace,
ops.RegexSearch,
ops.RegexSplit,
ops.RowID,
ops.StandardDev,
ops.Strftime,
ops.StringJoin,
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.TimeDelta,
ops.TimestampAdd,
ops.TimestampBucket,
ops.TimestampDelta,
ops.TimestampDiff,
ops.TimestampSub,
ops.TypeOf,
ops.Unnest,
ops.Variance,
)
UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayIntersect,
ops.ArrayMap,
ops.ArraySort,
ops.ArrayStringJoin,
ops.ArrayUnion,
ops.ArrayZip,
ops.BitwiseNot,
ops.Covariance,
ops.CumeDist,
ops.DateAdd,
ops.DateSub,
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.ElementWiseVectorizedUDF,
ops.IntervalFromInteger,
ops.IsInf,
ops.IsNan,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.ReductionVectorizedUDF,
ops.RegexExtract,
ops.RegexReplace,
ops.RegexSearch,
ops.RegexSplit,
ops.RowID,
ops.StandardDev,
ops.Strftime,
ops.StringJoin,
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.TimeDelta,
ops.TimestampAdd,
ops.TimestampBucket,
ops.TimestampDelta,
ops.TimestampDiff,
ops.TimestampSub,
ops.TypeOf,
ops.Unnest,
ops.Variance,
)

SIMPLE_OPS = {
ops.Log10: "log10",
ops.All: "min",
ops.Any: "max",
ops.First: "first_value",
ops.Last: "last_value",
}

@staticmethod
Expand All @@ -109,12 +103,6 @@ def _minimize_spec(start, end, spec):
return None
return spec

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _gen_valid_name(name: str) -> str:
"""Exasol does not allow dots in quoted column names."""
Expand Down Expand Up @@ -212,6 +200,9 @@ def visit_ExtractDayOfYear(self, op, *, arg):
def visit_ExtractWeekOfYear(self, op, *, arg):
return self.cast(self.f.to_char(arg, "IW"), op.dtype)

def visit_ExtractIsoYear(self, op, *, arg):
return self.cast(self.f.to_char(arg, "IYYY"), op.dtype)

def visit_DayOfWeekName(self, op, *, arg):
return self.f.concat(
self.f.substr(self.f.to_char(arg, "DAY"), 0, 1),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/exasol/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _exaplus(self) -> str:
"exec",
self.service_name,
"find",
"/usr",
"/opt",
"-name",
"exaplus",
"-type",
Expand Down
121 changes: 59 additions & 62 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,85 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import FlinkType
from ibis.backends.sql.dialects import Flink
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_stringslice


class FlinkAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
func = compiler.f[name]
if where is not None:
# Flink does support FILTER, but it's broken for:
#
# 1. certain aggregates: std/var doesn't return the right result
# 2. certain kinds of predicates: x IN y doesn't filter the right
# values out
# 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y)
# returns an incorrect result
#
# One solution is to try `IF(predicate, arg, NULL)`.
#
# Unfortunately that won't work without casting the NULL to a
# specific type.
#
# At this point in the Ibis compiler we don't have any of the Ibis
# operation's type information because we thrown it away. In every
# other engine Ibis supports the type of a NULL literal is inferred
# by the engine.
#
# Using a CASE statement and leaving out the explicit NULL does the
# trick for Flink.
args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args)
return func(*args)


class FlinkCompiler(SQLGlotCompiler):
quoted = True
dialect = Flink
type_mapper = FlinkType

agg = FlinkAggGen()

rewrites = (
rewrite_sample_as_filter,
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
rewrite_stringslice,
*SQLGlotCompiler.rewrites,
)

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayFlatten,
ops.ArraySort,
ops.ArrayStringJoin,
ops.Correlation,
ops.CountDistinctStar,
ops.Covariance,
ops.DateDiff,
ops.ExtractURLField,
ops.FindInSet,
ops.IsInf,
ops.IsNan,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.ReductionVectorizedUDF,
ops.RegexSplit,
ops.RowID,
ops.StringSplit,
ops.Translate,
)
UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayFlatten,
ops.ArraySort,
ops.ArrayStringJoin,
ops.Correlation,
ops.CountDistinctStar,
ops.Covariance,
ops.DateDiff,
ops.ExtractURLField,
ops.FindInSet,
ops.IsInf,
ops.IsNan,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.ReductionVectorizedUDF,
ops.RegexSplit,
ops.RowID,
ops.StringSplit,
ops.Translate,
)

SIMPLE_OPS = {
Expand Down Expand Up @@ -102,34 +127,6 @@ def POS_INF(self):
def _generate_groups(groups):
return groups

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
# FILTER (WHERE ) is broken for one or both of:
#
# 1. certain aggregates: std/var doesn't return the right result
# 2. certain kinds of predicates: x IN y doesn't filter the right
# values out
# 3. certain aggregates AND predicates STD(w) FILTER (WHERE x IN Y)
# returns an incorrect result
#
# One solution is to try `IF(predicate, arg, NULL)`.
#
# Unfortunately that won't work without casting the NULL to a
# specific type.
#
# At this point in the Ibis compiler we don't have any of the Ibis
# operation's type information because we thrown it away. In every
# other engine Ibis supports the type of a NULL literal is inferred
# by the engine.
#
# Using a CASE statement and leaving out the explicit NULL does the
# trick for Flink.
#
# Le sigh.
args = tuple(sge.Case(ifs=[sge.If(this=where, true=arg)]) for arg in args)
return func(*args)

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pathlib import Path

import pandas as pd
import polars as pl
import pyarrow as pa

import ibis.expr.operations as ops
Expand Down Expand Up @@ -447,7 +448,12 @@ def table(self, name: str, database: str | None = None, **kwargs: Any) -> ir.Tab
def create_table(
self,
name: str,
obj: ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema=None,
database=None,
Expand All @@ -459,7 +465,7 @@ def create_table(
partition=None,
like_parquet=None,
) -> ir.Table:
"""Create a new table in Impala using an Ibis table expression.
"""Create a new table using an Ibis table expression or in-memory data.
Parameters
----------
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/impala/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def insert(
if not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

if not set(self.columns).difference(obj.columns):
# project out using column order of parent table
# if column names match
obj = obj.select(self.columns)

self._client._run_pre_execute_hooks(obj)

expr = obj
Expand Down
64 changes: 26 additions & 38 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from ibis.backends.sql.dialects import Impala
from ibis.backends.sql.rewrites import (
rewrite_empty_order_by_window,
rewrite_sample_as_filter,
)
from ibis.expr.rewrites import rewrite_stringslice


class ImpalaCompiler(SQLGlotCompiler):
Expand All @@ -23,40 +21,36 @@ class ImpalaCompiler(SQLGlotCompiler):
dialect = Impala
type_mapper = ImpalaType
rewrites = (
rewrite_sample_as_filter,
rewrite_empty_order_by_window,
rewrite_stringslice,
*SQLGlotCompiler.rewrites,
)

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayPosition,
ops.Array,
ops.Covariance,
ops.DateDelta,
ops.ExtractDayOfYear,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Map,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.RegexSplit,
ops.RowID,
ops.StringSplit,
ops.StructColumn,
ops.Time,
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.Unnest,
)
UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayPosition,
ops.Array,
ops.Covariance,
ops.DateDelta,
ops.ExtractDayOfYear,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Map,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.RegexSplit,
ops.RowID,
ops.StringSplit,
ops.StructColumn,
ops.Time,
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.Unnest,
)

SIMPLE_OPS = {
Expand All @@ -81,12 +75,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.TypeOf: "typeof",
}

def _aggregate(self, funcname: str, *args, where):
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)

return self.f[funcname](*args, dialect=self.dialect)

@staticmethod
def _minimize_spec(start, end, spec):
# start is None means unbounded preceding
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TestConf(BackendTest):

@property
def test_files(self) -> Iterable[Path]:
return [self.data_dir.joinpath("impala")]
return [self.data_dir.joinpath("directory")]

def preload(self):
env = IbisTestEnv()
Expand Down Expand Up @@ -79,7 +79,7 @@ def _load_data(self, **_: Any) -> None:
(parquet,) = self.test_files

# container path to data
prefix = "/user/hive/warehouse/impala/parquet"
prefix = "/user/hive/warehouse/directory/parquet"
for dir in parquet.joinpath("parquet").glob("*"):
con.drop_table(dir.name, database=database, force=True)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
11 changes: 6 additions & 5 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,17 @@ def test_nullif_ifnull(tpch_lineitem, expr_fn, snapshot):
@pytest.mark.parametrize(
"expr_fn",
[
pytest.param(lambda t: t.l_quantity.fillna(0), id="fillna_l_quantity"),
pytest.param(lambda t: t.l_quantity.fill_null(0), id="fill_null_l_quantity"),
pytest.param(
lambda t: t.l_extendedprice.fillna(0), id="fillna_l_extendedprice"
lambda t: t.l_extendedprice.fill_null(0), id="fill_null_l_extendedprice"
),
pytest.param(
lambda t: t.l_extendedprice.fillna(0.0), id="fillna_l_extendedprice_double"
lambda t: t.l_extendedprice.fill_null(0.0),
id="fill_null_l_extendedprice_double",
),
],
)
def test_decimal_fillna_cast_arg(tpch_lineitem, expr_fn, snapshot):
def test_decimal_fill_null_cast_arg(tpch_lineitem, expr_fn, snapshot):
expr = expr_fn(tpch_lineitem)
result = translate(expr)
snapshot.assert_match(result, "out.sql")
Expand All @@ -99,6 +100,6 @@ def test_identical_to(mockcon, snapshot):


def test_identical_to_special_case(snapshot):
expr = ibis.NA.cast("int64").identical_to(ibis.NA.cast("int64")).name("tmp")
expr = ibis.null().cast("int64").identical_to(ibis.null().cast("int64")).name("tmp")
result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_change_format(table):


def test_query_avro(con, test_data_dir):
hdfs_path = pjoin(test_data_dir, "impala/avro/tpch/region")
hdfs_path = pjoin(test_data_dir, "directory/avro/tpch/region")

avro_schema = {
"fields": [
Expand Down
24 changes: 12 additions & 12 deletions ibis/backends/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def test_builtins(con, alltypes):
i4 % 10,
20 % i1,
d % 5,
i1.fillna(0),
i4.fillna(0),
i8.fillna(0),
i1.fill_null(0),
i4.fill_null(0),
i8.fill_null(0),
i4.to_timestamp("s"),
i4.to_timestamp("ms"),
i4.to_timestamp("us"),
Expand All @@ -65,7 +65,7 @@ def test_builtins(con, alltypes):
d.ceil(),
d.exp(),
d.isnull(),
d.fillna(0),
d.fill_null(0),
d.floor(),
d.log(),
d.ln(),
Expand Down Expand Up @@ -164,7 +164,7 @@ def _check_impala_output_types_match(con, table):
(5 / L(50).nullif(0), 0.1),
(5 / L(50).nullif(L(50000)), 0.1),
(5 / L(50000).nullif(0), 0.0001),
(L(50000).fillna(0), 50000),
(L(50000).fill_null(0), 50000),
],
)
def test_int_builtins(con, expr, expected):
Expand Down Expand Up @@ -257,13 +257,13 @@ def approx_equal(a, b, eps):
[
pytest.param(lambda dc: dc, "5.245", id="id"),
pytest.param(lambda dc: dc % 5, "0.245", id="mod"),
pytest.param(lambda dc: dc.fillna(0), "5.245", id="fillna"),
pytest.param(lambda dc: dc.fill_null(0), "5.245", id="fill_null"),
pytest.param(lambda dc: dc.exp(), "189.6158", id="exp"),
pytest.param(lambda dc: dc.log(), "1.65728", id="log"),
pytest.param(lambda dc: dc.log2(), "2.39094", id="log2"),
pytest.param(lambda dc: dc.log10(), "0.71975", id="log10"),
pytest.param(lambda dc: dc.sqrt(), "2.29019", id="sqrt"),
pytest.param(lambda dc: dc.fillna(0), "5.245", id="zero_ifnull"),
pytest.param(lambda dc: dc.fill_null(0), "5.245", id="zero_ifnull"),
pytest.param(lambda dc: -dc, "-5.245", id="neg"),
],
)
Expand Down Expand Up @@ -384,8 +384,8 @@ def test_decimal_timestamp_builtins(con):
dc * 2,
dc**2,
dc.cast("double"),
api.ifelse(table.l_discount > 0, dc * table.l_discount, api.NA),
dc.fillna(0),
api.ifelse(table.l_discount > 0, dc * table.l_discount, api.null()),
dc.fill_null(0),
ts < (ibis.now() + ibis.interval(months=3)),
ts < (ibis.timestamp("2005-01-01") + ibis.interval(months=3)),
# hashing
Expand Down Expand Up @@ -632,10 +632,10 @@ def test_unions_with_ctes(con, alltypes):
@pytest.mark.parametrize(
("left", "right", "expected"),
[
(ibis.NA.cast("int64"), ibis.NA.cast("int64"), True),
(ibis.null().cast("int64"), ibis.null().cast("int64"), True),
(L(1), L(1), True),
(ibis.NA.cast("int64"), L(1), False),
(L(1), ibis.NA.cast("int64"), False),
(ibis.null().cast("int64"), L(1), False),
(L(1), ibis.null().cast("int64"), False),
(L(0), L(1), False),
(L(1), L(0), False),
],
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/impala/tests/test_parquet_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def test_parquet_file_with_name(con, test_data_dir, temp_table):
hdfs_path = pjoin(test_data_dir, "impala/parquet/region")
hdfs_path = pjoin(test_data_dir, "directory/parquet/region")

name = temp_table
schema = ibis.schema(
Expand All @@ -30,7 +30,7 @@ def test_parquet_file_with_name(con, test_data_dir, temp_table):


def test_query_parquet_file_with_schema(con, test_data_dir):
hdfs_path = pjoin(test_data_dir, "impala/parquet/region")
hdfs_path = pjoin(test_data_dir, "directory/parquet/region")

schema = ibis.schema(
[
Expand All @@ -54,7 +54,7 @@ def test_query_parquet_file_with_schema(con, test_data_dir):


def test_query_parquet_file_like_table(con, test_data_dir):
hdfs_path = pjoin(test_data_dir, "impala/parquet/region")
hdfs_path = pjoin(test_data_dir, "directory/parquet/region")

ex_schema = ibis.schema(
[
Expand All @@ -70,7 +70,7 @@ def test_query_parquet_file_like_table(con, test_data_dir):


def test_query_parquet_infer_schema(con, test_data_dir):
hdfs_path = pjoin(test_data_dir, "impala/parquet/region")
hdfs_path = pjoin(test_data_dir, "directory/parquet/region")
table = con.parquet_file(hdfs_path, like_table="region")

# NOTE: the actual schema should have an int16, but bc this is being
Expand All @@ -88,7 +88,7 @@ def test_query_parquet_infer_schema(con, test_data_dir):


def test_create_table_persist_fails_if_called_twice(con, temp_table, test_data_dir):
hdfs_path = pjoin(test_data_dir, "impala/parquet/region")
hdfs_path = pjoin(test_data_dir, "directory/parquet/region")
con.parquet_file(hdfs_path, like_table="region", name=temp_table)

with pytest.raises(HiveServer2Error):
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_unary_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def table(mockcon):
param(lambda x: x.log2(), id="log2"),
param(lambda x: x.log10(), id="log10"),
param(lambda x: x.nullif(0), id="nullif_zero"),
param(lambda x: x.fillna(0), id="zero_ifnull"),
param(lambda x: x.fill_null(0), id="zero_ifnull"),
],
)
@pytest.mark.parametrize("cname", ["double_col", "int_col"])
Expand Down
36 changes: 24 additions & 12 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from collections.abc import Iterable, Mapping

import pandas as pd
import polars as pl
import pyarrow as pa


Expand Down Expand Up @@ -365,6 +366,17 @@ def list_tables(
) -> list[str]:
"""List the tables in the database.

::: {.callout-note}
## Ibis does not use the word `schema` to refer to database hierarchy.

A collection of tables is referred to as a `database`.
A collection of `database` is referred to as a `catalog`.

These terms are mapped onto the corresponding features in each
backend (where available), regardless of whether the backend itself
uses the same terminology.
:::

Parameters
----------
like
Expand All @@ -375,17 +387,6 @@ def list_tables(
To specify a table in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.

::: {.callout-note}
## Ibis does not use the word `schema` to refer to database hierarchy.

A collection of tables is referred to as a `database`.
A collection of `database` is referred to as a `catalog`.

These terms are mapped onto the corresponding features in each
backend (where available), regardless of whether the backend itself
uses the same terminology.
:::
schema
[deprecated] The schema inside `database` to perform the list against.
"""
Expand Down Expand Up @@ -435,7 +436,12 @@ def list_databases(
def create_table(
self,
name: str,
obj: pd.DataFrame | pa.Table | ir.Table | None = None,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand All @@ -457,9 +463,11 @@ def create_table(
if temp:
properties.append(sge.TemporaryProperty())

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand Down Expand Up @@ -513,6 +521,10 @@ def create_table(
cur.execute(f"EXEC sp_rename '{old}', '{new}'")

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)

# preserve the input schema if it was provided
Expand Down
114 changes: 51 additions & 63 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
rewrite_sample_as_filter,
)
from ibis.common.deferred import var
from ibis.expr.rewrites import rewrite_stringslice

y = var("y")
start = var("start")
Expand Down Expand Up @@ -59,68 +57,64 @@ class MSSQLCompiler(SQLGlotCompiler):
dialect = MSSQL
type_mapper = MSSQLType
rewrites = (
rewrite_sample_as_filter,
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_row_number,
rewrite_rows_range_order_by_window,
rewrite_stringslice,
*SQLGlotCompiler.rewrites,
)
copy_func_args = True

UNSUPPORTED_OPERATIONS = frozenset(
(
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayDistinct,
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArraySort,
ops.ArrayUnion,
ops.BitAnd,
ops.BitOr,
ops.BitXor,
ops.Covariance,
ops.CountDistinctStar,
ops.DateAdd,
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.First,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IsInf,
ops.IsNan,
ops.Last,
ops.LPad,
ops.Levenshtein,
ops.Map,
ops.Median,
ops.Mode,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.RegexExtract,
ops.RegexReplace,
ops.RegexSearch,
ops.RegexSplit,
ops.RowID,
ops.RPad,
ops.StartsWith,
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimestampAdd,
ops.TimestampDiff,
ops.TimestampSub,
ops.Unnest,
)
UNSUPPORTED_OPS = (
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayDistinct,
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArraySort,
ops.ArrayUnion,
ops.BitAnd,
ops.BitOr,
ops.BitXor,
ops.Covariance,
ops.CountDistinctStar,
ops.DateAdd,
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.First,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IsInf,
ops.IsNan,
ops.Last,
ops.LPad,
ops.Levenshtein,
ops.Map,
ops.Median,
ops.Mode,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.RegexExtract,
ops.RegexReplace,
ops.RegexSearch,
ops.RegexSplit,
ops.RowID,
ops.RPad,
ops.StartsWith,
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimestampAdd,
ops.TimestampDiff,
ops.TimestampSub,
ops.Unnest,
)

SIMPLE_OPS = {
Expand Down Expand Up @@ -150,12 +144,6 @@ def POS_INF(self):
def NEG_INF(self):
return self.f.double("-Infinity")

def _aggregate(self, funcname: str, *args, where):
func = self.f[funcname]
if where is not None:
args = tuple(self.if_(where, arg, NULL) for arg in args)
return func(*args)

@staticmethod
def _generate_groups(groups):
return groups
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
("NUMERIC(14,3)", dt.Decimal(14, 3)),
("SMALLINT", dt.int16),
("SMALLMONEY", dt.Decimal(10, 4)),
("TINYINT", dt.int8),
("TINYINT", dt.uint8),
# Approximate numerics
("REAL", dt.float32),
("FLOAT", dt.float64),
Expand Down
Loading