232 changes: 163 additions & 69 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict

import sqlalchemy as sa
import sqlalchemy.sql as sql

import ibis
import ibis.common.exceptions as com
Expand All @@ -13,6 +12,7 @@
import ibis.expr.types as ir
import ibis.expr.window as W
from ibis.backends.base.sql.alchemy.database import AlchemyTable
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported


Expand Down Expand Up @@ -99,7 +99,7 @@ def get_sqla_table(ctx, table):

def get_col_or_deferred_col(sa_table, colname):
"""
Get a `ColumnExpr`, or create a "deferred" column.
Get a `Column`, or create a "deferred" column.
This is to handle the case when selecting a column from a join, which
happens when a join expression is cached during join traversal
Expand Down Expand Up @@ -165,39 +165,51 @@ def _exists_subquery(t, expr):


def _cast(t, expr):
op = expr.op()
arg, target_type = op.args
arg, typ = expr.op().args

sa_arg = t.translate(arg)
sa_type = t.get_sqla_type(target_type)
sa_type = t.get_sqla_type(typ)

if isinstance(arg, ir.CategoryValue) and target_type == 'int32':
if isinstance(arg, ir.CategoryValue) and typ == dt.int32:
return sa_arg
else:
return sa.cast(sa_arg, sa_type)

# specialize going from an integer type to a timestamp
if isinstance(arg.type(), dt.Integer) and isinstance(sa_type, sa.DateTime):
return t.integer_to_timestamp(sa_arg)

def _contains(t, expr):
op = expr.op()
if arg.type().equals(dt.binary) and typ.equals(dt.string):
return sa.func.encode(sa_arg, 'escape')

left, right = (t.translate(arg) for arg in op.args)
if typ.equals(dt.binary):
# decode yields a column of memoryview which is annoying to deal with
# in pandas. CAST(expr AS BYTEA) is correct and returns byte strings.
return sa.cast(sa_arg, sa.LargeBinary())

return left.in_(right)
return sa.cast(sa_arg, sa_type)


def _not_contains(t, expr):
return sa.not_(_contains(t, expr))


def reduction(sa_func):
def formatter(t, expr):
def _contains(func):
def translate(t, expr):
op = expr.op()
if op.where is not None:
arg = t.translate(op.where.ifelse(op.arg, ibis.NA))

left = t.translate(op.value)
right = t.translate(op.options)

if (
# not a list expr
not isinstance(op.options, ir.ValueList)
# but still a column expr
and isinstance(op.options, ir.ColumnExpr)
# wasn't already compiled into a select statement
and not isinstance(right, sa.sql.Selectable)
):
right = sa.select(right)
else:
arg = t.translate(op.arg)
return sa_func(arg)
right = t.translate(op.options)

return formatter
return func(left, right)

return translate


def _group_concat(t, expr):
Expand All @@ -217,7 +229,7 @@ def _alias(t, expr):
return t.translate(op.arg)


def _literal(t, expr):
def _literal(_, expr):
dtype = expr.type()
value = expr.op().value

Expand Down Expand Up @@ -260,17 +272,6 @@ def _floor_divide(t, expr):
return sa.func.floor(left / right)


def _count_distinct(t, expr):
arg, where = expr.op().args

if where is not None:
sa_arg = t.translate(where.ifelse(arg, None))
else:
sa_arg = t.translate(arg)

return sa.func.count(sa_arg.distinct())


def _simple_case(t, expr):
op = expr.op()

Expand Down Expand Up @@ -360,6 +361,7 @@ def _window(t, expr):
ops.MinRank,
ops.NTile,
ops.PercentRank,
ops.CumeDist,
)

if isinstance(window_op, ops.CumulativeOp):
Expand Down Expand Up @@ -388,6 +390,7 @@ def _window(t, expr):
ops.MinRank,
ops.NTile,
ops.PercentRank,
ops.CumeDist,
ops.RowNumber,
)

Expand Down Expand Up @@ -437,7 +440,7 @@ def _lead(t, expr):
def _ntile(t, expr):
op = expr.op()
args = op.args
arg, buckets = map(t.translate, args)
_, buckets = map(t.translate, args)
return sa.func.ntile(buckets)


Expand All @@ -453,23 +456,97 @@ def _string_join(t, expr):
return sa.func.concat_ws(t.translate(sep), *map(t.translate, elements))


def reduction(sa_func):
def compile_expr(t, expr):
return t._reduction(sa_func, expr)

return compile_expr


def _zero_if_null(t, expr):
op = expr.op()
arg = op.arg
sa_arg = t.translate(op.arg)
return sa.case(
[(sa_arg.is_(None), sa.cast(0, to_sqla_type(arg.type())))],
else_=sa_arg,
)


def _substring(t, expr):
op = expr.op()

args = t.translate(op.arg), t.translate(op.start) + 1

if (length := op.length) is not None:
args += (t.translate(length),)

return sa.func.substr(*args)


def _gen_string_find(func):
def string_find(t, expr):
op = expr.op()

if op.start is not None:
raise NotImplementedError("`start` not yet implemented")

if op.end is not None:
raise NotImplementedError("`end` not yet implemented")

return func(t.translate(op.arg), t.translate(op.substr)) - 1

return string_find


def _nth_value(t, expr):
op = expr.op()
return sa.func.nth_value(t.translate(op.arg), t.translate(op.nth) + 1)


def _clip(*, min_func, max_func):
def translate(t, expr):
op = expr.op()
arg = t.translate(op.arg)

if (upper := op.upper) is not None:
arg = min_func(t.translate(upper), arg)

if (lower := op.lower) is not None:
arg = max_func(t.translate(lower), arg)

return arg

return translate


sqlalchemy_operation_registry: Dict[Any, Any] = {
ops.Alias: _alias,
ops.And: fixed_arity(sql.and_, 2),
ops.Or: fixed_arity(sql.or_, 2),
ops.And: fixed_arity(operator.and_, 2),
ops.Or: fixed_arity(operator.or_, 2),
ops.Xor: fixed_arity(lambda x, y: (x | y) & ~(x & y), 2),
ops.Not: unary(sa.not_),
ops.Abs: unary(sa.func.abs),
ops.Cast: _cast,
ops.Coalesce: varargs(sa.func.coalesce),
ops.NullIf: fixed_arity(sa.func.nullif, 2),
ops.Contains: _contains,
ops.NotContains: _not_contains,
ops.Contains: _contains(lambda left, right: left.in_(right)),
ops.NotContains: _contains(lambda left, right: left.notin_(right)),
ops.Count: reduction(sa.func.count),
ops.Sum: reduction(sa.func.sum),
ops.Mean: reduction(sa.func.avg),
ops.Min: reduction(sa.func.min),
ops.Max: reduction(sa.func.max),
ops.CountDistinct: _count_distinct,
ops.Variance: variance_reduction("var"),
ops.StandardDev: variance_reduction("stddev"),
ops.BitAnd: reduction(sa.func.bit_and),
ops.BitOr: reduction(sa.func.bit_or),
ops.BitXor: reduction(sa.func.bit_xor),
ops.CountDistinct: reduction(lambda arg: sa.func.count(arg.distinct())),
ops.HLLCardinality: reduction(lambda arg: sa.func.count(arg.distinct())),
ops.ApproxCountDistinct: reduction(
lambda arg: sa.func.count(arg.distinct())
),
ops.GroupConcat: _group_concat,
ops.Between: fixed_arity(sa.between, 3),
ops.IsNull: _is_null,
Expand Down Expand Up @@ -501,6 +578,7 @@ def _string_join(t, expr):
ops.Lowercase: unary(sa.func.lower),
ops.Uppercase: unary(sa.func.upper),
ops.StringAscii: unary(sa.func.ascii),
ops.StringFind: _gen_string_find(sa.func.strpos),
ops.StringLength: unary(sa.func.length),
ops.StringJoin: _string_join,
ops.StringReplace: fixed_arity(sa.func.replace, 3),
Expand All @@ -509,6 +587,7 @@ def _string_join(t, expr):
ops.StartsWith: _startswith,
ops.EndsWith: _endswith,
ops.StringConcat: varargs(sa.func.concat),
ops.Substring: _substring,
# math
ops.Ln: unary(sa.func.ln),
ops.Exp: unary(sa.func.exp),
Expand All @@ -518,6 +597,16 @@ def _string_join(t, expr):
ops.Floor: unary(sa.func.floor),
ops.Power: fixed_arity(sa.func.pow, 2),
ops.FloorDivide: _floor_divide,
ops.Acos: unary(sa.func.acos),
ops.Asin: unary(sa.func.asin),
ops.Atan: unary(sa.func.atan),
ops.Atan2: fixed_arity(sa.func.atan2, 2),
ops.Cos: unary(sa.func.cos),
ops.Sin: unary(sa.func.sin),
ops.Tan: unary(sa.func.tan),
ops.Cot: unary(sa.func.cot),
ops.Pi: fixed_arity(sa.func.pi, 0),
ops.E: fixed_arity(lambda: sa.func.exp(1), 0),
# other
ops.SortKey: _sort_key,
ops.Date: unary(lambda arg: sa.cast(arg, sa.DATE)),
Expand All @@ -526,29 +615,36 @@ def _string_join(t, expr):
ops.TimestampFromYMDHMS: lambda t, expr: sa.func.make_timestamp(
*map(t.translate, expr.op().args[:6]) # ignore timezone
),
}


# TODO: unit tests for each of these
_binary_ops = {
ops.Degrees: unary(sa.func.degrees),
ops.Radians: unary(sa.func.radians),
ops.ZeroIfNull: _zero_if_null,
ops.RandomScalar: fixed_arity(sa.func.random, 0),
# Binary arithmetic
ops.Add: operator.add,
ops.Subtract: operator.sub,
ops.Multiply: operator.mul,
ops.Add: fixed_arity(operator.add, 2),
ops.Subtract: fixed_arity(operator.sub, 2),
ops.Multiply: fixed_arity(operator.mul, 2),
# XXX `ops.Divide` is overwritten in `translator.py` with a custom
# function `_true_divide`, but for some reason both are required
ops.Divide: operator.truediv,
ops.Modulus: operator.mod,
ops.Divide: fixed_arity(operator.truediv, 2),
ops.Modulus: fixed_arity(operator.mod, 2),
# Comparisons
ops.Equals: operator.eq,
ops.NotEquals: operator.ne,
ops.Less: operator.lt,
ops.LessEqual: operator.le,
ops.Greater: operator.gt,
ops.GreaterEqual: operator.ge,
ops.IdenticalTo: lambda x, y: x.op('IS NOT DISTINCT FROM')(y),
# Boolean comparisons
# TODO
ops.Equals: fixed_arity(operator.eq, 2),
ops.NotEquals: fixed_arity(operator.ne, 2),
ops.Less: fixed_arity(operator.lt, 2),
ops.LessEqual: fixed_arity(operator.le, 2),
ops.Greater: fixed_arity(operator.gt, 2),
ops.GreaterEqual: fixed_arity(operator.ge, 2),
ops.IdenticalTo: fixed_arity(
sa.sql.expression.ColumnElement.is_not_distinct_from, 2
),
ops.Clip: _clip(min_func=sa.func.least, max_func=sa.func.greatest),
ops.Where: fixed_arity(
lambda predicate, value_if_true, value_if_false: sa.case(
[(predicate, value_if_true)],
else_=value_if_false,
),
3,
),
}


Expand All @@ -558,11 +654,13 @@ def _string_join(t, expr):
ops.NTile: _ntile,
ops.FirstValue: unary(sa.func.first_value),
ops.LastValue: unary(sa.func.last_value),
ops.RowNumber: fixed_arity(lambda: sa.func.row_number(), 0),
ops.DenseRank: unary(lambda arg: sa.func.dense_rank()),
ops.MinRank: unary(lambda arg: sa.func.rank()),
ops.PercentRank: unary(lambda arg: sa.func.percent_rank()),
ops.WindowOp: _window,
ops.RowNumber: fixed_arity(sa.func.row_number, 0),
ops.DenseRank: unary(lambda _: sa.func.dense_rank()),
ops.MinRank: unary(lambda _: sa.func.rank()),
ops.PercentRank: unary(lambda _: sa.func.percent_rank()),
ops.CumeDist: unary(lambda _: sa.func.cume_dist()),
ops.NthValue: _nth_value,
ops.Window: _window,
ops.CumulativeOp: _window,
ops.CumulativeMax: unary(sa.func.max),
ops.CumulativeMin: unary(sa.func.min),
Expand Down Expand Up @@ -634,7 +732,3 @@ def _string_join(t, expr):
}
else:
_geospatial_functions = {}


for _k, _v in _binary_ops.items():
sqlalchemy_operation_registry[_k] = fixed_arity(_v, 2)
35 changes: 35 additions & 0 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

import sqlalchemy as sa

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
Expand Down Expand Up @@ -40,12 +45,42 @@ class AlchemyExprTranslator(ExprTranslator):

context_class = AlchemyContext

_bool_aggs_need_cast_to_int32 = True
_has_reduction_filter_syntax = False

integer_to_timestamp = sa.func.to_timestamp

def name(self, translated, name, force=True):
return translated.label(name)

def get_sqla_type(self, data_type):
return to_sqla_type(data_type, type_map=self._type_map)

def _reduction(self, sa_func, expr):
op = expr.op()
arg = op.arg
if (
self._bool_aggs_need_cast_to_int32
and isinstance(op, (ops.Sum, ops.Mean, ops.Min, ops.Max))
and isinstance(
type := arg.type(),
dt.Boolean,
)
):
arg = arg.cast(dt.Int32(nullable=type.nullable))

if (where := op.where) is not None:
if self._has_reduction_filter_syntax:
return sa_func(self.translate(arg)).filter(
self.translate(where)
)
else:
sa_arg = self.translate(where.ifelse(arg, None))
else:
sa_arg = self.translate(arg)

return sa_func(sa_arg)


rewrites = AlchemyExprTranslator.rewrites

Expand Down
4 changes: 4 additions & 0 deletions ibis/backends/base/sql/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ibis.backends.base.sql.compiler.base import DDL, DML
from ibis.backends.base.sql.compiler.query_builder import (
Compiler,
Difference,
Intersection,
Select,
SelectBuilder,
TableSetFormatter,
Expand All @@ -16,6 +18,8 @@
'Select',
'SelectBuilder',
'Union',
'Intersection',
'Difference',
'TableSetFormatter',
'ExprTranslator',
'QueryContext',
Expand Down
15 changes: 11 additions & 4 deletions ibis/backends/base/sql/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

import toolz

import ibis.expr.analysis as an
import ibis.util as util
from ibis.backends.base.sql.compiler.extract_subqueries import (
ExtractSubqueries,
)


class DML(abc.ABC):
Expand Down Expand Up @@ -62,7 +60,9 @@ def __init__(self, tables, expr, context):
self.filters = []

def _extract_subqueries(self):
self.subqueries = ExtractSubqueries.extract(self)
self.subqueries = _extract_common_table_expressions(
[self.table_set, *self.filters]
)
for subquery in self.subqueries:
self.context.set_extracted(subquery)

Expand Down Expand Up @@ -106,3 +106,10 @@ def compile(self):
)
)
return '\n'.join(buf)


def _extract_common_table_expressions(exprs):
counts = an.find_subqueries(exprs)
duplicates = [op.to_expr() for op, count in counts.items() if count > 1]
duplicates.reverse()
return duplicates
133 changes: 0 additions & 133 deletions ibis/backends/base/sql/compiler/extract_subqueries.py

This file was deleted.

54 changes: 36 additions & 18 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def format_select_set(self):
context = self.context
formatted = []
for expr in self.select_set:
if isinstance(expr, ir.ValueExpr):
if isinstance(expr, ir.Value):
expr_str = self._translate(expr, named=True)
elif isinstance(expr, ir.TableExpr):
elif isinstance(expr, ir.Table):
# A * selection, possibly prefixed
if context.need_aliases(expr):
alias = context.get_ref(expr)
Expand Down Expand Up @@ -468,25 +468,29 @@ def _get_keyword_list(self):


class Intersection(SetOp):
_keyword = "INTERSECT"

def _get_keyword_list(self):
return ["INTERSECT"] * (len(self.tables) - 1)
return [self._keyword] * (len(self.tables) - 1)


class Difference(SetOp):
_keyword = "EXCEPT"

def _get_keyword_list(self):
return ["EXCEPT"] * (len(self.tables) - 1)
return [self._keyword] * (len(self.tables) - 1)


def flatten_union(table: ir.TableExpr):
def flatten_union(table: ir.Table):
"""Extract all union queries from `table`.
Parameters
----------
table : TableExpr
table : Table
Returns
-------
Iterable[Union[TableExpr, bool]]
Iterable[Union[Table, bool]]
"""
op = table.op()
if isinstance(op, ops.Union):
Expand All @@ -501,16 +505,16 @@ def flatten_union(table: ir.TableExpr):
return [table]


def flatten(table: ir.TableExpr):
def flatten(table: ir.Table):
"""Extract all intersection or difference queries from `table`.
Parameters
----------
table : TableExpr
table : Table
Returns
-------
Iterable[Union[TableExpr]]
Iterable[Union[Table]]
"""
op = table.op()
return list(toolz.concatv(flatten_union(op.left), flatten_union(op.right)))
Expand All @@ -523,6 +527,8 @@ class Compiler:
table_set_formatter_class = TableSetFormatter
select_class = Select
union_class = Union
intersect_class = Intersection
difference_class = Difference

@classmethod
def make_context(cls, params=None):
Expand Down Expand Up @@ -551,11 +557,11 @@ def to_ast(cls, expr, context=None):
# TODO: any setup / teardown DDL statements will need to be done prior
# to building the result set-generating statements.
if isinstance(op, ops.Union):
query = cls._make_union(cls.union_class, expr, context)
query = cls._make_union(expr, context)
elif isinstance(op, ops.Intersection):
query = Intersection(flatten(expr), expr, context=context)
query = cls._make_intersect(expr, context)
elif isinstance(op, ops.Difference):
query = Difference(flatten(expr), expr, context=context)
query = cls._make_difference(expr, context)
else:
query = cls.select_builder_class().to_select(
select_class=cls.select_class,
Expand All @@ -582,7 +588,7 @@ def to_ast_ensure_limit(cls, expr, limit=None, params=None):
for query in reversed(query_ast.queries):
if (
isinstance(query, Select)
and not isinstance(expr, ir.ScalarExpr)
and not isinstance(expr, ir.Scalar)
and query.table_set is not None
):
if query.limit is None:
Expand Down Expand Up @@ -611,8 +617,8 @@ def _generate_setup_queries(expr, context):
def _generate_teardown_queries(expr, context):
return []

@staticmethod
def _make_union(union_class, expr, context):
@classmethod
def _make_union(cls, expr, context):
# flatten unions so that we can codegen them all at once
union_info = list(flatten_union(expr))

Expand All @@ -624,10 +630,22 @@ def _make_union(union_class, expr, context):
npieces = len(union_info)
assert npieces >= 3 and npieces % 2 != 0, 'Invalid union expression'

# 1. every other object starting from 0 is a TableExpr instance
# 1. every other object starting from 0 is a Table instance
# 2. every other object starting from 1 is a bool indicating the type
# of union (distinct or not distinct)
table_exprs, distincts = union_info[::2], union_info[1::2]
return union_class(
return cls.union_class(
table_exprs, expr, distincts=distincts, context=context
)

@classmethod
def _make_intersect(cls, expr, context):
# flatten intersections so that we can codegen them all at once
table_exprs = list(flatten(expr))
return cls.intersect_class(table_exprs, expr, context=context)

@classmethod
def _make_difference(cls, expr, context):
# flatten differences so that we can codegen them all at once
table_exprs = list(flatten(expr))
return cls.difference_class(table_exprs, expr, context=context)
413 changes: 75 additions & 338 deletions ibis/backends/base/sql/compiler/select_builder.py

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,18 @@ def set_query(self, query):
self.query = query

def is_foreign_expr(self, expr):
from ibis.expr.analysis import ExprValidator
from ibis.expr.analysis import shares_all_roots

# The expression isn't foreign to us. For example, the parent table set
# in a correlated WHERE subquery
if self.has_ref(expr, parent_contexts=True):
return False

exprs = [self.query.table_set] + self.query.select_set
validator = ExprValidator(exprs)
return not validator.validate(expr)
parents = [self.query.table_set] + self.query.select_set
return not shares_all_roots(expr, parents)

def _get_table_key(self, table):
if isinstance(table, ir.TableExpr):
if isinstance(table, ir.Table):
return table.op()
elif isinstance(table, ops.TableNode):
return table
Expand Down
49 changes: 49 additions & 0 deletions ibis/backends/base/sql/registry/binary_infix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from typing import Literal

import ibis.expr.types as ir
from ibis.backends.base.sql.registry import helpers


Expand Down Expand Up @@ -50,3 +55,47 @@ def xor(translator, expr):
right_arg = helpers.parenthesize(right_arg)

return '({0} OR {1}) AND NOT ({0} AND {1})'.format(left_arg, right_arg)


def contains(op_string: Literal["IN", "NOT IN"]) -> str:
def translate(translator, expr):
from ibis.backends.base.sql.registry.main import table_array_view

op = expr.op()

left, right = op.args
if isinstance(right, ir.ValueList) and not right:
return {"NOT IN": "TRUE", "IN": "FALSE"}[op_string]

left_arg = translator.translate(left)
if helpers.needs_parens(left):
left_arg = helpers.parenthesize(left_arg)

ctx = translator.context

# special case non-foreign isin/notin expressions
if (
not isinstance(right, ir.ValueList)
and isinstance(right, ir.ColumnExpr)
# foreign refs are already been compiled correctly during
# TableColumn compilation
and not any(
ctx.is_foreign_expr(leaf.to_expr())
for leaf in right.op().root_tables()
)
):
if not right.has_name():
right = right.name("tmp")
right_arg = table_array_view(
translator,
right.to_projection().to_array(),
)
else:
right_arg = translator.translate(right)

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for ValueList operations

return f"{left_arg} {op_string} {right_arg}"

return translate
54 changes: 28 additions & 26 deletions ibis/backends/base/sql/registry/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,34 @@ def quote_identifier(name, quotechar='`', force=False):
return name


def needs_parens(op):
if isinstance(op, ir.Expr):
op = op.op()
op_klass = type(op)
# function calls don't need parens
return op_klass in {
ops.Negate,
ops.IsNull,
ops.NotNull,
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Power,
ops.Modulus,
ops.Equals,
ops.NotEquals,
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.IdenticalTo,
ops.And,
ops.Or,
ops.Xor,
}
_NEEDS_PARENS_OPS = (
ops.Negate,
ops.IsNull,
ops.NotNull,
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Power,
ops.Modulus,
ops.Equals,
ops.NotEquals,
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.IdenticalTo,
ops.And,
ops.Or,
ops.Xor,
)


def needs_parens(expr: ir.Expr):
op = expr.op()
if isinstance(op, ops.Alias):
op = op.arg.op()
return isinstance(op, _NEEDS_PARENS_OPS)


parenthesize = '({})'.format
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _date_literal_format(translator, expr):
def _timestamp_literal_format(translator, expr):
value = expr.op().value
if isinstance(value, datetime.datetime):
value = value.strftime('%Y-%m-%d %H:%M:%S')
value = value.isoformat()

return repr(value)

Expand Down
20 changes: 16 additions & 4 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,23 @@ def hash(translator, expr):
ops.Ln: unary('ln'),
ops.Log2: unary('log2'),
ops.Log10: unary('log10'),
ops.Acos: unary("acos"),
ops.Asin: unary("asin"),
ops.Atan: unary("atan"),
ops.Atan2: fixed_arity("atan2", 2),
ops.Cos: unary("cos"),
ops.Cot: unary("cot"),
ops.Sin: unary("sin"),
ops.Tan: unary("tan"),
ops.Pi: fixed_arity("pi", 0),
ops.E: fixed_arity("exp(1)", 0),
ops.DecimalPrecision: unary('precision'),
ops.DecimalScale: unary('scale'),
# Unary aggregates
ops.CMSMedian: aggregate.reduction('appx_median'),
ops.HLLCardinality: aggregate.reduction('ndv'),
ops.ApproxMedian: aggregate.reduction('appx_median'),
ops.ApproxCountDistinct: aggregate.reduction('ndv'),
ops.Mean: aggregate.reduction('avg'),
ops.Sum: aggregate.reduction('sum'),
ops.Max: aggregate.reduction('max'),
Expand Down Expand Up @@ -345,8 +357,8 @@ def hash(translator, expr):
ops.Least: varargs('least'),
ops.Where: fixed_arity('if', 3),
ops.Between: between,
ops.Contains: binary_infix.binary_infix_op('IN'),
ops.NotContains: binary_infix.binary_infix_op('NOT IN'),
ops.Contains: binary_infix.contains("IN"),
ops.NotContains: binary_infix.contains("NOT IN"),
ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.TableColumn: table_column,
Expand All @@ -365,12 +377,12 @@ def hash(translator, expr):
ops.DenseRank: lambda *args: 'dense_rank()',
ops.MinRank: lambda *args: 'rank()',
ops.PercentRank: lambda *args: 'percent_rank()',
ops.CumeDist: lambda *args: 'cume_dist()',
ops.FirstValue: unary('first_value'),
ops.LastValue: unary('last_value'),
ops.NthValue: window.nth_value,
ops.Lag: window.shift_like('lag'),
ops.Lead: window.shift_like('lead'),
ops.WindowOp: window.window,
ops.Window: window.window,
ops.NTile: window.ntile,
ops.DayOfWeekIndex: timestamp.day_of_week_index,
ops.DayOfWeekName: timestamp.day_of_week_name,
Expand Down
16 changes: 4 additions & 12 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _foll(f: Optional[int]) -> str:
ops.MinRank,
ops.NTile,
ops.PercentRank,
ops.CumeDist,
ops.RowNumber,
)

Expand Down Expand Up @@ -250,13 +251,14 @@ def window(translator, expr):
ops.FirstValue,
ops.LastValue,
ops.PercentRank,
ops.CumeDist,
ops.NTile,
)

_unsupported_reductions = (
ops.CMSMedian,
ops.ApproxMedian,
ops.GroupConcat,
ops.HLLCardinality,
ops.ApproxCountDistinct,
)

if isinstance(window_op, _unsupported_reductions):
Expand Down Expand Up @@ -291,16 +293,6 @@ def window(translator, expr):
return result


def nth_value(translator, expr):
op = expr.op()
arg, rank = op.args

arg_formatted = translator.translate(arg)
rank_formatted = translator.translate(rank - 1)

return f'first_value(lag({arg_formatted}, {rank_formatted}))'


def shift_like(name):
def formatter(translator, expr):
op = expr.op()
Expand Down
37 changes: 14 additions & 23 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
import ibis.config
import ibis.expr.schema as sch
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.clickhouse.client import (
ClickhouseDataType,
ClickhouseTable,
fully_qualified_re,
)
from ibis.backends.clickhouse.client import ClickhouseTable, fully_qualified_re
from ibis.backends.clickhouse.compiler import ClickhouseCompiler
from ibis.backends.clickhouse.datatypes import parse, serialize
from ibis.config import options

_default_compression: str | bool
Expand Down Expand Up @@ -52,6 +49,7 @@ def do_connect(
compression: (
Literal["lz4", "lz4hc", "quicklz", "zstd"] | bool
) = _default_compression,
**kwargs: Any,
):
"""Create a ClickHouse client for use with Ibis.
Expand Down Expand Up @@ -92,6 +90,7 @@ def do_connect(
password=password,
client_name=client_name,
compression=compression,
**kwargs,
)

