26 changes: 19 additions & 7 deletions ibis/backends/dask/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,14 +743,26 @@ def test_left_binary_op_gb(t, df, op, argfunc):
tm.assert_frame_equal(result.compute(), expected.compute())


def test_where_series(t, df):
@pytest.mark.parametrize(
"left_f", [lambda e: e - 1, lambda e: 0.0, lambda e: None]
)
@pytest.mark.parametrize(
"right_f", [lambda e: e + 1, lambda e: 1.0, lambda e: None]
)
def test_where_series(t, df, left_f, right_f):
col_expr = t['plain_int64']
result = ibis.where(col_expr > col_expr.mean(), col_expr, 0.0).compile()

ser = df['plain_int64']
expected = ser.where(ser > ser.mean(), other=0.0)

tm.assert_series_equal(result.compute(), expected.compute())
result = ibis.where(
col_expr > col_expr.mean(), left_f(col_expr), right_f(col_expr)
).execute()

ser = df['plain_int64'].compute()
cond = ser > ser.mean()
left = left_f(ser)
if not isinstance(left, pd.Series):
left = pd.Series(np.repeat(left, len(cond)), name=cond.name)
expected = left.where(cond, right_f(ser))

tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
Expand Down
100 changes: 69 additions & 31 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import ast
import itertools
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterator, NamedTuple
from typing import TYPE_CHECKING, Any, Iterator, MutableMapping

import sqlalchemy as sa
import toolz
Expand All @@ -23,15 +25,11 @@
from ibis.backends.duckdb.datatypes import parse
from ibis.common.dispatch import RegexDispatcher


class _ColumnMetadata(NamedTuple):
name: str
type: dt.DataType


_generate_view_code = RegexDispatcher("_register")
_dialect = sa.dialects.postgresql.dialect()

_gen_table_names = (f"registered_table{i:d}" for i in itertools.count())


def _name_from_path(path: Path) -> str:
base, *_ = path.name.partition(os.extsep)
Expand Down Expand Up @@ -90,6 +88,16 @@ class Backend(BaseAlchemyBackend):
def current_database(self) -> str:
return "main"

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
read_only = kwargs.pop("read_only", "False").capitalize()
try:
kwargs["read_only"] = ast.literal_eval(read_only)
except ValueError as e:
raise ValueError(
f"invalid value passed to ast.literal_eval: {read_only!r}"
) from e

@property
def version(self) -> str:
# TODO: there is a `PRAGMA version` we could use instead
Expand All @@ -99,76 +107,106 @@ def version(self) -> str:

def do_connect(
self,
path: str | Path = ":memory:",
database: str | Path = ":memory:",
path: str | Path = None,
read_only: bool = False,
**config: Any,
) -> None:
"""Create an Ibis client connected to a DuckDB database.
Parameters
----------
path
Path to a duckdb database
database
Path to a duckdb database.
read_only
Whether the database is read-only
Whether the database is read-only.
config
DuckDB configuration parameters. See the [DuckDB configuration
documentation](https://duckdb.org/docs/sql/configuration) for
possible configuration values.
Examples
--------
>>> import ibis
>>> ibis.duckdb.connect("database.ddb", threads=4, memory_limit="1GB")
"""
if path != ":memory:":
path = Path(path).absolute()
if path is not None:
warnings.warn(
"The `path` argument is deprecated in 4.0. Use `database=...` "
"instead."
)
database = path
if not (in_memory := database == ":memory:"):
database = Path(database).absolute()
super().do_connect(
sa.create_engine(
f"duckdb:///{path}",
connect_args={"read_only": read_only},
f"duckdb:///{database}",
connect_args=dict(read_only=read_only, config=config),
poolclass=sa.pool.SingletonThreadPool if in_memory else None,
)
)
self._meta = sa.MetaData(bind=self.con)

def register(
self,
path: str | Path,
source: str | Path | Any,
table_name: str | None = None,
) -> ir.Table:
"""Register an external file as a table in the current connection
database
"""Register a data source as a table in the current database.
Parameters
----------
path
Name of the parquet or CSV file
source
The data source. May be a path to a file or directory of
parquet/csv files, a pandas dataframe, or a pyarrow table or
dataset.
table_name
Name for the created table. Defaults to filename if not given.
Any dashes in a user-provided or generated name will be
replaced with underscores.
An optional name to use for the created table. This defaults to the
filename if a path (with hyphens replaced with underscores), or
sequentially generated name otherwise.
Returns
-------
ir.Table
The just-registered table
"""
view, table_name = _generate_view_code(path, table_name=table_name)
self.con.execute(view)
if isinstance(source, (str, Path)):
sql, table_name = _generate_view_code(
source, table_name=table_name
)
self.con.execute(sql)
else:
if table_name is None:
table_name = next(_gen_table_names)
self.con.execute("register", (table_name, source))

return self.table(table_name)

def fetch_from_cursor(
self,
cursor: duckdb.DuckDBPyConnection,
schema: sch.Schema,
):
df = cursor.cursor.fetch_df()
table = cursor.cursor.fetch_arrow_table()
df = table.to_pandas(timestamp_as_object=True)
return schema.apply_to(df)

def _metadata(self, query: str) -> Iterator[_ColumnMetadata]:
def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
for name, type, null in toolz.pluck(
["column_name", "column_type", "null"],
self.con.execute(f"DESCRIBE {query}"),
):
yield _ColumnMetadata(
name=name,
type=parse(type)(nullable=null.lower() == "yes"),
)
ibis_type = parse(type)
yield name, ibis_type(nullable=null.lower() == "yes")

def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a DuckDB SQL string."""
return sch.Schema.from_tuples(self._metadata(query))

def _register_in_memory_table(self, table_op):
df = table_op.data.to_frame()
self.con.execute("register", (table_op.name, df))

def _get_sqla_table(
self,
name: str,
Expand Down
29 changes: 29 additions & 0 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from __future__ import annotations

from sqlalchemy.ext.compiler import compiles

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
Expand All @@ -16,6 +22,28 @@ class DuckDBSQLExprTranslator(AlchemyExprTranslator):
_has_reduction_filter_syntax = True


@compiles(sat.UInt64, "duckdb")
@compiles(sat.UInt32, "duckdb")
@compiles(sat.UInt16, "duckdb")
@compiles(sat.UInt8, "duckdb")
def compile_uint(element, compiler, **kw):
return element.__class__.__name__.upper()


try:
import duckdb_engine
except ImportError:
pass
else:

@dt.dtype.register(duckdb_engine.Dialect, sat.UInt64)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt32)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt16)
@dt.dtype.register(duckdb_engine.Dialect, sat.UInt8)
def dtype_uint(_, satype, nullable=True):
return getattr(dt, satype.__class__.__name__)(nullable=nullable)


rewrites = DuckDBSQLExprTranslator.rewrites


Expand All @@ -29,4 +57,5 @@ def _no_op(expr):


class DuckDBSQLCompiler(AlchemyCompiler):
cheap_in_memory_tables = True
translator_class = DuckDBSQLExprTranslator
13 changes: 10 additions & 3 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
int16,
int32,
int64,
json,
spaceless,
spaceless_string,
string,
Expand Down Expand Up @@ -63,9 +64,14 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> DataType:
| spaceless_string("double", "float8").result(float64)
| spaceless_string("real", "float4", "float").result(float32)
| spaceless_string("smallint", "int2", "short").result(int16)
| spaceless_string("timestamp", "datetime").result(
Timestamp(timezone="UTC")
)
| spaceless_string(
"timestamp_tz",
"timestamp_sec",
"timestamp_ms",
"timestamp_ns",
"timestamp",
"datetime",
).result(Timestamp(timezone="UTC"))
| spaceless_string("date").result(date)
| spaceless_string("time").result(time)
| spaceless_string("tinyint", "int1").result(int8)
Expand All @@ -82,6 +88,7 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> DataType:
"text",
"string",
).result(string)
| spaceless_string("json").result(json)
)

@p.generate
Expand Down
14 changes: 12 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import numbers
import operator

import numpy as np
Expand Down Expand Up @@ -86,17 +87,23 @@ def _literal(_, expr):
return sa.cast(sa.func.list_value(*value), sqla_type)
elif isinstance(value, np.ndarray):
return sa.cast(sa.func.list_value(*value.tolist()), sqla_type)
elif isinstance(value, (numbers.Real, np.floating)) and np.isnan(value):
return sa.cast(sa.literal("NaN"), sqla_type)
elif isinstance(value, collections.abc.Mapping):
if isinstance(dtype, dt.Struct):
placeholders = ", ".join(
f"{key!r}: :v{i}" for i, key in enumerate(value.keys())
f"{key} := :v{i}" for i, key in enumerate(value.keys())
)
return sa.text(f"{{{placeholders}}}").bindparams(
text = sa.text(f"struct_pack({placeholders})")
bound_text = text.bindparams(
*(
sa.bindparam(f"v{i:d}", val)
for i, val in enumerate(value.values())
)
)
name = expr.get_name() if expr.has_name() else "tmp"
params = {name: to_sqla_type(dtype)}
return bound_text.columns(**params).scalar_subquery()
raise NotImplementedError(
f"Ibis dtype `{dtype}` with mapping type "
f"`{type(value).__name__}` isn't yet supported with the duckdb "
Expand Down Expand Up @@ -225,6 +232,9 @@ def _struct_column(t, expr):
ops.Arbitrary: _arbitrary,
ops.GroupConcat: _string_agg,
ops.StructColumn: _struct_column,
ops.ArgMin: reduction(sa.func.min_by),
ops.ArgMax: reduction(sa.func.max_by),
ops.BitwiseXor: fixed_arity(sa.func.xor, 2),
}
)

Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/duckdb/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
),
T=dt.Array(dt.Array(dt.int32)),
U=dt.Array(dt.Array(dt.int32)),
V=dt.Timestamp("UTC"),
W=dt.Timestamp("UTC"),
X=dt.Timestamp("UTC"),
Y=dt.Timestamp("UTC"),
Z=dt.json,
)


Expand Down Expand Up @@ -113,6 +118,11 @@
("S", "STRUCT(a INT, b TEXT, c LIST<MAP<TEXT, LIST<FLOAT8>>>)"),
("T", "LIST<LIST<INTEGER>>"),
("U", "INTEGER[][]"),
("V", "TIMESTAMP_TZ"),
("W", "TIMESTAMP_SEC"),
("X", "TIMESTAMP_MS"),
("Y", "TIMESTAMP_NS"),
("Z", "JSON"),
]
],
)
Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,27 @@ def test_register_parquet(

table = con.table(out_table_name)
assert table.count().execute()


def test_register_pandas():
pd = pytest.importorskip("pandas")
df = pd.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]})

