78 changes: 49 additions & 29 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Literal, Mapping

import sqlglot as sg
from toolz import flip

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -49,9 +50,11 @@ def _column(op, *, aliases, **_):


@translate_val.register(ops.Alias)
def _alias(op, **kw):
val = translate_val(op.arg, **kw)
return sg.alias(val, op.name, dialect="clickhouse")
def _alias(op, render_aliases: bool = True, **kw):
val = translate_val(op.arg, render_aliases=render_aliases, **kw)
if render_aliases:
return sg.alias(val, op.name, dialect="clickhouse")
return val


_interval_cast_suffixes = {
Expand All @@ -75,7 +78,10 @@ def _cast(op, **kw):
return f"toInterval{suffix}({arg})"

to = translate_val(op.to, **kw)
return f"CAST({arg} AS {to})"
result = f"CAST({arg} AS {to})"
if (timezone := getattr(op.to, "timezone", None)) is not None:
return f"toTimeZone({result}, {timezone!r})"
return result


@translate_val.register(ops.Between)
Expand Down Expand Up @@ -286,6 +292,13 @@ def _string_find(op, **kw):
return f"locate({arg}, {substr}) - 1"


@translate_val.register(ops.RegexSearch)
def _regex_search(op, **kw):
arg = translate_val(op.arg, **kw)
pattern = translate_val(op.pattern, **kw)
return f"multiMatchAny({arg}, [{pattern}])"


@translate_val.register(ops.RegexExtract)
def _regex_extract(op, **kw):
arg = translate_val(op.arg, **kw)
Expand All @@ -306,7 +319,7 @@ def _regex_extract(op, **kw):
# return the Nth match group
# else
# return null
does_match = f"match({arg}, {pattern})"
does_match = f"multiMatchAny({arg}, [{pattern}])"
idx = f"CAST(nullIf({index}, 0) AS Nullable(Int64))"
then = f"if({idx} IS NULL, {arg}, {extracted}[{idx}])"
return f"if({does_match}, {then}, NULL)"
Expand Down Expand Up @@ -573,11 +586,11 @@ def _date_from_ymd(op, **kw):
m = translate_val(op.month, **kw)
d = translate_val(op.day, **kw)
return (
f"toDate(concat("
"toDate(concat("
f"toString({y}), '-', "
f"leftPad(toString({m}), 2, '0'), '-', "
f"leftPad(toString({d}), 2, '0')"
f"))"
"))"
)


Expand All @@ -589,20 +602,20 @@ def _timestamp_from_ymdhms(op, **kw):
h = translate_val(op.hours, **kw)
min = translate_val(op.minutes, **kw)
s = translate_val(op.seconds, **kw)
timezone_arg = ''
if timezone := op.output_dtype.timezone:
timezone_arg = f', {timezone}'

return (
f"toDateTime("
to_datetime = (
"toDateTime("
f"concat(toString({y}), '-', "
f"leftPad(toString({m}), 2, '0'), '-', "
f"leftPad(toString({d}), 2, '0'), ' ', "
f"leftPad(toString({h}), 2, '0'), ':', "
f"leftPad(toString({min}), 2, '0'), ':', "
f"leftPad(toString({s}), 2, '0')"
f"), {timezone_arg})"
"))"
)
if timezone := op.output_dtype.timezone:
return f"toTimeZone({to_datetime}, {timezone})"
return to_datetime


@translate_val.register(ops.ExistsSubquery)
Expand Down Expand Up @@ -1024,7 +1037,6 @@ def formatter(op, **kw):
ops.LStrip: "trimLeft",
ops.RStrip: "trimRight",
ops.Strip: "trimBoth",
ops.RegexSearch: "match",
ops.RegexReplace: "replaceRegexpAll",
ops.StringAscii: "ascii",
# Temporal operations
Expand Down Expand Up @@ -1065,6 +1077,12 @@ def formatter(op, **kw):
ops.BitwiseLeftShift: "bitShiftLeft",
ops.BitwiseRightShift: "bitShiftRight",
ops.BitwiseNot: "bitNot",
ops.ArrayDistinct: "arrayDistinct",
ops.ArraySort: "arraySort",
ops.ArrayContains: "has",
ops.FirstValue: "first_value",
ops.LastValue: "last_value",
ops.NTile: "ntile",
}


Expand Down Expand Up @@ -1257,11 +1275,6 @@ def formatter(op, **kw):
shift_like(ops.Lead, "leadInFrame")


@translate_val.register(ops.NTile)
def _ntile(op, **kw):
return f'ntile({translate_val(op.buckets, **kw)})'


@translate_val.register(ops.RowNumber)
def _row_number(_, **kw):
return "row_number()"
Expand All @@ -1277,16 +1290,6 @@ def _rank(_, **kw):
return "rank()"


@translate_val.register(ops.FirstValue)
def _first_value(op, **kw):
return f"first_value({translate_val(op.arg, **kw)})"


@translate_val.register(ops.LastValue)
def _last_value(op, **kw):
return f"last_value({translate_val(op.arg, **kw)})"


@translate_val.register(ops.ExtractProtocol)
def _extract_protocol(op, **kw):
arg = translate_val(op.arg, **kw)
Expand Down Expand Up @@ -1357,3 +1360,20 @@ def _array_filter(op, **kw):
arg = translate_val(op.arg, **kw)
result = translate_val(op.result, **kw)
return f"arrayFilter(({op.parameter}) -> {result}, {arg})"


@translate_val.register(ops.ArrayPosition)
def _array_position(op, **kw):
arg = translate_val(op.arg, **kw)
el = translate_val(op.other, **kw)
return f"indexOf({arg}, {el}) - 1"


@translate_val.register(ops.ArrayRemove)
def _array_remove(op, **kw):
return translate_val(ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)), **kw)


@translate_val.register(ops.ArrayUnion)
def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw)
Original file line number Diff line number Diff line change
@@ -1 +1 @@
if(match(CAST(string_col AS String), '[\\d]+'), if(CAST(nullIf(3, 0) AS Nullable(Int64)) IS NULL, CAST(string_col AS String), CAST(extractAll(CAST(string_col AS String), '[\\d]+') AS Array(Nullable(String)))[CAST(nullIf(3, 0) AS Nullable(Int64))]), NULL)
if(multiMatchAny(CAST(string_col AS String), ['[\\d]+']), if(CAST(nullIf(3, 0) AS Nullable(Int64)) IS NULL, CAST(string_col AS String), CAST(extractAll(CAST(string_col AS String), '[\\d]+') AS Array(Nullable(String)))[CAST(nullIf(3, 0) AS Nullable(Int64))]), NULL)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ FROM (
SELECT
t0.string_col,
COUNT(*) AS count
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
GROUP BY
1
) AS t1
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
*
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
WHERE
t0.string_col IN ('foo', 'bar')
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
*
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
WHERE
NOT t0.string_col IN ('foo', 'bar')
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(CASE WHEN isNull(t0.string_col) THEN 1 ELSE 0 END) AS "Sum(Where(IsNull(string_col), 1, 0))"
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM functional_alltypes AS t0
INNER JOIN functional_alltypes AS t1
FROM ibis_testing.functional_alltypes AS t0
INNER JOIN ibis_testing.functional_alltypes AS t1
ON t0.id = t1.id
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
*
FROM functional_alltypes
FROM ibis_testing.functional_alltypes
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
*
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
ANY JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
ANY JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.awardID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
LEFT ANY JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
LEFT ANY JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.awardID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
INNER JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
INNER JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.awardID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
LEFT OUTER JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
LEFT OUTER JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.awardID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
ANY JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
ANY JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.playerID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
LEFT ANY JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
LEFT ANY JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.playerID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
INNER JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
INNER JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.playerID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
t0.*
FROM batting AS t0
LEFT OUTER JOIN awards_players AS t1
FROM ibis_testing.batting AS t0
LEFT OUTER JOIN ibis_testing.awards_players AS t1
ON t0.playerID = t1.playerID
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
SUM(t0.float_col) AS "Sum(float_col)"
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
WHERE
t0.int_col > 0
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ FROM (
SELECT
t0.string_col,
SUM(t0.float_col) AS total
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
WHERE
t0.int_col > 0
GROUP BY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ SELECT
toHour(t0.timestamp_col) AS hour,
toMinute(t0.timestamp_col) AS minute,
toSecond(t0.timestamp_col) AS second
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SELECT
*
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
WHERE
t0.float_col > 0 AND t0.int_col < (
t0.float_col * 2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
*
FROM functional_alltypes AS t0
FROM ibis_testing.functional_alltypes AS t0
WHERE
t0.int_col > 0 AND t0.float_col BETWEEN 0 AND 1
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SELECT
t0.uuid,
minIf(t0.ts, t0.search_level = 1) AS min_date
minIf(t0.ts, search_level = 1) AS min_date
FROM t AS t0
GROUP BY
1
83 changes: 83 additions & 0 deletions ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pandas.testing as tm
import pytest
from clickhouse_driver.dbapi import OperationalError
from pytest import param

import ibis
import ibis.expr.datatypes as dt
Expand All @@ -13,6 +15,8 @@
CLICKHOUSE_USER,
IBIS_TEST_CLICKHOUSE_DB,
)
from ibis.common.exceptions import IbisError
from ibis.util import gen_name

pytest.importorskip("clickhouse_driver")

Expand Down Expand Up @@ -204,3 +208,82 @@ def test_list_tables_empty_database(con, worker_id):
assert not con.list_tables(database=dbname)
finally:
con.raw_sql(f"DROP DATABASE IF EXISTS {dbname}")


@pytest.mark.parametrize(
"temp",
[
param(
True,
marks=pytest.mark.xfail(
reason="Ibis is likely making incorrect assumptions about object lifetime and cursors",
raises=IbisError,
),
),
False,
],
ids=["temp", "no_temp"],
)
def test_create_table_no_data(con, temp):
name = gen_name("clickhouse_create_table_no_data")
schema = ibis.schema(dict(a="!int", b="string"))
t = con.create_table(
name, schema=schema, temp=temp, engine="Memory", database="tmptables"
)
try:
assert t.execute().empty
finally:
con.drop_table(name, force=True, database="tmptables")
assert name not in con.list_tables(database="tmptables")


@pytest.mark.parametrize(
"data",
[
{"a": [1, 2, 3], "b": [None, "b", "c"]},
pd.DataFrame({"a": [1, 2, 3], "b": [None, "b", "c"]}),
],
ids=["dict", "dataframe"],
)
@pytest.mark.parametrize(
"engine",
["File(Native)", "File(Parquet)", "Memory"],
ids=["native", "mem", "parquet"],
)
def test_create_table_data(con, data, engine):
name = gen_name("clickhouse_create_table_data")
schema = ibis.schema(dict(a="!int", b="string"))
t = con.create_table(
name, obj=data, schema=schema, engine=engine, database="tmptables"
)
try:
assert len(t.execute()) == 3
finally:
con.drop_table(name, force=True, database="tmptables")
assert name not in con.list_tables(database="tmptables")


@pytest.mark.parametrize(
"engine",
[
"File(Native)",
param(
"File(Parquet)",
marks=pytest.mark.xfail(
reason="Parquet file size is 0 bytes", raises=OperationalError
),
),
"Memory",
],
ids=["native", "mem", "parquet"],
)
def test_truncate_table(con, engine):
name = gen_name("clickhouse_create_table_data")
t = con.create_table(name, obj={"a": [1]}, engine=engine, database="tmptables")
try:
assert len(t.execute()) == 1
con.truncate_table(name, database="tmptables")
assert len(t.execute()) == 0
finally:
con.drop_table(name, force=True, database="tmptables")
assert name not in con.list_tables(database="tmptables")
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_negate(con, alltypes, translate, column, operator):
)
def test_negate_non_boolean(alltypes, field, df):
t = alltypes.limit(10)
expr = t.projection([(-t[field]).name(field)])
expr = t.select((-t[field]).name(field))
result = expr.execute()[field]
expected = -df.head(10)[field]
tm.assert_series_equal(result, expected)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_filter_predicates(diamonds):