@property
Expand All @@ -109,12 +108,12 @@ def current_database(self):
return self.con.connection.database

def list_databases(self, like=None):
data, schema = self.raw_sql('SELECT name FROM system.databases')
data, _ = self.raw_sql('SELECT name FROM system.databases')
databases = list(data[0])
return self._filter_with_like(databases, like)

def list_tables(self, like=None, database=None):
data, schema = self.raw_sql('SHOW TABLES')
data, _ = self.raw_sql('SHOW TABLES')
databases = list(data[0])
return self._filter_with_like(databases, like)

Expand Down Expand Up @@ -152,13 +151,7 @@ def raw_sql(
'name': name,
'data': df.to_dict('records'),
'structure': list(
zip(
schema.names,
[
str(ClickhouseDataType.from_ibis(t))
for t in schema.types
],
)
zip(schema.names, map(serialize, schema.types))
),
}
)
Expand All @@ -175,11 +168,10 @@ def raw_sql(
def fetch_from_cursor(self, cursor, schema):
data, _ = cursor
names = schema.names
if not len(data):
# handle empty resultset
return pd.DataFrame([], columns=names)

df = pd.DataFrame.from_dict(dict(zip(names, data)))
if not data:
df = pd.DataFrame([], columns=names)
else:
df = pd.DataFrame.from_dict(dict(zip(names, data)))
return schema.apply_to(df)

def close(self):
Expand Down Expand Up @@ -216,9 +208,8 @@ def get_schema(
(column_names, types, *_), *_ = self.raw_sql(
f"DESCRIBE {qualified_name}"
)
return sch.Schema.from_tuples(
zip(column_names, map(ClickhouseDataType.parse, types))
)

return sch.Schema.from_tuples(zip(column_names, map(parse, types)))

def set_options(self, options):
self.con.set_options(options)
Expand All @@ -238,7 +229,7 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
)
[plan] = json.loads(raw_plans)
fields = [
(field["Name"], ClickhouseDataType.parse(field["Type"]))
(field["Name"], parse(field["Type"]))
for field in plan["Plan"]["Header"]
]
return sch.Schema.from_tuples(fields)
Expand Down
142 changes: 13 additions & 129 deletions ibis/backends/clickhouse/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,142 +4,22 @@
import numpy as np
import pandas as pd

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.types as ir

fully_qualified_re = re.compile(r"(.*)\.(?:`(.*)`|(.*))")
base_typename_re = re.compile(r"(\w+)")


_clickhouse_dtypes = {
'Null': dt.Null,
'Nothing': dt.Null,
'UInt8': dt.UInt8,
'UInt16': dt.UInt16,
'UInt32': dt.UInt32,
'UInt64': dt.UInt64,
'Int8': dt.Int8,
'Int16': dt.Int16,
'Int32': dt.Int32,
'Int64': dt.Int64,
'Float32': dt.Float32,
'Float64': dt.Float64,
'String': dt.String,
'FixedString': dt.String,
'Date': dt.Date,
'DateTime': dt.Timestamp,
'DateTime64': dt.Timestamp,
'Array': dt.Array,
}
_ibis_dtypes = {v: k for k, v in _clickhouse_dtypes.items()}
_ibis_dtypes[dt.String] = 'String'
_ibis_dtypes[dt.Timestamp] = 'DateTime'


class ClickhouseDataType:

__slots__ = 'typename', 'base_typename', 'nullable'

def __init__(self, typename, nullable=False):
m = base_typename_re.match(typename)
self.base_typename = m.groups()[0]
if self.base_typename not in _clickhouse_dtypes:
raise com.UnsupportedBackendType(typename)
self.typename = self.base_typename
self.nullable = nullable

if self.base_typename == 'Array':
self.typename = typename

def __str__(self):
if self.nullable:
return f'Nullable({self.typename})'
else:
return self.typename

def __repr__(self):
return f'<Clickhouse {str(self)}>'

@classmethod
def parse(cls, spec):
# TODO(kszucs): spare parsing, depends on clickhouse-driver#22
if spec.startswith('Nullable'):
return cls(spec[9:-1], nullable=True)
else:
return cls(spec)

def to_ibis(self):
if self.base_typename != 'Array':
return _clickhouse_dtypes[self.typename](nullable=self.nullable)

sub_type = ClickhouseDataType(
self.get_subname(self.typename)
).to_ibis()
return dt.Array(value_type=sub_type)

@staticmethod
def get_subname(name: str) -> str:
lbracket_pos = name.find('(')
rbracket_pos = name.rfind(')')

if lbracket_pos == -1 or rbracket_pos == -1:
return ''

subname = name[lbracket_pos + 1 : rbracket_pos]
return subname

@staticmethod
def get_typename_from_ibis_dtype(dtype):
if not isinstance(dtype, dt.Array):
return _ibis_dtypes[type(dtype)]

return 'Array({})'.format(
ClickhouseDataType.get_typename_from_ibis_dtype(dtype.value_type)
)

@classmethod
def from_ibis(cls, dtype, nullable=None):
typename = ClickhouseDataType.get_typename_from_ibis_dtype(dtype)
if nullable is None:
nullable = dtype.nullable
return cls(typename, nullable=nullable)


@dt.dtype.register(ClickhouseDataType)
def clickhouse_to_ibis_dtype(clickhouse_dtype):
return clickhouse_dtype.to_ibis()


class ClickhouseTable(ir.TableExpr):
class ClickhouseTable(ir.Table):
"""References a physical table in Clickhouse"""

@property
def _qualified_name(self):
return self.op().args[0]

@property
def _unqualified_name(self):
return self._match_name()[1]
return self.op().name

@property
def _client(self):
return self.op().args[2]

def _match_name(self):
m = fully_qualified_re.match(self._qualified_name)
if not m:
raise com.IbisError(
'Cannot determine database name from {}'.format(
self._qualified_name
)
)
db, quoted, unquoted = m.groups()
return db, quoted or unquoted

@property
def _database(self):
return self._match_name()[0]
return self.op().source

def invalidate_metadata(self):
self._client.invalidate_metadata(self._qualified_name)
Expand Down Expand Up @@ -168,10 +48,8 @@ def insert(self, obj, **kwargs):
assert isinstance(obj, pd.DataFrame)
assert set(schema.names) >= set(obj.columns)

columns = ', '.join(map(quote_identifier, obj.columns))
query = 'INSERT INTO {table} ({columns}) VALUES'.format(
table=self._qualified_name, columns=columns
)
columns = ", ".join(map(quote_identifier, obj.columns))
query = f"INSERT INTO {self._qualified_name} ({columns}) VALUES"

# convert data columns with datetime64 pandas dtype to native date
# because clickhouse-driver 0.0.10 does arithmetic operations on it
Expand All @@ -180,5 +58,11 @@ def insert(self, obj, **kwargs):
if isinstance(schema[col], dt.Date):
obj[col] = obj[col].dt.date

data = obj.to_dict('records')
return self._client.con.execute(query, data, **kwargs)
settings = kwargs.pop("settings", {})
settings["use_numpy"] = True
return self._client.con.insert_dataframe(
query,
obj,
settings=settings,
**kwargs,
)
220 changes: 220 additions & 0 deletions ibis/backends/clickhouse/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING

import parsy as p

if TYPE_CHECKING:
from ibis.expr.datatypes import DataType

import ibis.expr.datatypes as dt


def parse(text: str) -> DataType:
@p.generate
def datetime():
yield dt.spaceless_string("datetime64", "datetime")
timezone = yield parened_string.optional()
return dt.Timestamp(timezone=timezone, nullable=False)

primitive = (
datetime
| dt.spaceless_string("null", "nothing").result(dt.null)
| dt.spaceless_string("bigint", "int64").result(
dt.Int64(nullable=False)
)
| dt.spaceless_string("double", "float64").result(
dt.Float64(nullable=False)
)
| dt.spaceless_string("float32", "float").result(
dt.Float32(nullable=False)
)
| dt.spaceless_string("smallint", "int16", "int2").result(
dt.Int16(nullable=False)
)
| dt.spaceless_string("date32", "date").result(dt.Date(nullable=False))
| dt.spaceless_string("time").result(dt.Time(nullable=False))
| dt.spaceless_string(
"tinyint", "int8", "int1", "boolean", "bool"
).result(dt.Int8(nullable=False))
| dt.spaceless_string("integer", "int32", "int4", "int").result(
dt.Int32(nullable=False)
)
| dt.spaceless_string("uint64").result(dt.UInt64(nullable=False))
| dt.spaceless_string("uint32").result(dt.UInt32(nullable=False))
| dt.spaceless_string("uint16").result(dt.UInt16(nullable=False))
| dt.spaceless_string("uint8").result(dt.UInt8(nullable=False))
| dt.spaceless_string("uuid").result(dt.UUID(nullable=False))
| dt.spaceless_string(
"longtext",
"mediumtext",
"tinytext",
"text",
"longblob",
"mediumblob",
"tinyblob",
"blob",
"varchar",
"char",
"string",
).result(dt.String(nullable=False))
)

@p.generate
def parened_string():
yield dt.LPAREN
s = yield dt.RAW_STRING
yield dt.RPAREN
return s

@p.generate
def nullable():
yield dt.spaceless_string("nullable")
yield dt.LPAREN
parsed_ty = yield ty
yield dt.RPAREN
return parsed_ty(nullable=True)

@p.generate
def fixed_string():
yield dt.spaceless_string("fixedstring")
yield dt.LPAREN
yield dt.NUMBER
yield dt.RPAREN
return dt.String(nullable=False)

@p.generate
def decimal():
yield dt.spaceless_string("decimal", "numeric")
precision, scale = yield dt.LPAREN.then(
p.seq(dt.PRECISION.skip(dt.COMMA), dt.SCALE)
).skip(dt.RPAREN)
return dt.Decimal(precision, scale, nullable=False)

@p.generate
def paren_type():
yield dt.LPAREN
value_type = yield ty
yield dt.RPAREN
return value_type

@p.generate
def array():
yield dt.spaceless_string("array")
value_type = yield paren_type
return dt.Array(value_type, nullable=False)

@p.generate
def map():
yield dt.spaceless_string("map")
yield dt.LPAREN
key_type = yield ty
yield dt.COMMA
value_type = yield ty
yield dt.RPAREN
return dt.Map(key_type, value_type, nullable=False)

at_least_one_space = p.regex(r"\s+")

@p.generate
def nested():
yield dt.spaceless_string("nested")
yield dt.LPAREN

field_names_types = yield (
p.seq(dt.SPACES.then(dt.FIELD.skip(at_least_one_space)), ty)
.combine(lambda field, ty: (field, dt.Array(ty, nullable=False)))
.sep_by(dt.COMMA)
)
yield dt.RPAREN
return dt.Struct.from_tuples(field_names_types, nullable=False)

@p.generate
def struct():
yield dt.spaceless_string("tuple")
yield dt.LPAREN
field_names_types = yield (
p.seq(
dt.SPACES.then(dt.FIELD.skip(at_least_one_space).optional()),
ty,
)
.combine(lambda field, ty: (field, ty))
.sep_by(dt.COMMA)
)
yield dt.RPAREN
return dt.Struct.from_tuples(
[
(field_name if field_name is not None else f"f{i:d}", typ)
for i, (field_name, typ) in enumerate(field_names_types)
],
nullable=False,
)

ty = (
nullable
| nested
| primitive
| fixed_string
| decimal
| array
| map
| struct
)
return ty.parse(text)


@functools.singledispatch
def serialize(ty) -> str:
raise NotImplementedError(
f"{ty} not serializable to clickhouse type string"
)


@serialize.register(dt.DataType)
def _(ty: dt.DataType) -> str:
ser_ty = serialize_raw(ty)
if ty.nullable:
return f"Nullable({ser_ty})"
return ser_ty


@functools.singledispatch
def serialize_raw(ty: dt.DataType) -> str:
raise NotImplementedError(
f"{ty} not serializable to clickhouse type string"
)


@serialize_raw.register(dt.DataType)
def _(ty: dt.DataType) -> str:
return type(ty).__name__.capitalize()


@serialize_raw.register(dt.Array)
def _(ty: dt.Array) -> str:
return f"Array({serialize(ty.value_type)})"


@serialize_raw.register(dt.Map)
def _(ty: dt.Map) -> str:
key_type = serialize(ty.key_type)
value_type = serialize(ty.value_type)
return f"Map({key_type}, {value_type})"


@serialize_raw.register(dt.Struct)
def _(ty: dt.Struct) -> str:
fields = ", ".join(
f"{name} {serialize(field_ty)}" for name, field_ty in ty.pairs.items()
)
return f"Tuple({fields})"


@serialize_raw.register(dt.Timestamp)
def _(ty: dt.Timestamp) -> str:
return (
"DateTime64(6)"
if ty.timezone is None
else f"DateTime64(6, {ty.timezone!r})"
)
235 changes: 169 additions & 66 deletions ibis/backends/clickhouse/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from datetime import date, datetime
from io import StringIO

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.util as util
from ibis.backends.base.sql.registry import binary_infix
from ibis.backends.clickhouse.datatypes import serialize
from ibis.backends.clickhouse.identifiers import quote_identifier

# TODO(kszucs): should inherit operation registry from the base compiler
Expand All @@ -19,12 +21,11 @@ def _alias(translator, expr):


def _cast(translator, expr):
from ibis.backends.clickhouse.client import ClickhouseDataType

op = expr.op()
arg, target = op.args
arg = op.arg
target = op.to
arg_ = translator.translate(arg)
type_ = str(ClickhouseDataType.from_ibis(target, nullable=False))
type_ = serialize(target)

return f'CAST({arg_!s} AS {type_!s})'

Expand All @@ -36,27 +37,20 @@ def _between(translator, expr):


def _negate(translator, expr):
arg = expr.op().args[0]
if isinstance(expr, ir.BooleanValue):
arg_ = translator.translate(arg)
return f'NOT {arg_!s}'
else:
arg_ = _parenthesize(translator, arg)
return f'-{arg_!s}'
return f"-{_parenthesize(translator, expr.op().arg)}"


def _not(translator, expr):
return 'NOT {}'.format(*map(translator.translate, expr.op().args))
return f"NOT {_parenthesize(translator, expr.op().arg)}"


def _parenthesize(translator, expr):
op = expr.op()
op_klass = type(op)

# function calls don't need parens
what_ = translator.translate(expr)
if (op_klass in _binary_infix_ops) or (op_klass in _unary_ops):
return f'({what_!s})'
if isinstance(op, (*_binary_infix_ops.keys(), *_unary_ops.keys())):
return f"({what_})"
else:
return what_

Expand Down Expand Up @@ -136,7 +130,12 @@ def _array_slice_op(translator, expr):

def _agg(func):
def formatter(translator, expr):
return _aggregate(translator, func, *expr.op().args)
op = expr.op()
where = getattr(op, "where", None)
args = tuple(
arg for arg in op.args if arg is not None and arg is not where
)
return _aggregate(translator, func, *args, where=where)

return formatter

Expand All @@ -145,35 +144,31 @@ def _agg_variance_like(func):
variants = {'sample': f'{func}Samp', 'pop': f'{func}Pop'}

def formatter(translator, expr):
arg, how, where = expr.op().args
return _aggregate(translator, variants[how], arg, where)
*args, how, where = expr.op().args
return _aggregate(translator, variants[how], *args, where=where)

return formatter


def _binary_infix_op(infix_sym):
def formatter(translator, expr):
op = expr.op()

left, right = op.args
left_ = _parenthesize(translator, left)
right_ = _parenthesize(translator, right)

return f'{left_!s} {infix_sym!s} {right_!s}'

return formatter
def _corr(translator, expr):
op = expr.op()
if op.how == "pop":
raise ValueError(
"ClickHouse only implements `sample` correlation coefficient"
)
return _aggregate(translator, "corr", op.left, op.right, where=op.where)


def _call(translator, func, *args):
args_ = ', '.join(map(translator.translate, args))
return f'{func!s}({args_!s})'


def _aggregate(translator, func, arg, where=None):
def _aggregate(translator, func, *args, where=None):
if where is not None:
return _call(translator, func + 'If', arg, where)
return _call(translator, f"{func}If", *args, where)
else:
return _call(translator, func, arg)
return _call(translator, func, *args)


def _xor(translator, expr):
Expand Down Expand Up @@ -375,12 +370,30 @@ def _literal(translator, expr):
elif isinstance(expr, ir.IntervalValue):
return _interval_format(translator, expr)
elif isinstance(expr, ir.TimestampValue):
func = "toDateTime"
args = []

if isinstance(value, datetime):
if value.microsecond != 0:
msg = 'Unsupported subsecond accuracy {}'
raise ValueError(msg.format(value))
value = value.strftime('%Y-%m-%d %H:%M:%S')
return f"toDateTime('{value!s}')"
fmt = "%Y-%m-%dT%H:%M:%S"

if micros := value.microsecond:
func = "toDateTime64"
fmt += ".%f"

args.append(value.strftime(fmt))
if micros % 1000:
args.append(6)
elif micros // 1000:
args.append(3)
else:
args.append(str(value))

if (timezone := expr.type().timezone) is not None:
args.append(timezone)

joined_args = ", ".join(map(repr, args))
return f"{func}({joined_args})"

elif isinstance(expr, ir.DateValue):
if isinstance(value, date):
value = value.strftime('%Y-%m-%d')
Expand Down Expand Up @@ -589,12 +602,21 @@ def _string_ilike(translator, expr):


def _group_concat(translator, expr):
arg, sep, where = expr.op().args
if where is not None:
arg = where.ifelse(arg, ibis.NA)
return 'arrayStringConcat(groupArray({}), {})'.format(
*map(translator.translate, (arg, sep))
)
op = expr.op()

arg = translator.translate(op.arg)
sep = translator.translate(op.sep)

translated_args = [arg]
func = "groupArray"

if (where := op.where) is not None:
func += "If"
translated_args.append(translator.translate(where))

call = f"{func}({', '.join(translated_args)})"
expr = f"arrayStringConcat({call}, {sep})"
return f"CASE WHEN empty({call}) THEN NULL ELSE {expr} END"


def _string_right(translator, expr):
Expand All @@ -604,27 +626,81 @@ def _string_right(translator, expr):
return f"substring({arg}, -({nchars}))"


def _cotangent(translator, expr):
op = expr.op()
arg = translator.translate(op.arg)
return f"cos({arg}) / sin({arg})"


def _bit_agg(func):
def compile(translator, expr):
op = expr.op()
raw_arg = op.arg
arg = translator.translate(raw_arg)
if not isinstance((type := raw_arg.type()), dt.UnsignedInteger):
nbits = type._nbytes * 8
arg = f"reinterpretAsUInt{nbits}({arg})"

if (where := op.where) is not None:
return f"{func}If({arg}, {translator.translate(where)})"
else:
return f"{func}({arg})"

return compile


def _array_column(translator, expr):
args = ", ".join(map(translator.translate, expr.op().cols))
return f"[{args}]"


def _struct_column(translator, expr):
args = ", ".join(map(translator.translate, expr.op().values))
# ClickHouse struct types cannot be nullable
# (non-nested fields can be nullable)
struct_type = serialize(expr.type()(nullable=False))
return f"CAST(({args}) AS {struct_type})"


def _clip(translator, expr):
op = expr.op()
arg = translator.translate(op.arg)

if (upper := op.upper) is not None:
arg = f"least({translator.translate(upper)}, {arg})"

if (lower := op.lower) is not None:
arg = f"greatest({translator.translate(lower)}, {arg})"

return arg


def _struct_field(translator, expr):
op = expr.op()
return f"{translator.translate(op.arg)}.`{op.field}`"


# TODO: clickhouse uses different string functions
# for ascii and utf-8 encodings,

_binary_infix_ops = {
# Binary operations
ops.Add: _binary_infix_op('+'),
ops.Subtract: _binary_infix_op('-'),
ops.Multiply: _binary_infix_op('*'),
ops.Divide: _binary_infix_op('/'),
ops.Add: binary_infix.binary_infix_op('+'),
ops.Subtract: binary_infix.binary_infix_op('-'),
ops.Multiply: binary_infix.binary_infix_op('*'),
ops.Divide: binary_infix.binary_infix_op('/'),
ops.Power: _fixed_arity('pow', 2),
ops.Modulus: _binary_infix_op('%'),
ops.Modulus: binary_infix.binary_infix_op('%'),
# Comparisons
ops.Equals: _binary_infix_op('='),
ops.NotEquals: _binary_infix_op('!='),
ops.GreaterEqual: _binary_infix_op('>='),
ops.Greater: _binary_infix_op('>'),
ops.LessEqual: _binary_infix_op('<='),
ops.Less: _binary_infix_op('<'),
ops.Equals: binary_infix.binary_infix_op('='),
ops.NotEquals: binary_infix.binary_infix_op('!='),
ops.GreaterEqual: binary_infix.binary_infix_op('>='),
ops.Greater: binary_infix.binary_infix_op('>'),
ops.LessEqual: binary_infix.binary_infix_op('<='),
ops.Less: binary_infix.binary_infix_op('<'),
# Boolean comparisons
ops.And: _binary_infix_op('AND'),
ops.Or: _binary_infix_op('OR'),
ops.And: binary_infix.binary_infix_op('AND'),
ops.Or: binary_infix.binary_infix_op('OR'),
ops.Xor: _xor,
}

Expand All @@ -649,17 +725,32 @@ def _string_right(translator, expr):
ops.Ln: _unary('log'),
ops.Log2: _unary('log2'),
ops.Log10: _unary('log10'),
ops.Acos: _unary("acos"),
ops.Asin: _unary("asin"),
ops.Atan: _unary("atan"),
ops.Atan2: _fixed_arity("atan2", 2),
ops.Cos: _unary("cos"),
ops.Cot: _cotangent,
ops.Sin: _unary("sin"),
ops.Tan: _unary("tan"),
ops.Pi: _fixed_arity("pi", 0),
ops.E: _fixed_arity("e", 0),
# Unary aggregates
ops.CMSMedian: _agg('median'),
ops.ApproxMedian: _agg('median'),
# TODO: there is also a `uniq` function which is the
# recommended way to approximate cardinality
ops.HLLCardinality: _agg('uniqHLL12'),
ops.ApproxCountDistinct: _agg('uniqHLL12'),
ops.Mean: _agg('avg'),
ops.Sum: _agg('sum'),
ops.Max: _agg('max'),
ops.Min: _agg('min'),
ops.ArrayCollect: _agg('groupArray'),
ops.StandardDev: _agg_variance_like('stddev'),
ops.Variance: _agg_variance_like('var'),
ops.Covariance: _agg_variance_like('covar'),
ops.Correlation: _corr,
ops.GroupConcat: _group_concat,
ops.Count: _agg('count'),
ops.CountDistinct: _agg('uniq'),
Expand Down Expand Up @@ -719,18 +810,18 @@ def _string_right(translator, expr):
ops.Least: _varargs('least'),
ops.Where: _fixed_arity('if', 3),
ops.Between: _between,
ops.Contains: _binary_infix_op('IN'),
ops.NotContains: _binary_infix_op('NOT IN'),
ops.SimpleCase: _simple_case,
ops.SearchedCase: _searched_case,
ops.TableColumn: _table_column,
ops.TableArrayView: _table_array_view,
ops.DateAdd: _binary_infix_op('+'),
ops.DateSub: _binary_infix_op('-'),
ops.DateDiff: _binary_infix_op('-'),
ops.TimestampAdd: _binary_infix_op('+'),
ops.TimestampSub: _binary_infix_op('-'),
ops.TimestampDiff: _binary_infix_op('-'),
ops.DateAdd: binary_infix.binary_infix_op('+'),
ops.DateSub: binary_infix.binary_infix_op('-'),
ops.DateDiff: binary_infix.binary_infix_op('-'),
ops.Contains: binary_infix.contains("IN"),
ops.NotContains: binary_infix.contains("NOT IN"),
ops.TimestampAdd: binary_infix.binary_infix_op('+'),
ops.TimestampSub: binary_infix.binary_infix_op('-'),
ops.TimestampDiff: binary_infix.binary_infix_op('-'),
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.ExistsSubquery: _exists_subquery,
ops.NotExistsSubquery: _exists_subquery,
Expand All @@ -739,6 +830,17 @@ def _string_right(translator, expr):
ops.ArrayConcat: _fixed_arity('arrayConcat', 2),
ops.ArrayRepeat: _array_repeat_op,
ops.ArraySlice: _array_slice_op,
ops.Unnest: _unary("arrayJoin"),
ops.BitAnd: _bit_agg("groupBitAnd"),
ops.BitOr: _bit_agg("groupBitOr"),
ops.BitXor: _bit_agg("groupBitXor"),
ops.Degrees: _unary("degrees"),
ops.Radians: _unary("radians"),
ops.Strftime: _fixed_arity("formatDateTime", 2),
ops.ArrayColumn: _array_column,
ops.Clip: _clip,
ops.StructField: _struct_field,
ops.StructColumn: _struct_column,
}