con = ibis.duckdb.connect()

t = con.register(df)
assert t.x.sum().execute() == 6

t = con.register(df, "my_table")
assert t.op().name == "my_table"
assert t.x.sum().execute() == 6


def test_register_pyarrow_tables():
pa = pytest.importorskip("pyarrow")
pa_t = pa.Table.from_pydict({"x": [1, 2, 3], "y": ["a", "b", "c"]})

con = ibis.duckdb.connect()

t = con.register(pa_t)
assert t.x.sum().execute() == 6
10 changes: 5 additions & 5 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
'BINARY': 'object',
'VARCHAR': 'object',
'CHAR': 'object',
'DATE': 'datetime64[ns]',
'VOID': None,
}


Expand Down Expand Up @@ -983,8 +985,8 @@ def cache_table(self, table_name, database=None, pool='default'):
>>> table = 'my_table'
>>> db = 'operations'
>>> pool = 'op_4GB_pool'
>>> con.cache_table('my_table', database=db, pool=pool) # noqa: E501 # doctest: +SKIP
"""
>>> con.cache_table('my_table', database=db, pool=pool) # doctest: +SKIP
""" # noqa: E501
statement = ddl.CacheTable(table_name, database=database, pool=pool)
self.raw_sql(statement)

Expand Down Expand Up @@ -1312,9 +1314,7 @@ def column_stats(self, name, database=None):
return self._exec_statement(stmt)

def _exec_statement(self, stmt):
return self.fetch_from_cursor(
self.raw_sql(stmt, results=True), schema=None
)
return self.fetch_from_cursor(self.raw_sql(stmt), schema=None)

def _table_command(self, cmd, name, database=None):
qualified_name = self._fully_qualified_name(name, database)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/impala/tests/test_case_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def test_isnull_1_0(table):
expr = table.g.isnull().ifelse(1, 0)

result = translate(expr)
expected = 'CASE WHEN `g` IS NULL THEN 1 ELSE 0 END'
expected = 'if(`g` IS NULL, 1, 0)'
assert result == expected

# inside some other function
result = translate(expr.sum())
expected = 'sum(CASE WHEN `g` IS NULL THEN 1 ELSE 0 END)'
expected = 'sum(if(`g` IS NULL, 1, 0))'
assert result == expected


Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/impala/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def test_execute_exprs_default_backend(con_no_hdfs):

def test_cursor_garbage_collection(con):
for i in range(5):
con.raw_sql('select 1', True).fetchall()
con.raw_sql('select 1', True).fetchone()
con.raw_sql('select 1').fetchall()
con.raw_sql('select 1').fetchone()


def test_raise_ibis_error_no_hdfs(con_no_hdfs):
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_sql_with_limit(con):

def test_raw_sql(con):
query = 'SELECT * from functional_alltypes limit 10'
cur = con.raw_sql(query, results=True)
cur = con.raw_sql(query)
rows = cur.fetchall()
cur.release()
assert len(rows) == 10
Expand Down Expand Up @@ -298,17 +298,17 @@ def con2(env):
def test_rerelease_cursor(con2):
# we use a separate `con2` fixture here because any connection pool
# manipulation we want to happen independently of `con`
with con2.raw_sql('select 1', True) as cur1:
with con2.raw_sql('select 1') as cur1:
pass

cur1.release()

with con2.raw_sql('select 1', True) as cur2:
with con2.raw_sql('select 1') as cur2:
pass

cur2.release()

with con2.raw_sql('select 1', True) as cur3:
with con2.raw_sql('select 1') as cur3:
pass

assert cur1 == cur2
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_datetime_to_int_cast(con):

def test_set_option_with_dot(con):
con.set_options({'request_pool': 'baz.quux'})
result = dict(row[:2] for row in con.raw_sql('set', True).fetchall())
result = dict(row[:2] for row in con.raw_sql('set').fetchall())
assert result['REQUEST_POOL'] == 'baz.quux'


Expand Down
5 changes: 2 additions & 3 deletions ibis/backends/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,7 @@ def test_where_with_timestamp():
)
result = ibis.impala.compile(expr)
expected = """\
SELECT `uuid`,
min(CASE WHEN `search_level` = 1 THEN `ts` ELSE NULL END) AS `min_date`
SELECT `uuid`, min(if(`search_level` = 1, `ts`, NULL)) AS `min_date`
FROM t
GROUP BY 1"""
assert result == expected
Expand Down Expand Up @@ -766,7 +765,7 @@ def test_nunique_where():
t = ibis.table([('key', 'string'), ('value', 'double')], name='t0')
expr = t.key.nunique(where=t.value >= 1.0)
expected = """\
SELECT count(DISTINCT CASE WHEN `value` >= 1.0 THEN `key` ELSE NULL END) AS `nunique`
SELECT count(DISTINCT if(`value` >= 1.0, `key`, NULL)) AS `nunique`
FROM t0""" # noqa: E501
result = ibis.impala.compile(expr)
assert result == expected
8 changes: 4 additions & 4 deletions ibis/backends/impala/tests/test_patched.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,24 @@ def test_refresh(con, spy, qname):
def test_describe_formatted(con, spy, qname):
t = con.table('functional_alltypes')
desc = t.describe_formatted()
spy.assert_called_with(f'DESCRIBE FORMATTED {qname}', results=True)
spy.assert_called_with(f'DESCRIBE FORMATTED {qname}')
assert isinstance(desc, metadata.TableMetadata)


def test_show_files(con, spy, qname):
t = con.table('functional_alltypes')
desc = t.files()
spy.assert_called_with(f'SHOW FILES IN {qname}', results=True)
spy.assert_called_with(f'SHOW FILES IN {qname}')
assert isinstance(desc, pd.DataFrame)


def test_table_column_stats(con, spy, qname):
t = con.table('functional_alltypes')

desc = t.stats()
spy.assert_called_with(f'SHOW TABLE STATS {qname}', results=True)
spy.assert_called_with(f'SHOW TABLE STATS {qname}')
assert isinstance(desc, pd.DataFrame)

desc = t.column_stats()
spy.assert_called_with(f'SHOW COLUMN STATS {qname}', results=True)
spy.assert_called_with(f'SHOW COLUMN STATS {qname}')
assert isinstance(desc, pd.DataFrame)
6 changes: 1 addition & 5 deletions ibis/backends/impala/tests/test_unary_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,7 @@ def test_hash(table):
def test_reduction_where(table, expr_fn, func_name):
expr = expr_fn(table)
result = translate(expr)
expected = (
f'{func_name}'
'(CASE WHEN `bigint_col` < 70 THEN `double_col` '
'ELSE NULL END)'
)
expected = f'{func_name}(if(`bigint_col` < 70, `double_col`, NULL))'
assert result == expected


Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/impala/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,6 @@ def _ibis_string_to_impala(tval):
'char': 'string',
'timestamp': 'timestamp',
'decimal': 'decimal',
'date': 'date',
'void': 'null',
}
17 changes: 13 additions & 4 deletions ibis/backends/mysql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,18 @@ def test_get_schema_from_query(con, mysql_type, expected_type):
name = con.con.dialect.identifier_preparer.quote_identifier(raw_name)
# temporary tables get cleaned up by the db when the session ends, so we
# don't need to explicitly drop the table
con.raw_sql(
f"CREATE TEMPORARY TABLE {name} (x {mysql_type}, y {mysql_type})"
)
expected_schema = ibis.schema(dict(x=expected_type, y=expected_type))
con.raw_sql(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})")
expected_schema = ibis.schema(dict(x=expected_type))
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
assert result_schema == expected_schema


