107 changes: 48 additions & 59 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import itertools
import operator
from typing import Callable, Dict
from typing import Callable, Iterator

import ibis
import ibis.common.exceptions as com
import ibis.expr.analytics as analytics
import ibis.expr.datatypes as dt
import ibis.expr.format as fmt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.registry import (
Expand All @@ -22,22 +23,15 @@ class QueryContext:
parameters are tracked here.
"""

def __init__(
self, compiler, indent=2, parent=None, memo=None, params=None
):
def __init__(self, compiler, indent=2, parent=None, params=None):
self.compiler = compiler
self._table_refs = {}
self.table_refs = {}
self.extracted_subexprs = set()
self.subquery_memo = {}
self.indent = indent
self.parent = parent

self.always_alias = False

self.query = None

self._table_key_memo = {}
self.memo = memo or fmt.FormatMemo()
self.params = params if params is not None else {}

def _compile_subquery(self, expr):
Expand Down Expand Up @@ -74,8 +68,10 @@ def get_compiled_expr(self, expr):
this = self.top_context

key = self._get_table_key(expr)
if key in this.subquery_memo:
try:
return this.subquery_memo[key]
except KeyError:
pass

op = expr.op()
if isinstance(op, ops.SQLQueryResult):
Expand All @@ -87,46 +83,64 @@ def get_compiled_expr(self, expr):
return result

def make_alias(self, expr):
i = len(self._table_refs)
i = len(self.table_refs)

key = self._get_table_key(expr)

# Get total number of aliases up and down the tree at this point; if we
# find the table prior-aliased along the way, however, we reuse that
# alias
ctx = self
while ctx.parent is not None:
ctx = ctx.parent

if key in ctx._table_refs:
alias = ctx._table_refs[key]
for ctx in itertools.islice(self._contexts(), 1, None):
try:
alias = ctx.table_refs[key]
except KeyError:
pass
else:
self.set_ref(expr, alias)
return

i += len(ctx._table_refs)
i += len(ctx.table_refs)

alias = f't{i:d}'
self.set_ref(expr, alias)

def need_aliases(self, expr=None):
return self.always_alias or len(self._table_refs) > 1
return self.always_alias or len(self.table_refs) > 1

def _contexts(
self,
*,
parents: bool = True,
) -> Iterator[QueryContext]:
ctx = self
yield ctx
while parents and ctx.parent is not None:
ctx = ctx.parent
yield ctx

def has_ref(self, expr, parent_contexts=False):
key = self._get_table_key(expr)
return self._key_in(
key, '_table_refs', parent_contexts=parent_contexts
return any(
key in ctx.table_refs
for ctx in self._contexts(parents=parent_contexts)
)

def set_ref(self, expr, alias):
key = self._get_table_key(expr)
self._table_refs[key] = alias
self.table_refs[key] = alias

def get_ref(self, expr):
"""
Get the alias being used throughout a query to refer to a particular
table or inline view
"""
return self._get_table_item('_table_refs', expr)
key = self._get_table_key(expr)
top = self.top_context

if self.is_extracted(expr):
return top.table_refs.get(key)

return self.table_refs.get(key)

def is_extracted(self, expr):
key = self._get_table_key(expr)
Expand All @@ -138,7 +152,7 @@ def set_extracted(self, expr):
self.make_alias(expr)

def subcontext(self):
return type(self)(
return self.__class__(
compiler=self.compiler,
indent=self.indent,
parent=self,
Expand All @@ -162,37 +176,12 @@ def is_foreign_expr(self, expr):
validator = ExprValidator(exprs)
return not validator.validate(expr)

def _get_table_item(self, item, expr):
key = self._get_table_key(expr)
top = self.top_context

if self.is_extracted(expr):
return getattr(top, item).get(key)

return getattr(self, item).get(key)

def _get_table_key(self, table):
if isinstance(table, ir.TableExpr):
table = table.op()

try:
return self._table_key_memo[table]
except KeyError:
val = table._repr()
self._table_key_memo[table] = val
return val

def _key_in(self, key, memo_attr, parent_contexts=False):
if key in getattr(self, memo_attr):
return True

ctx = self
while parent_contexts and ctx.parent is not None:
ctx = ctx.parent
if key in getattr(ctx, memo_attr):
return True

return False
return table.op()
elif isinstance(table, ops.TableNode):
return table
raise TypeError(f"invalid table expression: {type(table)}")


class ExprTranslator:
Expand All @@ -202,7 +191,7 @@ class ExprTranslator:
"""

_registry = operation_registry
_rewrites: Dict[ops.Node, Callable] = {}
_rewrites: dict[ops.Node, Callable] = {}

def __init__(self, expr, context, named=False, permit_subquery=False):
self.expr = expr
Expand Down Expand Up @@ -298,7 +287,7 @@ def decorator(f):
rewrites = ExprTranslator.rewrites


@rewrites(analytics.Bucket)
@rewrites(ops.Bucket)
def _bucket(expr):
op = expr.op()
stmt = ibis.case()
Expand Down Expand Up @@ -345,7 +334,7 @@ def _bucket(expr):
return stmt.end().name(expr._name)


@rewrites(analytics.CategoryLabel)
@rewrites(ops.CategoryLabel)
def _category_label(expr):
op = expr.op()

Expand Down
5 changes: 1 addition & 4 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.expr.signature import Argument

_map_interval_to_microseconds = {
'W': 604800000000,
Expand Down Expand Up @@ -42,9 +41,7 @@
}


def _replace_interval_with_scalar(
expr: Union[ir.Expr, dt.Interval, float]
) -> Union[ir.Expr, float, Argument]:
def _replace_interval_with_scalar(expr: Union[ir.Expr, dt.Interval, float]):
"""
Good old Depth-First Search to identify the Interval and IntervalValue
components of the expression and return a comparable scalar expression.
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Backend(BaseSQLBackend):
table_expr_class = ClickhouseTable
compiler = ClickhouseCompiler

def connect(
def do_connect(
self,
host='localhost',
port=9000,
Expand Down Expand Up @@ -80,8 +80,7 @@ def connect(
-------
ClickhouseClient
"""
new_backend = self.__class__()
new_backend.con = _DriverClient(
self.con = _DriverClient(
host=host,
port=port,
database=database,
Expand All @@ -90,7 +89,6 @@ def connect(
client_name=client_name,
compression=compression,
)
return new_backend

def register_options(self):
ibis.config.register_option(
Expand Down
28 changes: 28 additions & 0 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,34 @@ def _floor_divide(expr):
return left.div(right).floor()


@rewrites(ops.DayOfWeekName)
def day_of_week_name(expr):
# ClickHouse 20 doesn't support dateName
#
# ClickHouse 21 supports dateName is broken for regexen:
# https://github.com/ClickHouse/ClickHouse/issues/32777
#
# ClickHouses 20 and 21 also have a broken case statement hence the ifnull:
# https://github.com/ClickHouse/ClickHouse/issues/32849
#
# We test against 20 in CI, so we implement day_of_week_name as follows
return (
expr.op()
.arg.day_of_week.index()
.case()
.when(0, "Monday")
.when(1, "Tuesday")
.when(2, "Wednesday")
.when(3, "Thursday")
.when(4, "Friday")
.when(5, "Saturday")
.when(6, "Sunday")
.else_("")
.end()
.nullif("")
)


class ClickhouseCompiler(Compiler):
translator_class = ClickhouseExprTranslator
table_set_formatter_class = ClickhouseTableSetFormatter
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/clickhouse/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,13 @@ def _zero_if_null(translator, expr):
return f'ifNull({arg_}, 0)'


def _day_of_week_index(translator, expr):
(arg,) = expr.op().args
weekdays = 7
offset = f"toDayOfWeek({translator.translate(arg)})"
return f"((({offset} - 1) % {weekdays:d}) + {weekdays:d}) % {weekdays:d}"


_undocumented_operations = {
ops.NullLiteral: _null_literal, # undocumented
ops.IsNull: _unary('isNull'),
Expand All @@ -680,6 +687,7 @@ def _zero_if_null(translator, expr):
ops.Coalesce: _varargs('coalesce'),
ops.NullIfZero: _null_if_zero,
ops.ZeroIfNull: _zero_if_null,
ops.DayOfWeekIndex: _day_of_week_index,
}


Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/clickhouse/tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def test_anonymus_aggregate(alltypes, df, translate):


def test_boolean_summary(alltypes):
expr = alltypes.bool_col.summary()
bool_col_summary = alltypes.bool_col.summary()
expr = alltypes.aggregate(bool_col_summary)

result = expr.execute()
expected = pd.DataFrame(
[[7300, 0, 0, 1, 3650, 0.5, 2]],
Expand Down
16 changes: 0 additions & 16 deletions ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import pandas.testing as tm
import pytest

import ibis
import ibis.config as config
import ibis.expr.types as ir
from ibis import literal as L


def test_get_table_ref(db):
Expand Down Expand Up @@ -135,20 +133,6 @@ def test_table_info(alltypes):
assert buf.getvalue() is not None


def test_execute_exprs_no_table_ref(con):
cases = [(L(1) + L(2), 3)]

for expr, expected in cases:
result = con.execute(expr)
assert result == expected

# ExprList
exlist = ibis.api.expr_list(
[L(1).name('a'), ibis.now().name('b'), L(2).log().name('c')]
)
con.execute(exlist)


@pytest.mark.skip(reason="FIXME: it is raising KeyError: 'Unnamed: 0'")
def test_insert(con, alltypes, df):
drop = 'DROP TABLE IF EXISTS temporary_alltypes'
Expand Down
43 changes: 28 additions & 15 deletions ibis/backends/clickhouse/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,21 +398,34 @@ def test_translate_math_functions(con, alltypes, translate, call, expected):
@pytest.mark.parametrize(
('expr', 'expected'),
[
(L(-5).abs(), 5),
(L(5).abs(), 5),
(L(5.5).round(), 6.0),
(L(5.556).round(2), 5.56),
(L(5.556).ceil(), 6.0),
(L(5.556).floor(), 5.0),
(L(5.556).exp(), math.exp(5.556)),
(L(5.556).sign(), 1),
(L(-5.556).sign(), -1),
(L(0).sign(), 0),
(L(5.556).sqrt(), math.sqrt(5.556)),
(L(5.556).log(2), math.log(5.556, 2)),
(L(5.556).ln(), math.log(5.556)),
(L(5.556).log2(), math.log(5.556, 2)),
(L(5.556).log10(), math.log10(5.556)),
pytest.param(L(-5).abs(), 5, id="abs_neg"),
pytest.param(L(5).abs(), 5, id="abs"),
pytest.param(L(5.5).round(), 6.0, id="round"),
pytest.param(L(5.556).round(2), 5.56, id="round_places"),
pytest.param(L(5.556).ceil(), 6.0, id="ceil"),
pytest.param(L(5.556).floor(), 5.0, id="floor"),
pytest.param(L(5.556).sign(), 1, id="sign"),
pytest.param(L(-5.556).sign(), -1, id="sign_neg"),
pytest.param(L(0).sign(), 0, id="sign_zero"),
pytest.param(L(5.556).sqrt(), math.sqrt(5.556), id="sqrt"),
pytest.param(L(5.556).log(2), math.log(5.556, 2), id="log2_arg"),
pytest.param(L(5.556).log2(), math.log(5.556, 2), id="log2"),
pytest.param(L(5.556).log10(), math.log10(5.556), id="log10"),
# clickhouse has different functions for exp/ln that are faster
# than the defaults, but less precise
#
# we can't use the e() function as it still gives different results
# from `math.exp`
pytest.param(
L(5.556).exp().round(8),
round(math.exp(5.556), 8),
id="exp",
),
pytest.param(
L(5.556).ln().round(7),
round(math.log(5.556), 7),
id="ln",
),
],
)
def test_math_functions(con, expr, expected, translate):
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ def test_string_temporal_compare_between_datetimes(con, left, right):

@pytest.mark.parametrize('container', [list, tuple, set])
def test_field_in_literals(con, alltypes, translate, container):
foobar = container(['foo', 'bar', 'baz'])
expected = tuple(set(foobar))
values = {'foo', 'bar', 'baz'}
foobar = container(values)
expected = tuple(values)

expr = alltypes.string_col.isin(foobar)
assert translate(expr) == f"`string_col` IN {expected}"
Expand Down
20 changes: 3 additions & 17 deletions ibis/backends/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,20 @@ def test_timestamp_extract_field(con, db, alltypes):


def test_isin_notin_in_select(con, db, alltypes, translate):
values = ['foo', 'bar']
values = {'foo', 'bar'}
filtered = alltypes[alltypes.string_col.isin(values)]
result = ibis.clickhouse.compile(filtered)
expected = """SELECT *
FROM {}.`functional_alltypes`
WHERE `string_col` IN {}"""
assert result == expected.format(db.name, tuple(set(values)))
assert result == expected.format(db.name, tuple(values))

filtered = alltypes[alltypes.string_col.notin(values)]
result = ibis.clickhouse.compile(filtered)
expected = """SELECT *
FROM {}.`functional_alltypes`
WHERE `string_col` NOT IN {}"""
assert result == expected.format(db.name, tuple(set(values)))
assert result == expected.format(db.name, tuple(values))


def test_head(alltypes):
Expand Down Expand Up @@ -211,20 +211,6 @@ def test_scalar_exprs_no_table_refs(expr, expected):
assert ibis.clickhouse.compile(expr) == expected


def test_expr_list_no_table_refs():
exlist = ibis.api.expr_list(
[
ibis.literal(1).name('a'),
ibis.now().name('b'),
ibis.literal(2).log().name('c'),
]
)
result = ibis.clickhouse.compile(exlist)
expected = """\
SELECT 1 AS `a`, now() AS `b`, log(2) AS `c`"""
assert result == expected


# TODO: use alltypes
def test_isnull_case_expr_rewrite_failure(db, alltypes):
# #172, case expression that was not being properly converted into an
Expand Down
10 changes: 6 additions & 4 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import toolz
from dask.base import DaskMethodsMixin

# import the pandas execution module to register dispatched implementations of
# execute_node that the dask backend will later override
import ibis.backends.pandas.execution # noqa: F401
import ibis.common.exceptions as com
import ibis.config
import ibis.expr.schema as sch
Expand All @@ -15,8 +18,7 @@
from .client import DaskDatabase, DaskTable, ibis_schema_to_dask
from .core import execute_and_reset

# Make sure that the pandas backend is loaded, dispatching has been
# executed, and options have been loaded
# Make sure that the pandas backend options have been loaded
ibis.pandas


Expand All @@ -25,11 +27,11 @@ class Backend(BasePandasBackend):
database_class = DaskDatabase
table_class = DaskTable

def connect(self, dictionary):
def do_connect(self, dictionary):
# register dispatchers
from . import udf # noqa: F401

return super().connect(dictionary)
super().do_connect(dictionary)

@property
def version(self):
Expand Down
30 changes: 0 additions & 30 deletions ibis/backends/dask/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""The dask client implementation."""
from functools import partial

import dask.dataframe as dd
import dateutil.parser
import numpy as np
import pandas as pd
from pandas.api.types import DatetimeTZDtype
Expand All @@ -15,7 +13,6 @@
PANDAS_DATE_TYPES,
PANDAS_STRING_TYPES,
_inferable_pandas_dtypes,
convert_timezone,
ibis_dtype_to_pandas,
ibis_schema_to_pandas,
)
Expand Down Expand Up @@ -84,33 +81,6 @@ def convert_datetimetz_to_timestamp(in_dtype, out_dtype, column):
DASK_DATE_TYPES = PANDAS_DATE_TYPES


@sch.convert.register(np.dtype, dt.Timestamp, dd.Series)
def convert_datetime64_to_timestamp(in_dtype, out_dtype, column):
if in_dtype.type == np.datetime64:
return column.astype(out_dtype.to_dask())
try:
# TODO - check this?
series = pd.to_datetime(column, utc=True)
except pd.errors.OutOfBoundsDatetime:
inferred_dtype = infer_dask_dtype(column, skipna=True)
if inferred_dtype in DASK_DATE_TYPES:
# not great, but not really any other option
return column.map(
partial(convert_timezone, timezone=out_dtype.timezone)
)
if inferred_dtype not in DASK_STRING_TYPES:
raise TypeError(
(
'Conversion to timestamp not supported for Series of type '
'{!r}'
).format(inferred_dtype)
)
return column.map(dateutil.parser.parse)
else:
utc_dtype = DatetimeTZDtype('ns', 'UTC')
return series.astype(utc_dtype).dt.tz_convert(out_dtype.timezone)


@sch.convert.register(np.dtype, dt.Interval, dd.Series)
def convert_any_to_interval(_, out_dtype, column):
return column.values.astype(out_dtype.to_dask())
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def execute_until_in_scope(
new_scope.get_value(arg.op(), timecontext)
if hasattr(arg, 'op')
else arg
for arg in computable_args
for (arg, timecontext) in zip(computable_args, arg_timecontexts)
]

result = execute_node(
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
execute_isinf,
execute_isnan,
execute_node_contains_series_sequence,
execute_node_dropna_dataframe,
execute_node_fillna_dataframe_dict,
execute_node_fillna_dataframe_scalar,
execute_node_ifnull_series,
execute_node_not_contains_series_sequence,
execute_node_nullif_series,
Expand Down Expand Up @@ -110,6 +113,11 @@
ops.Difference: [
((dd.DataFrame, dd.DataFrame), execute_difference_dataframe_dataframe)
],
ops.DropNa: [((dd.DataFrame,), execute_node_dropna_dataframe)],
ops.FillNa: [
((dd.DataFrame, simple_types), execute_node_fillna_dataframe_scalar),
((dd.DataFrame,), execute_node_fillna_dataframe_dict),
],
ops.IsNull: [((dd.Series,), execute_series_isnull)],
ops.NotNull: [((dd.Series,), execute_series_notnnull)],
ops.IsNan: [((dd.Series,), execute_isnan)],
Expand Down
9 changes: 7 additions & 2 deletions ibis/backends/dask/execution/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from ..execution import constants


@execute_node.register(ops.MaterializedJoin, dd.DataFrame)
def execute_materialized_join(op, df, **kwargs):
return df


@execute_node.register(
ops.AsOfJoin, dd.DataFrame, dd.DataFrame, (Timedelta, type(None))
)
Expand Down Expand Up @@ -69,9 +74,9 @@ def execute_cross_join(op, left, right, **kwargs):
return result


# TODO - execute_materialized_join - #2553
# TODO - execute_join - #2553
@execute_node.register(ops.Join, dd.DataFrame, dd.DataFrame)
def execute_materialized_join(op, left, right, **kwargs):
def execute_join(op, left, right, **kwargs):
op_type = type(op)

try:
Expand Down
17 changes: 0 additions & 17 deletions ibis/backends/dask/execution/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
NOTE: This file overwrite the pandas backend registered handlers for:
- execute_node_expr_list,
- execute_node_greatest_list,
- execute_node_least_list
Expand All @@ -20,18 +19,14 @@
import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
import numpy as np
import pandas as pd
import toolz

import ibis
import ibis.expr.operations as ops
from ibis.backends.pandas.execution.generic import (
execute_node_expr_list,
execute_node_greatest_list,
execute_node_least_list,
)

from ..core import execute
from ..dispatch import execute_node
from .util import make_selected_obj

Expand Down Expand Up @@ -178,15 +173,3 @@ def execute_standard_dev_series(op, data, mask, aggcontext=None, **kwargs):
'std',
ddof=variance_ddof[op.how],
)


@execute_node.register(ops.ExpressionList, collections.abc.Sequence)
def dask_execute_node_expr_list(op, sequence, **kwargs):
if all(type(s) != dd.Series for s in sequence):
execute_node_expr_list(op, sequence, **kwargs)
columns = [e.get_name() for e in op.exprs]
schema = ibis.schema(list(zip(columns, (e.type() for e in op.exprs))))
data = {col: [execute(el, **kwargs)] for col, el in zip(columns, sequence)}
return schema.apply_to(
dd.from_pandas(pd.DataFrame(data, columns=columns), npartitions=1)
)
3 changes: 0 additions & 3 deletions ibis/backends/dask/execution/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ def execute_selection_dataframe(
op.table.op(), predicates, data, scope, timecontext, **kwargs
)
predicate = functools.reduce(operator.and_, predicates)
assert len(predicate) == len(
result
), 'Selection predicate length does not match underlying table'
result = result.loc[predicate]

if sort_keys:
Expand Down
38 changes: 36 additions & 2 deletions ibis/backends/dask/tests/execution/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_multi_join_with_post_expression_filter(how, left, df1):
)


@pytest.mark.xfail(reason="TODO - execute_materialized_join - #2553")
@pytest.mark.xfail(reason="TODO - execute_join - #2553")
@join_type
def test_join_with_non_trivial_key(how, left, right, df1, df2):
# also test that the order of operands in the predicate doesn't matter
Expand All @@ -261,7 +261,7 @@ def test_join_with_non_trivial_key(how, left, right, df1, df2):
)


@pytest.mark.xfail(reason="TODO - execute_materialized_join - #2553")
@pytest.mark.xfail(reason="TODO - execute_join - #2553")
@join_type
def test_join_with_non_trivial_key_project_table(how, left, right, df1, df2):
# also test that the order of operands in the predicate doesn't matter
Expand Down Expand Up @@ -496,3 +496,37 @@ def test_select_on_unambiguous_asof_join(func, npartitions):
result.compute(scheduler='single-threaded'),
expected.compute(scheduler='single-threaded'),
)


def test_materialized_join(npartitions):
df = dd.from_pandas(
pd.DataFrame({"test": [1, 2, 3], "name": ["a", "b", "c"]}),
npartitions=npartitions,
)
df_2 = dd.from_pandas(
pd.DataFrame({"test_2": [1, 5, 6], "name_2": ["d", "e", "f"]}),
npartitions=npartitions,
)

conn = ibis.dask.connect({"df": df, "df_2": df_2})

ibis_table_1 = conn.table("df")
ibis_table_2 = conn.table("df_2")

joined = ibis_table_1.outer_join(
ibis_table_2,
predicates=ibis_table_1["test"] == ibis_table_2["test_2"],
)
joined = joined.materialize()
result = joined.compile()
expected = dd.merge(
df,
df_2,
left_on="test",
right_on="test_2",
how="outer",
)
tm.assert_frame_equal(
result.compute(scheduler='single-threaded'),
expected.compute(scheduler='single-threaded'),
)
101 changes: 100 additions & 1 deletion ibis/backends/dask/tests/execution/test_timecontext.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from typing import Optional

import dask.dataframe as dd
import pytest
from dask.dataframe.utils import tm
from pandas import Timedelta, Timestamp

import ibis
import ibis.common.exceptions as com
from ibis.expr.timecontext import TimeContextRelation, compare_timecontext
import ibis.expr.operations as ops
from ibis.backends.pandas.execution import execute
from ibis.expr.scope import Scope
from ibis.expr.timecontext import (
TimeContextRelation,
adjust_context,
compare_timecontext,
)
from ibis.expr.types import TimeContext


class CustomAsOfJoin(ops.AsOfJoin):
pass


def test_execute_with_timecontext(time_table):
Expand Down Expand Up @@ -218,3 +232,88 @@ def test_context_adjustment_window_groupby_id(time_table, time_df3):
# result should adjust time context accordingly
result = expr.execute(timecontext=context)
tm.assert_series_equal(result, expected)


def test_adjust_context_scope(time_keyed_left, time_keyed_right):
"""Test that `adjust_context` has access to `scope` by default."""

@adjust_context.register(CustomAsOfJoin)
def adjust_context_custom_asof_join(
op: ops.AsOfJoin,
timecontext: TimeContext,
scope: Optional[Scope] = None,
) -> TimeContext:
"""Confirms that `scope` is passed in."""
assert scope is not None
return timecontext

expr = CustomAsOfJoin(
left=time_keyed_left,
right=time_keyed_right,
predicates='time',
by='key',
tolerance=ibis.interval(days=4),
).to_expr()
expr = expr[time_keyed_left, time_keyed_right.other_value]
context = (Timestamp('20170105'), Timestamp('20170111'))
expr.execute(timecontext=context)


def test_adjust_context_complete_shift(
time_keyed_left,
time_keyed_right,
time_keyed_df1,
time_keyed_df2,
):
"""Test `adjust_context` function that completely shifts the context.
This results in an adjusted context that is NOT a subset of the
original context. This is unlike an `adjust_context` function
that only expands the context.
See #3104
"""

# Create a contrived `adjust_context` function for
# CustomAsOfJoin to mock this.

@adjust_context.register(CustomAsOfJoin)
def adjust_context_custom_asof_join(
op: ops.AsOfJoin,
timecontext: TimeContext,
scope: Optional[Scope] = None,
) -> TimeContext:
"""Shifts both the begin and end in the same direction."""
begin, end = timecontext
timedelta = execute(op.tolerance)
return (begin - timedelta, end - timedelta)

expr = CustomAsOfJoin(
left=time_keyed_left,
right=time_keyed_right,
predicates='time',
by='key',
tolerance=ibis.interval(days=4),
).to_expr()
expr = expr[time_keyed_left, time_keyed_right.other_value]
context = (Timestamp('20170101'), Timestamp('20170111'))
result = expr.execute(timecontext=context)

# Compare with asof_join of manually trimmed tables
# Left table: No shift for context
# Right table: Shift both begin and end of context by 4 days
trimmed_df1 = time_keyed_df1[time_keyed_df1['time'] >= context[0]][
time_keyed_df1['time'] < context[1]
]
trimmed_df2 = time_keyed_df2[
time_keyed_df2['time'] >= context[0] - Timedelta(days=4)
][time_keyed_df2['time'] < context[1] - Timedelta(days=4)]
expected = dd.merge_asof(
trimmed_df1,
trimmed_df2,
on='time',
by='key',
tolerance=Timedelta('4D'),
).compute()

tm.assert_frame_equal(result, expected)
6 changes: 0 additions & 6 deletions ibis/backends/dask/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,6 @@ def test_udf(t, df):
tm.assert_series_equal(result, expected, check_names=False)


def test_elementwise_udf_with_non_vectors(con):
expr = my_add(1.0, 2.0)
result = con.execute(expr)
assert result == 3.0


def test_multiple_argument_udf(con, t, df):
expr = my_add(t.b, t.c)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ibis.expr.datatypes as dt
import ibis.dask
from ibis.udf.vectorized import elementwise
from ibis.dask import trace
from ibis.backends.dask import trace
logging.basicConfig()
trace.enable()
df = dd.from_pandas(
Expand Down
24 changes: 11 additions & 13 deletions ibis/backends/dask/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def make_struct_op_meta(op: ir.Expr) -> List[Tuple[str, np.dtype]]:
"""Unpacks a dt.Struct into a DataFrame meta"""
return list(
zip(
op._output_type.names,
[x.to_dask() for x in op._output_type.types],
op.return_type.names,
[x.to_dask() for x in op.return_type.types],
)
)

Expand Down Expand Up @@ -72,16 +72,14 @@ def execute_udf_node(op, *args, **kwargs):
# kwargs here. This is true for all udf execution in this
# file.
# See ibis.udf.vectorized.UserDefinedFunction
if isinstance(op._output_type, dt.Struct):
if isinstance(op.return_type, dt.Struct):
meta = make_struct_op_meta(op)

df = dd.map_partitions(op.func, *args, meta=meta)
return df
else:
name = args[0].name if len(args) == 1 else None
meta = pandas.Series(
[], name=name, dtype=op._output_type.to_dask()
)
meta = pandas.Series([], name=name, dtype=op.return_type.to_dask())
df = dd.map_partitions(op.func, *args, meta=meta)

return df
Expand Down Expand Up @@ -124,11 +122,11 @@ def lazy_agg(*series: pandas.Series):
# Depending on the type of operation, lazy_result is a Delayed that
# could become a dd.Series or a dd.core.Scalar
if isinstance(op, ops.AnalyticVectorizedUDF):
if isinstance(op._output_type, dt.Struct):
if isinstance(op.return_type, dt.Struct):
meta = make_struct_op_meta(op)
else:
meta = make_meta_series(
dtype=op._output_type.to_dask(),
dtype=op.return_type.to_dask(),
name=args[0].name,
)
result = dd.from_delayed(lazy_result, meta=meta)
Expand All @@ -151,13 +149,13 @@ def lazy_agg(*series: pandas.Series):
result = result.repartition(divisions=original_divisions)
else:
# lazy_result is a dd.core.Scalar from an ungrouped reduction
if isinstance(op._output_type, (dt.Array, dt.Struct)):
if isinstance(op.return_type, (dt.Array, dt.Struct)):
# we're outputing a dt.Struct that will need to be destructured
# or an array of an unknown size.
# we compute so we can work with items inside downstream.
result = lazy_result.compute()
else:
output_meta = safe_scalar_type(op._output_type.to_dask())
output_meta = safe_scalar_type(op.return_type.to_dask())
result = dd.from_delayed(
lazy_result, meta=output_meta, verify_meta=False
)
Expand All @@ -181,7 +179,7 @@ def execute_reduction_node_groupby(op, *args, aggcontext, **kwargs):
func = op.func
groupings = args[0].index
parent_df = args[0].obj
out_type = op._output_type.to_dask()
out_type = op.return_type.to_dask()

grouped_df = parent_df.groupby(groupings)
col_names = [col._meta._selected_obj.name for col in args]
Expand Down Expand Up @@ -223,7 +221,7 @@ def execute_analytic_node_groupby(op, *args, aggcontext, **kwargs):
func = op.func
groupings = args[0].index
parent_df = args[0].obj
out_type = op._output_type.to_dask()
out_type = op.return_type.to_dask()

grouped_df = parent_df.groupby(groupings)
col_names = [col._meta._selected_obj.name for col in args]
Expand All @@ -232,7 +230,7 @@ def apply_wrapper(df, apply_func, col_names):
cols = (df[col] for col in col_names)
return apply_func(*cols)

if isinstance(op._output_type, dt.Struct):
if isinstance(op.return_type, dt.Struct):
# with struct output we destruct to a dataframe directly
meta = dd.utils.make_meta(make_struct_op_meta(op))
meta.index.name = parent_df.index.name
Expand Down
173 changes: 173 additions & 0 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

import re
from typing import Mapping

import datafusion as df
import pyarrow as pa

import ibis.common.exceptions as com
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend

from .compiler import translate


def _to_pyarrow_table(frame):
batches = frame.collect()
if batches:
return pa.Table.from_batches(batches)
else:
# TODO(kszucs): file a bug to datafusion because the fields'
# nullability from frame.schema() is not always consistent
# with the first record batch's schema
return pa.Table.from_batches(batches, schema=frame.schema())


class Backend(BaseBackend):
name = 'datafusion'
builder = None

@property
def version(self):
try:
import importlib.metadata as importlib_metadata
except ImportError:
# TODO: remove this when Python 3.7 support is dropped
import importlib_metadata
return importlib_metadata.version("datafusion")

def do_connect(self, config):
"""
Create a DataFusionClient for use with Ibis
Parameters
----------
config : DataFusionContext or dict
Returns
-------
DataFusionClient
"""
if isinstance(config, df.ExecutionContext):
self._context = config
else:
self._context = df.ExecutionContext()

for name, path in config.items():
strpath = str(path)
if strpath.endswith('.csv'):
self.register_csv(name, path)
elif strpath.endswith('.parquet'):
self.register_parquet(name, path)
else:
raise ValueError(
"Currently the DataFusion backend only supports CSV "
"files with the extension .csv and Parquet files with "
"the .parquet extension."
)

def current_database(self):
raise NotImplementedError()

def list_databases(self, like: str = None) -> list[str]:
raise NotImplementedError()

def list_tables(self, like: str = None, database: str = None) -> list[str]:
"""List the available tables."""
tables = list(self._context.tables())
if like is not None:
pattern = re.compile(like)
return list(filter(lambda t: pattern.findall(t), tables))
return tables

def table(self, name, schema=None):
"""Get an ibis expression representing a DataFusion table.
Parameters
---------
name
The name of the table to retreive
schema
An optional schema
Returns
-------
ibis.expr.types.TableExpr
A table expression
"""
catalog = self._context.catalog()
database = catalog.database('public')
table = database.table(name)
schema = sch.infer(table.schema)
return self.table_class(name, schema, self).to_expr()

def register_csv(self, name, path, schema=None):
"""Register a CSV file with with `name` located at `path`.
Parameters
----------
name
The name of the table
path
The path to the CSV file
schema
An optional schema
"""
self._context.register_csv(name, path, schema=schema)

def register_parquet(self, name, path, schema=None):
"""Register a parquet file with with `name` located at `path`.
Parameters
----------
name
The name of the table
path
The path to the parquet file
schema
An optional schema
"""
self._context.register_parquet(name, path, schema=schema)

def execute(
self,
expr: ir.Expr,
params: Mapping[ir.Expr, object] = None,
limit: str = 'default',
**kwargs,
):
if isinstance(expr, ir.TableExpr):
frame = self.compile(expr, params, **kwargs)
table = _to_pyarrow_table(frame)
return table.to_pandas()
elif isinstance(expr, ir.ColumnExpr):
# expression must be named for the projection
expr = expr.name('tmp').to_projection()
frame = self.compile(expr, params, **kwargs)
table = _to_pyarrow_table(frame)
return table['tmp'].to_pandas()
elif isinstance(expr, ir.ScalarExpr):
if expr.op().root_tables():
# there are associated datafusion tables so convert the expr
# to a selection which we can directly convert to a datafusion
# plan
expr = expr.name('tmp').to_projection()
frame = self.compile(expr, params, **kwargs)
else:
# doesn't have any tables associated so create a plan from a
# dummy datafusion table
compiled = self.compile(expr, params, **kwargs)
frame = self._context.empty_table().select(compiled)
table = _to_pyarrow_table(frame)
return table[0][0].as_py()
else:
raise com.IbisError(
f"Cannot execute expression of type: {type(expr)}"
)

def compile(
self, expr: ir.Expr, params: Mapping[ir.Expr, object] = None, **kwargs
):
return translate(expr)
396 changes: 396 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,396 @@
import functools
import operator

import datafusion as df
import datafusion.functions
import pyarrow as pa

import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.types as ir

from .datatypes import to_pyarrow_type


@functools.singledispatch
def translate(expr):
raise NotImplementedError(expr)


@translate.register(ir.Expr)
def expression(expr):
return translate(expr.op(), expr)


@translate.register(ops.Node)
def operation(op, expr):
raise com.OperationNotDefinedError(f'No translation rule for {type(op)}')


@translate.register(ops.DatabaseTable)
def table(op, expr):
name, _, client = op.args
return client._context.table(name)


@translate.register(ops.Literal)
def literal(op, expr):
if isinstance(op.value, (set, frozenset)):
value = list(op.value)
else:
value = op.value

arrow_type = to_pyarrow_type(op.dtype)
arrow_scalar = pa.scalar(value, type=arrow_type)

return df.literal(arrow_scalar)


@translate.register(ops.Cast)
def cast(op, expr):
arg = translate(op.arg)
typ = to_pyarrow_type(op.to)
return arg.cast(to=typ)


@translate.register(ops.TableColumn)
def column(op, expr):
table_op = op.table.op()

if hasattr(table_op, "name"):
return df.column(f'{table_op.name}."{op.name}"')
else:
return df.column(op.name)


@translate.register(ops.SortKey)
def sort_key(op, expr):
arg = translate(op.expr)
return arg.sort(ascending=op.ascending)


@translate.register(ops.Selection)
def selection(op, expr):
plan = translate(op.table)

selections = []
for expr in op.selections or [op.table]:
# TODO(kszucs) it would be nice if we wouldn't need to handle the
# specific cases in the backend implementations, we could add a
# new operator which retrieves all of the TableExpr columns
# (.e.g. Asterisk) so the translate() would handle this
# automatically
if isinstance(expr, ir.TableExpr):
for name in expr.columns:
column = expr.get_column(name)
field = translate(column)
if column.has_name():
field = field.alias(column.get_name())
selections.append(field)
elif isinstance(expr, ir.ValueExpr):
field = translate(expr)
if expr.has_name():
field = field.alias(expr.get_name())
selections.append(field)
else:
raise com.TranslationError(
"DataFusion backend is unable to compile selection with "
f"expression type of {type(expr)}"
)

plan = plan.select(*selections)

if op.predicates:
predicates = map(translate, op.predicates)
predicate = functools.reduce(operator.and_, predicates)
plan = plan.filter(predicate)

if op.sort_keys:
sort_keys = map(translate, op.sort_keys)
plan = plan.sort(*sort_keys)

return plan


@translate.register(ops.Aggregation)
def aggregation(op, expr):
table = translate(op.table)
group_by = [translate(expr) for expr in op.by]

metrics = []
for expr in op.metrics:
agg = translate(expr)
if expr.has_name():
agg = agg.alias(expr.get_name())
metrics.append(agg)

return table.aggregate(group_by, metrics)


@translate.register(ops.Not)
def invert(op, expr):
arg = translate(op.arg)
return ~arg


@translate.register(ops.Abs)
def abs(op, expr):
arg = translate(op.arg)
return df.functions.abs(arg)


@translate.register(ops.Ceil)
def ceil(op, expr):
arg = translate(op.arg)
return df.functions.ceil(arg).cast(pa.int64())


@translate.register(ops.Floor)
def floor(op, expr):
arg = translate(op.arg)
return df.functions.floor(arg).cast(pa.int64())


@translate.register(ops.Round)
def round(op, expr):
arg = translate(op.arg)
if op.digits is not None:
raise com.UnsupportedOperationError(
'Rounding to specific digits is not supported in datafusion'
)
return df.functions.round(arg).cast(pa.int64())


@translate.register(ops.Ln)
def ln(op, expr):
arg = translate(op.arg)
return df.functions.ln(arg)


@translate.register(ops.Log2)
def log2(op, expr):
arg = translate(op.arg)
return df.functions.log2(arg)


@translate.register(ops.Log10)
def log10(op, expr):
arg = translate(op.arg)
return df.functions.log10(arg)


@translate.register(ops.Sqrt)
def sqrt(op, expr):
arg = translate(op.arg)
return df.functions.sqrt(arg)


@translate.register(ops.Strip)
def strip(op, expr):
arg = translate(op.arg)
return df.functions.trim(arg)


@translate.register(ops.LStrip)
def lstrip(op, expr):
arg = translate(op.arg)
return df.functions.ltrim(arg)


@translate.register(ops.RStrip)
def rstrip(op, expr):
arg = translate(op.arg)
return df.functions.rtrim(arg)


@translate.register(ops.Lowercase)
def lower(op, expr):
arg = translate(op.arg)
return df.functions.lower(arg)


@translate.register(ops.Uppercase)
def upper(op, expr):
arg = translate(op.arg)
return df.functions.upper(arg)


@translate.register(ops.Reverse)
def reverse(op, expr):
arg = translate(op.arg)
return df.functions.reverse(arg)


@translate.register(ops.StringLength)
def strlen(op, expr):
arg = translate(op.arg)
return df.functions.character_length(arg)


@translate.register(ops.Capitalize)
def capitalize(op, expr):
arg = translate(op.arg)
return df.functions.initcap(arg)


@translate.register(ops.Substring)
def substring(op, expr):
arg = translate(op.arg)
start = translate(op.start + 1)
length = translate(op.length)
return df.functions.substr(arg, start, length)


@translate.register(ops.RegexExtract)
def regex_extract(op, expr):
arg = translate(op.arg)
pattern = translate(op.pattern)
return df.functions.regexp_match(arg, pattern)


@translate.register(ops.Repeat)
def repeat(op, expr):
arg = translate(op.arg)
times = translate(op.times)
return df.functions.repeat(arg, times)


@translate.register(ops.LPad)
def lpad(op, expr):
arg = translate(op.arg)
length = translate(op.length)
pad = translate(op.pad)
return df.functions.lpad(arg, length, pad)


@translate.register(ops.RPad)
def rpad(op, expr):
arg = translate(op.arg)
length = translate(op.length)
pad = translate(op.pad)
return df.functions.rpad(arg, length, pad)


@translate.register(ops.GreaterEqual)
def ge(op, expr):
return translate(op.left) >= translate(op.right)


@translate.register(ops.LessEqual)
def le(op, expr):
return translate(op.left) <= translate(op.right)


@translate.register(ops.Greater)
def gt(op, expr):
return translate(op.left) > translate(op.right)


@translate.register(ops.Less)
def lt(op, expr):
return translate(op.left) < translate(op.right)


@translate.register(ops.Equals)
def eq(op, expr):
return translate(op.left) == translate(op.right)


@translate.register(ops.NotEquals)
def ne(op, expr):
return translate(op.left) != translate(op.right)


@translate.register(ops.Add)
def add(op, expr):
return translate(op.left) + translate(op.right)


@translate.register(ops.Subtract)
def sub(op, expr):
return translate(op.left) - translate(op.right)


@translate.register(ops.Multiply)
def mul(op, expr):
return translate(op.left) * translate(op.right)


@translate.register(ops.Divide)
def div(op, expr):
return translate(op.left) / translate(op.right)


@translate.register(ops.FloorDivide)
def floordiv(op, expr):
return df.functions.floor(translate(op.left) / translate(op.right))


@translate.register(ops.Modulus)
def mod(op, expr):
return translate(op.left) % translate(op.right)


@translate.register(ops.Sum)
def sum(op, expr):
arg = translate(op.arg)
return df.functions.sum(arg)


@translate.register(ops.Min)
def min(op, expr):
arg = translate(op.arg)
return df.functions.min(arg)


@translate.register(ops.Max)
def max(op, expr):
arg = translate(op.arg)
return df.functions.max(arg)


@translate.register(ops.Mean)
def mean(op, expr):
arg = translate(op.arg)
return df.functions.avg(arg)


def _prepare_contains_options(options):
if isinstance(options, ir.AnyScalar):
# TODO(kszucs): it would be better if we could pass an arrow
# ListScalar to datafusions in_list function
return [df.literal(v) for v in options.op().value]
else:
return translate(options)


@translate.register(ops.ValueList)
def value_list(op, expr):
return list(map(translate, op.values))


@translate.register(ops.Contains)
def contains(op, expr):
value = translate(op.value)
options = _prepare_contains_options(op.options)
return df.functions.in_list(value, options, negated=False)


@translate.register(ops.NotContains)
def not_contains(op, expr):
value = translate(op.value)
options = _prepare_contains_options(op.options)
return df.functions.in_list(value, options, negated=True)


@translate.register(ops.ElementWiseVectorizedUDF)
def elementwise_udf(op, expr):
udf = df.udf(
op.func,
input_types=list(map(to_pyarrow_type, op.input_type)),
return_type=to_pyarrow_type(op.return_type),
volatility="volatile",
)
args = map(translate, op.func_args)

return udf(*args)
90 changes: 90 additions & 0 deletions ibis/backends/datafusion/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import functools

import pyarrow as pa

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch

# TODO(kszucs): the following conversions are really rudimentary
# we should have a pyarrow backend which would be responsible
# for conversions between ibis types to pyarrow types

# TODO(kszucs): support nested and parametric types
# consolidate with the logic from the parquet backend


_to_ibis_dtypes = {
pa.int8(): dt.Int8,
pa.int16(): dt.Int16,
pa.int32(): dt.Int32,
pa.int64(): dt.Int64,
pa.uint8(): dt.UInt8,
pa.uint16(): dt.UInt16,
pa.uint32(): dt.UInt32,
pa.uint64(): dt.UInt64,
pa.float16(): dt.Float16,
pa.float32(): dt.Float32,
pa.float64(): dt.Float64,
pa.string(): dt.String,
pa.binary(): dt.Binary,
pa.bool_(): dt.Boolean,
}


@dt.dtype.register(pa.DataType)
def from_pyarrow_primitive(arrow_type, nullable=True):
return _to_ibis_dtypes[arrow_type](nullable=nullable)


@dt.dtype.register(pa.TimestampType)
def from_pyarrow_timestamp(arrow_type, nullable=True):
return dt.TimestampType(timezone=arrow_type.tz)


@sch.infer.register(pa.Schema)
def infer_pyarrow_schema(schema):
fields = [(f.name, dt.dtype(f.type, nullable=f.nullable)) for f in schema]
return sch.schema(fields)


_to_pyarrow_types = {
dt.Int8: pa.int8(),
dt.Int16: pa.int16(),
dt.Int32: pa.int32(),
dt.Int64: pa.int64(),
dt.UInt8: pa.uint8(),
dt.UInt16: pa.uint16(),
dt.UInt32: pa.uint32(),
dt.UInt64: pa.uint64(),
dt.Float16: pa.float16(),
dt.Float32: pa.float32(),
dt.Float64: pa.float64(),
dt.String: pa.string(),
dt.Binary: pa.binary(),
dt.Boolean: pa.bool_(),
dt.Timestamp: pa.timestamp('ns'),
}


@functools.singledispatch
def to_pyarrow_type(dtype):
return _to_pyarrow_types[dtype.__class__]


@to_pyarrow_type.register(dt.Array)
def from_ibis_array(dtype):
return pa.list_(to_pyarrow_type(dtype.value_type))


@to_pyarrow_type.register(dt.Set)
def from_ibis_set(dtype):
return pa.list_(to_pyarrow_type(dtype.value_type))


@to_pyarrow_type.register(dt.Interval)
def from_ibis_interval(dtype):
try:
return pa.duration(dtype.unit)
except ValueError:
raise com.IbisTypeError(f"Unsupported interval unit: {dtype.unit}")
Empty file.
77 changes: 77 additions & 0 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from pathlib import Path

import pyarrow as pa
import pytest

import ibis
import ibis.expr.types as ir
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero


class TestConf(BackendTest, RoundAwayFromZero):
# check_names = False
# additional_skipped_operations = frozenset({ops.StringSQLLike})
# supports_divide_by_zero = True
# returned_timestamp_unit = 'ns'
bool_is_int = True

@staticmethod
def connect(data_directory: Path):
# can be various types:
# pyarrow.RecordBatch
# parquet file path
# csv file path
client = ibis.datafusion.connect({})
client.register_csv(
name='functional_alltypes',
path=data_directory / 'functional_alltypes.csv',
schema=pa.schema(
[
('index', 'int64'),
('Unnamed 0', 'int64'),
('id', 'int64'),
('bool_col', 'int8'),
('tinyint_col', 'int8'),
('smallint_col', 'int16'),
('int_col', 'int32'),
('bigint_col', 'int64'),
('float_col', 'float32'),
('double_col', 'float64'),
('date_string_col', 'string'),
('string_col', 'string'),
('timestamp_col', 'string'),
('year', 'int64'),
('month', 'int64'),
]
),
)
client.register_csv(
name='batting', path=data_directory / 'batting.csv'
)
client.register_csv(
name='awards_players', path=data_directory / 'awards_players.csv'
)
return client

@property
def functional_alltypes(self) -> ir.TableExpr:
t = self.connection.table('functional_alltypes')
return t.mutate(
bool_col=t.bool_col == 1,
timestamp_col=t.timestamp_col.cast('timestamp'),
)


@pytest.fixture(scope='session')
def client(data_directory):
return TestConf.connect(data_directory)


@pytest.fixture(scope='session')
def alltypes(client):
return client.table("functional_alltypes")


@pytest.fixture(scope='session')
def alltypes_df(alltypes):
return alltypes.execute()
6 changes: 6 additions & 0 deletions ibis/backends/datafusion/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def test_list_tables(client):
assert set(client.list_tables()) == {
'awards_players',
'batting',
'functional_alltypes',
}
20 changes: 20 additions & 0 deletions ibis/backends/datafusion/tests/test_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .conftest import BackendTest


def test_where_multiple_conditions(alltypes, alltypes_df):
expr = alltypes.filter(
[
alltypes.float_col > 0,
alltypes.smallint_col == 9,
alltypes.int_col < alltypes.float_col * 2,
]
)
result = expr.execute()

expected = alltypes_df[
(alltypes_df['float_col'] > 0)
& (alltypes_df['smallint_col'] == 9)
& (alltypes_df['int_col'] < alltypes_df['float_col'] * 2)
]

BackendTest.assert_frame_equal(result, expected)
43 changes: 43 additions & 0 deletions ibis/backends/datafusion/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pandas.testing as tm
import pyarrow.compute as pc

import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis.udf.vectorized import elementwise, reduction


@elementwise(input_type=['string'], output_type='int64')
def my_string_length(arr, **kwargs):
# arr is a pyarrow.StringArray
return pc.cast(pc.multiply(pc.utf8_length(arr), 2), target_type='int64')


@elementwise(input_type=[dt.int64, dt.int64], output_type=dt.int64)
def my_add(arr1, arr2, **kwargs):
return pc.add(arr1, arr2)


@reduction(input_type=[dt.float64], output_type=dt.float64)
def my_mean(arr):
return pc.mean(arr)


def test_udf(alltypes):
data_string_col = alltypes.date_string_col.execute()
expected = data_string_col.str.len() * 2

expr = my_string_length(alltypes.date_string_col)
assert isinstance(expr, ir.ColumnExpr)

result = expr.execute()
tm.assert_series_equal(result, expected, check_names=False)


def test_multiple_argument_udf(alltypes):
expr = my_add(alltypes.smallint_col, alltypes.int_col)
result = expr.execute()

df = alltypes[['smallint_col', 'int_col']].execute()
expected = (df.smallint_col + df.int_col).astype('int64')

tm.assert_series_equal(result, expected.rename('tmp'))
12 changes: 5 additions & 7 deletions ibis/backends/hdf5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import warnings

import pandas as pd

import ibis.expr.operations as ops
import ibis.expr.schema as sch
from ibis.backends.base.file import BaseFileBackend
from ibis.backends.pandas.core import execute, execute_node
from ibis.util import warn_deprecated


class HDFTable(ops.DatabaseTable):
Expand Down Expand Up @@ -50,11 +49,10 @@ def list_databases(self, path=None, like=None):
if path is None:
path = self.path
else:
warnings.warn(
'The `path` argument of `list_databases` is deprecated and '
'will be removed in a future version of Ibis. Connect to a '
'different path with the `connect()` method instead.',
FutureWarning,
warn_deprecated(
'The `path` argument of `list_databases`',
version='2.0',
instead='`connect()` with a different path',
)
databases = self._list_databases_dirs_or_files(path)
return self._filter_with_like(databases, like)
Expand Down
15 changes: 3 additions & 12 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import io
import operator
import re
import warnings
import weakref
from posixpath import join as pjoin

Expand Down Expand Up @@ -170,8 +169,8 @@ class Backend(BaseSQLBackend):
def hdfs_connect(self, *args, **kwargs):
return hdfs_connect(*args, **kwargs)

def connect(
self,
def do_connect(
new_backend,
host='localhost',
port=21050,
database='default',
Expand Down Expand Up @@ -240,7 +239,6 @@ def connect(
"""
import hdfs

new_backend = self.__class__()
new_backend._kudu = None
new_backend._temp_objects = set()

Expand All @@ -267,8 +265,6 @@ def connect(

new_backend._ensure_temp_db_exists()

return new_backend

@property
def version(self):
cursor = self.raw_sql('select version()')
Expand Down Expand Up @@ -374,16 +370,11 @@ def _get_list(self, cur):
tuples = cur.fetchall()
return list(map(operator.itemgetter(0), tuples))

@util.deprecated(version='2.0', instead='a new connection to database')
def set_database(self, name):
# XXX The parent `Client` has a generic method that calls this same
# method in the backend. But for whatever reason calling this code from
# that method doesn't seem to work. Maybe `con` is a copy?
warnings.warn(
'`set_database` is deprecated and will be removed in a future '
'version of Ibis. Create a new connection to the desired database '
'instead',
FutureWarning,
)
self.con.set_database(name)

@property
Expand Down
11 changes: 2 additions & 9 deletions ibis/backends/impala/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,19 +689,12 @@ def hdfs_connect(
session = requests.Session()
session.verify = verify
if auth_mechanism in ('GSSAPI', 'LDAP'):
from hdfs.ext.kerberos import KerberosClient

if use_https == 'default':
prefix = 'https'
else:
prefix = 'https' if use_https else 'http'
try:
import requests_kerberos # noqa: F401
except ImportError:
raise com.IbisError(
"Unable to import requests-kerberos, which is required for "
"Kerberos HDFS support. Install it by executing `pip install "
"requests-kerberos` or `pip install hdfs[kerberos]`."
)
from hdfs.ext.kerberos import KerberosClient

# note SSL
url = f'{prefix}://{host}:{port}'
Expand Down
11 changes: 6 additions & 5 deletions ibis/backends/impala/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import time

import pandas as pd
import pytest
Expand Down Expand Up @@ -363,11 +362,13 @@ def test_day_of_week(con):
assert result == 'Sunday'


def test_time_to_int_cast(con):
now = pytz.utc.localize(datetime.datetime.now())
d = ibis.literal(now)
def test_datetime_to_int_cast(con):
timestamp = pytz.utc.localize(
datetime.datetime(2021, 9, 12, 14, 45, 33, 0)
)
d = ibis.literal(timestamp)
result = con.execute(d.cast('int64'))
assert result == int(time.mktime(now.timetuple())) * 1000000
assert result == pd.Timestamp(timestamp).value // 1000


def test_set_option_with_dot(con):
Expand Down
98 changes: 43 additions & 55 deletions ibis/backends/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import ibis.expr.api as api
import ibis.expr.types as ir
from ibis import literal as L
from ibis.backends.impala.compiler import ImpalaCompiler, ImpalaExprTranslator
from ibis.common.exceptions import RelationError
from ibis.expr.datatypes import Category
from ibis.tests.expr.mocks import MockBackend
from ibis.tests.sql.test_compiler import ExprTestCases

from ..compiler import ImpalaCompiler, ImpalaExprTranslator


def approx_equal(a, b, eps):
assert abs(a - b) < eps
Expand Down Expand Up @@ -560,27 +559,27 @@ def test_bucket_to_case(self):
expr1 = self.table.f.bucket(buckets)
expected1 = """\
CASE
WHEN (`f` >= 0) AND (`f` < 10) THEN 0
WHEN (`f` >= 10) AND (`f` < 25) THEN 1
WHEN (`f` >= 25) AND (`f` <= 50) THEN 2
WHEN (0 <= `f`) AND (`f` < 10) THEN 0
WHEN (10 <= `f`) AND (`f` < 25) THEN 1
WHEN (25 <= `f`) AND (`f` <= 50) THEN 2
ELSE CAST(NULL AS tinyint)
END"""

expr2 = self.table.f.bucket(buckets, close_extreme=False)
expected2 = """\
CASE
WHEN (`f` >= 0) AND (`f` < 10) THEN 0
WHEN (`f` >= 10) AND (`f` < 25) THEN 1
WHEN (`f` >= 25) AND (`f` < 50) THEN 2
WHEN (0 <= `f`) AND (`f` < 10) THEN 0
WHEN (10 <= `f`) AND (`f` < 25) THEN 1
WHEN (25 <= `f`) AND (`f` < 50) THEN 2
ELSE CAST(NULL AS tinyint)
END"""

expr3 = self.table.f.bucket(buckets, closed='right')
expected3 = """\
CASE
WHEN (`f` >= 0) AND (`f` <= 10) THEN 0
WHEN (`f` > 10) AND (`f` <= 25) THEN 1
WHEN (`f` > 25) AND (`f` <= 50) THEN 2
WHEN (0 <= `f`) AND (`f` <= 10) THEN 0
WHEN (10 < `f`) AND (`f` <= 25) THEN 1
WHEN (25 < `f`) AND (`f` <= 50) THEN 2
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -589,19 +588,19 @@ def test_bucket_to_case(self):
)
expected4 = """\
CASE
WHEN (`f` > 0) AND (`f` <= 10) THEN 0
WHEN (`f` > 10) AND (`f` <= 25) THEN 1
WHEN (`f` > 25) AND (`f` <= 50) THEN 2
WHEN (0 < `f`) AND (`f` <= 10) THEN 0
WHEN (10 < `f`) AND (`f` <= 25) THEN 1
WHEN (25 < `f`) AND (`f` <= 50) THEN 2
ELSE CAST(NULL AS tinyint)
END"""

expr5 = self.table.f.bucket(buckets, include_under=True)
expected5 = """\
CASE
WHEN `f` < 0 THEN 0
WHEN (`f` >= 0) AND (`f` < 10) THEN 1
WHEN (`f` >= 10) AND (`f` < 25) THEN 2
WHEN (`f` >= 25) AND (`f` <= 50) THEN 3
WHEN (0 <= `f`) AND (`f` < 10) THEN 1
WHEN (10 <= `f`) AND (`f` < 25) THEN 2
WHEN (25 <= `f`) AND (`f` <= 50) THEN 3
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -611,10 +610,10 @@ def test_bucket_to_case(self):
expected6 = """\
CASE
WHEN `f` < 0 THEN 0
WHEN (`f` >= 0) AND (`f` < 10) THEN 1
WHEN (`f` >= 10) AND (`f` < 25) THEN 2
WHEN (`f` >= 25) AND (`f` <= 50) THEN 3
WHEN `f` > 50 THEN 4
WHEN (0 <= `f`) AND (`f` < 10) THEN 1
WHEN (10 <= `f`) AND (`f` < 25) THEN 2
WHEN (25 <= `f`) AND (`f` <= 50) THEN 3
WHEN 50 < `f` THEN 4
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -624,10 +623,10 @@ def test_bucket_to_case(self):
expected7 = """\
CASE
WHEN `f` < 0 THEN 0
WHEN (`f` >= 0) AND (`f` < 10) THEN 1
WHEN (`f` >= 10) AND (`f` < 25) THEN 2
WHEN (`f` >= 25) AND (`f` < 50) THEN 3
WHEN `f` >= 50 THEN 4
WHEN (0 <= `f`) AND (`f` < 10) THEN 1
WHEN (10 <= `f`) AND (`f` < 25) THEN 2
WHEN (25 <= `f`) AND (`f` < 50) THEN 3
WHEN 50 <= `f` THEN 4
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -637,9 +636,9 @@ def test_bucket_to_case(self):
expected8 = """\
CASE
WHEN `f` <= 0 THEN 0
WHEN (`f` > 0) AND (`f` <= 10) THEN 1
WHEN (`f` > 10) AND (`f` <= 25) THEN 2
WHEN (`f` > 25) AND (`f` <= 50) THEN 3
WHEN (0 < `f`) AND (`f` <= 10) THEN 1
WHEN (10 < `f`) AND (`f` <= 25) THEN 2
WHEN (25 < `f`) AND (`f` <= 50) THEN 3
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -649,7 +648,7 @@ def test_bucket_to_case(self):
expected9 = """\
CASE
WHEN `f` <= 10 THEN 0
WHEN `f` > 10 THEN 1
WHEN 10 < `f` THEN 1
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -659,7 +658,7 @@ def test_bucket_to_case(self):
expected10 = """\
CASE
WHEN `f` < 10 THEN 0
WHEN `f` >= 10 THEN 1
WHEN 10 <= `f` THEN 1
ELSE CAST(NULL AS tinyint)
END"""

Expand Down Expand Up @@ -687,7 +686,7 @@ def test_cast_category_to_int_noop(self):
expected = """\
CASE
WHEN `f` < 10 THEN 0
WHEN `f` >= 10 THEN 1
WHEN 10 <= `f` THEN 1
ELSE CAST(NULL AS tinyint)
END"""

Expand All @@ -698,7 +697,7 @@ def test_cast_category_to_int_noop(self):
expected2 = """\
CAST(CASE
WHEN `f` < 10 THEN 0
WHEN `f` >= 10 THEN 1
WHEN 10 <= `f` THEN 1
ELSE CAST(NULL AS tinyint)
END AS double)"""

Expand Down Expand Up @@ -727,9 +726,9 @@ def test_bucket_assign_labels(self):
SELECT
CASE
WHEN `f` < 0 THEN 0
WHEN (`f` >= 0) AND (`f` < 10) THEN 1
WHEN (`f` >= 10) AND (`f` < 25) THEN 2
WHEN (`f` >= 25) AND (`f` <= 50) THEN 3
WHEN (0 <= `f`) AND (`f` < 10) THEN 1
WHEN (10 <= `f`) AND (`f` < 25) THEN 2
WHEN (25 <= `f`) AND (`f` <= 50) THEN 3
ELSE CAST(NULL AS tinyint)
END AS `tier`, count(*) AS `count`
FROM alltypes
Expand All @@ -752,8 +751,8 @@ def setUp(self):
self.table = self.con.table('alltypes')

def test_field_in_literals(self):
values = ['foo', 'bar', 'baz']
values_formatted = tuple(set(values))
values = {'foo', 'bar', 'baz'}
values_formatted = tuple(values)
cases = [
(self.table.g.isin(values), f"`g` IN {values_formatted}"),
(
Expand All @@ -777,8 +776,8 @@ def test_literal_in_list(self):
self._check_expr_cases(cases)

def test_isin_notin_in_select(self):
values = ['foo', 'bar']
values_formatted = tuple(set(values))
values = {'foo', 'bar'}
values_formatted = tuple(values)

filtered = self.table[self.table.g.isin(values)]
result = ImpalaCompiler.to_sql(filtered)
Expand Down Expand Up @@ -1010,35 +1009,24 @@ def test_table_info(alltypes):
assert buf.getvalue() is not None


@pytest.mark.parametrize(('expr', 'expected'), [(L(1) + L(2), 3)])
def test_execute_exprs_no_table_ref(con, expr, expected):
result = con.execute(expr)
assert result == expected

# ExprList
exlist = ibis.api.expr_list(
[L(1).name('a'), ibis.now().name('b'), L(2).log().name('c')]
)
con.execute(exlist)


def test_summary_execute(alltypes):
table = alltypes

# also test set_column while we're at it
table = table.set_column('double_col', table.double_col * 2)

expr = table.double_col.summary()
metrics = table.double_col.summary()
expr = table.aggregate(metrics)
repr(expr)

result = expr.execute()
assert isinstance(result, pd.DataFrame)

expr = table.group_by('string_col').aggregate(
[
table.double_col.summary().prefix('double_'),
table.float_col.summary().prefix('float_'),
table.string_col.summary().suffix('_string'),
table.double_col.summary(prefix='double_'),
table.float_col.summary(prefix='float_'),
table.string_col.summary(suffix='_string'),
]
)
result = expr.execute()
Expand Down
73 changes: 21 additions & 52 deletions ibis/backends/impala/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import re

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import ibis.expr.signature as sig
import ibis.udf.validate as v
import ibis.util as util
from ibis.backends.base.sql.registry import fixed_arity, sql_type_names
Expand All @@ -33,17 +33,16 @@
]


class Function:
class Function(metaclass=abc.ABCMeta):
def __init__(self, inputs, output, name):
self.inputs = tuple(map(dt.dtype, inputs))
self.output = dt.dtype(output)
self.name = name
self._klass = self._create_operation(name)
self.name = name or util.guid()
self._klass = self._create_operation_class()

def _create_operation(self, name):
class_name = self._get_class_name(name)
input_type, output_type = self._type_signature()
return _create_operation_class(class_name, input_type, output_type)
@abc.abstractmethod
def _create_operation_class(self):
pass

def __repr__(self):
klass = type(self).__name__
Expand All @@ -68,35 +67,22 @@ def register(self, name, database):


class ScalarFunction(Function):
def _get_class_name(self, name):
if name is None:
name = util.guid()
return f'UDF_{name}'

def _type_signature(self):
input_type = _ibis_signature(self.inputs)
output_type = rlz.shape_like('args', dt.dtype(self.output))
return input_type, output_type
def _create_operation_class(self):
fields = {
f'_{i}': rlz.value(dtype) for i, dtype in enumerate(self.inputs)
}
fields['output_type'] = rlz.shape_like('args', self.output)
return type(f"UDF_{self.name}", (ops.ValueOp,), fields)


class AggregateFunction(Function):
def _create_operation(self, name):
klass = super()._create_operation(name)
klass._reduction = True
return klass

def _get_class_name(self, name):
if name is None:
name = util.guid()
return f'UDA_{name}'

def _type_signature(self):
def output_type(op):
return dt.dtype(self.output).scalar_type()

input_type = _ibis_signature(self.inputs)

return input_type, output_type
def _create_operation_class(self):
fields = {
f'_{i}': rlz.value(dtype) for i, dtype in enumerate(self.inputs)
}
fields['output_type'] = lambda op: self.output.scalar_type()
fields['_reduction'] = True
return type(f"UDA_{self.name}", (ops.ValueOp,), fields)


class ImpalaFunction:
Expand Down Expand Up @@ -287,23 +273,6 @@ def aggregate_function(inputs, output, name=None):
return AggregateFunction(inputs, output, name=name)


def _ibis_signature(inputs):
if isinstance(inputs, sig.TypeSignature):
return inputs

arguments = [
(f'_{i}', sig.Argument(rlz.value(dtype)))
for i, dtype in enumerate(inputs)
]
return sig.TypeSignature(arguments)


def _create_operation_class(name, input_type, output_type):
func_dict = {'signature': input_type, 'output_type': output_type}
klass = type(name, (ops.ValueOp,), func_dict)
return klass


def add_operation(op, func_name, db):
"""
Registers the given operation within the Ibis SQL translation toolchain
Expand All @@ -319,7 +288,7 @@ def add_operation(op, func_name, db):
# if op.input_type is rlz.listof:
# translator = comp.varargs(full_name)
# else:
arity = len(op.signature)
arity = len(op.__signature__.parameters)
translator = fixed_arity(full_name, arity)

ImpalaExprTranslator._registry[op] = translator
Expand Down
8 changes: 3 additions & 5 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Backend(BaseAlchemyBackend):
name = 'mysql'
compiler = MySQLCompiler

def connect(
def do_connect(
self,
host='localhost',
user=None,
Expand Down Expand Up @@ -96,10 +96,8 @@ def connect(
driver=f'mysql+{driver}',
)

new_backend = super().connect(sqlalchemy.create_engine(alchemy_url))
new_backend.database_name = alchemy_url.database

return new_backend
self.database_name = alchemy_url.database
super().do_connect(sqlalchemy.create_engine(alchemy_url))

@contextlib.contextmanager
def begin(self):
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ def _group_concat(t, expr):
return sa.func.group_concat(arg.op('SEPARATOR')(t.translate(sep)))


def _day_of_week_index(t, expr):
(arg,) = expr.op().args
left = sa.func.dayofweek(t.translate(arg)) - 2
right = 7
return ((left % right) + right) % right


def _day_of_week_name(t, expr):
(arg,) = expr.op().args
return sa.func.dayname(t.translate(arg))


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -259,5 +271,7 @@ def _group_concat(t, expr):
ops.TimestampNow: fixed_arity(sa.func.now, 0),
# others
ops.GroupConcat: _group_concat,
ops.DayOfWeekIndex: _day_of_week_index,
ops.DayOfWeekName: _day_of_week_name,
}
)
6 changes: 2 additions & 4 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BasePandasBackend(BaseBackend):
Base class for backends based on pandas.
"""

def connect(self, dictionary):
def do_connect(self, dictionary):
"""Construct a client from a dictionary of DataFrames.
Parameters
Expand All @@ -30,9 +30,7 @@ def connect(self, dictionary):
from . import execution # noqa F401
from . import udf # noqa F401

new_backend = self.__class__()
new_backend.dictionary = dictionary
return new_backend
self.dictionary = dictionary

def from_dataframe(self, df, name='df', client=None):
"""
Expand Down
199 changes: 86 additions & 113 deletions ibis/backends/pandas/aggcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,121 +574,94 @@ def agg(
group_by = self.group_by
order_by = self.order_by

# if we don't have a grouping key, just call into pandas
if not group_by and not order_by:
# the result of calling .rolling(...) in pandas
windowed = self.construct_window(grouped_data)

# if we're a UD(A)F or a function that isn't a string (like the
# collect implementation) then call apply
if callable(function):
return windowed.apply(
wrap_for_apply(function, args, kwargs), raw=True
)
else:
# otherwise we're a string and probably faster
assert isinstance(function, str)
method = getattr(windowed, function, None)
if method is not None:
return method(*args, **kwargs)

# handle the case where we pulled out a name from an operation
# but it doesn't actually exist
return windowed.apply(
wrap_for_apply(
operator.methodcaller(function, *args, **kwargs)
),
raw=True,
)
assert group_by or order_by

# Get the DataFrame from which the operand originated
# (passed in when constructing this context object in
# execute_node(ops.WindowOp))
parent = self.parent
frame = getattr(parent, 'obj', parent)
obj = getattr(grouped_data, 'obj', grouped_data)
name = obj.name
if frame[name] is not obj or name in group_by or name in order_by:
name = f"{name}_{ibis.util.guid()}"
frame = frame.assign(**{name: obj})

# set the index to our order_by keys and append it to the existing
# index
# TODO: see if we can do this in the caller, when the context
# is constructed rather than pulling out the data
columns = group_by + order_by + [name]
# Create a new frame to avoid mutating the original one
indexed_by_ordering = frame[columns].copy()
# placeholder column to compute window_sizes below
indexed_by_ordering['_placeholder'] = 0
indexed_by_ordering = indexed_by_ordering.set_index(order_by)

# regroup if needed
if group_by:
grouped_frame = indexed_by_ordering.groupby(group_by)
else:
# Get the DataFrame from which the operand originated
# (passed in when constructing this context object in
# execute_node(ops.WindowOp))
parent = self.parent
frame = getattr(parent, 'obj', parent)
obj = getattr(grouped_data, 'obj', grouped_data)
name = obj.name
if frame[name] is not obj or name in group_by or name in order_by:
name = f"{name}_{ibis.util.guid()}"
frame = frame.assign(**{name: obj})

# set the index to our order_by keys and append it to the existing
# index
# TODO: see if we can do this in the caller, when the context
# is constructed rather than pulling out the data
columns = group_by + order_by + [name]
# Create a new frame to avoid mutating the original one
indexed_by_ordering = frame[columns].copy()
# placeholder column to compute window_sizes below
indexed_by_ordering['_placeholder'] = 0
indexed_by_ordering = indexed_by_ordering.set_index(order_by)

# regroup if needed
if group_by:
grouped_frame = indexed_by_ordering.groupby(group_by)
else:
grouped_frame = indexed_by_ordering
grouped = grouped_frame[name]

if callable(function):
# To compute the window_size, we need to contruct a
# RollingGroupby and compute count using construct_window.
# However, if the RollingGroupby is not numeric, e.g.,
# we are calling window UDF on a timestamp column, we
# cannot compute rolling count directly because:
# (1) windowed.count() will exclude NaN observations
# , which results in incorrect window sizes.
# (2) windowed.apply(len, raw=True) will include NaN
# obversations, but doesn't work on non-numeric types.
# https://github.com/pandas-dev/pandas/issues/23002
# To deal with this, we create a _placeholder column

windowed_frame = self.construct_window(grouped_frame)
window_sizes = (
windowed_frame['_placeholder']
.count()
.reset_index(drop=True)
)
mask = ~(window_sizes.isna())
window_upper_indices = pd.Series(range(len(window_sizes))) + 1
window_lower_indices = window_upper_indices - window_sizes
# The result Series of udf may need to be trimmed by
# timecontext. In order to do so, 'time' must be added
# as an index to the Series, if present. Here We extract
# time column from the parent Dataframe `frame`.
if get_time_col() in frame:
result_index = construct_time_context_aware_series(
obj, frame
).index
else:
result_index = obj.index
result = window_agg_udf(
grouped_data,
function,
window_lower_indices,
window_upper_indices,
mask,
result_index,
self.dtype,
self.max_lookback,
*args,
**kwargs,
)
grouped_frame = indexed_by_ordering
grouped = grouped_frame[name]

if callable(function):
# To compute the window_size, we need to contruct a
# RollingGroupby and compute count using construct_window.
# However, if the RollingGroupby is not numeric, e.g.,
# we are calling window UDF on a timestamp column, we
# cannot compute rolling count directly because:
# (1) windowed.count() will exclude NaN observations
# , which results in incorrect window sizes.
# (2) windowed.apply(len, raw=True) will include NaN
# obversations, but doesn't work on non-numeric types.
# https://github.com/pandas-dev/pandas/issues/23002
# To deal with this, we create a _placeholder column

windowed_frame = self.construct_window(grouped_frame)
window_sizes = (
windowed_frame['_placeholder'].count().reset_index(drop=True)
)
mask = ~(window_sizes.isna())
window_upper_indices = pd.Series(range(len(window_sizes))) + 1
window_lower_indices = window_upper_indices - window_sizes
# The result Series of udf may need to be trimmed by
# timecontext. In order to do so, 'time' must be added
# as an index to the Series, if present. Here We extract
# time column from the parent Dataframe `frame`.
if get_time_col() in frame:
result_index = construct_time_context_aware_series(
obj, frame
).index
else:
# perform the per-group rolling operation
windowed = self.construct_window(grouped)
result = window_agg_built_in(
frame,
windowed,
function,
self.max_lookback,
*args,
**kwargs,
)
try:
return result.astype(self.dtype, copy=False)
except (TypeError, ValueError):
return result
result_index = obj.index
result = window_agg_udf(
grouped_data,
function,
window_lower_indices,
window_upper_indices,
mask,
result_index,
self.dtype,
self.max_lookback,
*args,
**kwargs,
)
else:
# perform the per-group rolling operation
windowed = self.construct_window(grouped)
result = window_agg_built_in(
frame,
windowed,
function,
self.max_lookback,
*args,
**kwargs,
)
try:
return result.astype(self.dtype, copy=False)
except (TypeError, ValueError):
return result


class Cumulative(Window):
Expand Down
45 changes: 0 additions & 45 deletions ibis/backends/pandas/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""The pandas client implementation."""
from functools import partial

import dateutil.parser
import numpy as np
import pandas as pd
import pytz
import toolz
from pandas.api.types import CategoricalDtype, DatetimeTZDtype

Expand Down Expand Up @@ -248,52 +245,10 @@ def convert_datetimetz_to_timestamp(in_dtype, out_dtype, column):
return column.astype(out_dtype.to_pandas(), errors='ignore')


def convert_timezone(obj, timezone):
"""Convert `obj` to the timezone `timezone`.
Parameters
----------
obj : datetime.date or datetime.datetime
Returns
-------
type(obj)
"""
if timezone is None:
return obj.replace(tzinfo=None)
return pytz.timezone(timezone).localize(obj)


PANDAS_STRING_TYPES = {'string', 'unicode', 'bytes'}
PANDAS_DATE_TYPES = {'datetime', 'datetime64', 'date'}


@sch.convert.register(np.dtype, dt.Timestamp, pd.Series)
def convert_datetime64_to_timestamp(in_dtype, out_dtype, column):
if in_dtype.type == np.datetime64:
return column.astype(out_dtype.to_pandas(), errors='ignore')
try:
series = pd.to_datetime(column, utc=True)
except pd.errors.OutOfBoundsDatetime:
inferred_dtype = infer_pandas_dtype(column, skipna=True)
if inferred_dtype in PANDAS_DATE_TYPES:
# not great, but not really any other option
return column.map(
partial(convert_timezone, timezone=out_dtype.timezone)
)
if inferred_dtype not in PANDAS_STRING_TYPES:
raise TypeError(
(
'Conversion to timestamp not supported for Series of type '
'{!r}'
).format(inferred_dtype)
)
return column.map(dateutil.parser.parse)
else:
utc_dtype = DatetimeTZDtype('ns', 'UTC')
return series.astype(utc_dtype).dt.tz_convert(out_dtype.timezone)


@sch.convert.register(np.dtype, dt.Interval, pd.Series)
def convert_any_to_interval(_, out_dtype, column):
return column.values.astype(out_dtype.to_pandas())
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def execute_until_in_scope(
new_scope.get_value(arg.op(), timecontext)
if hasattr(arg, 'op')
else arg
for arg in computable_args
for (arg, timecontext) in zip(computable_args, arg_timecontexts)
]
result = execute_node(
op,
Expand Down
26 changes: 16 additions & 10 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pandas.api.types import DatetimeTZDtype
from pandas.core.groupby import DataFrameGroupBy, SeriesGroupBy

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
Expand Down Expand Up @@ -964,6 +963,22 @@ def execute_node_log_number_number(op, value, base, **kwargs):
return math.log(value, base)


@execute_node.register(ops.DropNa, pd.DataFrame)
def execute_node_dropna_dataframe(op, df, **kwargs):
subset = [col.get_name() for col in op.subset] if op.subset else None
return df.dropna(how=op.how, subset=subset)


@execute_node.register(ops.FillNa, pd.DataFrame, simple_types)
def execute_node_fillna_dataframe_scalar(op, df, replacements, **kwargs):
return df.fillna(replacements)


@execute_node.register(ops.FillNa, pd.DataFrame)
def execute_node_fillna_dataframe_dict(op, df, **kwargs):
return df.fillna(op.replacements)


@execute_node.register(ops.IfNull, pd.Series, simple_types)
@execute_node.register(ops.IfNull, pd.Series, pd.Series)
def execute_node_ifnull_series(op, value, replacement, **kwargs):
Expand Down Expand Up @@ -1040,15 +1055,6 @@ def execute_node_coalesce(op, values, **kwargs):
return compute_row_reduction(coalesce, values)


@execute_node.register(ops.ExpressionList, collections.abc.Sequence)
def execute_node_expr_list(op, sequence, **kwargs):
# TODO: no true approx count distinct for pandas, so we use exact for now
columns = [e.get_name() for e in op.exprs]
schema = ibis.schema(list(zip(columns, (e.type() for e in op.exprs))))
data = {col: [execute(el, **kwargs)] for col, el in zip(columns, sequence)}
return schema.apply_to(pd.DataFrame(data, columns=columns))


def wrap_case_result(raw, expr):
"""Wrap a CASE statement result in a Series and handle returning scalars.
Expand Down
7 changes: 6 additions & 1 deletion ibis/backends/pandas/execution/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def execute_cross_join(op, left, right, **kwargs):


@execute_node.register(ops.Join, pd.DataFrame, pd.DataFrame)
def execute_materialized_join(op, left, right, **kwargs):
def execute_join(op, left, right, **kwargs):
op_type = type(op)

try:
Expand Down Expand Up @@ -94,6 +94,11 @@ def execute_materialized_join(op, left, right, **kwargs):
return df


@execute_node.register(ops.MaterializedJoin, pd.DataFrame)
def execute_materialized_join(op, df, **kwargs):
return df


@execute_node.register(
ops.AsOfJoin, pd.DataFrame, pd.DataFrame, (pd.Timedelta, type(None))
)
Expand Down
25 changes: 18 additions & 7 deletions ibis/backends/pandas/tests/execution/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,16 @@ def test_execute_with_same_hash_value_in_scope(
def my_func(x, y):
return x

expr = my_func(left, right)
df = pd.DataFrame({"left": [left], "right": [right]})
table = ibis.pandas.from_dataframe(df)

expr = my_func(table.left, table.right)
result = execute(expr)
assert type(result) is expected_type
assert result == expected_value
assert isinstance(result, pd.Series)

result = result.tolist()
assert result == [expected_value]
assert type(result[0]) is expected_type


def test_ifelse_returning_bool():
Expand Down Expand Up @@ -248,7 +254,12 @@ def test_signature_does_not_match_input_type(dtype, value):
def func(x):
return x

expr = func(value)
result = execute(expr)
assert type(result) == type(value)
assert result == value
df = pd.DataFrame({"col": [value]})
table = ibis.pandas.from_dataframe(df)

result = execute(table.col)
assert isinstance(result, pd.Series)

result = result.tolist()
assert result == [value]
assert type(result[0]) is type(value)
25 changes: 25 additions & 0 deletions ibis/backends/pandas/tests/execution/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,28 @@ def test_select_on_unambiguous_asof_join(func):
expr = func(join)
result = expr.execute()
tm.assert_frame_equal(result, expected)


def test_materialized_join():
df = pd.DataFrame({"test": [1, 2, 3], "name": ["a", "b", "c"]})
df_2 = pd.DataFrame({"test_2": [1, 5, 6], "name_2": ["d", "e", "f"]})

conn = ibis.pandas.connect({"df": df, "df_2": df_2})

ibis_table_1 = conn.table("df")
ibis_table_2 = conn.table("df_2")

joined = ibis_table_1.outer_join(
ibis_table_2,
predicates=ibis_table_1["test"] == ibis_table_2["test_2"],
)
joined = joined.materialize()
result = joined.execute()
expected = pd.merge(
df,
df_2,
left_on="test",
right_on="test_2",
how="outer",
)
tm.assert_frame_equal(result, expected)
13 changes: 13 additions & 0 deletions ibis/backends/pandas/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,19 @@ def q_fun(x, quantile, interpolation):
tm.assert_series_equal(result, expected)


def test_summary_execute(t):
expr = t.group_by('plain_strings').aggregate(
[
t.plain_int64.summary(prefix='int64_'),
t.plain_int64.summary(suffix='_int64'),
t.plain_datetimes_utc.summary(prefix='datetime_'),
t.plain_datetimes_utc.summary(suffix='_datetime'),
]
)
result = expr.execute()
assert isinstance(result, pd.DataFrame)


def test_summary_numeric(batting, batting_df):
expr = batting.G.summary()
result = expr.execute()
Expand Down
64 changes: 63 additions & 1 deletion ibis/backends/pandas/tests/execution/test_timecontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.pandas.execution import execute
from ibis.backends.pandas.execution.window import trim_window_result
from ibis.expr.scope import Scope
from ibis.expr.timecontext import (
Expand Down Expand Up @@ -289,13 +290,74 @@ def adjust_context_custom_asof_join(
right=time_keyed_right,
predicates='time',
by='key',
tolerance=4 * ibis.interval(days=1),
tolerance=ibis.interval(days=4),
).to_expr()
expr = expr[time_keyed_left, time_keyed_right.other_value]
context = (pd.Timestamp('20170105'), pd.Timestamp('20170111'))
expr.execute(timecontext=context)


def test_adjust_context_complete_shift(
time_keyed_left,
time_keyed_right,
time_keyed_df1,
time_keyed_df2,
):
"""Test `adjust_context` function that completely shifts the context.
This results in an adjusted context that is NOT a subset of the
original context. This is unlike an `adjust_context` function
that only expands the context.
See #3104
"""

# Create a contrived `adjust_context` function for
# CustomAsOfJoin to mock this.

@adjust_context.register(CustomAsOfJoin)
def adjust_context_custom_asof_join(
op: ops.AsOfJoin,
timecontext: TimeContext,
scope: Optional[Scope] = None,
) -> TimeContext:
"""Shifts both the begin and end in the same direction."""

begin, end = timecontext
timedelta = execute(op.tolerance)
return (begin - timedelta, end - timedelta)

expr = CustomAsOfJoin(
left=time_keyed_left,
right=time_keyed_right,
predicates='time',
by='key',
tolerance=ibis.interval(days=4),
).to_expr()
expr = expr[time_keyed_left, time_keyed_right.other_value]
context = (pd.Timestamp('20170101'), pd.Timestamp('20170111'))
result = expr.execute(timecontext=context)

# Compare with asof_join of manually trimmed tables
# Left table: No shift for context
# Right table: Shift both begin and end of context by 4 days
trimmed_df1 = time_keyed_df1[time_keyed_df1['time'] >= context[0]][
time_keyed_df1['time'] < context[1]
]
trimmed_df2 = time_keyed_df2[
time_keyed_df2['time'] >= context[0] - pd.Timedelta(days=4)
][time_keyed_df2['time'] < context[1] - pd.Timedelta(days=4)]
expected = pd.merge_asof(
trimmed_df1,
trimmed_df2,
on='time',
by='key',
tolerance=pd.Timedelta('4D'),
)

tm.assert_frame_equal(result, expected)


def test_construct_time_context_aware_series(time_df3):
"""Unit test for `construct_time_context_aware_series`"""
# Series without 'time' index will result in a MultiIndex with 'time'
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/pandas/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
import ibis.expr.types as ir


def test_no_infer_ambiguities():
Expand Down Expand Up @@ -139,15 +138,6 @@ def test_pandas_dtype(pandas_dtype, ibis_dtype):
assert dt.dtype(pandas_dtype) == ibis_dtype


def test_series_to_ibis_literal():
values = [1, 2, 3, 4]
s = pd.Series(values)

expr = ir.as_value_expr(s)
expected = ir.sequence(list(s))
assert expr.equals(expected)


@pytest.mark.parametrize(
('col_data', 'schema_type'),
[
Expand Down
6 changes: 0 additions & 6 deletions ibis/backends/pandas/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,6 @@ def test_udf(t, df):
tm.assert_series_equal(result, expected)


def test_elementwise_udf_with_non_vectors(con):
expr = my_add(1.0, 2.0)
result = con.execute(expr)
assert result == 3.0


def test_multiple_argument_udf(con, t, df):
expr = my_add(t.b, t.c)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
}
)
con = ibis.backends.pandas.connect({"table1": df})
con = ibis.pandas.connect({"table1": df})
@elementwise(
input_type=[dt.double],
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Backend(BaseAlchemyBackend):
name = 'postgres'
compiler = PostgreSQLCompiler

def connect(
def do_connect(
self,
host='localhost',
user=None,
Expand Down Expand Up @@ -93,9 +93,8 @@ def connect(
database=database,
driver=f'postgresql+{driver}',
)
new_backend = super().connect(sqlalchemy.create_engine(alchemy_url))
new_backend.database_name = alchemy_url.database
return new_backend
self.database_name = alchemy_url.database
super().do_connect(sqlalchemy.create_engine(alchemy_url))

def list_databases(self, like=None):
# http://dba.stackexchange.com/a/1304/58517
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,9 @@ def test_boolean_reduction(alltypes, opname, df):


def test_boolean_summary(alltypes):
expr = alltypes.bool_col.summary()
bool_col_summary = alltypes.bool_col.summary()
expr = alltypes.aggregate(bool_col_summary)

result = expr.execute()
expected = pd.DataFrame(
[[7300, 0, 0, 1, 3650, 0.5, 2]],
Expand Down
23 changes: 7 additions & 16 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import pandas as pd
import pyspark
import pyspark as ps
Expand All @@ -8,6 +6,7 @@
import ibis.common.exceptions as com
import ibis.expr.schema as sch
import ibis.expr.types as types
import ibis.util as util
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.base.sql.ddl import (
CreateDatabase,
Expand Down Expand Up @@ -81,38 +80,30 @@ class Backend(BaseSQLBackend):
table_class = PySparkDatabaseTable
table_expr_class = PySparkTable

def connect(self, session):
def do_connect(self, session):
"""
Create a pyspark `Backend` for use with Ibis.

Pipes `**kwargs` into Backend, which pipes them into SparkContext.
See documentation for SparkContext:
https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext
"""
new_backend = self.__class__()
new_backend._context = session.sparkContext
new_backend._session = session
new_backend._catalog = session.catalog
self._context = session.sparkContext
self._session = session
self._catalog = session.catalog

# Spark internally stores timestamps as UTC values, and timestamp data
# that is brought in without a specified time zone is converted as
# local time to UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
new_backend._session.conf.set('spark.sql.session.timeZone', 'UTC')

return new_backend
self._session.conf.set('spark.sql.session.timeZone', 'UTC')

@property
def version(self):
return pyspark.__version__

@util.deprecated(version='2.0', instead='a new connection to database')
def set_database(self, name):
warnings.warn(
'`set_database` is deprecated and will be removed in a future '
'version of Ibis. Create a new connection to the desired database '
'instead',
FutureWarning,
)
self._catalog.setCurrentDatabase(name)

@property
Expand Down
24 changes: 22 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,13 +1772,33 @@ def compile_not_null(t, expr, scope, timecontext, **kwargs):
return ~F.isnull(col) & ~F.isnan(col)


@compiles(ops.DropNa)
def compile_dropna_table(t, expr, scope, timecontext, **kwargs):
op = expr.op()
table = t.translate(op.table, scope, timecontext)
subset = [col.get_name() for col in op.subset] if op.subset else None
return table.dropna(how=op.how, subset=subset)


@compiles(ops.FillNa)
def compile_fillna_table(t, expr, scope, timecontext, **kwargs):
op = expr.op()
table = t.translate(op.table, scope, timecontext)
replacements = (
op.replacements.op().value
if hasattr(op.replacements, 'op')
else op.replacements
)
return table.fillna(replacements)


# ------------------------- User defined function ------------------------


@compiles(ops.ElementWiseVectorizedUDF)
def compile_elementwise_udf(t, expr, scope, timecontext, **kwargs):
op = expr.op()
spark_output_type = spark_dtype(op._output_type)
spark_output_type = spark_dtype(op.return_type)
func = op.func
spark_udf = pandas_udf(func, spark_output_type, PandasUDFType.SCALAR)
func_args = (t.translate(arg, scope, timecontext) for arg in op.func_args)
Expand All @@ -1789,7 +1809,7 @@ def compile_elementwise_udf(t, expr, scope, timecontext, **kwargs):
def compile_reduction_udf(t, expr, scope, timecontext, context=None, **kwargs):
op = expr.op()

spark_output_type = spark_dtype(op._output_type)
spark_output_type = spark_dtype(op.return_type)
spark_udf = pandas_udf(
op.func, spark_output_type, PandasUDFType.GROUPED_AGG
)
Expand Down
Loading