Expand Down Expand Up @@ -787,10 +889,11 @@ def _day_of_week_index(translator, expr):


_unsupported_ops_list = [
ops.WindowOp,
ops.Window,
ops.DecimalPrecision,
ops.DecimalScale,
ops.BaseConvert,
ops.CumeDist,
ops.CumulativeSum,
ops.CumulativeMin,
ops.CumulativeMax,
Expand Down
86 changes: 60 additions & 26 deletions ibis/backends/clickhouse/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ibis
import ibis.expr.types as ir
from ibis.backends.conftest import TEST_TABLES, read_tables
from ibis.backends.tests.base import (
BackendTest,
RoundHalfToEven,
Expand All @@ -28,37 +29,71 @@ class TestConf(UnorderedComparator, BackendTest, RoundHalfToEven):
bool_is_int = True

@staticmethod
def connect(data_directory: Path):
pytest.importorskip("clickhouse_driver")
host = os.environ.get('IBIS_TEST_CLICKHOUSE_HOST', 'localhost')
port = int(os.environ.get('IBIS_TEST_CLICKHOUSE_PORT', 9000))
user = os.environ.get('IBIS_TEST_CLICKHOUSE_USER', 'default')
password = os.environ.get('IBIS_TEST_CLICKHOUSE_PASSWORD', '')
database = os.environ.get(
'IBIS_TEST_CLICKHOUSE_DATABASE', 'ibis_testing'
)
return ibis.clickhouse.connect(
def _load_data(
data_dir: Path,
script_dir: Path,
host: str = CLICKHOUSE_HOST,
port: int = CLICKHOUSE_PORT,
user: str = CLICKHOUSE_USER,
password: str = CLICKHOUSE_PASS,
database: str = IBIS_TEST_CLICKHOUSE_DB,
**_,
) -> None:
"""Load test data into a ClickHouse backend instance.
Parameters
----------
data_dir
Location of test data
script_dir
Location of scripts defining schemas
"""
clickhouse_driver = pytest.importorskip("clickhouse_driver")

client = clickhouse_driver.Client(
host=host,
port=port,
password=password,
database=database,
user=user,
password=password,
)

client.execute(f"DROP DATABASE IF EXISTS {database}")
client.execute(f"CREATE DATABASE {database}")
client.execute(f"USE {database}")

with open(script_dir / 'schema' / 'clickhouse.sql') as schema:
for stmt in filter(None, map(str.strip, schema.read().split(";"))):
client.execute(stmt)

for table, df in read_tables(TEST_TABLES, data_dir):
query = f"INSERT INTO {table} VALUES"
client.insert_dataframe(
query,
df.to_pandas(),
settings={"use_numpy": True},
)

@staticmethod
def connect(data_directory: Path):
pytest.importorskip("clickhouse_driver")
return ibis.clickhouse.connect(
host=CLICKHOUSE_HOST,
port=CLICKHOUSE_PORT,
password=CLICKHOUSE_PASS,
database=IBIS_TEST_CLICKHOUSE_DB,
user=CLICKHOUSE_USER,
)

@staticmethod
def greatest(
f: Callable[..., ir.ValueExpr], *args: ir.ValueExpr
) -> ir.ValueExpr:
def greatest(f: Callable[..., ir.Value], *args: ir.Value) -> ir.Value:
if len(args) > 2:
raise NotImplementedError(
'Clickhouse does not support more than 2 arguments to greatest'
)
return f(*args)

@staticmethod
def least(
f: Callable[..., ir.ValueExpr], *args: ir.ValueExpr
) -> ir.ValueExpr:
def least(f: Callable[..., ir.Value], *args: ir.Value) -> ir.Value:
if len(args) > 2:
raise NotImplementedError(
'Clickhouse does not support more than 2 arguments to least'
Expand All @@ -67,14 +102,13 @@ def least(


@pytest.fixture(scope='module')
def con():
return ibis.clickhouse.connect(
host=CLICKHOUSE_HOST,
port=CLICKHOUSE_PORT,
user=CLICKHOUSE_USER,
password=CLICKHOUSE_PASS,
database=IBIS_TEST_CLICKHOUSE_DB,
)
def con(tmp_path_factory, data_directory, script_directory, worker_id):
return TestConf.load_data(
data_directory,
script_directory,
tmp_path_factory,
worker_id,
).connect(data_directory)


@pytest.fixture(scope='module')
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_run_sql(con):
table = con.sql(query)

fa = con.table('functional_alltypes')
assert isinstance(table, ir.TableExpr)
assert isinstance(table, ir.Table)
assert table.schema() == fa.schema()

expr = table.limit(10)
Expand Down Expand Up @@ -78,23 +78,23 @@ def test_sql_query_limits(alltypes):
with config.option_context('sql.default_limit', 100000):
# table has 25 rows
assert len(table.execute()) == 7300
# comply with limit arg for TableExpr
# comply with limit arg for Table
assert len(table.execute(limit=10)) == 10
# state hasn't changed
assert len(table.execute()) == 7300
# non-TableExpr ignores default_limit
# non-Table ignores default_limit
assert table.count().execute() == 7300
# non-TableExpr doesn't observe limit arg
# non-Table doesn't observe limit arg
assert table.count().execute(limit=10) == 7300
with config.option_context('sql.default_limit', 20):
# TableExpr observes default limit setting
# Table observes default limit setting
assert len(table.execute()) == 20
# explicit limit= overrides default
assert len(table.execute(limit=15)) == 15
assert len(table.execute(limit=23)) == 23
# non-TableExpr ignores default_limit
# non-Table ignores default_limit
assert table.count().execute() == 7300
# non-TableExpr doesn't observe limit arg
# non-Table doesn't observe limit arg
assert table.count().execute(limit=10) == 7300
# eliminating default_limit doesn't break anything
with config.option_context('sql.default_limit', None):
Expand Down
91 changes: 43 additions & 48 deletions ibis/backends/clickhouse/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
@pytest.mark.parametrize(
('to_type', 'expected'),
[
param('int8', 'CAST(`double_col` AS Int8)', id="int8"),
param('int16', 'CAST(`double_col` AS Int16)', id="int16"),
param('float32', 'CAST(`double_col` AS Float32)', id="float32"),
param('int8', 'CAST(`double_col` AS Nullable(Int8))', id="int8"),
param('int16', 'CAST(`double_col` AS Nullable(Int16))', id="int16"),
param(
'float32', 'CAST(`double_col` AS Nullable(Float32))', id="float32"
),
param('float', '`double_col`', id="float"),
# alltypes.double_col is non-nullable
param(
Expand All @@ -38,11 +40,22 @@ def test_cast_double_col(alltypes, translate, to_type, expected):
@pytest.mark.parametrize(
('to_type', 'expected'),
[
('int8', 'CAST(`string_col` AS Int8)'),
('int16', 'CAST(`string_col` AS Int16)'),
('int8', 'CAST(`string_col` AS Nullable(Int8))'),
('int16', 'CAST(`string_col` AS Nullable(Int16))'),
(dt.String(nullable=False), 'CAST(`string_col` AS String)'),
('timestamp', 'CAST(`string_col` AS DateTime)'),
('date', 'CAST(`string_col` AS Date)'),
('timestamp', 'CAST(`string_col` AS Nullable(DateTime64(6)))'),
('date', 'CAST(`string_col` AS Nullable(Date))'),
(
'!map<string, int64>',
'CAST(`string_col` AS Map(Nullable(String), Nullable(Int64)))',
),
(
'!struct<a: string, b: int64>',
(
'CAST(`string_col` AS '
'Tuple(a Nullable(String), b Nullable(Int64)))'
),
),
],
)
def test_cast_string_col(alltypes, translate, to_type, expected):
Expand Down Expand Up @@ -85,15 +98,13 @@ def test_timestamp_cast(alltypes, translate):
assert isinstance(result1, ir.TimestampColumn)
assert isinstance(result2, ir.TimestampColumn)

assert translate(result1) == 'CAST(`timestamp_col` AS DateTime)'
assert translate(result2) == 'CAST(`int_col` AS DateTime)'
assert translate(result1) == 'CAST(`timestamp_col` AS DateTime64(6))'
assert translate(result2) == 'CAST(`int_col` AS DateTime64(6))'


def test_timestamp_now(con, translate):
def test_timestamp_now(translate):
expr = ibis.now()
# now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
assert translate(expr) == 'now()'
# assert con.execute(expr) == now


@pytest.mark.parametrize(
Expand All @@ -107,32 +118,12 @@ def test_timestamp_now(con, translate):
('minute', '2009-05-17 12:34:00'),
],
)
def test_timestamp_truncate(con, translate, unit, expected):
def test_timestamp_truncate(con, unit, expected):
stamp = ibis.timestamp('2009-05-17 12:34:56')
expr = stamp.truncate(unit)
assert con.execute(expr) == pd.Timestamp(expected)


@pytest.mark.parametrize(
('func', 'expected'),
[
(methodcaller('year'), 2015),
(methodcaller('month'), 9),
(methodcaller('day'), 1),
(methodcaller('hour'), 14),
(methodcaller('minute'), 48),
(methodcaller('second'), 5),
],
)
def test_simple_datetime_operations(con, func, expected):
value = ibis.timestamp('2015-09-01 14:48:05.359')
with pytest.raises(ValueError):
con.execute(func(value))

value = ibis.timestamp('2015-09-01 14:48:05')
con.execute(func(value)) == expected


@pytest.mark.parametrize(('value', 'expected'), [(0, None), (5.5, 5.5)])
def test_nullifzero(con, value, expected):
result = con.execute(L(value).nullifzero())
Expand Down Expand Up @@ -268,7 +259,7 @@ def test_string_contains(con, op, value, expected):


# TODO: clickhouse-driver escaping bug
def test_re_replace(con, translate):
def test_re_replace(con):
expr1 = L('Hello, World!').re_replace('.', '\\\\0\\\\0')
expr2 = L('Hello, World!').re_replace('^', 'here: ')

Expand All @@ -280,7 +271,7 @@ def test_re_replace(con, translate):
('value', 'expected'),
[(L('a'), 0), (L('b'), 1), (L('d'), -1)], # TODO: what's the expected?
)
def test_find_in_set(con, value, expected, translate):
def test_find_in_set(con, value, expected):
vals = list('abc')
expr = value.find_in_set(vals)
assert con.execute(expr) == expected
Expand Down Expand Up @@ -312,12 +303,12 @@ def test_string_column_find_in_set(con, alltypes, translate):
),
],
)
def test_parse_url(con, translate, url, extract, expected):
def test_parse_url(con, url, extract, expected):
expr = url.parse_url(extract)
assert con.execute(expr) == expected


def test_parse_url_query_parameter(con, translate):
def test_parse_url_query_parameter(con):
url = L('https://www.youtube.com/watch?v=kEuEcWfewf8&t=10')
expr = url.parse_url('QUERY', 't')
assert con.execute(expr) == '10'
Expand Down Expand Up @@ -427,7 +418,7 @@ def test_translate_math_functions(con, alltypes, translate, call, expected):
),
],
)
def test_math_functions(con, expr, expected, translate):
def test_math_functions(con, expr, expected):
assert con.execute(expr) == expected