@pytest.mark.parametrize(
"coltype",
["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"],
)
def test_blob_type(con, coltype):
tmp = f"tmp_{ibis.util.guid()}"
con.raw_sql(f"CREATE TEMPORARY TABLE {tmp} (a {coltype})")
t = con.table(tmp)
assert t.schema() == ibis.schema({"a": dt.binary})
5 changes: 3 additions & 2 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ def do_connect(
--------
>>> import ibis
>>> ibis.pandas.connect({"t": pd.DataFrame({"a": [1, 2, 3]})})
<ibis.backends.pandas.Backend at 0x...>
"""
# register dispatchers
from ibis.backends.pandas import execution # noqa F401
from ibis.backends.pandas import udf # noqa F401
from ibis.backends.pandas import execution # noqa: F401
from ibis.backends.pandas import udf # noqa: F401

self.dictionary = dictionary
self.schemas: MutableMapping[str, sch.Schema] = {}
Expand Down
39 changes: 37 additions & 2 deletions ibis/backends/pandas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import ibis.expr.schema as sch
from ibis import util
from ibis.backends.base import Database
from ibis.common.grounds import Immutable

infer_pandas_dtype = pd.api.types.infer_dtype

Expand Down Expand Up @@ -110,10 +113,15 @@ def from_pandas_tzdtype(value):


@dt.dtype.register(CategoricalDtype)
def from_pandas_categorical(value):
def from_pandas_categorical(_):
return dt.Category()


@dt.dtype.register(pd.core.arrays.string_.StringDtype)
def from_pandas_string(_):
return dt.String()


@dt.infer.register(np.generic)
def infer_numpy_scalar(value):
return dt.dtype(value.dtype)
Expand Down Expand Up @@ -203,7 +211,7 @@ def infer_pandas_schema(df, schema=None):
schema = schema if schema is not None else {}

pairs = []
for column_name, pandas_dtype in df.dtypes.iteritems():
for column_name in df.dtypes.keys():
if not isinstance(column_name, str):
raise TypeError(
'Column names must be strings to use the pandas backend'
Expand Down Expand Up @@ -289,10 +297,37 @@ def convert_element(values, names=out_dtype.names):
return column.map(convert_element)


@sch.convert.register(np.dtype, dt.Array, pd.Series)
def convert_array_to_series(in_dtype, out_dtype, column):
return column.map(lambda x: x if x is None else list(x))


dt.DataType.to_pandas = ibis_dtype_to_pandas # type: ignore
sch.Schema.to_pandas = ibis_schema_to_pandas # type: ignore


class DataFrameProxy(Immutable, util.ToFrame):
__slots__ = ('_df', '_hash')

def __init__(self, df):
object.__setattr__(self, "_df", df)
object.__setattr__(self, "_hash", hash((type(df), id(df))))

def __hash__(self):
return self._hash

def __repr__(self):
df_repr = util.indent(repr(self._df), spaces=2)
return f"{self.__class__.__name__}:\n{df_repr}"

def to_frame(self):
return self._df


class PandasInMemoryTable(ops.InMemoryTable):
data = rlz.instance_of(DataFrameProxy)


class PandasTable(ops.DatabaseTable):
pass

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pandas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def execute_with_scope(
# computing anything *and* before associating leaf nodes with data. This
# allows clients to provide their own data for each leaf.
if clients is None:
clients = expr._find_backends()
clients, _ = expr._find_backends()

if aggcontext is None:
aggcontext = agg_ctx.Summarize()
Expand Down
218 changes: 152 additions & 66 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ibis.backends.pandas.client import PandasTable
from ibis.backends.pandas.core import (
boolean_types,
date_types,
execute,
fixed_width_types,
floating_types,
Expand Down Expand Up @@ -584,6 +585,23 @@ def execute_arbitrary_series_groupby(op, data, _, aggcontext=None, **kwargs):
return aggcontext.agg(data, how)


@execute_node.register(
(ops.ArgMin, ops.ArgMax),
SeriesGroupBy,
SeriesGroupBy,
type(None),
)
def execute_reduction_series_groupby_argidx(
op, data, key, _, aggcontext=None, **kwargs
):
method = operator.methodcaller(op.__class__.__name__.lower())

def reduce(data, key=key.obj, method=method):
return data.iloc[method(key.loc[data.index])]

return aggcontext.agg(data, reduce)


def _filtered_reduction(mask, method, data):
return method(data[mask[data.index]])

Expand Down Expand Up @@ -757,11 +775,37 @@ def execute_bit_xor_series(_, data, mask, aggcontext=None, **kwargs):
)


@execute_node.register(
(ops.ArgMin, ops.ArgMax),
pd.Series,
pd.Series,
(pd.Series, type(None)),
)
def execute_argmin_series_mask(op, data, key, mask, aggcontext=None, **kwargs):
method_name = op.__class__.__name__.lower()
masked_key = key[mask] if mask is not None else key
idx = aggcontext.agg(masked_key, method_name)
masked = data[mask] if mask is not None else data
return masked.iloc[idx]


@execute_node.register((ops.Not, ops.Negate), (bool, np.bool_))
def execute_not_bool(_, data, **kwargs):
return not data


def _execute_binary_op_impl(op, left, right, **_):
op_type = type(op)
try:
operation = constants.BINARY_OPERATIONS[op_type]
except KeyError:
raise NotImplementedError(
f'Binary operation {op_type.__name__} not implemented'
)
else:
return operation(left, right)


@execute_node.register(ops.Binary, pd.Series, pd.Series)
@execute_node.register(
(ops.NumericBinary, ops.LogicalBinary, ops.Comparison),
Expand All @@ -784,17 +828,16 @@ def execute_not_bool(_, data, **kwargs):
@execute_node.register(ops.Multiply, integer_types, str)
@execute_node.register(ops.Multiply, str, integer_types)
@execute_node.register(ops.Comparison, pd.Series, timestamp_types)
@execute_node.register(ops.Comparison, timestamp_types, pd.Series)
@execute_node.register(ops.Comparison, timedelta_types, pd.Series)
def execute_binary_op(op, left, right, **kwargs):
op_type = type(op)
try:
operation = constants.BINARY_OPERATIONS[op_type]
except KeyError:
raise NotImplementedError(
f'Binary operation {op_type.__name__} not implemented'
)
else:
return operation(left, right)
return _execute_binary_op_impl(op, left, right, **kwargs)


@execute_node.register(ops.Comparison, pd.Series, date_types)
def execute_binary_op_date(op, left, right, **kwargs):
return _execute_binary_op_impl(
op, pd.to_datetime(left), pd.to_datetime(right), **kwargs
)


@execute_node.register(ops.Binary, SeriesGroupBy, SeriesGroupBy)
Expand Down Expand Up @@ -879,22 +922,38 @@ def execute_union_dataframe_dataframe(
return result.drop_duplicates() if distinct else result


@execute_node.register(ops.Intersection, pd.DataFrame, pd.DataFrame)
@execute_node.register(ops.Intersection, pd.DataFrame, pd.DataFrame, bool)
def execute_intersection_dataframe_dataframe(
op, left: pd.DataFrame, right: pd.DataFrame, **kwargs
op,
left: pd.DataFrame,
right: pd.DataFrame,
distinct: bool,
**kwargs,
):
if not distinct:
raise NotImplementedError(
"`distinct=False` is not supported by the pandas backend"
)
result = left.merge(right, on=list(left.columns), how="inner")
return result


@execute_node.register(ops.Difference, pd.DataFrame, pd.DataFrame)
@execute_node.register(ops.Difference, pd.DataFrame, pd.DataFrame, bool)
def execute_difference_dataframe_dataframe(
op, left: pd.DataFrame, right: pd.DataFrame, **kwargs
op,
left: pd.DataFrame,
right: pd.DataFrame,
distinct: bool,
**kwargs,
):
if not distinct:
raise NotImplementedError(
"`distinct=False` is not supported by the pandas backend"
)
merged = left.merge(
right, on=list(left.columns), how='outer', indicator=True
right, on=list(left.columns), how="outer", indicator=True
)
result = merged[merged["_merge"] != "both"].drop("_merge", axis=1)
result = merged[merged["_merge"] == "left_only"].drop("_merge", axis=1)
return result


Expand All @@ -910,7 +969,13 @@ def execute_series_notnnull(op, data, **kwargs):

@execute_node.register(ops.IsNan, (pd.Series, floating_types))
def execute_isnan(op, data, **kwargs):
return np.isnan(data)
try:
return np.isnan(data)
except (TypeError, ValueError):
# if `data` contains `None` np.isnan will complain
# so we take advantage of NaN not equaling itself
# to do the correct thing
return data != data


@execute_node.register(ops.IsInf, (pd.Series, floating_types))
Expand All @@ -924,10 +989,10 @@ def execute_node_self_reference_dataframe(op, data, **kwargs):


@execute_node.register(ops.Alias, object)
def execute_alias(op, _, **kwargs):
# just compile the underlying argument because the naming is handled
def execute_alias(op, data, **kwargs):
# just return the underlying argument because the naming is handled
# by the translator for the top level expression
return execute(op.arg, **kwargs)
return data


@execute_node.register(ops.ValueList, collections.abc.Sequence)
Expand All @@ -954,6 +1019,17 @@ def execute_node_contains_series_sequence(op, data, elements, **kwargs):
return data.isin(elements)


@execute_node.register(
ops.Contains,
SeriesGroupBy,
(collections.abc.Sequence, collections.abc.Set, pd.Series),
)
def execute_node_contains_series_group_by_sequence(
op, data, elements, **kwargs
):
return data.obj.isin(elements).groupby(data.grouper.groupings)


@execute_node.register(
ops.NotContains,
pd.Series,
Expand All @@ -963,57 +1039,65 @@ def execute_node_not_contains_series_sequence(op, data, elements, **kwargs):
return ~(data.isin(elements))


# Series, Series, Series
# Series, Series, scalar
@execute_node.register(ops.Where, pd.Series, pd.Series, pd.Series)
@execute_node.register(ops.Where, pd.Series, pd.Series, scalar_types)
def execute_node_where_series_series_series(op, cond, true, false, **kwargs):
# No need to turn false into a series, pandas will broadcast it
return true.where(cond, other=false)


# Series, scalar, Series
def execute_node_where_series_scalar_scalar(op, cond, true, false, **kwargs):
return pd.Series(np.repeat(true, len(cond))).where(cond, other=false)


# Series, scalar, scalar
for scalar_type in scalar_types:
execute_node_where_series_scalar_scalar = execute_node.register(
ops.Where, pd.Series, scalar_type, scalar_type
)(execute_node_where_series_scalar_scalar)


# scalar, Series, Series
@execute_node.register(ops.Where, boolean_types, pd.Series, pd.Series)
def execute_node_where_scalar_scalar_scalar(op, cond, true, false, **kwargs):
# Note that it is not necessary to check that true and false are also
# scalars. This allows users to do things like:
# ibis.where(even_or_odd_bool, [2, 4, 6], [1, 3, 5])
return true if cond else false
@execute_node.register(
ops.NotContains,
SeriesGroupBy,
(collections.abc.Sequence, collections.abc.Set, pd.Series),
)
def execute_node_not_contains_series_group_by_sequence(
op, data, elements, **kwargs
):
return (~data.obj.isin(elements)).groupby(data.grouper.groupings)


# scalar, scalar, scalar
for scalar_type in scalar_types:
execute_node_where_scalar_scalar_scalar = execute_node.register(
ops.Where, boolean_types, scalar_type, scalar_type
)(execute_node_where_scalar_scalar_scalar)
def pd_where(cond, true, false):
"""Execute `where` following ibis's intended semantics"""
if isinstance(cond, pd.Series):
if not isinstance(true, pd.Series):
true = pd.Series(
np.repeat(true, len(cond)), name=cond.name, index=cond.index
)
return true.where(cond, other=false)
if cond:
if isinstance(false, pd.Series) and not isinstance(true, pd.Series):
return pd.Series(np.repeat(true, len(false)))
return true
else:
if isinstance(true, pd.Series) and not isinstance(false, pd.Series):
return pd.Series(np.repeat(false, len(true)), index=true.index)
return false


# scalar, Series, scalar
@execute_node.register(ops.Where, boolean_types, pd.Series, scalar_types)
def execute_node_where_scalar_series_scalar(op, cond, true, false, **kwargs):
return (
true
if cond
else pd.Series(np.repeat(false, len(true)), index=true.index)
)
@execute_node.register(
ops.Where, (pd.Series, *boolean_types), pd.Series, pd.Series
)
@execute_node.register(
ops.Where, (pd.Series, *boolean_types), pd.Series, simple_types
)
@execute_node.register(
ops.Where, (pd.Series, *boolean_types), simple_types, pd.Series
)
@execute_node.register(
ops.Where, (pd.Series, *boolean_types), type(None), type(None)
)
def execute_node_where(op, cond, true, false, **kwargs):
return pd_where(cond, true, false)


# scalar, scalar, Series
@execute_node.register(ops.Where, boolean_types, scalar_types, pd.Series)
def execute_node_where_scalar_scalar_series(op, cond, true, false, **kwargs):
return pd.Series(np.repeat(true, len(false))) if cond else false
# For true/false as scalars, we only support identical type pairs + None to
# limit the size of the dispatch table and not have to worry about type
# promotion.
for typ in (str, *scalar_types):
for cond_typ in (pd.Series, *boolean_types):
execute_node.register(ops.Where, cond_typ, typ, typ)(
execute_node_where
)
execute_node.register(ops.Where, cond_typ, type(None), typ)(
execute_node_where
)
execute_node.register(ops.Where, cond_typ, typ, type(None))(
execute_node_where
)


@execute_node.register(PandasTable, PandasBackend)
Expand Down Expand Up @@ -1060,9 +1144,11 @@ def execute_node_log_number_number(op, value, base, **kwargs):
return math.log(value, base)


@execute_node.register(ops.DropNa, pd.DataFrame, type(None))
@execute_node.register(ops.DropNa, pd.DataFrame, tuple)
def execute_node_dropna_dataframe(op, df, subset, **kwargs):
subset = [col.get_name() for col in subset] if subset else None
if subset is not None:
subset = [col.get_name() for col in subset]
return df.dropna(how=op.how, subset=subset)


Expand Down
16 changes: 6 additions & 10 deletions ibis/backends/pandas/execution/join.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import operator

import pandas as pd
Expand Down Expand Up @@ -47,16 +46,13 @@ def _get_semi_anti_join_filter(op, left, right, predicates, **kwargs):
predicates,
**kwargs,
)
inner = pd.merge(
left,
right,
how="inner",
left_on=left_on,
right_on=right_on,
suffixes=constants.JOIN_SUFFIXES,
inner = left.merge(
right[right_on].drop_duplicates(),
on=left_on,
how="left",
indicator=True,
)
predicates = [left.loc[:, key].isin(inner.loc[:, key]) for key in left_on]
return functools.reduce(operator.and_, predicates)
return (inner["_merge"] == "both").values


@execute_node.register(ops.LeftSemiJoin, pd.DataFrame, pd.DataFrame, tuple)
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/pandas/execution/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def execute_group_concat_series_gb_mask(
op, data, sep, mask, aggcontext=None, **kwargs
):
def method(series, sep=sep):
if series.empty:
return pd.NA
return sep.join(series.values.astype(str))

return aggcontext.agg(
Expand Down Expand Up @@ -355,6 +357,13 @@ def execute_series_right_gb(op, data, nchars, **kwargs):
)


@execute_node.register(
ops.StringReplace, pd.Series, (pd.Series, str), (pd.Series, str)
)
def execute_series_string_replace(_, data, needle, replacement, **kwargs):
return data.str.replace(needle, replacement)


@execute_node.register(ops.StringJoin, (pd.Series, str), list)
def execute_series_join_scalar_sep(op, sep, data, **kwargs):
return reduce(lambda x, y: x + sep + y, data)
Expand Down
18 changes: 15 additions & 3 deletions ibis/backends/pandas/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,24 @@ def test_left_binary_op_gb(t, df, op, argfunc):
tm.assert_frame_equal(result, expected)


def test_where_series(t, df):
@pytest.mark.parametrize(
"left_f", [lambda e: e - 1, lambda e: 0.0, lambda e: None]
)
@pytest.mark.parametrize(
"right_f", [lambda e: e + 1, lambda e: 1.0, lambda e: None]
)
def test_where_series(t, df, left_f, right_f):
col_expr = t['plain_int64']
result = ibis.where(col_expr > col_expr.mean(), col_expr, 0.0).execute()
result = ibis.where(
col_expr > col_expr.mean(), left_f(col_expr), right_f(col_expr)
).execute()

ser = df['plain_int64']
expected = ser.where(ser > ser.mean(), other=0.0)
cond = ser > ser.mean()
left = left_f(ser)
if not isinstance(left, pd.Series):
left = pd.Series(np.repeat(left, len(cond)), name=cond.name)
expected = left.where(cond, right_f(ser))

tm.assert_series_equal(result, expected)

Expand Down
5 changes: 1 addition & 4 deletions ibis/backends/pandas/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,7 @@ def test_pre_execute(op, client, **kwargs):
# once in window op at the top to pickup any scope changes before computing
# twice in window op when calling execute on the ops.Lag node at the
# beginning of execute and once before the actual computation
#
# this process happens twice because of the pre_execute call on the Alias
# operation
assert called[0] == 3 + 3
assert called[0] == 3


def test_window_grouping_key_has_scope(t, df):
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/pandas/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_numpy_dtype(numpy_dtype, ibis_dtype):
dt.Timestamp('US/Eastern'),
),
(CategoricalDtype(), dt.Category()),
(pd.Series([], dtype="string").dtype, dt.String()),
],
)
def test_pandas_dtype(pandas_dtype, ibis_dtype):
Expand Down Expand Up @@ -206,6 +207,7 @@ def test_pandas_dtype(pandas_dtype, ibis_dtype):
(pd.Series([b'1', '2', 3.0]), dt.binary),
# empty
(pd.Series([], dtype='object'), dt.binary),
(pd.Series([], dtype="string"), dt.string),
],
)
def test_schema_infer(col_data, schema_type):
Expand All @@ -214,3 +216,10 @@ def test_schema_infer(col_data, schema_type):
inferred = sch.infer(df)
expected = ibis.schema([('col', schema_type)])
assert inferred == expected