expr = diamonds
for pred in predicates:
expr = expr[pred(expr)].projection([expr])
expr = expr[pred(expr)].select(expr)

expr.execute()

Expand Down
114 changes: 74 additions & 40 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import importlib
import importlib.metadata
import itertools
import os
import platform
import sys
from functools import lru_cache
from pathlib import Path
from typing import Any, TextIO

import _pytest
import numpy as np
import pandas as pd
import pytest
import sqlalchemy as sa
Expand All @@ -21,6 +23,14 @@
from ibis import util
from ibis.backends.base import _get_backend_names

SANDBOXED = (
any(key.startswith("NIX_") for key in os.environ)
and os.environ.get("IN_NIX_SHELL") != "impure"
)
LINUX = platform.system() == "Linux"
MACOS = platform.system() == "Darwin"
WINDOWS = platform.system() == "Windows"

TEST_TABLES = {
"functional_alltypes": ibis.schema(
{
Expand Down Expand Up @@ -101,13 +111,16 @@
# by improving all tests file by file. All files that have already been improved are
# added to this list to prevent regression.
FIlES_WITH_STRICT_EXCEPTION_CHECK = [
'ibis/backends/tests/test_api.py',
'ibis/backends/tests/test_array.py',
'ibis/backends/tests/test_aggregation.py',
'ibis/backends/tests/test_binary.py',
'ibis/backends/tests/test_numeric.py',
'ibis/backends/tests/test_column.py',
'ibis/backends/tests/test_string.py',
'ibis/backends/tests/test_temporal.py',
'ibis/backends/tests/test_uuid.py',
'ibis/backends/tests/test_window.py',
]


Expand Down Expand Up @@ -273,6 +286,11 @@ def pytest_collection_modifyitems(session, config, items):
all_backends = _get_backend_names()
additional_markers = []

try:
import pyspark
except ImportError:
pyspark = None

for item in items:
parts = item.path.parts
backend = _get_backend_from_parts(parts)
Expand All @@ -283,34 +301,40 @@ def pytest_collection_modifyitems(session, config, items):
itertools.chain(
*(item.iter_markers(name=name) for name in all_backends),
item.iter_markers(name="backend"),
item.iter_markers(name="backend_nodata"),
)
):
# anything else is a "core" test and is run by default
if not any(item.iter_markers(name="benchmark")):
item.add_marker(pytest.mark.core)

for name in ("duckdb", "sqlite"):
# build a list of markers so we're don't invalidate the item's
# marker iterator
for _ in item.iter_markers(name=name):
additional_markers.append((item, pytest.mark.xdist_group(name=name)))

for _ in item.iter_markers(name="pyspark"):
additional_markers.append(
(
item,
pytest.mark.xfail(
(
sys.version_info >= (3, 11)
and not isinstance(item, pytest.DoctestItem)
),
reason="PySpark doesn't support Python 3.11",
),
if not isinstance(item, pytest.DoctestItem):
additional_markers.append(
(
item,
[
pytest.mark.xfail(
sys.version_info >= (3, 11),
reason="PySpark doesn't support Python 3.11",
),
pytest.mark.xfail(
vparse(pd.__version__) >= vparse("2"),
reason="PySpark doesn't support pandas>=2",
),
pytest.mark.skipif(
pyspark is not None
and vparse(pyspark.__version__) < vparse("3.3.3")
and vparse(np.__version__) >= vparse("1.24"),
reason="PySpark doesn't support numpy >= 1.24",
),
],
)
)
)

for item, marker in additional_markers:
item.add_marker(marker)
for item, markers in additional_markers:
for marker in markers:
item.add_marker(marker)


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -342,7 +366,7 @@ def pytest_runtest_call(item):
backend = [
backend.name()
for key, backend in item.funcargs.items()
if key.endswith("backend")
if key.endswith(("backend", "backend_nodata"))
]
if len(backend) > 1:
raise ValueError(
Expand All @@ -367,7 +391,11 @@ def pytest_runtest_call(item):
funcargs = item.funcargs
con = funcargs.get(
"con",
getattr(funcargs.get("backend"), "connection", None),
getattr(
funcargs.get("backend", funcargs.get("backend_nodata")),
"connection",
None,
),
)

if con is None:
Expand Down Expand Up @@ -500,6 +528,20 @@ def con(backend):
return backend.connection


@pytest.fixture(params=_get_backends_to_test(), scope='session')
def backend_nodata(request, data_directory):
"""Return an instance of BackendTest, loaded with data."""

cls = _get_backend_conf(request.param)
return cls(data_directory)


@pytest.fixture(scope="session")
def con_nodata(backend_nodata):
"""Instance of a backend client."""
return backend_nodata.connection


def _setup_backend(
request, data_directory, script_directory, tmp_path_factory, worker_id
):
Expand Down Expand Up @@ -661,11 +703,9 @@ def alchemy_temp_table(alchemy_con) -> str:
Random table name for a temporary usage.
"""
name = _random_identifier('table')
try:
yield name
finally:
with contextlib.suppress(NotImplementedError):
alchemy_con.drop_table(name, force=True)
yield name
with contextlib.suppress(NotImplementedError):
alchemy_con.drop_table(name, force=True)


@pytest.fixture
Expand All @@ -682,11 +722,9 @@ def temp_table(con) -> str:
Random table name for a temporary usage.
"""
name = _random_identifier('table')
try:
yield name
finally:
with contextlib.suppress(NotImplementedError):
con.drop_table(name, force=True)
yield name
with contextlib.suppress(NotImplementedError):
con.drop_table(name, force=True)


@pytest.fixture
Expand All @@ -703,11 +741,9 @@ def temp_view(ddl_con) -> str:
Random view name for a temporary usage.
"""
name = _random_identifier('view')
try:
yield name
finally:
with contextlib.suppress(NotImplementedError):
ddl_con.drop_view(name, force=True)
yield name
with contextlib.suppress(NotImplementedError):
ddl_con.drop_view(name, force=True)


@pytest.fixture(scope='session')
Expand All @@ -734,10 +770,8 @@ def alternate_current_database(ddl_con, ddl_backend) -> str:
ddl_con.create_database(name)
except NotImplementedError:
pytest.skip(f"{ddl_backend.name()} doesn't have create_database method.")
try:
yield name
finally:
ddl_con.drop_database(name, force=True)
yield name
ddl_con.drop_database(name, force=True)


@pytest.fixture
Expand Down
5 changes: 1 addition & 4 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ibis.backends.pandas.core import _apply_schema

# Make sure that the pandas backend options have been loaded
ibis.pandas
ibis.pandas # noqa: B018


class Backend(BasePandasBackend):
Expand Down Expand Up @@ -129,6 +129,3 @@ def _convert_object(cls, obj: dd.DataFrame) -> dd.DataFrame:

def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).persist())

def _clean_up_cached_table(self, op):
del self.dictionary[op.name]
33 changes: 23 additions & 10 deletions ibis/backends/dask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@

import dask.dataframe as dd
import numpy as np
import pandas as pd
from dateutil.parser import parse as date_parse
from pandas.api.types import DatetimeTZDtype

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis.backends.base import Database
from ibis.backends.pandas.client import (
PANDAS_DATE_TYPES,
PANDAS_STRING_TYPES,
ibis_dtype_to_pandas,
ibis_schema_to_pandas,
)
from ibis.backends.pandas.client import ibis_dtype_to_pandas, ibis_schema_to_pandas


@sch.schema.register(dd.Series)
Expand Down Expand Up @@ -54,15 +51,31 @@ def infer_dask_schema(df, schema=None):


@sch.convert.register(DatetimeTZDtype, dt.Timestamp, dd.Series)
def convert_datetimetz_to_timestamp(in_dtype, out_dtype, column):
def convert_datetimetz_to_timestamp(_, out_dtype, column):
output_timezone = out_dtype.timezone
if output_timezone is not None:
return column.dt.tz_convert(output_timezone)
return column.astype(out_dtype.to_dask())
else:
return column.dt.tz_localize(None)


DASK_STRING_TYPES = PANDAS_STRING_TYPES
DASK_DATE_TYPES = PANDAS_DATE_TYPES
@sch.convert.register(np.dtype, dt.Timestamp, dd.Series)
def convert_any_to_timestamp(_, out_dtype, column):
if isinstance(dtype := out_dtype.to_dask(), DatetimeTZDtype):
column = dd.to_datetime(column)
timezone = out_dtype.timezone
if getattr(column.dtype, "tz", None) is not None:
return column.dt.tz_convert(timezone)
else:
return column.dt.tz_localize(timezone)
else:
try:
return column.astype(dtype)
except pd.errors.OutOfBoundsDatetime:
try:
return column.map(date_parse)
except TypeError:
return column


@sch.convert.register(np.dtype, dt.Interval, dd.Series)
Expand Down
23 changes: 13 additions & 10 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def execute_cast_scalar_timestamp(op, data, type, **kwargs):

def cast_series_to_timestamp(data, tz):
if pd.api.types.is_string_dtype(data):
timestamps = to_datetime(data, infer_datetime_format=True)
timestamps = to_datetime(data)
else:
timestamps = to_datetime(data, unit="s")
if getattr(timestamps.dtype, "tz", None) is not None:
Expand All @@ -290,10 +290,17 @@ def execute_cast_series_timestamp(op, data, type, **kwargs):
tz = type.timezone
dtype = 'M8[ns]' if tz is None else DatetimeTZDtype('ns', tz)

if from_type.is_timestamp() or from_type.is_date():
return data.astype(dtype)

if from_type.is_string() or from_type.is_integer():
if from_type.is_timestamp():
from_tz = from_type.timezone
if tz is None and from_tz is None:
return data
elif tz is None or from_tz is None:
return data.dt.tz_localize(tz)
elif tz is not None and from_tz is not None:
return data.dt.tz_convert(tz)
elif from_type.is_date():
return data if tz is None else data.dt.tz_localize(tz)
elif from_type.is_string() or from_type.is_integer():
return data.map_partitions(
cast_series_to_timestamp,
tz,
Expand All @@ -319,11 +326,7 @@ def execute_cast_series_date(op, data, type, **kwargs):

if from_type.equals(dt.string):
# TODO - this is broken
datetimes = data.map_partitions(
to_datetime,
infer_datetime_format=True,
meta=(data.name, 'datetime64[ns]'),
)
datetimes = data.map_partitions(to_datetime, meta=(data.name, 'datetime64[ns]'))

# TODO - we are getting rid of the index here
return datetimes.dt.normalize()
Expand Down
11 changes: 7 additions & 4 deletions ibis/backends/dask/execution/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,13 @@
(dd.Series, str, datetime.time),
)
def execute_between_time(op, data, lower, upper, **kwargs):
# TODO - Can this be done better?
indexer = (
(data.dt.time.astype(str) >= lower) & (data.dt.time.astype(str) <= upper)
).to_dask_array(True)
if getattr(data.dtype, "tz", None) is not None:
localized = data.dt.tz_convert("UTC").dt.tz_localize(None)
else:
localized = data

time = localized.dt.time.astype(str)
indexer = ((time >= lower) & (time <= upper)).to_dask_array(True)

result = da.zeros(len(data), dtype=np.bool_)
result[indexer] = True
Expand Down
18 changes: 10 additions & 8 deletions ibis/backends/dask/tests/execution/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import pytest
from pytest import param

import ibis

Expand All @@ -11,12 +12,10 @@


def test_array_length(t):
expr = t.projection(
[
t.array_of_float64.length().name('array_of_float64_length'),
t.array_of_int64.length().name('array_of_int64_length'),
t.array_of_strings.length().name('array_of_strings_length'),
]
expr = t.select(
t.array_of_float64.length().name('array_of_float64_length'),
t.array_of_int64.length().name('array_of_int64_length'),
t.array_of_strings.length().name('array_of_strings_length'),
)
result = expr.compile()
expected = dd.from_pandas(
Expand Down Expand Up @@ -134,7 +133,10 @@ def test_array_slice_scalar(client, start, stop):
assert np.array_equal(result, expected)


@pytest.mark.parametrize('index', [1, 3, 4, 11, -11])
@pytest.mark.parametrize(
'index',
[param(1, marks=pytest.mark.xfail_version(dask=["pandas>=2"])), 3, 4, 11, -11],
)
def test_array_index(t, df, index):
expr = t[t.array_of_float64[index].name('indexed')]
result = expr.compile()
Expand Down Expand Up @@ -169,7 +171,7 @@ def test_array_index_scalar(client, index):
@pytest.mark.parametrize('n', [1, 3, 4, 7, -2]) # negative returns empty list
@pytest.mark.parametrize('mul', [lambda x, n: x * n, lambda x, n: n * x])
def test_array_repeat(t, df, n, mul):
expr = t.projection([mul(t.array_of_strings, n).name('repeated')])
expr = t.select(repeated=mul(t.array_of_strings, n))
result = expr.execute()
expected = pd.DataFrame({'repeated': df.array_of_strings * n})
tm.assert_frame_equal(result, expected)
Expand Down
10 changes: 6 additions & 4 deletions ibis/backends/dask/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,10 @@ def test_value_counts(t, df):
expected = (
df.compute()
.dup_strings.value_counts()
.reset_index()
.rename(columns={'dup_strings': 'dup_strings_count'})
.rename(columns={'index': 'dup_strings'})
.sort_values(['dup_strings'])
.rename("dup_strings")
.reset_index(name="dup_strings_count")
.rename(columns={"index": "dup_strings"})
.sort_values(["dup_strings"])
.reset_index(drop=True)
)
tm.assert_frame_equal(
Expand Down Expand Up @@ -861,6 +861,7 @@ def test_summary_numeric(batting, batting_df):
assert dict(result.iloc[0]) == expected


@pytest.mark.xfail_version(dask=["pandas>=2"])
def test_summary_numeric_group_by(batting, batting_df):
with pytest.warns(FutureWarning, match="is deprecated"):
expr = batting.group_by('teamID').G.summary()
Expand Down Expand Up @@ -900,6 +901,7 @@ def test_summary_non_numeric(batting, batting_df):
assert dict(result.iloc[0]) == expected


@pytest.mark.xfail_version(dask=["pandas>=2"])
def test_summary_non_numeric_group_by(batting, batting_df):
with pytest.warns(FutureWarning, match="is deprecated"):
expr = batting.group_by('teamID').playerID.summary()
Expand Down
23 changes: 12 additions & 11 deletions ibis/backends/dask/tests/execution/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ def test_cast_datetime_strings_to_date(t, df, column):
result = expr.compile()
df_computed = df.compute()
expected = dd.from_pandas(
pd.to_datetime(
df_computed[column],
infer_datetime_format=True,
).dt.normalize(),
pd.to_datetime(df_computed[column]).dt.normalize(),
npartitions=1,
)
tm.assert_series_equal(
Expand All @@ -88,10 +85,7 @@ def test_cast_datetime_strings_to_timestamp(t, df, column):
expr = t[column].cast('timestamp')
result = expr.compile()
df_computed = df.compute()
expected = dd.from_pandas(
pd.to_datetime(df_computed[column], infer_datetime_format=True),
npartitions=1,
)
expected = dd.from_pandas(pd.to_datetime(df_computed[column]), npartitions=1)
if getattr(expected.dtype, 'tz', None) is not None:
expected = expected.dt.tz_convert(None)
tm.assert_series_equal(
Expand Down Expand Up @@ -158,10 +152,17 @@ def test_times_ops(t, df):


@pytest.mark.parametrize(
('tz', 'rconstruct'),
[('US/Eastern', np.zeros), ('UTC', np.ones), (None, np.ones)],
('tz', 'rconstruct', 'column'),
[
('US/Eastern', np.ones, 'plain_datetimes_utc'),
('US/Eastern', np.zeros, 'plain_datetimes_naive'),
('UTC', np.ones, 'plain_datetimes_utc'),
('UTC', np.ones, 'plain_datetimes_naive'),
(None, np.ones, 'plain_datetimes_utc'),
(None, np.ones, 'plain_datetimes_naive'),
],
ids=lambda x: str(getattr(x, "__name__", x)).lower().replace("/", "_"),
)
@pytest.mark.parametrize('column', ['plain_datetimes_utc', 'plain_datetimes_naive'])
def test_times_ops_with_tz(t, df, tz, rconstruct, column):
expected = dd.from_array(
rconstruct(len(df), dtype=bool),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_infer_exhaustive_dataframe(npartitions):
def test_apply_to_schema_with_timezone(npartitions):
data = {'time': pd.date_range('2018-01-01', '2018-01-02', freq='H')}
df = dd.from_pandas(pd.DataFrame(data), npartitions=npartitions)
expected = df.assign(time=df.time.astype('datetime64[ns, EST]'))
expected = df.assign(time=df.time.dt.tz_localize("EST"))
desired_schema = ibis.schema([('time', 'timestamp("EST")')])
result = desired_schema.apply_to(df.copy())
tm.assert_frame_equal(result.compute(), expected.compute())
4 changes: 2 additions & 2 deletions ibis/backends/dask/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def enable():
"""Enable tracing."""
if options.dask is None:
# dask options haven't been registered yet - force module __getattr__
ibis.dask
ibis.dask # noqa: B018

options.dask.enable_trace = True
logging.getLogger('ibis.dask.trace').setLevel(logging.DEBUG)
Expand Down Expand Up @@ -119,7 +119,7 @@ def traced_func(*args, **kwargs):
# Similar to the pandas backend, it is possible to call this function
# without having initialized the configuration option. This can happen
# when tests are distributed across multiple processes, for example.
ibis.dask
ibis.dask # noqa: B018

if not options.dask.enable_trace:
return func(*args, **kwargs)
Expand Down
12 changes: 3 additions & 9 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import itertools
import re
from functools import lru_cache
from pathlib import Path
Expand All @@ -15,7 +14,7 @@
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend
from ibis.backends.datafusion.compiler import translate
from ibis.util import normalize_filename
from ibis.util import gen_name, normalize_filename

try:
from datafusion import ExecutionContext as SessionContext
Expand All @@ -24,11 +23,6 @@

import datafusion

# counters for in-memory, parquet, and csv reads
# used if no table name is specified
pa_n = itertools.count(0)
csv_n = itertools.count(0)


class Backend(BaseBackend):
name = 'datafusion'
Expand Down Expand Up @@ -169,7 +163,7 @@ def read_csv(
The just-registered table
"""
path = normalize_filename(path)
table_name = table_name or f"ibis_read_csv_{next(csv_n)}"
table_name = table_name or gen_name("read_csv")
# Our other backends support overwriting views / tables when reregistering
self._context.deregister_table(table_name)
self._context.register_csv(table_name, path, **kwargs)
Expand All @@ -196,7 +190,7 @@ def read_parquet(
The just-registered table
"""
path = normalize_filename(path)
table_name = table_name or f"ibis_read_parquet_{next(pa_n)}"
table_name = table_name or gen_name("read_parquet")
# Our other backends support overwriting views / tables when reregistering
self._context.deregister_table(table_name)
self._context.register_parquet(table_name, path, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/druid/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _join(t, op):
ops.Log10: fixed_arity(sa.func.log10, 1),
ops.Sign: _sign,
ops.StringJoin: _join,
ops.RegexSearch: fixed_arity(sa.func.regexp_like, 2),
}
)

Expand Down
239 changes: 146 additions & 93 deletions ibis/backends/duckdb/__init__.py

Large diffs are not rendered by default.

29 changes: 1 addition & 28 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@
from sqlalchemy.ext.compiler import compiles

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
AlchemyExprTranslator,
to_sqla_type,
)
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.duckdb.registry import operation_registry


Expand Down Expand Up @@ -45,28 +40,6 @@ def compile_array(element, compiler, **kw):
return f"{compiler.process(element.value_type, **kw)}[]"


try:
import duckdb_engine
except ImportError:
pass
else:

@dt.dtype.register(duckdb_engine.Dialect, sat.UInt64)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt32)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt16)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt8)
def dtype_uint(_, satype, nullable=True):
return getattr(dt, satype.__class__.__name__)(nullable=nullable)

@dt.dtype.register(duckdb_engine.Dialect, sat.ArrayType)
def _(dialect, satype, nullable=True):
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)

@to_sqla_type.register(duckdb_engine.Dialect, dt.Array)
def _(dialect, itype):
return sat.ArrayType(to_sqla_type(dialect, itype.value_type))


rewrites = DuckDBSQLExprTranslator.rewrites


Expand Down
49 changes: 42 additions & 7 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import parsy
import sqlalchemy as sa
import toolz
from duckdb_engine import Dialect as DuckDBDialect
from sqlalchemy.dialects import postgresql

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy import to_sqla_type
from ibis.common.parsing import (
Expand Down Expand Up @@ -92,11 +92,46 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
return ty.parse(text)


@to_sqla_type.register(DuckDBDialect, dt.UUID)
def sa_duckdb_uuid(*_):
return postgresql.UUID(as_uuid=True)
try:
from duckdb_engine import Dialect as DuckDBDialect
except ImportError:
pass
else:

@dt.dtype.register(DuckDBDialect, sat.UInt64)
@dt.dtype.register(DuckDBDialect, sat.UInt32)
@dt.dtype.register(DuckDBDialect, sat.UInt16)
@dt.dtype.register(DuckDBDialect, sat.UInt8)
def dtype_uint(_, satype, nullable=True):
return getattr(dt, satype.__class__.__name__)(nullable=nullable)

@to_sqla_type.register(DuckDBDialect, (dt.MACADDR, dt.INET))
def sa_duckdb_macaddr(*_):
return sa.TEXT()
@dt.dtype.register(DuckDBDialect, sat.ArrayType)
def _(dialect, satype, nullable=True):
return dt.Array(dt.dtype(dialect, satype.value_type), nullable=nullable)

@dt.dtype.register(DuckDBDialect, sat.MapType)
def _(dialect, satype, nullable=True):
return dt.Map(
dt.dtype(dialect, satype.key_type),
dt.dtype(dialect, satype.value_type),
nullable=nullable,
)

@to_sqla_type.register(DuckDBDialect, dt.UUID)
def sa_duckdb_uuid(*_):
return postgresql.UUID()

@to_sqla_type.register(DuckDBDialect, (dt.MACADDR, dt.INET))
def sa_duckdb_macaddr(*_):
return sa.TEXT()

@to_sqla_type.register(DuckDBDialect, dt.Map)
def sa_duckdb_map(dialect, itype):
return sat.MapType(
to_sqla_type(dialect, itype.key_type),
to_sqla_type(dialect, itype.value_type),
)

@to_sqla_type.register(DuckDBDialect, dt.Array)
def _(dialect, itype):
return sat.ArrayType(to_sqla_type(dialect, itype.value_type))
32 changes: 0 additions & 32 deletions ibis/backends/duckdb/pyarrow.py

This file was deleted.

92 changes: 71 additions & 21 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction
from toolz.curried import flip

import ibis.expr.operations as ops
from ibis.backends.base.sql import alchemy
Expand Down Expand Up @@ -53,23 +54,24 @@ def _round(t, op):
}


def _generic_log(arg, base):
return sa.func.ln(arg) / sa.func.ln(base)
def _generic_log(arg, base, *, type_):
return sa.func.ln(arg, type_=type_) / sa.func.ln(base, type_=type_)


def _log(t, op):
arg, base = op.args
sqla_type = t.get_sqla_type(op.output_dtype)
sa_arg = t.translate(arg)
if base is not None:
sa_base = t.translate(base)
try:
base_value = sa_base.value
except AttributeError:
return _generic_log(sa_arg, sa_base)
return _generic_log(sa_arg, sa_base, type_=sqla_type)
else:
func = _LOG_BASE_FUNCS.get(base_value, _generic_log)
return func(sa_arg)
return sa.func.ln(sa_arg)
return func(sa_arg, type_=sqla_type)
return sa.func.ln(sa_arg, type_=sqla_type)


def _timestamp_from_unix(t, op):
Expand Down Expand Up @@ -135,17 +137,17 @@ def _literal(t, op):
elif dtype.is_string():
return sa.literal(value)
elif dtype.is_map():
raise NotImplementedError(
f"Ibis dtype `{dtype}` with mapping type "
f"`{type(value).__name__}` isn't yet supported with the duckdb "
"backend"
return sa.func.map(
sa.func.list_value(*value.keys()), sa.func.list_value(*value.values())
)
else:
return sa.cast(sa.literal(value), sqla_type)


if_ = getattr(sa.func, "if")


def _neg_idx_to_pos(array, idx):
if_ = getattr(sa.func, "if")
arg_length = sa.func.array_length(array)
return if_(idx < 0, arg_length + sa.func.greatest(idx, -arg_length), idx)

Expand Down Expand Up @@ -258,6 +260,32 @@ def _array_filter(t, op):
)


def _map_keys(t, op):
m = t.translate(op.arg)
return sa.cast(
sa.func.json_keys(sa.func.to_json(m)), t.get_sqla_type(op.output_dtype)
)


def _map_values(t, op):
m_json = sa.func.to_json(t.translate(op.arg))
return sa.cast(
sa.func.json_extract_string(m_json, sa.func.json_keys(m_json)),
t.get_sqla_type(op.output_dtype),
)


def _map_merge(t, op):
left = sa.func.to_json(t.translate(op.left))
right = sa.func.to_json(t.translate(op.right))
pairs = sa.func.json_merge_patch(left, right)
keys = sa.func.json_keys(pairs)
return sa.cast(
sa.func.map(keys, sa.func.json_extract_string(pairs, keys)),
t.get_sqla_type(op.output_dtype),
)


operation_registry.update(
{
ops.ArrayColumn: (
Expand All @@ -284,13 +312,26 @@ def _array_filter(t, op):
ops.ArrayIndex: _array_index(
index_converter=_neg_idx_to_pos, func=sa.func.list_extract
),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.ArrayContains: fixed_arity(sa.func.list_has, 2),
ops.ArrayPosition: fixed_arity(
lambda lst, el: sa.func.list_indexof(lst, el) - 1, 2
),
ops.ArrayDistinct: fixed_arity(sa.func.list_distinct, 1),
ops.ArraySort: fixed_arity(sa.func.list_sort, 1),
ops.ArrayRemove: lambda t, op: _array_filter(
t, ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other))
),
ops.ArrayUnion: fixed_arity(
lambda left, right: sa.func.list_distinct(sa.func.list_cat(left, right)), 2
),
ops.DayOfWeekName: unary(sa.func.dayname),
ops.Literal: _literal,
ops.Log2: unary(sa.func.log2),
ops.Ln: unary(sa.func.ln),
ops.Log: _log,
ops.IsNan: unary(sa.func.isnan),
# TODO: map operations, but DuckDB's maps are multimaps
ops.Modulus: fixed_arity(operator.mod, 2),
ops.Round: _round,
ops.StructField: (
Expand All @@ -312,6 +353,7 @@ def _array_filter(t, op):
ops.RegexReplace: fixed_arity(
lambda *args: sa.func.regexp_replace(*args, sa.text("'g'")), 3
),
ops.RegexSearch: fixed_arity(lambda x, y: x.op("SIMILAR TO")(y), 2),
ops.StringContains: fixed_arity(sa.func.contains, 2),
ops.ApproxMedian: reduction(
# without inline text, duckdb fails with
Expand Down Expand Up @@ -343,9 +385,22 @@ def _array_filter(t, op):
ops.SimpleCase: _simple_case,
ops.StartsWith: fixed_arity(sa.func.prefix, 2),
ops.EndsWith: fixed_arity(sa.func.suffix, 2),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.Argument: lambda _, op: sa.literal_column(op.name),
ops.Unnest: unary(sa.func.unnest),
ops.MapGet: fixed_arity(
lambda arg, key, default: sa.func.coalesce(
sa.func.list_extract(sa.func.element_at(arg, key), 1), default
),
3,
),
ops.Map: fixed_arity(sa.func.map, 2),
ops.MapContains: fixed_arity(
lambda arg, key: sa.func.array_length(sa.func.element_at(arg, key)) != 0, 2
),
ops.MapLength: unary(sa.func.cardinality),
ops.MapKeys: _map_keys,
ops.MapValues: _map_values,
ops.MapMerge: _map_merge,
}
)

Expand All @@ -358,14 +413,9 @@ def _array_filter(t, op):
ops.NTile,
# ibis.expr.operations.strings
ops.Translate,
# ibis.expr.operations.maps
ops.MapGet,
ops.MapContains,
ops.MapKeys,
ops.MapValues,
ops.MapMerge,
ops.MapLength,
ops.Map,
# ibis.expr.operations.json
ops.ToJSONMap,
ops.ToJSONArray,
}

operation_registry = {
Expand Down
61 changes: 33 additions & 28 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,60 @@
from __future__ import annotations

import functools
from pathlib import Path
from typing import TYPE_CHECKING, Any

import pytest

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.conftest import SANDBOXED, TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero

if TYPE_CHECKING:
from ibis.backends.base import BaseBackend


class TestConf(BackendTest, RoundAwayFromZero):
def __init__(self, data_directory: Path) -> None:
self.connection = self.connect(data_directory)
supports_map = True

def __init__(self, data_directory: Path, **kwargs: Any) -> None:
self.connection = self.connect(data_directory, **kwargs)

script_dir = data_directory.parent

schema = (script_dir / 'schema' / 'duckdb.sql').read_text()

if not SANDBOXED:
self.connection._load_extensions(
["httpfs", "postgres_scanner", "sqlite_scanner"]
)

with self.connection.begin() as con:
for stmt in filter(None, map(str.strip, schema.split(';'))):
con.exec_driver_sql(stmt)

for table in TEST_TABLES:
src = data_directory / f'{table}.csv'
con.exec_driver_sql(
f"COPY {table} FROM {str(src)!r} (DELIMITER ',', HEADER, SAMPLE_SIZE 1)"
)

@staticmethod
def _load_data(
data_dir,
script_dir,
database: str = "ibis_testing",
**_: Any,
) -> None:
def _load_data(data_dir, script_dir, **_: Any) -> None:
"""Load test data into a DuckDB backend instance.
Parameters
----------
data_dir
Location of test data
script_dir
Location of scripts defining schemas
"""
duckdb = pytest.importorskip("duckdb")

schema = (script_dir / 'schema' / 'duckdb.sql').read_text()
return TestConf(data_directory=data_dir)

conn = duckdb.connect(str(data_dir / f"{database}.ddb"))
for stmt in filter(None, map(str.strip, schema.split(';'))):
conn.execute(stmt)
@staticmethod
def connect(data_directory: Path, **kwargs: Any) -> BaseBackend:
pytest.importorskip("duckdb")
return ibis.duckdb.connect(**kwargs) # type: ignore

for table in TEST_TABLES:
src = data_dir / f'{table}.csv'
conn.execute(
f"COPY {table} FROM {str(src)!r} (DELIMITER ',', HEADER, SAMPLE_SIZE 1)"
)

@staticmethod
@functools.lru_cache(maxsize=None)
def connect(data_directory: Path) -> BaseBackend:
path = data_directory / "ibis_testing.ddb"
return ibis.duckdb.connect(str(path)) # type: ignore
@pytest.fixture
def con(data_directory, tmp_path: Path):
return TestConf(data_directory, extension_directory=str(tmp_path)).connection
74 changes: 55 additions & 19 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import sqlite3
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest

Expand Down Expand Up @@ -78,43 +81,48 @@ def pgurl(): # pragma: no cover
@pytest.mark.skipif(
os.environ.get("DUCKDB_POSTGRES") is None, reason="avoiding CI shenanigans"
)
def test_read_postgres(pgurl): # pragma: no cover
con = ibis.duckdb.connect()
def test_read_postgres(con, pgurl): # pragma: no cover
table = con.read_postgres(
f"postgres://{pgurl.username}:{pgurl.password}@{pgurl.host}:{pgurl.port}",
table_name="duckdb_test",
)
assert table.count().execute()


def test_read_sqlite(data_directory):
con = ibis.duckdb.connect()
path = data_directory / "ibis_testing.db"
ft = con.read_sqlite(path, table_name="functional_alltypes")
def test_read_sqlite(con, tmp_path):
path = tmp_path / "test.db"

sqlite_con = sqlite3.connect(str(path))
sqlite_con.execute("CREATE TABLE t AS SELECT 1 a UNION SELECT 2 UNION SELECT 3")

ft = con.read_sqlite(path, table_name="t")
assert ft.count().execute()

with pytest.raises(ValueError):
con.read_sqlite(path)


def test_read_sqlite_no_table_name(data_directory):
con = ibis.duckdb.connect()
path = data_directory / "ibis_testing.db"
def test_read_sqlite_no_table_name(con, tmp_path):
path = tmp_path / "test.db"

sqlite3.connect(str(path))

assert path.exists()

with pytest.raises(ValueError):
con.read_sqlite(path)


def test_register_sqlite(data_directory):
con = ibis.duckdb.connect()
path = data_directory / "ibis_testing.db"
ft = con.register(f"sqlite://{path}", "functional_alltypes")
assert ft.count().execute()
def test_register_sqlite(con, tmp_path):
path = tmp_path / "test.db"

sqlite_con = sqlite3.connect(str(path))
sqlite_con.execute("CREATE TABLE t AS SELECT 1 a UNION SELECT 2 UNION SELECT 3")
ft = con.register(f"sqlite://{path}", "t")
assert ft.count().execute()

def test_read_in_memory():
con = ibis.duckdb.connect()

def test_read_in_memory(con):
df_arrow = pa.table({"a": ["a"], "b": [1]})
df_pandas = pd.DataFrame({"a": ["a"], "b": [1]})
con.read_in_memory(df_arrow, table_name="df_arrow")
Expand All @@ -124,9 +132,7 @@ def test_read_in_memory():
assert "df_pandas" in con.list_tables()


def test_re_read_in_memory_overwrite():
con = ibis.duckdb.connect()

def test_re_read_in_memory_overwrite(con):
df_pandas_1 = pd.DataFrame({"a": ["a"], "b": [1], "d": ["hi"]})
df_pandas_2 = pd.DataFrame({"a": [1], "c": [1.4]})

Expand Down Expand Up @@ -192,3 +198,33 @@ def test_set_temp_dir(tmp_path):
path = tmp_path / "foo" / "bar"
ibis.duckdb.connect(temp_directory=path)
assert path.exists()


def test_s3_403_fallback(con, httpserver, monkeypatch):
# monkeypatch to avoid downloading extensions in tests
monkeypatch.setattr(con, "_load_extensions", lambda x: True)

# Throw a 403 to trigger fallback to pyarrow.dataset
httpserver.expect_request("/myfile").respond_with_data(
"Forbidden", status=403, content_type="text/plain"
)

# Since the URI is nonsense to pyarrow, expect an error, but raises from
# pyarrow, which indicates the fallback worked
with pytest.raises(pa.lib.ArrowInvalid):
con.read_parquet(httpserver.url_for("/myfile"))


@pytest.mark.xfail_version(
duckdb=["duckdb<=0.7.1"],
reason="""
the fix for this (issue #5879) caused a serious performance regression in the repr.
added this xfail in #5959, which also reverted the bugfix that caused the regression.
the issue was fixed upstream in duckdb in https://github.com/duckdb/duckdb/pull/6978
""",
)
def test_register_numpy_str(con):
data = pd.DataFrame({"a": [np.str_("xyz"), None]})
result = con.read_in_memory(data)
tm.assert_frame_equal(result.execute(), data)
2 changes: 1 addition & 1 deletion ibis/backends/impala/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def insert(
if partition is not None:
partition_schema = self.partition_schema()
partition_schema_names = frozenset(partition_schema.names)
expr = expr.projection(
expr = expr.select(
[
column
for column in expr.columns
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/impala/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _parse_schema(self):
schema = []
while True:
tup = self._next_tuple()
if tup[0].strip() == '':
if not tup[0].strip():
break
schema.append((tup[0], tup[1]))

Expand All @@ -160,7 +160,7 @@ def _parse_info(self):
orig_key = tup[0].strip(':')
key = _clean_param_name(tup[0])

if key == '' or key.startswith('#'):
if not key or key.startswith('#'):
# section is done
break

Expand Down Expand Up @@ -214,7 +214,7 @@ def _parse_storage_info(self):
orig_key = tup[0].strip(':')
key = _clean_param_name(tup[0])

if key == '' or key.startswith('#'):
if not key or key.startswith('#'):
# section is done
break

Expand Down
133 changes: 63 additions & 70 deletions ibis/backends/impala/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
Location of scripts defining schemas
"""
fsspec = pytest.importorskip("fsspec")
fs = fsspec.filesystem("file")

data_files = {
data_file
for data_file in fs.find(data_dir)
# ignore sqlite databases and markdown files
if not data_file.endswith((".db", ".md"))
# ignore files in the test data .git directory
if (
# ignore .git
os.path.relpath(data_file, data_dir).split(os.sep, 1)[0]
!= ".git"
)
}

# without setting the pool size
# connections are dropped from the urllib3
Expand All @@ -52,79 +66,58 @@ def _load_data(data_dir: Path, script_dir: Path, **_: Any) -> None:
URLLIB_DEFAULT_POOL_SIZE = 10

env = IbisTestEnv()
con = ibis.impala.connect(
host=env.impala_host,
port=env.impala_port,
hdfs_client=fsspec.filesystem(
env.hdfs_protocol,
host=env.nn_host,
port=env.hdfs_port,
user=env.hdfs_user,
),
pool_size=URLLIB_DEFAULT_POOL_SIZE,
)

try:
fs = fsspec.filesystem("file")

data_files = {
data_file
for data_file in fs.find(data_dir)
# ignore sqlite databases and markdown files
if not data_file.endswith((".db", ".md"))
# ignore files in the test data .git directory
if (
# ignore .git
os.path.relpath(data_file, data_dir).split(os.sep, 1)[0]
!= ".git"
)
}

with contextlib.closing(
ibis.impala.connect(
host=env.impala_host,
port=env.impala_port,
hdfs_client=fsspec.filesystem(
env.hdfs_protocol,
host=env.nn_host,
port=env.hdfs_port,
user=env.hdfs_user,
),
pool_size=URLLIB_DEFAULT_POOL_SIZE,
)
) as con, concurrent.futures.ThreadPoolExecutor(
max_workers=int(
os.environ.get("IBIS_DATA_MAX_WORKERS", URLLIB_DEFAULT_POOL_SIZE)
)
) as executor:
hdfs = con.hdfs
with concurrent.futures.ThreadPoolExecutor(
max_workers=int(
os.environ.get(
"IBIS_DATA_MAX_WORKERS",
URLLIB_DEFAULT_POOL_SIZE,
tasks = {
# make the database
executor.submit(impala_create_test_database, con, env),
# build and upload UDFs
*itertools.starmap(
executor.submit,
impala_build_and_upload_udfs(hdfs, env, fs=fs),
),
# upload data files
*(
executor.submit(
hdfs_make_dir_and_put_file,
hdfs,
data_file,
os.path.join(
env.test_data_dir,
os.path.relpath(data_file, data_dir),
),
)
for data_file in data_files
),
}

for future in concurrent.futures.as_completed(tasks):
future.result()

# create the tables and compute stats
for future in concurrent.futures.as_completed(
executor.submit(table_future.result().compute_stats)
for table_future in concurrent.futures.as_completed(
impala_create_tables(con, env, executor=executor)
)
) as executor:
tasks = {
# make the database
executor.submit(impala_create_test_database, con, env),
# build and upload UDFs
*itertools.starmap(
executor.submit,
impala_build_and_upload_udfs(hdfs, env, fs=fs),
),
# upload data files
*(
executor.submit(
hdfs_make_dir_and_put_file,
hdfs,
data_file,
os.path.join(
env.test_data_dir,
os.path.relpath(data_file, data_dir),
),
)
for data_file in data_files
),
}

for future in concurrent.futures.as_completed(tasks):
future.result()

# create the tables and compute stats
for future in concurrent.futures.as_completed(
executor.submit(table_future.result().compute_stats)
for table_future in concurrent.futures.as_completed(
impala_create_tables(con, env, executor=executor)
)
):
future.result()
finally:
con.close()
):
future.result()

@staticmethod
def connect(
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_cursor_garbage_collection(con):
def test_raise_ibis_error_no_hdfs(con_no_hdfs):
# GH299
with pytest.raises(com.IbisError):
con_no_hdfs.hdfs
con_no_hdfs.hdfs # noqa: B018


def test_get_table_ref(db):
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,4 @@ def test_kudu_property_raises_useful_error(con):
NotImplementedError,
match="kudu support using kudu-python",
):
con.kudu
con.kudu # noqa: B018
7 changes: 3 additions & 4 deletions ibis/backends/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def test_embedded_identifier_quoting(alltypes):
def test_summary_execute(alltypes):
table = alltypes

# also test set_column while we're at it
table = table.set_column('double_col', table.double_col * 2)
table = table.mutate(double_col=table.double_col * 2)

with pytest.warns(FutureWarning, match="is deprecated"):
metrics = table.double_col.summary()
Expand Down Expand Up @@ -373,7 +372,7 @@ def test_filter_predicates(con):

expr = t
for pred in predicates:
expr = expr[pred(expr)].projection([expr])
expr = expr[pred(expr)].select(expr)

expr.execute()

Expand Down Expand Up @@ -672,7 +671,7 @@ def test_identical_to(con, left, right, expected):

def test_not(alltypes):
t = alltypes.limit(10)
expr = t.projection([(~t.double_col.isnull()).name('double_col')])
expr = t.select(double_col=~t.double_col.isnull())
result = expr.execute().double_col
expected = ~t.execute().double_col.isnull()
tm.assert_series_equal(result, expected)
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/impala/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def limit_cte_extract(con):
alltypes = con.table('functional_alltypes')
t = alltypes.limit(100)
t2 = t.view()
return t.join(t2).projection(t)
return t.join(t2).select(t)


@pytest.mark.parametrize(
Expand All @@ -52,7 +52,7 @@ def test_nested_join_base(snapshot):
t = ibis.table(dict(uuid='string', ts='timestamp'), name='t')
counts = t.group_by('uuid').size()
max_counts = counts.group_by('uuid').aggregate(max_count=lambda x: x['count'].max())
result = max_counts.left_join(counts, 'uuid').projection([counts])
result = max_counts.left_join(counts, 'uuid').select(counts)
compiled_result = ImpalaCompiler.to_sql(result)
snapshot.assert_match(compiled_result, "out.sql")

Expand All @@ -68,10 +68,10 @@ def test_nested_joins_single_cte(snapshot):

main_kw = max_counts.left_join(
counts, ['uuid', max_counts.max_count == counts['count']]
).projection([counts])
).select(counts)

result = main_kw.left_join(last_visit, 'uuid').projection(
[main_kw, last_visit.last_visit]
result = main_kw.left_join(last_visit, 'uuid').select(
main_kw, last_visit.last_visit
)
compiled_result = ImpalaCompiler.to_sql(result)
snapshot.assert_match(compiled_result, "out.sql")
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/impala/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_window_frame_specs(alltypes, window, snapshot):
t = alltypes

w2 = window.order_by(t.f)
expr = t.projection([t.d.sum().over(w2).name('foo')])
expr = t.select(foo=t.d.sum().over(w2))
assert_sql_equal(expr, snapshot)


Expand All @@ -83,8 +83,8 @@ def test_cumulative_functions(alltypes, name, snapshot):
expr = cumfunc().over(w).name("foo")
expected = func().over(ibis.cumulative_window(order_by=t.d)).name("foo")

expr1 = t.projection(expr)
expr2 = t.projection(expected)
expr1 = t.select(expr)
expr2 = t.select(expected)

assert_sql_equal(expr1, snapshot, "out1.sql")
assert_sql_equal(expr2, snapshot, "out2.sql")
Expand All @@ -95,7 +95,7 @@ def test_nested_analytic_function(alltypes, snapshot):

w = window(order_by=t.f)
expr = (t.f - t.f.lag()).lag().over(w).name('foo')
result = t.projection([expr])
result = t.select(expr)
assert_sql_equal(result, snapshot)


Expand All @@ -112,7 +112,7 @@ def test_multiple_windows(alltypes, snapshot):
w = window(group_by=t.g)

expr = t.f.sum().over(w) - t.f.sum()
proj = t.projection([t.g, expr.name('result')])
proj = t.select(t.g, result=expr)

assert_sql_equal(proj, snapshot)

Expand Down Expand Up @@ -154,7 +154,7 @@ def test_unsupported_aggregate_functions(alltypes, column, op):
t = alltypes
w = ibis.window(order_by=t.d)
expr = getattr(t[column], op)()
proj = t.projection([expr.over(w).name('foo')])
proj = t.select(foo=expr.over(w))
with pytest.raises(com.TranslationError):
ImpalaCompiler.to_sql(proj)

Expand All @@ -172,5 +172,5 @@ def test_propagate_nested_windows(alltypes, snapshot):
ex_expr = (t.f - t.f.lag().over(w)).lag().over(w)
assert_equal(result, ex_expr)

expr = t.projection(col.over(w).name('foo'))
expr = t.select(col.over(w).name('foo'))
assert_sql_equal(expr, snapshot)
7 changes: 6 additions & 1 deletion ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from sqlalchemy.dialects.mssql import DATETIME2

import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
Expand All @@ -12,7 +14,10 @@ class MsSqlExprTranslator(AlchemyExprTranslator):
_registry = operation_registry
_rewrites = AlchemyExprTranslator._rewrites.copy()
_bool_aggs_need_cast_to_int32 = True
integer_to_timestamp = staticmethod(_timestamp_from_unix)

_timestamp_type = DATETIME2
_integer_to_timestamp = staticmethod(_timestamp_from_unix)

native_json_type = False

_forbids_frame_clause = AlchemyExprTranslator._forbids_frame_clause + (
Expand Down
18 changes: 15 additions & 3 deletions ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
sqlalchemy_window_functions_registry,
unary,
)
from ibis.backends.base.sql.alchemy.registry import substr
from ibis.backends.base.sql.alchemy.registry import substr, variance_reduction


def _reduction(func, cast_type='int32'):
def reduction_compiler(t, op):
arg, where = op.args

if arg.output_dtype.is_boolean():
nullable = arg.output_dtype.nullable
arg = ops.Cast(arg, dt.dtype(cast_type)(nullable=nullable))
if isinstance(arg, ops.TableColumn):
nullable = arg.output_dtype.nullable
arg = ops.Cast(arg, dt.dtype(cast_type)(nullable=nullable))
else:
arg = ops.Where(arg, 1, 0)

if where is not None:
arg = ops.Where(where, arg, None)
Expand Down Expand Up @@ -155,6 +158,8 @@ def _timestamp_truncate(t, op):
ops.Log: fixed_arity(lambda x, p: sa.func.log(x, p), 2),
ops.Log2: fixed_arity(lambda x: sa.func.log(x, 2), 1),
ops.Log10: fixed_arity(lambda x: sa.func.log(x, 10), 1),
ops.StandardDev: variance_reduction('stdev', {'sample': '', 'pop': 'p'}),
ops.Variance: variance_reduction('var', {'sample': '', 'pop': 'p'}),
# timestamp methods
ops.TimestampNow: fixed_arity(sa.func.GETDATE, 0),
ops.ExtractYear: _extract('year'),
Expand Down Expand Up @@ -195,6 +200,13 @@ def _timestamp_truncate(t, op):
# ibis.expr.operations.strings
ops.RPad,
ops.LPad,
# ibis.expr.operations.reductions
ops.BitAnd,
ops.BitOr,
ops.BitXor,
ops.GroupConcat,
# ibis.expr.operations.window
ops.NthValue,
}

operation_registry = {
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class MySQLExprTranslator(AlchemyExprTranslator):
# https://dev.mysql.com/doc/refman/8.0/en/spatial-function-reference.html
_registry = operation_registry.copy()
_rewrites = AlchemyExprTranslator._rewrites.copy()
integer_to_timestamp = sa.func.from_unixtime
_integer_to_timestamp = sa.func.from_unixtime
native_json_type = False
_dialect_name = "mysql"

Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
issubclass(operation, op_impl) for op_impl in op_classes
)

def _clean_up_cached_table(self, op):
del self.dictionary[op.name]


class Backend(BasePandasBackend):
name = 'pandas'
Expand Down Expand Up @@ -303,9 +306,5 @@ def execute(self, query, params=None, limit='default', **kwargs):

return execute_and_reset(node, params=params, **kwargs)

def _cached(self, expr):
"""No-op. The expression is already in memory."""
return ir.CachedTable(expr.op())

def _release_cached(self, _):
"""No-op."""
def _load_into_cache(self, name, expr):
self.create_table(name, expr.execute())
55 changes: 27 additions & 28 deletions ibis/backends/pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.base import Database
from ibis.common.grounds import Immutable
from ibis.expr.operations.relations import TableProxy

_ibis_dtypes = toolz.valmap(
np.dtype,
Expand Down Expand Up @@ -128,11 +127,8 @@ def convert_datetimetz_to_timestamp(_, out_dtype, column):
output_timezone = out_dtype.timezone
if output_timezone is not None:
return column.dt.tz_convert(output_timezone)
return column.astype(out_dtype.to_pandas(), errors='ignore')


PANDAS_STRING_TYPES = {'string', 'unicode', 'bytes'}
PANDAS_DATE_TYPES = {'datetime', 'datetime64', 'date'}
else:
return column.dt.tz_localize(None)


@sch.convert.register(np.dtype, dt.Interval, pd.Series)
Expand Down Expand Up @@ -172,15 +168,28 @@ def convert_timestamp_to_date(in_dtype, out_dtype, column):

@sch.convert.register(object, dt.DataType, pd.Series)
def convert_any_to_any(_, out_dtype, column):
try:
return column.astype(out_dtype.to_pandas())
except Exception: # noqa: BLE001
return column


@sch.convert.register(np.dtype, dt.Timestamp, pd.Series)
def convert_any_to_timestamp(_, out_dtype, column):
try:
return column.astype(out_dtype.to_pandas())
except pd.errors.OutOfBoundsDatetime:
try:
return column.map(date_parse)
except TypeError:
return column
except Exception: # noqa: BLE001
return column
except TypeError:
column = pd.to_datetime(column)
timezone = out_dtype.timezone
try:
return column.dt.tz_convert(timezone)
except TypeError:
return column.dt.tz_localize(timezone)


@sch.convert.register(object, dt.Struct, pd.Series)
Expand All @@ -198,6 +207,11 @@ def convert_array_to_series(in_dtype, out_dtype, column):
return column.map(lambda x: list(x) if util.is_iterable(x) else x)


@sch.convert.register(np.dtype, dt.Map, pd.Series)
def convert_map_to_series(in_dtype, out_dtype, column):
return column.map(lambda x: dict(x) if util.is_iterable(x) else x)


@sch.convert.register(np.dtype, dt.JSON, pd.Series)
def convert_json_to_series(in_, out, col: pd.Series):
def try_json(x):
Expand All @@ -211,33 +225,18 @@ def try_json(x):
return pd.Series(list(map(try_json, col)), dtype="object")


class DataFrameProxy(Immutable, util.ToFrame):
__slots__ = ('_df', '_hash')

def __init__(self, df: pd.DataFrame) -> None:
object.__setattr__(self, "_df", df)
object.__setattr__(self, "_hash", hash((type(df), id(df))))

def __hash__(self) -> int:
return self._hash

def __repr__(self) -> str:
df_repr = util.indent(repr(self._df), spaces=2)
return f"{self.__class__.__name__}:\n{df_repr}"
class DataFrameProxy(TableProxy):
__slots__ = ()

def to_frame(self) -> pd.DataFrame:
return self._df
return self._data

def to_pyarrow(self, schema: sch.Schema) -> pa.Table:
import pyarrow as pa

from ibis.backends.pyarrow.datatypes import ibis_to_pyarrow_schema

return pa.Table.from_pandas(self._df, schema=ibis_to_pyarrow_schema(schema))


class PandasInMemoryTable(ops.InMemoryTable):
data = rlz.instance_of(DataFrameProxy)
return pa.Table.from_pandas(self._data, schema=ibis_to_pyarrow_schema(schema))


class PandasTable(ops.DatabaseTable):
Expand Down
22 changes: 17 additions & 5 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pandas as pd
import pytz
import toolz
from pandas.api.types import DatetimeTZDtype
from pandas.core.groupby import DataFrameGroupBy, SeriesGroupBy

import ibis.common.exceptions as com
Expand Down Expand Up @@ -152,14 +151,22 @@ def execute_cast_series_timestamp(op, data, type, **kwargs):

tz = type.timezone

if from_type.is_timestamp() or from_type.is_date():
return data.astype('M8[ns]' if tz is None else DatetimeTZDtype('ns', tz))
if from_type.is_timestamp():
from_tz = from_type.timezone
if tz is None and from_tz is None:
return data
elif tz is None or from_tz is None:
return data.dt.tz_localize(tz)
elif tz is not None and from_tz is not None:
return data.dt.tz_convert(tz)
elif from_type.is_date():
return data if tz is None else data.dt.tz_localize(tz)

if from_type.is_string() or from_type.is_integer():
if from_type.is_integer():
timestamps = pd.to_datetime(data.values, unit="s")
else:
timestamps = pd.to_datetime(data.values, infer_datetime_format=True)
timestamps = pd.to_datetime(data.values)
if getattr(timestamps.dtype, "tz", None) is not None:
method_name = "tz_convert"
else:
Expand Down Expand Up @@ -192,7 +199,7 @@ def execute_cast_series_date(op, data, type, **kwargs):
# TODO: remove String as subclass of JSON
if from_type.is_string() and not from_type.is_json():
values = data.values
datetimes = pd.to_datetime(values, infer_datetime_format=True)
datetimes = pd.to_datetime(values)
with contextlib.suppress(TypeError):
datetimes = datetimes.tz_convert(None)
dates = _normalize(datetimes, data.index, data.name)
Expand Down Expand Up @@ -1435,6 +1442,11 @@ def execute_zero_if_null_series(op, data, **kwargs):
return data.replace({np.nan: zero, None: zero, pd.NA: zero})


@execute_node.register(ops.InMemoryTable)
def execute_in_memory_table(op, **kwargs):
return op.data.to_frame()


@execute_node.register(
ops.ZeroIfNull,
(type(None), type(pd.NA), numbers.Real, np.integer, np.floating),
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/pandas/execution/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def execute_epoch_seconds(op, data, **kwargs):
(pd.Series, str, datetime.time),
)
def execute_between_time(op, data, lower, upper, **kwargs):
indexer = pd.DatetimeIndex(data).indexer_between_time(lower, upper)
idx = pd.DatetimeIndex(data)
if idx.tz is not None:
idx = idx.tz_convert(None) # make naive because times are naive
indexer = idx.indexer_between_time(lower, upper)
result = np.zeros(len(data), dtype=np.bool_)
result[indexer] = True
return pd.Series(result)
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/pandas/tests/execution/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ def test_array_literal(client, arr, create_arr_expr):


def test_array_length(t):
expr = t.projection(
[
t.array_of_float64.length().name('array_of_float64_length'),
t.array_of_int64.length().name('array_of_int64_length'),
t.array_of_strings.length().name('array_of_strings_length'),
]
expr = t.select(
t.array_of_float64.length().name('array_of_float64_length'),
t.array_of_int64.length().name('array_of_int64_length'),
t.array_of_strings.length().name('array_of_strings_length'),
)
result = expr.execute()
expected = pd.DataFrame(
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/pandas/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,10 @@ def test_value_counts(t, df):
result = expr.execute()
expected = (
df.dup_strings.value_counts()
.reset_index()
.rename(columns={'dup_strings': 'dup_strings_count'})
.rename(columns={'index': 'dup_strings'})
.sort_values(['dup_strings'])
.rename("dup_strings")
.reset_index(name="dup_strings_count")
.rename(columns={"index": "dup_strings"})
.sort_values(["dup_strings"])
.reset_index(drop=True)
)
tm.assert_frame_equal(result[expected.columns], expected)
Expand Down
21 changes: 12 additions & 9 deletions ibis/backends/pandas/tests/execution/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ def test_timestamp_functions(case_func, expected_func):
def test_cast_datetime_strings_to_date(t, df, column):
expr = t[column].cast('date')
result = expr.execute()
expected = (
pd.to_datetime(df[column], infer_datetime_format=True)
.dt.normalize()
.dt.tz_localize(None)
)
expected = pd.to_datetime(df[column]).dt.normalize().dt.tz_localize(None)
tm.assert_series_equal(result, expected)


Expand All @@ -78,7 +74,7 @@ def test_cast_datetime_strings_to_date(t, df, column):
def test_cast_datetime_strings_to_timestamp(t, df, column):
expr = t[column].cast('timestamp')
result = expr.execute()
expected = pd.to_datetime(df[column], infer_datetime_format=True)
expected = pd.to_datetime(df[column])
if getattr(expected.dtype, 'tz', None) is not None:
expected = expected.dt.tz_convert(None)
tm.assert_series_equal(result, expected)
Expand Down Expand Up @@ -122,10 +118,17 @@ def test_times_ops(t, df):


@pytest.mark.parametrize(
('tz', 'rconstruct'),
[('US/Eastern', np.zeros), ('UTC', np.ones), (None, np.ones)],
('tz', 'rconstruct', 'column'),
[
('US/Eastern', np.ones, 'plain_datetimes_utc'),
('US/Eastern', np.zeros, 'plain_datetimes_naive'),
('UTC', np.ones, 'plain_datetimes_utc'),
('UTC', np.ones, 'plain_datetimes_naive'),
(None, np.ones, 'plain_datetimes_utc'),
(None, np.ones, 'plain_datetimes_naive'),
],
ids=lambda x: str(getattr(x, "__name__", x)).lower().replace("/", "_"),
)
@pytest.mark.parametrize('column', ['plain_datetimes_utc', 'plain_datetimes_naive'])
def test_times_ops_with_tz(t, df, tz, rconstruct, column):
expected = pd.Series(rconstruct(len(df), dtype=bool))
time = t[column].time()
Expand Down
7 changes: 4 additions & 3 deletions ibis/backends/pandas/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,10 @@ def test_infer_array():

def test_apply_to_schema_with_timezone():
data = {'time': pd.date_range('2018-01-01', '2018-01-02', freq='H')}
df = pd.DataFrame(data)
expected = df.assign(time=df.time.astype('datetime64[ns, EST]'))
desired_schema = ibis.schema([('time', 'timestamp("EST")')])
df = expected = pd.DataFrame(data).assign(
time=lambda df: df.time.dt.tz_localize("EST")
)
desired_schema = ibis.schema(dict(time='timestamp("EST")'))
result = desired_schema.apply_to(df.copy())
tm.assert_frame_equal(expected, result)

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pandas/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def enable():
"""Enable tracing."""
if options.pandas is None:
# pandas options haven't been registered yet - force module __getattr__
ibis.pandas
ibis.pandas # noqa: B018
options.pandas.enable_trace = True
logging.getLogger('ibis.backends.pandas.trace').setLevel(logging.DEBUG)

Expand Down Expand Up @@ -130,7 +130,7 @@ def traced_func(*args, **kwargs):
# the pandas attribute here forces the option initialization
import ibis

ibis.pandas
ibis.pandas # noqa: B018

if not options.pandas.enable_trace:
return func(*args, **kwargs)
Expand Down
15 changes: 4 additions & 11 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import itertools
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping
Expand All @@ -15,17 +14,11 @@
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend
from ibis.backends.polars.compiler import translate
from ibis.util import deprecated, normalize_filename
from ibis.util import deprecated, gen_name, normalize_filename

if TYPE_CHECKING:
import pandas as pd

# counters for in-memory, parquet, and csv reads
# used if no table name is specified
pd_n = itertools.count(0)
pa_n = itertools.count(0)
csv_n = itertools.count(0)


class Backend(BaseBackend):
name = "polars"
Expand Down Expand Up @@ -154,7 +147,7 @@ def read_csv(
The just-registered table
"""
path = normalize_filename(path)
table_name = table_name or f"ibis_read_csv_{next(csv_n)}"
table_name = table_name or gen_name("read_csv")
try:
self._tables[table_name] = pl.scan_csv(path, **kwargs)
except pl.exceptions.ComputeError:
Expand Down Expand Up @@ -184,7 +177,7 @@ def read_pandas(
ir.Table
The just-registered table
"""
table_name = table_name or f"ibis_read_in_memory_{next(pd_n)}"
table_name = table_name or gen_name("read_in_memory")
self._tables[table_name] = pl.from_pandas(source, **kwargs).lazy()
return self.table(table_name)

Expand All @@ -211,7 +204,7 @@ def read_parquet(
The just-registered table
"""
path = normalize_filename(path)
table_name = table_name or f"ibis_read_parquet_{next(pa_n)}"
table_name = table_name or gen_name("read_parquet")
self._tables[table_name] = pl.scan_parquet(path, **kwargs)
return self.table(table_name)

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def struct_column(op):
ops.StandardDev: 'std',
ops.Sum: 'sum',
ops.Variance: 'var',
ops.CountDistinct: 'n_unique',
}

for reduction in _reductions.keys():
Expand Down
12 changes: 10 additions & 2 deletions ibis/backends/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,20 @@ def to_ibis_dtype(typ):

@to_ibis_dtype.register(pl.Datetime)
def from_polars_datetime(typ):
return dt.Timestamp(timezone=typ.tz)
try:
timezone = typ.time_zone
except AttributeError: # pragma: no cover
timezone = typ.tz # pragma: no cover
return dt.Timestamp(timezone=timezone)


@to_ibis_dtype.register(pl.Duration)
def from_polars_duration(typ):
return dt.Interval(unit=typ.tu)
try:
time_unit = typ.time_unit
except AttributeError: # pragma: no cover
time_unit = typ.tu # pragma: no cover
return dt.Interval(unit=time_unit)


@to_ibis_dtype.register(pl.List)
Expand Down
56 changes: 55 additions & 1 deletion ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Iterable, Literal

import sqlalchemy as sa
Expand All @@ -16,6 +17,35 @@
import ibis.expr.datatypes as dt


# adapted from https://wiki.postgresql.org/wiki/First/last_%28aggregate%29
_CREATE_FIRST_LAST_AGGS_SQL = """\
CREATE OR REPLACE FUNCTION public._ibis_first_agg (anyelement, anyelement)
RETURNS anyelement
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS
'SELECT $1';

CREATE OR REPLACE AGGREGATE public._ibis_first (anyelement) (
SFUNC = public._ibis_first_agg,
STYPE = anyelement,
PARALLEL = safe
);

CREATE OR REPLACE FUNCTION public._ibis_last_agg (anyelement, anyelement)
RETURNS anyelement
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS
'SELECT $2';

CREATE OR REPLACE AGGREGATE public._ibis_last (anyelement) (
SFUNC = public._ibis_last_agg,
STYPE = anyelement,
PARALLEL = safe
);"""

_DROP_FIRST_LAST_AGGS_SQL = """\
DROP AGGREGATE IF EXISTS public._ibis_first(anyelement), public._ibis_last(anyelement);
DROP FUNCTION IF EXISTS public._ibis_first_agg(anyelement, anyelement), public._ibis_last_agg(anyelement, anyelement);"""


class Backend(BaseAlchemyBackend):
name = 'postgres'
compiler = PostgreSQLCompiler
Expand Down Expand Up @@ -120,6 +150,28 @@ def connect(dbapi_connection, connection_record):
with dbapi_connection.cursor() as cur:
cur.execute("SET TIMEZONE = UTC")

@sa.event.listens_for(engine, "before_execute")
def receive_before_execute(
conn, clauseelement, multiparams, params, execution_options
):
with conn.connection.cursor() as cur:
try:
cur.execute(_CREATE_FIRST_LAST_AGGS_SQL)
except Exception as e: # noqa: BLE001
# a user may not have permissions to create funtions and/or aggregates
warnings.warn(f"Unable to create first/last aggregates: {e}")

@sa.event.listens_for(engine, "after_execute")
def receive_after_execute(
conn, clauseelement, multiparams, params, execution_options, result
):
with conn.connection.cursor() as cur:
try:
cur.execute(_DROP_FIRST_LAST_AGGS_SQL)
except Exception as e: # noqa: BLE001
# a user may not have permissions to drop funtions and/or aggregates
warnings.warn(f"Unable to drop first/last aggregates: {e}")

super().do_connect(engine)

def list_databases(self, like=None):
Expand Down Expand Up @@ -194,6 +246,8 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
AND attnum > 0
AND NOT attisdropped
ORDER BY attnum"""
if self.inspector.has_table(query):
query = f"TABLE {query}"
with self.begin() as con:
con.exec_driver_sql(f"CREATE TEMPORARY VIEW {name} AS {query}")
type_info = con.execute(
Expand All @@ -205,4 +259,4 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
yield f"CREATE OR REPLACE VIEW {name} AS {definition}"
yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"
3 changes: 3 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):
_has_reduction_filter_syntax = True
_dialect_name = "postgresql"

# it does support it, but we can't use it because of support for pivot
supports_unnest_in_select = False


rewrites = PostgreSQLExprTranslator.rewrites

Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/postgres/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def _get_type(typestr: str) -> dt.DataType:
is_array = typestr.endswith(_BRACKETS)
if (typ := _type_mapping.get(typestr.replace(_BRACKETS, ""))) is not None:
return dt.Array(typ) if is_array else typ
return _parse_numeric(typestr)
try:
return _parse_numeric(typestr)
except parsy.ParseError:
# postgres can have arbitrary types unknown to ibis
return dt.unknown


_type_mapping = {
Expand Down Expand Up @@ -174,4 +178,4 @@ def sa_pg_array(dialect, satype, nullable=True):

@dt.dtype.register(PGDialect, postgresql.TSVECTOR)
def sa_postgres_tsvector(_, satype, nullable=True):
return dt.String(nullable=nullable)
return dt.Unknown(nullable=nullable)
70 changes: 69 additions & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as pg
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction

import ibis.backends.base.sql.registry.geospatial as geo
import ibis.common.exceptions as com
Expand Down Expand Up @@ -479,6 +481,69 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
return translate


def _arbitrary(t, op):
if (how := op.how) == "heavy":
raise com.UnsupportedOperationError(
f"postgres backend doesn't support how={how!r} for the arbitrary() aggregate"
)
func = getattr(sa.func.public, f"_ibis_{op.how}")
return t._reduction(func, op)


class struct_field(GenericFunction):
inherit_cache = True


@compiles(struct_field)
def compile_struct_field_postgresql(element, compiler, **kw):
arg, field = element.clauses
return f"({compiler.process(arg, **kw)}).{field.name}"


def _struct_field(t, op):
arg = op.arg
idx = arg.output_dtype.names.index(op.field) + 1
field_name = sa.literal_column(f"f{idx:d}")
return struct_field(
t.translate(arg), field_name, type_=t.get_sqla_type(op.output_dtype)
)


def _struct_column(t, op):
types = op.output_dtype.types
return sa.func.row(
# we have to cast here, otherwise postgres refuses to allow the statement
*map(t.translate, map(ops.Cast, op.values, types)),
type_=t.get_sqla_type(
dt.Struct({f"f{i:d}": typ for i, typ in enumerate(types, start=1)})
),
)


def _unnest(t, op):
arg = op.arg
row_type = arg.output_dtype.value_type

types = getattr(row_type, "types", (row_type,))

is_struct = row_type.is_struct()
derived = (
sa.func.unnest(t.translate(arg))
.table_valued(
*(
sa.column(f"f{i:d}", stype)
for i, stype in enumerate(map(t.get_sqla_type, types), start=1)
)
)
.render_derived(with_types=is_struct)
)

# wrap in a row column so that we can return a single column from this rule
if not is_struct:
return derived.c[0]
return sa.func.row(*derived.c)


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -585,7 +650,7 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
),
ops.ArrayConcat: fixed_arity(sa.sql.expression.ColumnElement.concat, 2),
ops.ArrayRepeat: _array_repeat,
ops.Unnest: unary(sa.func.unnest),
ops.Unnest: _unnest,
ops.Covariance: _covar,
ops.Correlation: _corr,
ops.BitwiseXor: _bitwise_op("#"),
Expand Down Expand Up @@ -629,5 +694,8 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)),
ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)),
ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2),
ops.Arbitrary: _arbitrary,
ops.StructColumn: _struct_column,
ops.StructField: _struct_field,
}
)
Original file line number Diff line number Diff line change
@@ -1 +1 @@
WITH anon_1 AS (SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric FROM functional_alltypes AS t0 GROUP BY 1), anon_2 AS (SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric FROM functional_alltypes AS t0 GROUP BY 1), anon_3 AS (SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric FROM functional_alltypes AS t0 GROUP BY 1) SELECT anon_1.string_col, anon_1.metric FROM anon_1 UNION ALL SELECT anon_2.string_col, anon_2.metric FROM anon_2 UNION ALL SELECT anon_3.string_col, anon_3.metric FROM anon_3
WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION ALL SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION ALL SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1
Original file line number Diff line number Diff line change
@@ -1 +1 @@
WITH anon_1 AS (SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric FROM functional_alltypes AS t0 GROUP BY 1), anon_2 AS (SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric FROM functional_alltypes AS t0 GROUP BY 1), anon_3 AS (SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric FROM functional_alltypes AS t0 GROUP BY 1) SELECT anon_1.string_col, anon_1.metric FROM anon_1 UNION SELECT anon_2.string_col, anon_2.metric FROM anon_2 UNION SELECT anon_3.string_col, anon_3.metric FROM anon_3
WITH anon_2 AS (SELECT t2.string_col AS string_col, sum(t2.double_col) AS metric FROM functional_alltypes AS t2 GROUP BY 1), anon_3 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1), anon_1 AS (SELECT t2.string_col AS string_col, t2.metric AS metric FROM (SELECT anon_2.string_col AS string_col, anon_2.metric AS metric FROM anon_2 UNION SELECT anon_3.string_col AS string_col, anon_3.metric AS metric FROM anon_3) AS t2), anon_4 AS (SELECT t3.string_col AS string_col, sum(t3.double_col) AS metric FROM functional_alltypes AS t3 GROUP BY 1) SELECT t1.string_col, t1.metric FROM (SELECT anon_1.string_col AS string_col, anon_1.metric AS metric FROM anon_1 UNION SELECT anon_4.string_col AS string_col, anon_4.metric AS metric FROM anon_4) AS t1
10 changes: 8 additions & 2 deletions ibis/backends/postgres/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_schema_type_conversion():
def test_interval_films_schema(con):
t = con.table("films")
assert t.len.type() == dt.Interval(unit="m")
assert t.len.execute().dtype == np.dtype("timedelta64[ns]")
assert issubclass(t.len.execute().dtype.type, np.timedelta64)


@pytest.mark.parametrize(
Expand All @@ -129,7 +129,7 @@ def test_all_interval_types_execute(intervals, column, expected_dtype):
assert expr.type() == expected_dtype

series = expr.execute()
assert series.dtype == np.dtype("timedelta64[ns]")
assert issubclass(series.dtype.type, np.timedelta64)


@pytest.mark.xfail(
Expand Down Expand Up @@ -207,3 +207,9 @@ def test_get_schema_from_query(con, pg_type, expected_type):
expected_schema = ibis.schema(dict(x=expected_type, y=dt.Array(expected_type)))
result_schema = con._get_schema_using_query(f"SELECT x, y FROM {name}")
assert result_schema == expected_schema


@pytest.mark.parametrize("col", ["search", "simvec"])
def test_unknown_column_type(con, col):
awards_players = con.table("awards_players")
assert awards_players[col].type().is_unknown()
40 changes: 18 additions & 22 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_timestamp_cast_noop(alltypes, at, translate):
assert isinstance(result2, ir.TimestampColumn)

expected1 = at.c.timestamp_col
expected2 = sa.func.to_timestamp(at.c.int_col)
expected2 = sa.cast(sa.func.to_timestamp(at.c.int_col), sa.TIMESTAMP())

assert str(translate(result1.op())) == str(expected1)
assert str(translate(result2.op())) == str(expected2)
Expand Down Expand Up @@ -743,9 +743,7 @@ def test_simple_window(alltypes, func, df):
t = alltypes
f = getattr(t.double_col, func)
df_f = getattr(df.double_col, func)
result = (
t.projection([(t.double_col - f()).name('double_col')]).execute().double_col
)
result = t.select((t.double_col - f()).name('double_col')).execute().double_col
expected = df.double_col - df_f()
tm.assert_series_equal(result, expected)

Expand All @@ -761,7 +759,7 @@ def test_rolling_window(alltypes, func, df):
window = ibis.window(order_by=t.timestamp_col, preceding=6, following=0)
f = getattr(t.double_col, func)
df_f = getattr(df.double_col.rolling(7, min_periods=0), func)
result = t.projection([f().over(window).name('double_col')]).execute().double_col
result = t.select(f().over(window).name('double_col')).execute().double_col
expected = df_f()
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -797,7 +795,7 @@ def rolled(df):

f = getattr(t.double_col, func)
expr = f().over(window).name('double_col')
result = t.projection([expr]).execute().double_col
result = t.select(expr).execute().double_col
expected = df.groupby('string_col').apply(roller(func)).reset_index(drop=True)
tm.assert_series_equal(result, expected)

Expand All @@ -807,7 +805,7 @@ def test_cumulative_simple_window(alltypes, func, df):
t = alltypes
f = getattr(t.double_col, func)
col = t.double_col - f().over(ibis.cumulative_window())
expr = t.projection([col.name('double_col')])
expr = t.select(col.name('double_col'))
result = expr.execute().double_col
expected = df.double_col - getattr(df.double_col, 'cum%s' % func)()
tm.assert_series_equal(result, expected)
Expand All @@ -819,7 +817,7 @@ def test_cumulative_partitioned_window(alltypes, func, df):
df = df.sort_values('string_col').reset_index(drop=True)
window = ibis.cumulative_window(group_by=t.string_col)
f = getattr(t.double_col, func)
expr = t.projection([(t.double_col - f().over(window)).name('double_col')])
expr = t.select((t.double_col - f().over(window)).name('double_col'))
result = expr.execute().double_col
expected = df.groupby(df.string_col).double_col.transform(
lambda c: c - getattr(c, 'cum%s' % func)()
Expand All @@ -833,7 +831,7 @@ def test_cumulative_ordered_window(alltypes, func, df):
df = df.sort_values('timestamp_col').reset_index(drop=True)
window = ibis.cumulative_window(order_by=t.timestamp_col)
f = getattr(t.double_col, func)
expr = t.projection([(t.double_col - f().over(window)).name('double_col')])
expr = t.select((t.double_col - f().over(window)).name('double_col'))
result = expr.execute().double_col
expected = df.double_col - getattr(df.double_col, 'cum%s' % func)()
tm.assert_series_equal(result, expected)
Expand All @@ -845,7 +843,7 @@ def test_cumulative_partitioned_ordered_window(alltypes, func, df):
df = df.sort_values(['string_col', 'timestamp_col']).reset_index(drop=True)
window = ibis.cumulative_window(order_by=t.timestamp_col, group_by=t.string_col)
f = getattr(t.double_col, func)
expr = t.projection([(t.double_col - f().over(window)).name('double_col')])
expr = t.select((t.double_col - f().over(window)).name('double_col'))
result = expr.execute().double_col
method = operator.methodcaller(f'cum{func}')
expected = df.groupby(df.string_col).double_col.transform(lambda c: c - method(c))
Expand Down Expand Up @@ -931,12 +929,10 @@ def array_types(con):


def test_array_length(array_types):
expr = array_types.projection(
[
array_types.x.length().name('x_length'),
array_types.y.length().name('y_length'),
array_types.z.length().name('z_length'),
]
expr = array_types.select(
array_types.x.length().name('x_length'),
array_types.y.length().name('y_length'),
array_types.z.length().name('z_length'),
)
result = expr.execute()
expected = pd.DataFrame(
Expand Down Expand Up @@ -995,7 +991,7 @@ def test_array_index(array_types, index):
],
)
def test_array_repeat(array_types, n, mul):
expr = array_types.projection([mul(array_types.x, n).name('repeated')])
expr = array_types.select(mul(array_types.x, n).name('repeated'))
result = expr.execute()
expected = pd.DataFrame(
{'repeated': array_types.x.execute().map(lambda x, n=n: mul(x, n))}
Expand All @@ -1013,9 +1009,9 @@ def test_array_repeat(array_types, n, mul):
def test_array_concat(array_types, catop):
t = array_types
x, y = t.x.cast('array<string>').name('x'), t.y
expr = t.projection([catop(x, y).name('catted')])
expr = t.select(catop(x, y).name('catted'))
result = expr.execute()
tuples = t.projection([x, y]).execute().itertuples(index=False)
tuples = t.select(x, y).execute().itertuples(index=False)
expected = pd.DataFrame({'catted': [catop(i, j) for i, j in tuples]})
tm.assert_frame_equal(result, expected)

Expand Down Expand Up @@ -1159,7 +1155,7 @@ def test_ntile(con):
def test_not_and_negate_bool(con, opname, df):
op = getattr(operator, opname)
t = con.table('functional_alltypes').limit(10)
expr = t.projection([op(t.bool_col).name('bool_col')])
expr = t.select(op(t.bool_col).name('bool_col'))
result = expr.execute().bool_col
expected = op(df.head(10).bool_col)
tm.assert_series_equal(result, expected)
Expand All @@ -1180,15 +1176,15 @@ def test_not_and_negate_bool(con, opname, df):
)
def test_negate_non_boolean(con, field, df):
t = con.table('functional_alltypes').limit(10)
expr = t.projection([(-t[field]).name(field)])
expr = t.select((-t[field]).name(field))
result = expr.execute()[field]
expected = -df.head(10)[field]
tm.assert_series_equal(result, expected)


def test_negate_boolean(con, df):
t = con.table('functional_alltypes').limit(10)
expr = t.projection([(-t.bool_col).name('bool_col')])
expr = t.select((-t.bool_col).name('bool_col'))
result = expr.execute().bool_col
expected = -df.head(10).bool_col
tm.assert_series_equal(result, expected)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_geo_literals_smoke(con, shape, value, modifier, expected):
)
def test_geo_ops_smoke(geotable, fn_expr):
"""Smoke tests for geo spatial operations."""
assert fn_expr(geotable).compile() != ''
assert str(fn_expr(geotable).compile())


def test_geo_equals(geotable):
Expand Down
7 changes: 0 additions & 7 deletions ibis/backends/postgres/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pytest import param

import ibis
import ibis.expr.datatypes as dt


@pytest.mark.parametrize(
Expand All @@ -17,9 +16,3 @@ def test_special_strings(alltypes, data, data_type):
expr = alltypes[[alltypes.id, lit]].head(1)
df = expr.execute()
assert df['tmp'].iloc[0] == uuid.UUID(data)


def test_load_tsvector_table(con):
awards_players = con.table("awards_players")
assert "search" in awards_players.columns
assert awards_players.schema()["search"] == dt.String(nullable=True)
28 changes: 28 additions & 0 deletions ibis/backends/pyarrow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pyarrow as pa

import ibis.expr.schema as sch
from ibis.expr.operations.relations import TableProxy

if TYPE_CHECKING:
import pandas as pd


class PyArrowTableProxy(TableProxy):
__slots__ = ()

def to_frame(self) -> pd.DataFrame:
return self._data.to_pandas()

def to_pyarrow(self, _: sch.Schema) -> pa.Table:
return self._data


@sch.infer.register(pa.Table)
def infer_pyarrow_table_schema(t: pa.Table, schema=None):
import ibis.backends.pyarrow.datatypes # noqa: F401

return sch.schema(schema if schema is not None else t.schema)
2 changes: 2 additions & 0 deletions ibis/backends/pyarrow/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
dt.Date: pa.date64(),
dt.JSON: pa.string(),
dt.Null: pa.null(),
# assume unknown types can be converted into strings
dt.Unknown: pa.string(),
}


Expand Down
Loading