Expand Down Expand Up @@ -476,7 +467,7 @@ def test_regexp(con, expr, expected):
# (L('abcd').re_extract('abcd', 3), None),
],
)
def test_regexp_extract(con, expr, expected, translate):
def test_regexp_extract(con, expr, expected):
assert con.execute(expr) == expected


Expand All @@ -496,14 +487,14 @@ def test_column_regexp_replace(con, alltypes, translate):
assert len(con.execute(expr))


def test_numeric_builtins_work(con, alltypes, df, translate):
def test_numeric_builtins_work(alltypes, df):
expr = alltypes.double_col
result = expr.execute()
expected = df.double_col.fillna(0)
tm.assert_series_equal(result, expected)


def test_null_column(alltypes, translate):
def test_null_column(alltypes):
t = alltypes
nrows = t.count().execute()
expr = t.mutate(na_column=ibis.NA).na_column
Expand Down Expand Up @@ -546,16 +537,20 @@ def test_count_distinct_with_filter(alltypes):
@pytest.mark.parametrize(
('sep', 'where_case', 'expected'),
[
(',', None, "arrayStringConcat(groupArray(`string_col`), ',')"),
('-', None, "arrayStringConcat(groupArray(`string_col`), '-')"),
(
',',
None,
"CASE WHEN empty(groupArray(`string_col`)) THEN NULL ELSE arrayStringConcat(groupArray(`string_col`), ',') END", # noqa: E501
),
(
'-',
None,
"CASE WHEN empty(groupArray(`string_col`)) THEN NULL ELSE arrayStringConcat(groupArray(`string_col`), '-') END", # noqa: E501
),
pytest.param(
',',
0,
(
"arrayStringConcat(groupArray("
"CASE WHEN `bool_col` = 0 THEN "
"`string_col` ELSE Null END), ',')"
),
"CASE WHEN empty(groupArrayIf(`string_col`, `bool_col` = 0)) THEN NULL ELSE arrayStringConcat(groupArrayIf(`string_col`, `bool_col` = 0), ',') END", # noqa: E501
),
],
)
Expand Down
33 changes: 32 additions & 1 deletion ibis/backends/clickhouse/tests/test_literals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from pandas import Timestamp
from pytest import param