def test_pyarrow_string():
pytest.importorskip("pyarrow")

s = pd.Series([], dtype="string[pyarrow]")
assert dt.dtype(s.dtype) == dt.String()
15 changes: 15 additions & 0 deletions ibis/backends/pandas/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,21 @@ def error_udf(s):
error_udf(t.c).execute()


def test_udf_no_reexecution(t2):
execution_count = 0

@udf.elementwise(input_type=[dt.double], output_type=dt.double)
def times_two_count_executions(x):
nonlocal execution_count
execution_count += 1
return x * 2.0

expr = t2.mutate(doubled=times_two_count_executions(t2.a))
expr.execute()

assert execution_count == 1


def test_compose_udfs(t2, df2):
expr = times_two(add_one(t2.a))
result = expr.execute()
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
unary,
)
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type
from ibis.backends.base.sql.alchemy.registry import get_col_or_deferred_col
from ibis.backends.base.sql.alchemy.registry import (
_bitwise_op,
get_col_or_deferred_col,
)

operation_registry = sqlalchemy_operation_registry.copy()
operation_registry.update(sqlalchemy_window_functions_registry)
Expand Down Expand Up @@ -617,5 +620,6 @@ def variance_compiler(t, expr):
ops.Unnest: unary(sa.func.unnest),
ops.Covariance: _covar,
ops.Correlation: _corr,
ops.BitwiseXor: _bitwise_op("#"),
}
)
56 changes: 24 additions & 32 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,45 +544,37 @@ def test_category_label(alltypes, df):