import ibis
from ibis import literal as L
Expand All @@ -16,12 +17,42 @@
],
)
def test_timestamp_literals(con, translate, expr):
expected = "toDateTime('2015-01-01 12:34:56')"
expected = "toDateTime('2015-01-01T12:34:56')"

assert translate(expr) == expected
assert con.execute(expr) == Timestamp('2015-01-01 12:34:56')


@pytest.mark.parametrize(
("expr", "expected"),
[
param(
ibis.timestamp('2015-01-01 12:34:56.789'),
"toDateTime64('2015-01-01T12:34:56.789000', 3)",
id="millis",
),
param(
ibis.timestamp('2015-01-01 12:34:56.789321'),
"toDateTime64('2015-01-01T12:34:56.789321', 6)",
id="micros",
),
param(
ibis.timestamp('2015-01-01 12:34:56.789 UTC'),
"toDateTime64('2015-01-01T12:34:56.789000', 3, 'UTC')",
id="millis_tz",
),
param(
ibis.timestamp('2015-01-01 12:34:56.789321 UTC'),
"toDateTime64('2015-01-01T12:34:56.789321', 6, 'UTC')",
id="micros_tz",
),
],
)
def test_subsecond_timestamp_literals(con, translate, expr, expected):
assert translate(expr) == expected
assert con.execute(expr) == expr.op().value


@pytest.mark.parametrize(
('value', 'expected'),
[
Expand Down
25 changes: 2 additions & 23 deletions ibis/backends/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,6 @@ def test_string_temporal_compare(con, op, left, right, type):
assert result == expected


@pytest.mark.parametrize(
('func', 'left', 'right', 'expected'),
[
(operator.add, L(3), L(4), 7),
(operator.sub, L(3), L(4), -1),
(operator.mul, L(3), L(4), 12),
(operator.truediv, L(12), L(4), 3),
(operator.pow, L(12), L(2), 144),
(operator.mod, L(12), L(5), 2),
(operator.truediv, L(7), L(2), 3.5),
(operator.floordiv, L(7), L(2), 3),
(lambda x, y: x.floordiv(y), L(7), 2, 3),
(lambda x, y: x.rfloordiv(y), L(2), 7, 3),
],
)
def test_binary_arithmetic(con, func, left, right, expected):
expr = func(left, right)
result = con.execute(expr)
assert result == expected


@pytest.mark.parametrize(
('op', 'expected'),
[
Expand Down Expand Up @@ -312,10 +291,10 @@ def test_array_index(con, arr, ids):
],
)
def test_array_concat(con, arrays):
expr = L([]).cast(dt.Array(dt.int8))
expr = L([]).cast("!array<int8>")
expected = sum(arrays, [])
for arr in arrays:
expr += L(arr)
expr += L(arr, type="!array<int8>")

assert con.execute(expr) == expected

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_complex_array_expr_projection(db, alltypes):

query = ibis.clickhouse.compile(expr2)
name = expr2.get_name()
expected = f"""SELECT CAST(`string_col` AS Float64) AS `{name}`
expected = f"""SELECT CAST(`string_col` AS Nullable(Float64)) AS `{name}`
FROM (
SELECT `string_col`, count(*) AS `count`
FROM {db.name}.`functional_alltypes`
Expand Down
145 changes: 112 additions & 33 deletions ibis/backends/clickhouse/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from pkg_resources import parse_version

import ibis.expr.datatypes as dt
from ibis.backends.clickhouse.client import ClickhouseDataType
from ibis.backends.clickhouse.datatypes import parse

pytest.importorskip("clickhouse_driver")

Expand All @@ -19,48 +18,128 @@ def test_column_types(alltypes):


def test_columns_types_with_additional_argument(con):
sql_types = ["toFixedString('foo', 8) AS fixedstring_col"]
if parse_version(con.version).base_version >= '1.1.54337':
sql_types.append(
"toDateTime('2018-07-02 00:00:00', 'UTC') AS datetime_col"
)
sql = 'SELECT {}'.format(', '.join(sql_types))
df = con.sql(sql).execute()
sql_types = [
"toFixedString('foo', 8) AS fixedstring_col",
"toDateTime('2018-07-02 00:00:00', 'UTC') AS datetime_col",
]
df = con.sql(f"SELECT {', '.join(sql_types)}").execute()
assert df.fixedstring_col.dtype.name == 'object'
if parse_version(con.version).base_version >= '1.1.54337':
assert df.datetime_col.dtype.name == 'datetime64[ns]'
assert df.datetime_col.dtype.name == 'datetime64[ns, UTC]'


@pytest.mark.parametrize(
('ch_type', 'ibis_type'),
[
('Array(Int8)', dt.Array(dt.Int8(nullable=False))),
('Array(Int16)', dt.Array(dt.Int16(nullable=False))),
('Array(Int32)', dt.Array(dt.Int32(nullable=False))),
('Array(Int64)', dt.Array(dt.Int64(nullable=False))),
('Array(UInt8)', dt.Array(dt.UInt8(nullable=False))),
('Array(UInt16)', dt.Array(dt.UInt16(nullable=False))),
('Array(UInt32)', dt.Array(dt.UInt32(nullable=False))),
('Array(UInt64)', dt.Array(dt.UInt64(nullable=False))),
('Array(Float32)', dt.Array(dt.Float32(nullable=False))),
('Array(Float64)', dt.Array(dt.Float64(nullable=False))),
('Array(String)', dt.Array(dt.String(nullable=False))),
('Array(FixedString(32))', dt.Array(dt.String(nullable=False))),
('Array(Date)', dt.Array(dt.Date(nullable=False))),
('Array(DateTime)', dt.Array(dt.Timestamp(nullable=False))),
('Array(DateTime64)', dt.Array(dt.Timestamp(nullable=False))),
('Array(Nothing)', dt.Array(dt.Null(nullable=False))),
('Array(Null)', dt.Array(dt.Null(nullable=False))),
('Array(Array(Int8))', dt.Array(dt.Array(dt.Int8(nullable=False)))),
('Array(Int8)', dt.Array(dt.Int8(nullable=False), nullable=False)),
('Array(Int16)', dt.Array(dt.Int16(nullable=False), nullable=False)),
('Array(Int32)', dt.Array(dt.Int32(nullable=False), nullable=False)),
('Array(Int64)', dt.Array(dt.Int64(nullable=False), nullable=False)),
('Array(UInt8)', dt.Array(dt.UInt8(nullable=False), nullable=False)),
('Array(UInt16)', dt.Array(dt.UInt16(nullable=False), nullable=False)),
('Array(UInt32)', dt.Array(dt.UInt32(nullable=False), nullable=False)),
('Array(UInt64)', dt.Array(dt.UInt64(nullable=False), nullable=False)),
(
'Array(Float32)',
dt.Array(dt.Float32(nullable=False), nullable=False),
),
(
'Array(Float64)',
dt.Array(dt.Float64(nullable=False), nullable=False),
),
('Array(String)', dt.Array(dt.String(nullable=False), nullable=False)),
(
'Array(FixedString(32))',
dt.Array(dt.String(nullable=False), nullable=False),
),
('Array(Date)', dt.Array(dt.Date(nullable=False), nullable=False)),
(
'Array(DateTime)',
dt.Array(dt.Timestamp(nullable=False), nullable=False),
),
(
'Array(DateTime64)',
dt.Array(dt.Timestamp(nullable=False), nullable=False),
),
('Array(Nothing)', dt.Array(dt.null, nullable=False)),
('Array(Null)', dt.Array(dt.null, nullable=False)),
(
'Array(Array(Int8))',
dt.Array(
dt.Array(dt.Int8(nullable=False), nullable=False),
nullable=False,
),
),
(
'Array(Array(Array(Int8)))',
dt.Array(dt.Array(dt.Array(dt.Int8(nullable=False)))),
dt.Array(
dt.Array(
dt.Array(dt.Int8(nullable=False), nullable=False),
nullable=False,
),
nullable=False,
),
),
(
'Array(Array(Array(Array(Int8))))',
dt.Array(dt.Array(dt.Array(dt.Array(dt.Int8(nullable=False))))),
dt.Array(
dt.Array(
dt.Array(
dt.Array(dt.Int8(nullable=False), nullable=False),
nullable=False,
),
nullable=False,
),
nullable=False,
),
),
(
"Map(Nullable(String), Nullable(UInt64))",
dt.Map(dt.string, dt.uint64, nullable=False),
),
("Decimal(10, 3)", dt.Decimal(10, 3, nullable=False)),
(
"Tuple(a String, b Array(Nullable(Float64)))",
dt.Struct.from_dict(
dict(
a=dt.String(nullable=False),
b=dt.Array(dt.float64, nullable=False),
),
nullable=False,
),
),
(
"Tuple(String, Array(Nullable(Float64)))",
dt.Struct.from_dict(
dict(
f0=dt.String(nullable=False),
f1=dt.Array(dt.float64, nullable=False),
),
nullable=False,
),
),
(
"Tuple(a String, Array(Nullable(Float64)))",
dt.Struct.from_dict(
dict(
a=dt.String(nullable=False),
f1=dt.Array(dt.float64, nullable=False),
),
nullable=False,
),
),
(
"Nested(a String, b Array(Nullable(Float64)))",
dt.Struct.from_dict(
dict(
a=dt.Array(dt.String(nullable=False), nullable=False),
b=dt.Array(
dt.Array(dt.float64, nullable=False), nullable=False
),
),
nullable=False,
),
),
],
)
def test_array_type(ch_type, ibis_type):
assert ClickhouseDataType(ch_type).to_ibis() == ibis_type
def test_parse_type(ch_type, ibis_type):
assert parse(ch_type) == ibis_type
300 changes: 256 additions & 44 deletions ibis/backends/conftest.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def execute(
limit: str = 'default',
**kwargs,
):
if limit != 'default':
if limit != 'default' and limit is not None:
raise ValueError(
'limit parameter to execute is not yet implemented in the '
'dask backend'
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
First, at the beginning of the main execution loop, ``compute_time_context`` is
called. This function computes time contexts, and pass them to all children of
the current node. These time contexts could be used in later steps to get data.
This is essential for time series TableExpr, and related operations that adjust
This is essential for time series Table, and related operations that adjust
time context, such as window, asof_join, etc.
By default, this function simply pass the unchanged time context to all
Expand Down Expand Up @@ -85,7 +85,7 @@
data in the result for a few reasons, one of which is that it would break the
contract of window functions: given N rows of input there are N rows of output.
Defining a ``post_execute`` rule for :class:`~ibis.expr.operations.WindowOp`
Defining a ``post_execute`` rule for :class:`~ibis.expr.operations.Window`
allows you to encode such logic. One might want to implement this using
:class:`~ibis.expr.operations.ScalarParameter`, in which case the ``scope``
passed to ``post_execute`` would be the bound values passed in at the time the
Expand Down
42 changes: 28 additions & 14 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
execute_series_notnnull,
execute_sort_key_series_bool,
execute_table_column_df_or_df_groupby,
execute_zero_if_null_series,
)

# Many dask and pandas functions are functionally equivalent, so we just add
Expand Down Expand Up @@ -125,13 +126,19 @@
],
ops.Contains: [
(
(dd.Series, (collections.abc.Sequence, collections.abc.Set)),
(
dd.Series,
(collections.abc.Sequence, collections.abc.Set, dd.Series),
),
execute_node_contains_series_sequence,
)
],
ops.NotContains: [
(
(dd.Series, (collections.abc.Sequence, collections.abc.Set)),
(
dd.Series,
(collections.abc.Sequence, collections.abc.Set, dd.Series),
),
execute_node_not_contains_series_sequence,
)
],
Expand All @@ -144,6 +151,13 @@
((dd.Series, simple_types), execute_node_nullif_series_scalar),
],
ops.Distinct: [((dd.DataFrame,), execute_distinct_dataframe)],
ops.ZeroIfNull: [
((dd.Series,), execute_zero_if_null_series),
(
(type(None), type(pd.NA), np.floating, float),
execute_zero_if_null_series,
),
],
}