@pytest.mark.parametrize(
('distinct1', 'distinct2', 'expected1', 'expected2'),
[
(True, True, 'UNION', 'UNION'),
(True, False, 'UNION', 'UNION ALL'),
(False, True, 'UNION ALL', 'UNION'),
(False, False, 'UNION ALL', 'UNION ALL'),
],
('distinct', 'union'),
[(True, 'UNION'), (False, 'UNION ALL')],
)
def test_union_cte(alltypes, distinct1, distinct2, expected1, expected2):
def test_union_cte(alltypes, distinct, union):
t = alltypes
expr1 = t.group_by(t.string_col).aggregate(metric=t.double_col.sum())
expr2 = expr1.view()
expr3 = expr1.view()
expr = expr1.union(expr2, distinct=distinct1).union(
expr3, distinct=distinct2
expr = expr1.union(expr2, distinct=distinct).union(
expr3, distinct=distinct
)
result = '\n'.join(
map(
lambda line: line.rstrip(), # strip trailing whitespace
str(
expr.compile().compile(compile_kwargs={'literal_binds': True})
).splitlines(),
)
result = ' '.join(
line.strip()
for line in str(
expr.compile().compile(compile_kwargs={'literal_binds': True})
).splitlines()
)
expected = """\
WITH anon_1 AS
(SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric
FROM functional_alltypes AS t0 GROUP BY t0.string_col),
anon_2 AS
(SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric
FROM functional_alltypes AS t0 GROUP BY t0.string_col),
anon_3 AS
(SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric
FROM functional_alltypes AS t0 GROUP BY t0.string_col)
(SELECT anon_1.string_col, anon_1.metric
FROM anon_1 {} SELECT anon_2.string_col, anon_2.metric
FROM anon_2) {} SELECT anon_3.string_col, anon_3.metric
FROM anon_3""".format(
expected1, expected2
expected = (
"WITH anon_1 AS "
"(SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric "
"FROM functional_alltypes AS t0 GROUP BY t0.string_col), "
"anon_2 AS "
"(SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric "
"FROM functional_alltypes AS t0 GROUP BY t0.string_col), "
"anon_3 AS "
"(SELECT t0.string_col AS string_col, sum(t0.double_col) AS metric "
"FROM functional_alltypes AS t0 GROUP BY t0.string_col) "
"SELECT anon_1.string_col, anon_1.metric "
f"FROM anon_1 {union} SELECT anon_2.string_col, anon_2.metric "
f"FROM anon_2 {union} SELECT anon_3.string_col, anon_3.metric "
"FROM anon_3"
)
assert str(result) == expected

Expand Down
70 changes: 60 additions & 10 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Mapping
from typing import Any, Mapping

import pandas as pd
import pyspark
from pyspark.sql import DataFrame
import sqlalchemy as sa
from pydantic import Field
from pyspark import SparkConf
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.column import Column

if TYPE_CHECKING:
import ibis.expr.types as ir
import ibis.expr.operations as ops

import ibis.common.exceptions as com
import ibis.config
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as types
import ibis.expr.types as ir
import ibis.util as util
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.base.sql.compiler import Compiler, TableSetFormatter
from ibis.backends.base.sql.ddl import (
CreateDatabase,
DropTable,
Expand Down Expand Up @@ -83,12 +86,49 @@ def __exit__(self, exc_type, exc_value, traceback):
"""No-op for compatibility."""


class PySparkTableSetFormatter(TableSetFormatter):
def _format_in_memory_table(self, op):
# we don't need to compile the table to a VALUES statement because the
# table has been registered already by createOrReplaceTempView.
#
# The only place where the SQL API is currently used is DDL operations
return op.name


class PySparkCompiler(Compiler):
cheap_in_memory_tables = True
table_set_formatter_class = PySparkTableSetFormatter


class Backend(BaseSQLBackend):
compiler = PySparkCompiler
name = 'pyspark'
table_class = PySparkDatabaseTable
table_expr_class = PySparkTable

def do_connect(self, session: pyspark.sql.SparkSession) -> None:
class Options(ibis.config.BaseModel):
treat_nan_as_null: bool = Field(
default=False,
description="Treat NaNs in floating point expressions as NULL.",
)

def _from_url(self, url: str) -> Backend:
"""Construct a PySpark backend from a URL `url`."""
url = sa.engine.make_url(url)

conf = SparkConf().setAll(url.query.items())

if database := url.database:
conf = conf.set(
"spark.sql.warehouse.dir",
str(Path(database).absolute()),
)

builder = SparkSession.builder.config(conf=conf)
session = builder.getOrCreate()
return self.connect(session)

def do_connect(self, session: SparkSession) -> None:
"""Create a PySpark `Backend` for use with Ibis.
Parameters
Expand All @@ -99,9 +139,10 @@ def do_connect(self, session: pyspark.sql.SparkSession) -> None:
Examples
--------
>>> import ibis
>>> import pyspark
>>> session = pyspark.sql.SparkSession.builder.getOrCreate()
>>> from pyspark.sql import SparkSession
>>> session = SparkSession.builder.getOrCreate()
>>> ibis.pyspark.connect(session)
<ibis.backends.pyspark.Backend at 0x...>
"""
self._context = session.sparkContext
self._session = session
Expand Down Expand Up @@ -164,7 +205,10 @@ def compile(self, expr, timecontext=None, params=None, *args, **kwargs):
timecontext,
)
return PySparkExprTranslator().translate(
expr, scope=scope, timecontext=timecontext
expr,
scope=scope,
timecontext=timecontext,
session=self._session,
)

def execute(
Expand Down Expand Up @@ -417,6 +461,8 @@ def create_table(
table_name, format=format, mode=mode
)
return
else:
self._register_in_memory_tables(obj)

ast = self.compiler.to_ast(obj)
select = ast.queries[0]
Expand All @@ -441,6 +487,10 @@ def create_table(

return self.raw_sql(statement.compile())

def _register_in_memory_table(self, table_op):
spark_df = self.compile(table_op.to_expr())
spark_df.createOrReplaceTempView(table_op.name)

def create_view(
self,
name: str,
Expand Down
464 changes: 289 additions & 175 deletions ibis/backends/pyspark/compiler.py

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def get_common_spark_testing_client(data_directory, connect):

def get_pyspark_testing_client(data_directory):
return get_common_spark_testing_client(
data_directory,
lambda session: ibis.backends.pyspark.Backend().connect(session),
data_directory, ibis.pyspark.connect
)


Expand All @@ -254,15 +253,16 @@ def client(data_directory):
df = df.withColumn("str_col", F.lit('value'))
df.createTempView('basic_table')

df_nans = client._session.createDataFrame(
df_nulls = client._session.createDataFrame(
[
[np.NaN, 'Alfred', None],
[27.0, 'Batman', 'motocycle'],
[3.0, None, 'joker'],
['k1', np.NaN, 'Alfred', None],
['k1', 3.0, None, 'joker'],
['k2', 27.0, 'Batman', 'batmobile'],
['k2', None, 'Catwoman', 'motorcycle'],
],
['age', 'user', 'toy'],
['key', 'age', 'user', 'toy'],
)
df_nans.createTempView('nan_table')
df_nulls.createTempView('null_table')

df_dates = client._session.createDataFrame(
[['2018-01-02'], ['2018-01-03'], ['2018-01-04']], ['date_str']
Expand Down
47 changes: 47 additions & 0 deletions ibis/backends/pyspark/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
from pytest import param

import ibis

pytest.importorskip("pyspark")


@pytest.fixture
def treat_nan_as_null():
treat_nan_as_null = ibis.options.pyspark.treat_nan_as_null
ibis.options.pyspark.treat_nan_as_null = True
try:
yield
finally:
ibis.options.pyspark.treat_nan_as_null = treat_nan_as_null


@pytest.mark.parametrize(
('result_fn', 'expected_fn'),
[
param(
lambda t: t.age.count(),
lambda t: len(t.age.dropna()),
id='count',
),
param(
lambda t: t.age.sum(),
lambda t: t.age.sum(),
id='sum',
),
],
)
def test_aggregation_float_nulls(
client,
result_fn,
expected_fn,
treat_nan_as_null,
):
table = client.table('null_table')
df = table.compile().toPandas()

expr = result_fn(table)
result = expr.execute()

expected = expected_fn(df)
assert pytest.approx(expected) == result
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/tests/test_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_isnull(client):
table = client.table('nan_table')
table = client.table('null_table')
table_pandas = table.compile().toPandas()

for (col, _) in table_pandas.iteritems():
Expand All @@ -22,7 +22,7 @@ def test_isnull(client):


def test_notnull(client):
table = client.table('nan_table')
table = client.table('null_table')
table_pandas = table.compile().toPandas()

for (col, _) in table_pandas.iteritems():
Expand Down
21 changes: 16 additions & 5 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import sqlite3
import warnings
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -49,15 +50,19 @@ def __getstate__(self) -> dict:
)
return r

def do_connect(self, path: str | Path | None = None) -> None:
def do_connect(
self,
database: str | Path | None = None,
path: str | Path | None = None,
) -> None:
"""Create an Ibis client connected to a SQLite database.
Multiple database files can be created using the `attach()` method
Multiple database files can be accessed using the `attach()` method.
Parameters
----------
path
File path to the SQLite database file. If None, creates an
database
File path to the SQLite database file. If `None`, creates an
in-memory transient database and you can use attach() to add more
files
Expand All @@ -66,10 +71,16 @@ def do_connect(self, path: str | Path | None = None) -> None:
>>> import ibis
>>> ibis.sqlite.connect("path/to/my/sqlite.db")
"""
if path is not None:
warnings.warn(
"The `path` argument is deprecated in 4.0. Use `database=...`"
)
database = path

self.database_name = "main"

engine = sa.create_engine(
f"sqlite:///{path if path is not None else ':memory:'}"
f"sqlite:///{database if database is not None else ':memory:'}"
)

sqlite3.register_adapter(pd.Timestamp, lambda value: value.isoformat())
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,5 +359,8 @@ def _time_from_hms(t, expr):
ops.Degrees: unary(sa.func._ibis_sqlite_degrees),
ops.Radians: unary(sa.func._ibis_sqlite_radians),
ops.Clip: _clip(min_func=sa.func.min, max_func=sa.func.max),
# sqlite doesn't implement a native xor operator
ops.BitwiseXor: fixed_arity(sa.func._ibis_sqlite_xor, 2),
ops.BitwiseNot: unary(sa.func._ibis_sqlite_inv),
}
)
10 changes: 10 additions & 0 deletions ibis/backends/sqlite/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@ def _ibis_sqlite_radians(x):
return None if x is None else math.radians(x)


@udf
def _ibis_sqlite_xor(x, y):
return None if x is None or y is None else x ^ y


@udf
def _ibis_sqlite_inv(x):
return None if x is None else ~x


class _ibis_sqlite_var:
def __init__(self, offset):
self.mean = 0.0
Expand Down
98 changes: 71 additions & 27 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ def mean_udf(s):
aggregate_test_params = [
param(
lambda t: t.double_col.mean(),
lambda s: s.mean(),
'double_col',
lambda t: t.double_col.mean(),
id='mean',
),
param(
lambda t: mean_udf(t.double_col),
lambda s: s.mean(),
'double_col',
lambda t: t.double_col.mean(),
id='mean_udf',
marks=[
pytest.mark.notimpl(
Expand All @@ -35,55 +33,75 @@ def mean_udf(s):
),
param(
lambda t: t.double_col.min(),
lambda s: s.min(),
'double_col',
lambda t: t.double_col.min(),
id='min',
),
param(
lambda t: t.double_col.max(),
lambda s: s.max(),
'double_col',
lambda t: t.double_col.max(),
id='max',
),
param(
lambda t: (t.double_col + 5).sum(),
lambda s: (s + 5).sum(),
'double_col',
lambda t: (t.double_col + 5).sum(),
id='complex_sum',
),
param(
lambda t: t.timestamp_col.max(),
lambda s: s.max(),
'timestamp_col',
lambda t: t.timestamp_col.max(),
id='timestamp_max',
),
]

argidx_not_grouped_marks = [
"datafusion",
"impala",
"mysql",
"postgres",
"pyspark",
"sqlite",
]
argidx_grouped_marks = ["dask"] + argidx_not_grouped_marks


def make_argidx_params(marks):
marks = pytest.mark.notyet(marks)
return [
param(
lambda t: t.timestamp_col.argmin(t.int_col),
lambda s: s.timestamp_col.iloc[s.int_col.argmin()],
id='argmin',
marks=marks,
),
param(
lambda t: t.double_col.argmax(t.int_col),
lambda s: s.double_col.iloc[s.int_col.argmax()],
id='argmax',
marks=marks,
),
]


@pytest.mark.parametrize(
('result_fn', 'expected_fn', 'expected_col'),
aggregate_test_params,
('result_fn', 'expected_fn'),
aggregate_test_params + make_argidx_params(argidx_not_grouped_marks),
)
def test_aggregate(
backend, alltypes, df, result_fn, expected_fn, expected_col
):
def test_aggregate(backend, alltypes, df, result_fn, expected_fn):
expr = alltypes.aggregate(tmp=result_fn)
result = expr.execute()

# Create a single-row single-column dataframe with the Pandas `agg` result
# (to match the output format of Ibis `aggregate`)
expected = pd.DataFrame({'tmp': [df[expected_col].agg(expected_fn)]})
expected = pd.DataFrame({'tmp': [expected_fn(df)]})

backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
('result_fn', 'expected_fn', 'expected_col'),
aggregate_test_params,
('result_fn', 'expected_fn'),
aggregate_test_params + make_argidx_params(argidx_grouped_marks),
)
def test_aggregate_grouped(
backend, alltypes, df, result_fn, expected_fn, expected_col
):
def test_aggregate_grouped(backend, alltypes, df, result_fn, expected_fn):
grouping_key_col = 'bigint_col'

# Two (equivalent) variations:
Expand All @@ -96,8 +114,8 @@ def test_aggregate_grouped(

# Note: Using `reset_index` to get the grouping key as a column
expected = (
df.groupby(grouping_key_col)[expected_col]
.agg(expected_fn)
df.groupby(grouping_key_col)
.apply(expected_fn)
.rename('tmp')
.reset_index()
)
Expand Down Expand Up @@ -217,6 +235,26 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].max(),
id='max',
),
param(
lambda t, where: t.double_col.argmin(t.int_col, where=where),
lambda t, where: t.double_col[where].iloc[
t.int_col[where].argmin()
],
id='argmin',
marks=pytest.mark.notyet(
["impala", "mysql", "postgres", "pyspark", "sqlite"]
),
),
param(
lambda t, where: t.double_col.argmax(t.int_col, where=where),
lambda t, where: t.double_col[where].iloc[
t.int_col[where].argmax()
],
id='argmax',
marks=pytest.mark.notyet(
["impala", "mysql", "postgres", "pyspark", "sqlite"]
),
),
param(
lambda t, where: t.double_col.std(how='sample', where=where),
lambda t, where: t.double_col[where].std(ddof=1),
Expand Down Expand Up @@ -495,7 +533,7 @@ def test_approx_median(alltypes):
L(":") + ":",
"::",
id="expr",
marks=mark.notyet(["duckdb", "impala", "mysql", "pyspark"]),
marks=mark.notyet(["duckdb", "mysql", "pyspark"]),
),
],
)
Expand All @@ -506,9 +544,15 @@ def test_approx_median(alltypes):
param(
lambda t: t.string_col.isin(['1', '7']),
lambda t: t.string_col.isin(['1', '7']),
marks=mark.notimpl(["dask", "pandas"]),
marks=mark.notimpl(["dask"]),
id='is_in',
),
param(
lambda t: t.string_col.notin(['1', '7']),
lambda t: ~t.string_col.isin(['1', '7']),
marks=mark.notimpl(["dask"]),
id='not_in',
),
],
)
@mark.notimpl(["datafusion"])
Expand Down
16 changes: 16 additions & 0 deletions ibis/backends/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from pytest import param

import ibis
import ibis.common.exceptions as exc


def test_backend_name(backend):
Expand Down Expand Up @@ -79,3 +81,17 @@ def test_tables_accessor_tab_completion(con):

keys = con.tables._ipython_key_completions_()
assert 'functional_alltypes' in keys


@pytest.mark.notimpl(["datafusion"], raises=exc.OperationNotDefinedError)
@pytest.mark.parametrize(
"expr_fn",
[
param(lambda t: t.limit(5).limit(10), id="small_big"),
param(lambda t: t.limit(10).limit(5), id="big_small"),
],
)
def test_limit_chain(alltypes, expr_fn):
expr = expr_fn(alltypes)
result = expr.execute()
assert len(result) == 5
388 changes: 378 additions & 10 deletions ibis/backends/tests/test_client.py

Large diffs are not rendered by default.

409 changes: 290 additions & 119 deletions ibis/backends/tests/test_generic.py

Large diffs are not rendered by default.

9 changes: 2 additions & 7 deletions ibis/backends/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ def _pandas_semi_join(left, right, on, **_):


def _pandas_anti_join(left, right, on, **_):
assert len(on) == 1, str(on)
inner = pd.merge(left, right, how="inner", on=on)
filt = left.loc[:, on[0]].isin(inner.loc[:, on[0]])
return left.loc[~filt, :]
inner = pd.merge(left, right, how="left", indicator=True, on=on)
return inner[inner["_merge"] == "left_only"]


def _merge(
Expand Down Expand Up @@ -165,9 +163,6 @@ def test_filtering_join(backend, batting, awards_players, how):
backend.assert_frame_equal(result, expected, check_like=True)


@pytest.mark.skip_backends(
["dask", "pandas"], reason="insane memory explosion"
)
@pytest.mark.notyet(
["pyspark"],
reason="pyspark doesn't support joining on differing column names",
Expand Down
24 changes: 8 additions & 16 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_trig_functions_literals(con, expr, expected):
param(_.dc.atan(), np.arctan, id="atan"),
param(_.dc.atan2(_.dc), lambda c: np.arctan2(c, c), id="atan2"),
param(_.dc.cos(), np.cos, id="cos"),
param(_.dc.cot(), lambda c: np.cos(c) / np.sin(c), id="cot"),
param(_.dc.cot(), lambda c: 1.0 / np.tan(c), id="cot"),
param(_.dc.sin(), np.sin, id="sin"),
param(_.dc.tan(), np.tan, id="tan"),
],
Expand All @@ -217,12 +217,10 @@ def test_trig_functions_literals(con, expr, expected):
)
def test_trig_functions_columns(backend, expr, alltypes, df, expected_fn):
dc_max = df.double_col.max()
result = (
alltypes.mutate(dc=(_.double_col / dc_max).nullifzero())
.select([expr.name("tmp")])
.execute()
.tmp
expr = alltypes.mutate(dc=(_.double_col / dc_max).nullifzero()).select(
tmp=expr
)
result = expr.tmp.execute()
expected = expected_fn(
(df.double_col / dc_max).replace(0.0, np.nan)
).rename("tmp")
Expand Down Expand Up @@ -518,16 +516,10 @@ def test_sa_default_numeric_precision_and_scale(
con.drop_table(table_name, force=True)


@pytest.mark.notimpl(
[
"clickhouse",
"dask",
"datafusion",
"impala",
"pandas",
"pyspark",
"sqlite",
]
@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "sqlite"])
@pytest.mark.notyet(
["clickhouse"],
reason="clickhouse doesn't implement a [0.0, 1.0) random function",
)
def test_random(con):
expr = ibis.random()
Expand Down
93 changes: 93 additions & 0 deletions ibis/backends/tests/test_pretty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import io

import pytest
from pytest import mark, param

import ibis
import ibis.common.exceptions as exc
from ibis import _

sa = pytest.importorskip("sqlalchemy")
pytest.importorskip("sqlglot")


@mark.never(
["dask", "pandas"],
reason="Dask and Pandas are not SQL backends",
raises=(NotImplementedError, AssertionError),
)
@mark.notimpl(
["datafusion", "pyspark"],
reason="Not clear how to extract SQL from the backend",
raises=(exc.OperationNotDefinedError, NotImplementedError, AssertionError),
)
def test_table(con):
expr = con.tables.functional_alltypes.select(c=_.int_col + 1)
buf = io.StringIO()
ibis.show_sql(expr, file=buf)
assert buf.getvalue()


simple_literal = param(
ibis.literal(1),
id="simple_literal",
)
array_literal = param(
ibis.array([1]),
marks=[
mark.never(
["mysql", "sqlite"],
raises=sa.exc.CompileError,
reason="arrays not supported in the backend",
),
mark.notyet(
["impala"],
raises=NotImplementedError,
reason="Impala hasn't implemented array literals",
),
mark.notimpl(
["postgres"],
reason="array literals are not yet implemented",
raises=NotImplementedError,
),
],
id="array_literal",
)
no_structs = mark.never(
["impala", "mysql", "sqlite"],
raises=(NotImplementedError, sa.exc.CompileError),
reason="structs not supported in the backend",
)
no_struct_literals = mark.notimpl(
["postgres"],
reason="struct literals are not yet implemented",
)
not_sql = mark.never(
["pandas", "dask"],
raises=(exc.IbisError, NotImplementedError, AssertionError),
reason="Not a SQL backend",
)
no_sql_extraction = mark.notimpl(
["datafusion", "pyspark"],
reason="Not clear how to extract SQL from the backend",
)


@mark.parametrize(
"expr",
[
simple_literal,
array_literal,
param(
ibis.struct(dict(a=1)),
marks=[no_structs, no_struct_literals],
id="struct_literal",
),
],
)
@not_sql
@no_sql_extraction
def test_literal(backend, expr):
buf = io.StringIO()
ibis.show_sql(expr, dialect=backend.name(), file=buf)
assert buf.getvalue()
130 changes: 130 additions & 0 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pandas as pd
import pytest
from pytest import param

import ibis
from ibis import _


@pytest.fixture
def union_subsets(alltypes, df):
a = alltypes.filter((5200 <= _.id) & (_.id <= 5210))
b = alltypes.filter((5205 <= _.id) & (_.id <= 5215))
c = alltypes.filter((5213 <= _.id) & (_.id <= 5220))

da = df[(5200 <= df.id) & (df.id <= 5210)]
db = df[(5205 <= df.id) & (df.id <= 5215)]
dc = df[(5213 <= df.id) & (df.id <= 5220)]

return (a, b, c), (da, db, dc)


@pytest.mark.parametrize(
"distinct",
[param(False, id="all"), param(True, id="distinct")],
)
@pytest.mark.notimpl(["datafusion"])
def test_union(backend, union_subsets, distinct):
(a, b, c), (da, db, dc) = union_subsets

expr = ibis.union(a, b, c, distinct=distinct).sort_by("id")
result = expr.execute()

expected = (
pd.concat([da, db, dc], axis=0)
.sort_values("id")
.reset_index(drop=True)
)
if distinct:
expected = expected.drop_duplicates("id")

backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["datafusion"])
def test_union_mixed_distinct(backend, union_subsets):
(a, b, c), (da, db, dc) = union_subsets

expr = a.union(b, distinct=True).union(c, distinct=False).sort_by("id")
result = expr.execute()
expected = pd.concat(
[pd.concat([da, db], axis=0).drop_duplicates("id"), dc], axis=0
).sort_values("id")

backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"distinct",
[
param(
False,
marks=pytest.mark.notyet(
["clickhouse", "dask", "pandas", "sqlite"],
reason="backend doesn't support INTERSECT ALL",
),
id="all",
),
param(True, id="distinct"),
],
)
@pytest.mark.notimpl(["datafusion"])
@pytest.mark.notyet(["impala"])
def test_intersect(backend, alltypes, df, distinct):
a = alltypes.filter((5200 <= _.id) & (_.id <= 5210))
b = alltypes.filter((5205 <= _.id) & (_.id <= 5215))
c = alltypes.filter((5195 <= _.id) & (_.id <= 5208))