register_types_to_dispatcher(execute_node, DASK_DISPATCH_TYPES)
Expand Down Expand Up @@ -279,16 +293,16 @@ def execute_not_scalar_or_series(op, data, **kwargs):
return ~data


@execute_node.register(ops.BinaryOp, dd.Series, dd.Series)
@execute_node.register(ops.BinaryOp, dd.Series, dd.core.Scalar)
@execute_node.register(ops.BinaryOp, dd.core.Scalar, dd.Series)
@execute_node.register(ops.Binary, dd.Series, dd.Series)
@execute_node.register(ops.Binary, dd.Series, dd.core.Scalar)
@execute_node.register(ops.Binary, dd.core.Scalar, dd.Series)
@execute_node.register(
(ops.NumericBinaryOp, ops.LogicalBinaryOp, ops.Comparison),
(ops.NumericBinary, ops.LogicalBinary, ops.Comparison),
numeric_types,
dd.Series,
)
@execute_node.register(
(ops.NumericBinaryOp, ops.LogicalBinaryOp, ops.Comparison),
(ops.NumericBinary, ops.LogicalBinary, ops.Comparison),
dd.Series,
numeric_types,
)
Expand All @@ -308,7 +322,7 @@ def execute_binary_op(op, left, right, **kwargs):
return operation(left, right)


@execute_node.register(ops.BinaryOp, ddgb.SeriesGroupBy, ddgb.SeriesGroupBy)
@execute_node.register(ops.Binary, ddgb.SeriesGroupBy, ddgb.SeriesGroupBy)
def execute_binary_op_series_group_by(op, left, right, **kwargs):
if left.index != right.index:
raise ValueError(
Expand All @@ -321,19 +335,19 @@ def execute_binary_op_series_group_by(op, left, right, **kwargs):
return result.groupby(left.index)


@execute_node.register(ops.BinaryOp, ddgb.SeriesGroupBy, simple_types)
@execute_node.register(ops.Binary, ddgb.SeriesGroupBy, simple_types)
def execute_binary_op_series_gb_simple(op, left, right, **kwargs):
result = execute_binary_op(op, make_selected_obj(left), right, **kwargs)
return result.groupby(left.index)


@execute_node.register(ops.BinaryOp, simple_types, ddgb.SeriesGroupBy)
@execute_node.register(ops.Binary, simple_types, ddgb.SeriesGroupBy)
def execute_binary_op_simple_series_gb(op, left, right, **kwargs):
result = execute_binary_op(op, left, make_selected_obj(right), **kwargs)
return result.groupby(right.index)


@execute_node.register(ops.UnaryOp, ddgb.SeriesGroupBy)
@execute_node.register(ops.Unary, ddgb.SeriesGroupBy)
def execute_unary_op_series_gb(op, operand, **kwargs):
result = execute_node(op, make_selected_obj(operand), **kwargs)
return result.groupby(operand.index)
Expand Down Expand Up @@ -385,14 +399,14 @@ def execute_node_nullif_scalar_series(op, value, series, **kwargs):
return dd.from_array(da.where(series.eq(value).values, np.nan, value))


def wrap_case_result(raw: np.ndarray, expr: ir.ValueExpr):
def wrap_case_result(raw: np.ndarray, expr: ir.Value):
"""Wrap a CASE statement result in a Series and handle returning scalars.
Parameters
----------
raw : ndarray[T]
The raw results of executing the ``CASE`` expression
expr : ValueExpr
expr : Value
The expression from the which `raw` was computed
Returns
Expand All @@ -407,7 +421,7 @@ def wrap_case_result(raw: np.ndarray, expr: ir.ValueExpr):
raw_1d.astype(constants.IBIS_TYPE_TO_PANDAS_TYPE[expr.type()])
)
# TODO - we force computation here
if isinstance(expr, ir.ScalarExpr) and result.size.compute() == 1:
if isinstance(expr, ir.Scalar) and result.size.compute() == 1:
return result.head().item()
return result

Expand Down
33 changes: 30 additions & 3 deletions ibis/backends/dask/execution/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


@execute_node.register(ops.Negate, dd.Series)
def execute_series_negate(op, data, **kwargs):
return data.mul(-1)
def execute_series_negate(_, data, **kwargs):
return -data


@execute_node.register(ops.Negate, ddgb.SeriesGroupBy)
Expand All @@ -34,12 +34,39 @@ def call_numpy_ufunc(func, op, data, **kwargs):
return func(data)


@execute_node.register(ops.UnaryOp, dd.Series)
@execute_node.register(ops.Unary, dd.Series)
def execute_series_unary_op(op, data, **kwargs):
function = getattr(np, type(op).__name__.lower())
return call_numpy_ufunc(function, op, data, **kwargs)


@execute_node.register(ops.Acos, dd.Series)
def execute_series_acos(_, data, **kwargs):
return np.arccos(data)


@execute_node.register(ops.Asin, dd.Series)
def execute_series_asin(_, data, **kwargs):
return np.arcsin(data)


@execute_node.register(ops.Atan, dd.Series)
def execute_series_atan(_, data, **kwargs):
return np.arctan(data)


@execute_node.register(ops.Cot, dd.Series)
def execute_series_cot(_, data, **kwargs):
return np.cos(data) / np.sin(data)


@execute_node.register(ops.Atan2, dd.Series, dd.Series)
@execute_node.register(ops.Atan2, numeric_types, dd.Series)
@execute_node.register(ops.Atan2, dd.Series, numeric_types)
def execute_series_atan2(_, y, x, **kwargs):
return np.arctan2(y, x)


@execute_node.register((ops.Ceil, ops.Floor), dd.Series)
def execute_series_ceil(op, data, **kwargs):
return_type = np.object_ if data.dtype == np.object_ else np.int64
Expand Down
10 changes: 7 additions & 3 deletions ibis/backends/dask/execution/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def execute_reduction_series_mask(op, data, mask, aggcontext=None, **kwargs):


@execute_node.register(
(ops.CountDistinct, ops.HLLCardinality), ddgb.SeriesGroupBy, type(None)
(ops.CountDistinct, ops.ApproxCountDistinct),
ddgb.SeriesGroupBy,
type(None),
)
def execute_count_distinct_series_groupby(
op, data, _, aggcontext=None, **kwargs
Expand All @@ -103,7 +105,7 @@ def execute_count_distinct_series_groupby(


@execute_node.register(
(ops.CountDistinct, ops.HLLCardinality),
(ops.CountDistinct, ops.ApproxCountDistinct),
ddgb.SeriesGroupBy,
ddgb.SeriesGroupBy,
)
Expand All @@ -115,7 +117,9 @@ def execute_count_distinct_series_groupby_mask(


@execute_node.register(
(ops.CountDistinct, ops.HLLCardinality), dd.Series, (dd.Series, type(None))
(ops.CountDistinct, ops.ApproxCountDistinct),
dd.Series,
(dd.Series, type(None)),
)
def execute_count_distinct_series_mask(
op, data, mask, aggcontext=None, **kwargs
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/dask/execution/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ibis.expr.typing import TimeContext


@compute_projection.register(ir.ScalarExpr, ops.Selection, dd.DataFrame)
@compute_projection.register(ir.Scalar, ops.Selection, dd.DataFrame)
def compute_projection_scalar_expr(
expr,
parent,
Expand Down Expand Up @@ -66,7 +66,7 @@ def compute_projection_scalar_expr(
return data.assign(**{name: scalar})[name]


@compute_projection.register(ir.ColumnExpr, ops.Selection, dd.DataFrame)
@compute_projection.register(ir.Column, ops.Selection, dd.DataFrame)
def compute_projection_column_expr(
expr,
parent,
Expand Down Expand Up @@ -113,7 +113,7 @@ def compute_projection_column_expr(
return result


compute_projection.register(ir.TableExpr, ops.Selection, dd.DataFrame)(
compute_projection.register(ir.Table, ops.Selection, dd.DataFrame)(
compute_projection_table_expr
)

Expand Down Expand Up @@ -245,7 +245,7 @@ def _compute_predicates(
Parameters
----------
table_op : TableNode
predicates : List[ir.ColumnExpr]
predicates : List[ir.Column]
data : pd.DataFrame
scope : Scope
timecontext: Optional[TimeContext]
Expand Down
35 changes: 27 additions & 8 deletions ibis/backends/dask/execution/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dask.dataframe as dd
import dask.delayed
import numpy as np
import pandas as pd
from dask.dataframe.groupby import SeriesGroupBy

Expand Down Expand Up @@ -37,10 +38,26 @@ def register_types_to_dispatcher(
dispatcher.register(ibis_op, *types_to_register)(fn)


def make_meta_series(dtype, name=None, index_name=None):
def make_meta_series(
dtype: np.dtype,
name: Optional[str] = None,
meta_index: Optional[pd.Index] = None,
):
if isinstance(meta_index, pd.MultiIndex):
index_names = meta_index.names
series_index = pd.MultiIndex(
levels=[[]] * len(index_names),
codes=[[]] * len(index_names),
names=index_names,
)
elif isinstance(meta_index, pd.Index):
series_index = pd.Index([], name=meta_index.name)
else:
series_index = pd.Index([])

return pd.Series(
[],
index=pd.Index([], name=index_name),
index=series_index,
dtype=dtype,
name=name,
)
Expand Down Expand Up @@ -123,7 +140,7 @@ def coerce_to_output(
elif isinstance(result, (pd.Series, dd.Series)):
# Series from https://github.com/ibis-project/ibis/issues/2711
return result.rename(result_name)
elif isinstance(expr, ir.ScalarExpr):
elif isinstance(expr, ir.Scalar):
if isinstance(result, dd.core.Scalar):
# wrap the scalar in a series
out_dtype = _pandas_dtype_from_dd_scalar(result)
Expand Down Expand Up @@ -209,15 +226,17 @@ def _coerce_to_dataframe(
# NOTE - We add a detailed meta here so we do not drop the key index
# downstream. This seems to be fixed in versions of dask > 2020.12.0
dtypes = map(ibis_dtype_to_pandas, types)

series = [
data.apply(
_select_item_in_iter,
selection=i,
meta=make_meta_series(dtype, index_name=data.index.name),
meta=make_meta_series(
dtype, meta_index=data._meta_nonempty.index
),
)
for i, dtype in enumerate(dtypes)
]

result = dd.concat(series, axis=1)

elif isinstance(data, (tuple, list)):
Expand Down Expand Up @@ -254,14 +273,14 @@ def safe_concat(dfs: List[Union[dd.Series, dd.DataFrame]]) -> dd.DataFrame:
operate on objects with equal lengths, otherwise it will raise a
ValueError in `concat_and_check`.
See https://github.com/dask/dask/blob/2c2e837674895cafdb0612be81250ef2657d947e/dask/dataframe/multi.py#L907 # noqa
See https://github.com/dask/dask/blob/2c2e837674895cafdb0612be81250ef2657d947e/dask/dataframe/multi.py#L907.
Note - Repeatedly joining dataframes is likely to be quite slow, but this
should be hit rarely in real usage. A situtation that triggeres this slow
path is aggregations where aggregations return different numbers of rows
(see `test_aggregation_group_by` for a specific example).
TODO - performance.
"""
""" # noqa: E501
if len(dfs) == 1:
maybe_df = dfs[0]
if isinstance(maybe_df, dd.Series):
Expand Down Expand Up @@ -447,7 +466,7 @@ def is_row_order_preserving(exprs) -> bool:
"""

def _is_row_order_preserving(expr: ir.Expr):
if isinstance(expr.op(), (ops.Reduction, ops.WindowOp)):
if isinstance(expr.op(), (ops.Reduction, ops.Window)):
return (lin.halt, False)
else:
return (lin.proceed, True)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/execution/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _post_process_empty(
return parent.apply(lambda row: result, meta=(None, 'object'))


@execute_node.register(ops.WindowOp, dd.Series, win.Window)
@execute_node.register(ops.Window, dd.Series, win.Window)
def execute_window_op(
op,
data,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/dask/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def npartitions():


class TestConf(PandasTest):
supports_structs = False

@staticmethod
def connect(data_directory: Path):
# Note - we use `dd.from_pandas(pd.read_csv(...))` instead of
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/tests/execution/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from pkg_resources import parse_version
from packaging.version import parse as parse_version
from pytest import param

import ibis
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/dask/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def quantiles(series, *, quantiles):
def test_udf(t, df):
expr = my_string_length(t.a)

assert isinstance(expr, ir.ColumnExpr)
assert isinstance(expr, ir.Column)

result = expr.execute()
expected = df.a.str.len().mul(2).compute()
Expand All @@ -173,7 +173,7 @@ def test_udf(t, df):
def test_multiple_argument_udf(con, t, df):
expr = my_add(t.b, t.c)

assert isinstance(expr, ir.ColumnExpr)
assert isinstance(expr, ir.Column)
assert isinstance(expr, ir.NumericColumn)
assert isinstance(expr, ir.FloatingColumn)

Expand All @@ -185,8 +185,8 @@ def test_multiple_argument_udf(con, t, df):
def test_multiple_argument_udf_group_by(con, t, df):
expr = t.groupby(t.key).aggregate(my_add=my_add(t.b, t.c).sum())

assert isinstance(expr, ir.TableExpr)
assert isinstance(expr.my_add, ir.ColumnExpr)
assert isinstance(expr, ir.Table)
assert isinstance(expr.my_add, ir.Column)
assert isinstance(expr.my_add, ir.NumericColumn)
assert isinstance(expr.my_add, ir.FloatingColumn)

Expand All @@ -200,7 +200,7 @@ def test_multiple_argument_udf_group_by(con, t, df):
def test_udaf(con, t, df):
expr = my_string_length_sum(t.a)

assert isinstance(expr, ir.ScalarExpr)
assert isinstance(expr, ir.Scalar)

result = expr.execute()
expected = t.a.execute().str.len().mul(2).sum()
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_udaf_elementwise_tzcol(con, t_timestamp, df_timestamp):
def test_udaf_analytic(con, t, df):
expr = zscore(t.c)

assert isinstance(expr, ir.ColumnExpr)
assert isinstance(expr, ir.Column)

result = expr.execute()

Expand All @@ -242,7 +242,7 @@ def f(s):
def test_udaf_analytic_groupby(con, t, df):
expr = zscore(t.c).over(ibis.window(group_by=t.key))

assert isinstance(expr, ir.ColumnExpr)
assert isinstance(expr, ir.Column)

result = expr.execute()

Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/dask/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import traceback
from datetime import datetime

import ibis
from ibis.backends.pandas.dispatcher import TwoLevelDispatcher
from ibis.config import options
from ibis.expr import types as ir
Expand Down Expand Up @@ -69,6 +70,10 @@ def add_one(v):

def enable():
"""Enable tracing."""
if options.dask is None:
# dask options haven't been registered yet - force module __getattr__
ibis.dask

options.dask.enable_trace = True
logging.getLogger('ibis.dask.trace').setLevel(logging.DEBUG)

Expand Down
Loading