# Reset index to ensure simple RangeIndex, needed for computing `expected`
df = df.reset_index(drop=True)
da = df[(5200 <= df.id) & (df.id <= 5210)]
db = df[(5205 <= df.id) & (df.id <= 5215)]
dc = df[(5195 <= df.id) & (df.id <= 5208)]

expr = ibis.intersect(a, b, c, distinct=distinct).sort_by("id")
result = expr.execute()

index = da.index.intersection(db.index).intersection(dc.index)
expected = df.iloc[index].sort_values("id").reset_index(drop=True)
if distinct:
expected = expected.drop_duplicates()

backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"distinct",
[
param(
False,
marks=pytest.mark.notyet(
["clickhouse", "dask", "pandas", "sqlite"],
reason="backend doesn't support EXCEPT ALL",
),
id="all",
),
param(True, id="distinct"),
],
)
@pytest.mark.notimpl(["datafusion"])
@pytest.mark.notyet(["impala"])
def test_difference(backend, alltypes, df, distinct):
a = alltypes.filter((5200 <= _.id) & (_.id <= 5210))
b = alltypes.filter((5205 <= _.id) & (_.id <= 5215))
c = alltypes.filter((5195 <= _.id) & (_.id <= 5202))

# Reset index to ensure simple RangeIndex, needed for computing `expected`
df = df.reset_index(drop=True)
da = df[(5200 <= df.id) & (df.id <= 5210)]
db = df[(5205 <= df.id) & (df.id <= 5215)]
dc = df[(5195 <= df.id) & (df.id <= 5202)]

expr = ibis.difference(a, b, c, distinct=distinct).sort_by("id")
result = expr.execute()

index = da.index.difference(db.index).difference(dc.index)
expected = df.iloc[index].sort_values("id").reset_index(drop=True)
if distinct:
expected = expected.drop_duplicates()

backend.assert_frame_equal(result, expected)
12 changes: 9 additions & 3 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,25 @@ def test_string_col_is_unicode(alltypes, df):
lambda t: t.string_col + t.date_string_col,
lambda t: t.string_col + t.date_string_col,
id='concat_columns',
marks=pytest.mark.notimpl(["datafusion", "impala"]),
marks=pytest.mark.notimpl(["datafusion"]),
),
param(
lambda t: t.string_col + 'a',
lambda t: t.string_col + 'a',
id='concat_column_scalar',
marks=pytest.mark.notimpl(["datafusion", "impala"]),
marks=pytest.mark.notimpl(["datafusion"]),
),
param(
lambda t: 'a' + t.string_col,
lambda t: 'a' + t.string_col,
id='concat_scalar_column',
marks=pytest.mark.notimpl(["datafusion", "impala"]),
marks=pytest.mark.notimpl(["datafusion"]),
),
param(
lambda t: t.string_col.replace("1", "42"),
lambda t: t.string_col.str.replace("1", "42"),
id="replace",
marks=pytest.mark.notimpl(["datafusion"]),
),
],
)
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
@pytest.mark.notimpl(["dask"])
@fields
def test_single_field(backend, struct, struct_df, field):
result = struct.abc[field].execute()
expr = struct.abc[field]
result = expr.execute()
expected = struct_df.abc.map(
lambda value: value[field] if isinstance(value, dict) else value
).rename(field)
Expand Down
49 changes: 47 additions & 2 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,8 @@ def test_integer_cast_to_timestamp(backend, alltypes, df):
),
)
@pytest.mark.notimpl(
["datafusion", "duckdb"],
reason="DataFusion and DuckDB backends assume ns resolution timestamps",
["datafusion"],
reason="DataFusion backend assumes ns resolution timestamps",
)
@pytest.mark.notyet(
["pyspark"],
Expand All @@ -870,3 +870,48 @@ def test_big_timestamp(con):
result = con.execute(value)
expected = datetime.datetime(2419, 10, 11, 10, 10, 25)
assert result == expected


DATE = datetime.date(2010, 11, 1)


def build_date_col(t):
return (
t.year.cast("string")
+ "-"
+ t.month.cast("string").lpad(2, "0")
+ "-"
+ (t.int_col + 1).cast("string").lpad(2, "0")
).cast("date")


@pytest.mark.notimpl(["datafusion"])
@pytest.mark.notyet(["impala"], reason="impala doesn't support dates")
@pytest.mark.parametrize(
("left_fn", "right_fn"),
[
param(build_date_col, lambda _: DATE, id="column_date"),
param(lambda _: DATE, build_date_col, id="date_column"),
],
)
def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn):
left = left_fn(alltypes)
right = right_fn(alltypes)
expr = left == right
result = expr.execute().rename("result")
expected = (
pd.to_datetime(
(
df.year.astype(str)
.add("-")
.add(df.month.astype(str).str.rjust(2, "0"))
.add("-")
.add(df.int_col.add(1).astype(str).str.rjust(2, "0"))
),
format="%Y-%m-%d",
exact=True,
)
.eq(pd.Timestamp(DATE))
.rename("result")
)
backend.assert_series_equal(result, expected)
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_timecontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def ctx_col():


@pytest.mark.notimpl(["dask", "duckdb"])
@pytest.mark.min_spark_version('3.1')
@pytest.mark.min_version(pyspark="3.1")
@pytest.mark.parametrize(
'window',
[
Expand Down
37 changes: 0 additions & 37 deletions ibis/backends/tests/test_union.py

This file was deleted.

2 changes: 1 addition & 1 deletion ibis/backends/tests/test_vectorized_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def test_elementwise_udf_overwrite_destruct_and_assign(
udf_backend.assert_frame_equal(result, expected, check_like=True)


@pytest.mark.min_spark_version('3.1')
@pytest.mark.min_version(pyspark="3.1")
def test_elementwise_udf_destruct_exact_once(udf_backend, udf_alltypes):
with tempfile.TemporaryDirectory() as tempdir:

Expand Down
112 changes: 94 additions & 18 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,22 @@ def calc_zscore(s):
lambda t, win: t.float_col.lead().over(win),
lambda t: t.float_col.shift(-1),
id='lead',
marks=pytest.mark.broken(
["clickhouse"],
reason="upstream is broken; returns all nulls",
),
),
param(
lambda t, win: t.id.rank().over(win),
lambda t: t.id.rank(method='min').astype('int64') - 1,
id='rank',
marks=pytest.mark.min_server_version(clickhouse="22.8"),
),
param(
lambda t, win: t.id.dense_rank().over(win),
lambda t: t.id.rank(method='dense').astype('int64') - 1,
id='dense_rank',
marks=pytest.mark.min_server_version(clickhouse="22.8"),
),
param(
lambda t, win: t.id.percent_rank().over(win),
Expand All @@ -52,12 +58,19 @@ def calc_zscore(s):
)
).reset_index(drop=True, level=[0]),
id='percent_rank',
marks=pytest.mark.notyet(
["clickhouse"],
reason="clickhouse doesn't implement percent_rank",
),
),
param(
lambda t, win: t.id.cume_dist().over(win),
lambda t: t.id.rank(method='min') / t.id.transform(len),
id='cume_dist',
marks=pytest.mark.notimpl(["pyspark"]),
marks=[
pytest.mark.notimpl(["pyspark"]),
pytest.mark.notyet(["clickhouse"]),
],
),
param(
lambda t, win: t.float_col.ntile(buckets=7).over(win),
Expand Down Expand Up @@ -99,7 +112,10 @@ def calc_zscore(s):
lambda _, win: ibis.row_number().over(win),
lambda t: t.cumcount(),
id='row_number',
marks=pytest.mark.notimpl(["pandas"]),
marks=[
pytest.mark.notimpl(["pandas"]),
pytest.mark.min_server_version(clickhouse="22.8"),
],
),
param(
lambda t, win: t.double_col.cumsum().over(win),
Expand Down Expand Up @@ -143,7 +159,14 @@ def calc_zscore(s):
),
id='cumnotany',
marks=pytest.mark.notyet(
("duckdb", 'impala', 'postgres', 'mysql', 'sqlite'),
(
"clickhouse",
"duckdb",
'impala',
'postgres',
'mysql',
'sqlite',
),
reason="notany() over window not supported",
),
),
Expand All @@ -167,7 +190,14 @@ def calc_zscore(s):
),
id='cumnotall',
marks=pytest.mark.notyet(
("duckdb", 'impala', 'postgres', 'mysql', 'sqlite'),
(
"clickhouse",
"duckdb",
'impala',
'postgres',
'mysql',
'sqlite',
),
reason="notall() over window not supported",
),
),
Expand Down Expand Up @@ -204,7 +234,7 @@ def calc_zscore(s):
),
],
)
@pytest.mark.notimpl(["clickhouse", "dask", "datafusion"])
@pytest.mark.notimpl(["dask", "datafusion"])
def test_grouped_bounded_expanding_window(
backend, alltypes, df, result_fn, expected_fn
):
Expand Down Expand Up @@ -244,14 +274,21 @@ def test_grouped_bounded_expanding_window(
id='mean_udf',
marks=[
pytest.mark.notimpl(
["duckdb", "impala", "mysql", "postgres", "sqlite"]
[
"clickhouse",
"duckdb",
"impala",
"mysql",
"postgres",
"sqlite",
]
)
],
),
],
)
# Some backends do not support non-grouped window specs
@pytest.mark.notimpl(["clickhouse", "dask", "datafusion"])
@pytest.mark.notimpl(["dask", "datafusion"])
def test_ungrouped_bounded_expanding_window(
backend, alltypes, df, result_fn, expected_fn
):
Expand All @@ -271,7 +308,7 @@ def test_ungrouped_bounded_expanding_window(
backend.assert_series_equal(left, right)


@pytest.mark.notimpl(["clickhouse", "dask", "datafusion", "pandas"])
@pytest.mark.notimpl(["dask", "datafusion", "pandas"])
def test_grouped_bounded_following_window(backend, alltypes, df):
window = ibis.window(
preceding=0,
Expand Down Expand Up @@ -326,7 +363,7 @@ def test_grouped_bounded_following_window(backend, alltypes, df):
),
],
)
@pytest.mark.notimpl(["clickhouse", "dask", "datafusion"])
@pytest.mark.notimpl(["dask", "datafusion"])
def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn):
window = window_fn(alltypes)

Expand Down Expand Up @@ -363,7 +400,15 @@ def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn):
lambda gb: (gb.double_col.transform('mean')),
id='mean_udf',
marks=pytest.mark.notimpl(
["dask", "duckdb", "impala", "mysql", "postgres", "sqlite"]
[
"clickhouse",
"dask",
"duckdb",
"impala",
"mysql",
"postgres",
"sqlite",
]
),
),
],
Expand All @@ -377,7 +422,7 @@ def test_grouped_bounded_preceding_window(backend, alltypes, df, window_fn):
param(False, id='unordered'),
],
)
@pytest.mark.notimpl(["clickhouse", "datafusion"])
@pytest.mark.notimpl(["datafusion"])
def test_grouped_unbounded_window(
backend, alltypes, df, result_fn, expected_fn, ordered
):
Expand Down Expand Up @@ -417,7 +462,12 @@ def test_grouped_unbounded_window(
lambda df: pd.Series([df.double_col.mean()] * len(df.double_col)),
True,
id='ordered-mean',
marks=pytest.mark.notimpl(["dask", "impala", "pandas"]),
marks=[
pytest.mark.notimpl(["dask", "impala", "pandas"]),
pytest.mark.broken(
["clickhouse"], reason="upstream appears broken"
),
],
),
param(
lambda t, win: t.double_col.mean().over(win),
Expand All @@ -432,6 +482,7 @@ def test_grouped_unbounded_window(
id='ordered-mean_udf',
marks=pytest.mark.notimpl(
[
"clickhouse",
"dask",
"duckdb",
"impala",
Expand All @@ -448,7 +499,14 @@ def test_grouped_unbounded_window(
False,
id='unordered-mean_udf',
marks=pytest.mark.notimpl(
["duckdb", "impala", "mysql", "postgres", "sqlite"]
[
"clickhouse",
"duckdb",
"impala",
"mysql",
"postgres",
"sqlite",
]
),
),
# Analytic ops
Expand All @@ -471,14 +529,16 @@ def test_grouped_unbounded_window(
lambda df: df.float_col.shift(-1),
True,
id='ordered-lead',
marks=pytest.mark.notimpl(["dask"]),
marks=pytest.mark.notimpl(["clickhouse", "dask"]),
),
param(
lambda t, win: t.float_col.lead().over(win),
lambda df: df.float_col.shift(-1),
False,
id='unordered-lead',
marks=pytest.mark.notimpl(["dask", "mysql", "pyspark"]),
marks=pytest.mark.notimpl(
["clickhouse", "dask", "mysql", "pyspark"]
),
),
param(
lambda t, win: calc_zscore(t.double_col).over(win),
Expand All @@ -487,6 +547,7 @@ def test_grouped_unbounded_window(
id='ordered-zscore_udf',
marks=pytest.mark.notimpl(
[
"clickhouse",
"dask",
"duckdb",
"impala",
Expand All @@ -504,13 +565,21 @@ def test_grouped_unbounded_window(
False,
id='unordered-zscore_udf',
marks=pytest.mark.notimpl(
["duckdb", "impala", "mysql", "postgres", "pyspark", "sqlite"]
[
"clickhouse",
"duckdb",
"impala",
"mysql",
"postgres",
"pyspark",
"sqlite",
]
),
),
],
)
# Some backends do not support non-grouped window specs
@pytest.mark.notimpl(["clickhouse", "datafusion"])
@pytest.mark.notimpl(["datafusion"])
def test_ungrouped_unbounded_window(
backend, alltypes, df, con, result_fn, expected_fn, ordered
):
Expand Down Expand Up @@ -541,7 +610,11 @@ def test_ungrouped_unbounded_window(
backend.assert_series_equal(left, right)


@pytest.mark.notimpl(["clickhouse", "dask", "datafusion", "impala", "pandas"])
@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas"])
@pytest.mark.notyet(
["clickhouse"],
reason="RANGE OFFSET frame for 'DB::ColumnNullable' ORDER BY column is not implemented", # noqa: E501
)
def test_grouped_bounded_range_window(backend, alltypes, df):
# Explanation of the range window spec below:
#
Expand Down Expand Up @@ -591,6 +664,9 @@ def gb_fn(df):


@pytest.mark.notimpl(["clickhouse", "dask", "datafusion", "pyspark"])
@pytest.mark.notyet(
["clickhouse"], reason="clickhouse doesn't implement percent_rank"
)
def test_percent_rank_whole_table_no_order_by(backend, alltypes, df):
expr = alltypes.mutate(val=lambda t: t.id.percent_rank())

Expand Down
12 changes: 11 additions & 1 deletion ibis/common/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ def dispatch(self, s: str) -> Any:
for r, func in self.funcs.items()
if (match := r.match(s)) is not None
)
return max(funcs, key=lambda pair: self.priorities.get(pair[0]))
priorities = self.priorities
value = max(
funcs,
key=lambda pair: priorities.get(pair[0]),
default=None,
)
if value is None:
raise NotImplementedError(
f"no pattern for `{self.name}` matches input string: {s!r}"
)
return value

def __call__(self, s: str, *args: Any, **kwargs: Any) -> Any:
func, match = self.dispatch(s)
Expand Down
23 changes: 16 additions & 7 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
from typing import Any, Hashable
from weakref import WeakValueDictionary

from rich.console import Console

from ibis.common.caching import WeakCache
from ibis.common.validators import ImmutableProperty, Optional, Validator
from ibis.util import frozendict

EMPTY = inspect.Parameter.empty # marker for missing argument

console = Console()


class BaseMeta(ABCMeta):

Expand Down Expand Up @@ -62,6 +66,17 @@ def validate(self, this, arg):
return self.validator(arg, this=this)


class Immutable(Hashable):

__slots__ = ()

def __setattr__(self, name: str, _: Any) -> None:
raise TypeError(
f"Attribute {name!r} cannot be assigned to immutable instance of "
f"type {type(self)}"
)


class AnnotableMeta(BaseMeta):
"""
Metaclass to turn class annotations into a validatable function signature.
Expand Down Expand Up @@ -128,7 +143,7 @@ def __new__(metacls, clsname, bases, dct):
return super().__new__(metacls, clsname, bases, attribs)


class Annotable(Base, Hashable, metaclass=AnnotableMeta):
class Annotable(Base, Immutable, metaclass=AnnotableMeta):
"""Base class for objects with custom validation rules."""

__slots__ = ("args", "_hash")
Expand Down Expand Up @@ -178,12 +193,6 @@ def __hash__(self):
def __eq__(self, other):
return super().__eq__(other)

def __setattr__(self, name: str, _: Any) -> None:
raise TypeError(
f"Attribute {name!r} cannot be assigned to immutable instance of "
f"type {type(self)}"
)

def __repr__(self) -> str:
args = ", ".join(
f"{name}={value!r}"
Expand Down
108 changes: 108 additions & 0 deletions ibis/common/pretty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from typing import IO

import ibis
import ibis.common.exceptions as com
import ibis.expr.types as ir

_IBIS_TO_SQLGLOT_NAME_MAP = {
# not 100% accurate, but very close
"impala": "hive",
# for now map clickhouse to Hive so that _something_ works
"clickhouse": "mysql",
}


def show_sql(
expr: ir.Expr,
dialect: str | None = None,
file: IO[str] | None = None,
) -> None:
"""Pretty-print the compiled SQL string of an expression.
If a dialect cannot be inferred and one was not passed, duckdb
will be used as the dialect
Parameters
----------
expr
Ibis expression whose SQL will be printed
dialect
String dialect. This is typically not required, but can be useful if
ibis cannot infer the backend dialect.
file
File to write output to
Examples
--------
>>> import ibis
>>> from ibis import _
>>> t = ibis.table(dict(a="int"), name="t")
>>> expr = t.select(c=_.a * 2)
>>> ibis.show_sql(expr) # duckdb dialect by default
SELECT
t0.a * CAST(2 AS SMALLINT) AS c
FROM t AS t0
>>> ibis.show_sql(expr, dialect="mysql")
SELECT
t0.a * 2 AS c
FROM t AS t0
"""
print(to_sql(expr, dialect=dialect), file=file)


def to_sql(expr: ir.Expr, dialect: str | None = None) -> str:
"""Return the formatted SQL string for an expression.
Parameters
----------
expr
Ibis expression.
dialect
SQL dialect to use for compilation.
Returns
-------
str
Formatted SQL string
"""
import sqlglot

# try to infer from a non-str expression or if not possible fallback to
# the default pretty dialect for expressions
if dialect is None:
try:
backend = expr._find_backend()
except com.IbisError:
# default to duckdb for sqlalchemy compilation because it supports
# the widest array of ibis features for SQL backends
read = "duckdb"
write = ibis.options.sql.default_dialect
else:
read = write = backend.name
else:
read = write = dialect

write = _IBIS_TO_SQLGLOT_NAME_MAP.get(write, write)

try:
compiled = expr.compile()
except com.IbisError:
backend = getattr(ibis, read)
compiled = backend.compile(expr)
try:
sql = str(compiled.compile(compile_kwargs={"literal_binds": True}))
except (AttributeError, TypeError):
sql = compiled

assert isinstance(
sql, str
), f"expected `str`, got `{sql.__class__.__name__}`"
(pretty,) = sqlglot.transpile(
sql,
read=_IBIS_TO_SQLGLOT_NAME_MAP.get(read, read),
write=write,
pretty=True,
)
return pretty
40 changes: 39 additions & 1 deletion ibis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class SQL(BaseModel):
"explicit limit. [`None`][None] means no limit."
),
)
default_dialect: str = Field(
default="duckdb",
description=(
"Dialect to use for printing SQL when the backend cannot be "
"determined."
),
)


class Repr(BaseModel):
Expand Down Expand Up @@ -79,6 +86,31 @@ def query_text_length_ge_zero(cls, query_text_length: int) -> int:
return query_text_length


_HAS_DUCKDB = True
_DUCKDB_CON = None


def _default_backend() -> Any:
global _HAS_DUCKDB, _DUCKDB_CON

if not _HAS_DUCKDB:
return None

if _DUCKDB_CON is not None:
return _DUCKDB_CON

try:
import duckdb as _ # noqa: F401
except ImportError:
_HAS_DUCKDB = False
return None

import ibis

_DUCKDB_CON = ibis.duckdb.connect(":memory:")
return _DUCKDB_CON


class Options(BaseSettings):
"""Ibis configuration options."""

Expand All @@ -99,10 +131,15 @@ class Options(BaseSettings):
default=False,
description="Render expressions as GraphViz PNGs when running in a Jupyter notebook.", # noqa: E501
)

default_backend: Any = Field(
default=None,
description="The default backend to use for execution.",
description=(
"The default backend to use for execution. "
"Defaults to DuckDB if not set."
),
)

context_adjustment: ContextAdjustment = Field(
default=ContextAdjustment(),
description=ContextAdjustment.__doc__,
Expand All @@ -113,6 +150,7 @@ class Options(BaseSettings):
dask: Optional[BaseModel] = None
impala: Optional[BaseModel] = None
pandas: Optional[BaseModel] = None
pyspark: Optional[BaseModel] = None

class Config:
validate_assignment = True
Expand Down
35 changes: 27 additions & 8 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
from ibis.common.exceptions import ExpressionError, IbisTypeError
from ibis.common.exceptions import (
ExpressionError,
IbisTypeError,
IntegrityError,
)
from ibis.expr.window import window

# ---------------------------------------------------------------------
Expand Down Expand Up @@ -343,21 +347,36 @@ def _filter_selection(expr, predicates):
# the parent tables in the join being projected

op = expr.op()
if not op.blocks():
# Potential fusion opportunity. The predicates may need to be
# rewritten in terms of the child table. This prevents the broken
# ref issue (described in more detail in #59)
# Potential fusion opportunity. The predicates may need to be
# rewritten in terms of the child table. This prevents the broken
# ref issue (described in more detail in #59)
try:
simplified_predicates = tuple(
sub_for(predicate, [(expr, op.table)])
if not is_reduction(predicate)
else predicate
for predicate in predicates
)

if shares_all_roots(simplified_predicates, op.table):
except IntegrityError:
pass
else:
if shares_all_roots(simplified_predicates, op.table) and not any(
# we can't push down filters on unnest because unnest changes the
# shape and potential values of the data: unnest can potentially
# produce NULLs
#
# the getattr shenanigans is to handle Alias
isinstance(
child_op.arg.op()
if isinstance(child_op := sel.op(), ops.Alias)
else child_op,
ops.Unnest,
)
for sel in op.selections
):
result = ops.Selection(
op.table,
[],
selections=op.selections,
predicates=op.predicates + simplified_predicates,
sort_keys=op.sort_keys,
)
Expand Down
Loading