336 changes: 336 additions & 0 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
import functools

import sqlalchemy as sa
import sqlalchemy.sql as sql

import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.compiler import (
Compiler,
Select,
SelectBuilder,
TableSetFormatter,
Union,
)

from .database import AlchemyTable
from .datatypes import to_sqla_type
from .translator import AlchemyContext, AlchemyExprTranslator


class _AlchemyTableSetFormatter(TableSetFormatter):
def get_result(self):
# Got to unravel the join stack; the nesting order could be
# arbitrary, so we do a depth first search and push the join tokens
# and predicates onto a flat list, then format them
op = self.expr.op()

if isinstance(op, ops.Join):
self._walk_join_tree(op)
else:
self.join_tables.append(self._format_table(self.expr))

result = self.join_tables[0]
for jtype, table, preds in zip(
self.join_types, self.join_tables[1:], self.join_predicates
):
if len(preds):
sqla_preds = [self._translate(pred) for pred in preds]
onclause = functools.reduce(sql.and_, sqla_preds)
else:
onclause = None

if jtype in (ops.InnerJoin, ops.CrossJoin):
result = result.join(table, onclause)
elif jtype is ops.LeftJoin:
result = result.join(table, onclause, isouter=True)
elif jtype is ops.RightJoin:
result = table.join(result, onclause, isouter=True)
elif jtype is ops.OuterJoin:
result = result.outerjoin(table, onclause, full=True)
elif jtype is ops.LeftSemiJoin:
result = sa.select([result]).where(
sa.exists(sa.select([1]).where(onclause))
)
elif jtype is ops.LeftAntiJoin:
result = sa.select([result]).where(
~(sa.exists(sa.select([1]).where(onclause)))
)
else:
raise NotImplementedError(jtype)

return result

def _get_join_type(self, op):
return type(op)

def _format_table(self, expr):
ctx = self.context
ref_expr = expr
op = ref_op = expr.op()

if isinstance(op, ops.SelfReference):
ref_expr = op.table
ref_op = ref_expr.op()

alias = ctx.get_ref(expr)

if isinstance(ref_op, AlchemyTable):
result = ref_op.sqla_table
elif isinstance(ref_op, ops.UnboundTable):
# use SQLAlchemy's TableClause and ColumnClause for unbound tables
schema = ref_op.schema
result = sa.table(
ref_op.name if ref_op.name is not None else ctx.get_ref(expr),
*(
sa.column(n, to_sqla_type(t))
for n, t in zip(schema.names, schema.types)
),
)
else:
# A subquery
if ctx.is_extracted(ref_expr):
# Was put elsewhere, e.g. WITH block, we just need to grab
# its alias
alias = ctx.get_ref(expr)

# hack
if isinstance(op, ops.SelfReference):
table = ctx.get_table(ref_expr)
self_ref = table.alias(alias)
ctx.set_table(expr, self_ref)
return self_ref
else:
return ctx.get_table(expr)

result = ctx.get_compiled_expr(expr)
alias = ctx.get_ref(expr)

result = result.alias(alias)
ctx.set_table(expr, result)
return result


def _can_lower_sort_column(table_set, expr):
# TODO(wesm): This code is pending removal through cleaner internal
# semantics

# we can currently sort by just-appeared aggregate metrics, but the way
# these are references in the expression DSL is as a SortBy (blocking
# table operation) on an aggregation. There's a hack in _collect_SortBy
# in the generic SQL compiler that "fuses" the sort with the
# aggregation so they appear in same query. It's generally for
# cosmetics and doesn't really affect query semantics.
bases = ops.find_all_base_tables(expr)
if len(bases) > 1:
return False

base = list(bases.values())[0]
base_op = base.op()

if isinstance(base_op, ops.Aggregation):
return base_op.table.equals(table_set)
elif isinstance(base_op, ops.Selection):
return base.equals(table_set)
else:
return False


class AlchemySelect(Select):
def __init__(self, *args, **kwargs):
self.exists = kwargs.pop('exists', False)
super().__init__(*args, **kwargs)

def compile(self):
# Can't tell if this is a hack or not. Revisit later
self.context.set_query(self)

self._compile_subqueries()

frag = self._compile_table_set()
steps = [
self._add_select,
self._add_groupby,
self._add_where,
self._add_order_by,
self._add_limit,
]

for step in steps:
frag = step(frag)

return frag

def _compile_subqueries(self):
if not self.subqueries:
return

for expr in self.subqueries:
result = self.context.get_compiled_expr(expr)
alias = self.context.get_ref(expr)
result = result.cte(alias)
self.context.set_table(expr, result)

def _compile_table_set(self):
if self.table_set is not None:
helper = _AlchemyTableSetFormatter(self, self.table_set)
result = helper.get_result()
if isinstance(result, sql.selectable.Select) and hasattr(
result, 'subquery'
):
return result.subquery()
return result
else:
return None

def _add_select(self, table_set):
to_select = []

has_select_star = False
for expr in self.select_set:
if isinstance(expr, ir.ValueExpr):
arg = self._translate(expr, named=True)
elif isinstance(expr, ir.TableExpr):
if expr.equals(self.table_set):
cached_table = self.context.get_table(expr)
if cached_table is None:
# the select * case from materialized join
has_select_star = True
continue
else:
arg = table_set
else:
arg = self.context.get_table(expr)
if arg is None:
raise ValueError(expr)

to_select.append(arg)

if has_select_star:
if table_set is None:
raise ValueError('table_set cannot be None here')

clauses = [table_set] + to_select
else:
clauses = to_select

if self.exists:
result = sa.exists(clauses)
else:
result = sa.select(clauses)

if self.distinct:
result = result.distinct()

if not has_select_star:
if table_set is not None:
return result.select_from(table_set)
else:
return result
else:
return result

def _add_groupby(self, fragment):
# GROUP BY and HAVING
if not len(self.group_by):
return fragment

group_keys = [self._translate(arg) for arg in self.group_by]
fragment = fragment.group_by(*group_keys)

if len(self.having) > 0:
having_args = [self._translate(arg) for arg in self.having]
having_clause = functools.reduce(sql.and_, having_args)
fragment = fragment.having(having_clause)

return fragment

def _add_where(self, fragment):
if not len(self.where):
return fragment

args = [
self._translate(pred, permit_subquery=True) for pred in self.where
]
clause = functools.reduce(sql.and_, args)
return fragment.where(clause)

def _add_order_by(self, fragment):
if not len(self.order_by):
return fragment

clauses = []
for expr in self.order_by:
key = expr.op()
sort_expr = key.expr

# here we have to determine if key.expr is in the select set (as it
# will be in the case of order_by fused with an aggregation
if _can_lower_sort_column(self.table_set, sort_expr):
arg = sort_expr.get_name()
else:
arg = self._translate(sort_expr)

if not key.ascending:
arg = sa.desc(arg)

clauses.append(arg)

return fragment.order_by(*clauses)

def _among_select_set(self, expr):
for other in self.select_set:
if expr.equals(other):
return True
return False

def _add_limit(self, fragment):
if self.limit is None:
return fragment

n, offset = self.limit['n'], self.limit['offset']
fragment = fragment.limit(n)
if offset is not None and offset != 0:
fragment = fragment.offset(offset)

return fragment


class AlchemySelectBuilder(SelectBuilder):
def _convert_group_by(self, exprs):
return exprs


class AlchemyUnion(Union):
def compile(self):
def reduce_union(left, right, distincts=iter(self.distincts)):
distinct = next(distincts)
sa_func = sa.union if distinct else sa.union_all
return sa_func(left, right)

context = self.context
selects = []

for table in self.tables:
table_set = context.get_compiled_expr(table)
selects.append(table_set.cte().select())

return functools.reduce(reduce_union, selects)


class AlchemyCompiler(Compiler):
translator_class = AlchemyExprTranslator
context_class = AlchemyContext
table_set_formatter_class = _AlchemyTableSetFormatter
select_builder_class = AlchemySelectBuilder
select_class = AlchemySelect
union_class = AlchemyUnion

@classmethod
def to_sql(cls, expr, context=None, params=None, exists=False):
if context is None:
context = cls.make_context(params=params)
query = cls.to_ast(expr, context).queries[0]
if exists:
query.exists = True
return query.compile()
583 changes: 583 additions & 0 deletions ibis/backends/base/sql/alchemy/registry.py

Large diffs are not rendered by default.

92 changes: 92 additions & 0 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import ibis
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base.sql.compiler import ExprTranslator, QueryContext

from .datatypes import ibis_type_to_sqla, to_sqla_type
from .registry import fixed_arity, sqlalchemy_operation_registry


class AlchemyContext(QueryContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._table_objects = {}

def collapse(self, queries):
if isinstance(queries, str):
return queries

if len(queries) > 1:
raise NotImplementedError(
'Only a single query is supported for SQLAlchemy backends'
)
return queries[0]

def subcontext(self):
return type(self)(
compiler=self.compiler, parent=self, params=self.params
)

def _compile_subquery(self, expr):
sub_ctx = self.subcontext()
return self._to_sql(expr, sub_ctx)

def has_table(self, expr, parent_contexts=False):
key = self._get_table_key(expr)
return self._key_in(
key, '_table_objects', parent_contexts=parent_contexts
)

def set_table(self, expr, obj):
key = self._get_table_key(expr)
self._table_objects[key] = obj

def get_table(self, expr):
"""
Get the memoized SQLAlchemy expression object
"""
return self._get_table_item('_table_objects', expr)


class AlchemyExprTranslator(ExprTranslator):

_registry = sqlalchemy_operation_registry
_rewrites = ExprTranslator._rewrites.copy()
_type_map = ibis_type_to_sqla

context_class = AlchemyContext

def name(self, translated, name, force=True):
if hasattr(translated, 'label'):
return translated.label(name)
return translated

def get_sqla_type(self, data_type):
return to_sqla_type(data_type, type_map=self._type_map)


rewrites = AlchemyExprTranslator.rewrites


@rewrites(ops.NullIfZero)
def _nullifzero(expr):
arg = expr.op().args[0]
return (arg == 0).ifelse(ibis.NA, arg)


# TODO This was previously implemented with the legacy `@compiles` decorator.
# This definition should now be in the registry, but there is some magic going
# on that things fail if it's not defined here (and in the registry
# `operator.truediv` is used.
def _true_divide(t, expr):
op = expr.op()
left, right = args = op.args

if util.all_of(args, ir.IntegerValue):
return t.translate(left.div(right.cast('double')))

return fixed_arity(lambda x, y: x / y, 2)(t, expr)


AlchemyExprTranslator._registry[ops.Divide] = _true_divide
21 changes: 21 additions & 0 deletions ibis/backends/base/sql/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .base import DDL, DML
from .query_builder import (
Compiler,
Select,
SelectBuilder,
TableSetFormatter,
Union,
)
from .translator import ExprTranslator, QueryContext

__all__ = (
'Compiler',
'Select',
'SelectBuilder',
'Union',
'TableSetFormatter',
'ExprTranslator',
'QueryContext',
'DML',
'DDL',
)
107 changes: 107 additions & 0 deletions ibis/backends/base/sql/compiler/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import abc
from itertools import chain

import toolz

import ibis.util as util

from .extract_subqueries import ExtractSubqueries


class DML(abc.ABC):
@abc.abstractmethod
def compile(self):
pass


class DDL(abc.ABC):
@abc.abstractmethod
def compile(self):
pass


class QueryAST:

__slots__ = 'context', 'dml', 'setup_queries', 'teardown_queries'

def __init__(
self, context, dml, setup_queries=None, teardown_queries=None
):
self.context = context
self.dml = dml
self.setup_queries = setup_queries
self.teardown_queries = teardown_queries

@property
def queries(self):
return [self.dml]

def compile(self):
compiled_setup_queries = [q.compile() for q in self.setup_queries]
compiled_queries = [q.compile() for q in self.queries]
compiled_teardown_queries = [
q.compile() for q in self.teardown_queries
]
return self.context.collapse(
list(
chain(
compiled_setup_queries,
compiled_queries,
compiled_teardown_queries,
)
)
)


class SetOp(DML):
def __init__(self, tables, expr, context):
self.context = context
self.tables = tables
self.table_set = expr
self.filters = []

def _extract_subqueries(self):
self.subqueries = ExtractSubqueries.extract(self)
for subquery in self.subqueries:
self.context.set_extracted(subquery)

def format_subqueries(self):
context = self.context
subqueries = self.subqueries

return ',\n'.join(
'{} AS (\n{}\n)'.format(
context.get_ref(expr),
util.indent(context.get_compiled_expr(expr), 2),
)
for expr in subqueries
)

def format_relation(self, expr):
ref = self.context.get_ref(expr)
if ref is not None:
return f'SELECT *\nFROM {ref}'
return self.context.get_compiled_expr(expr)

def _get_keyword_list(self):
raise NotImplementedError("Need objects to interleave")

def compile(self):
self._extract_subqueries()

extracted = self.format_subqueries()

buf = []

if extracted:
buf.append(f'WITH {extracted}')

buf.extend(
toolz.interleave(
(
map(self.format_relation, self.tables),
self._get_keyword_list(),
)
)
)
return '\n'.join(buf)
129 changes: 129 additions & 0 deletions ibis/backends/base/sql/compiler/extract_subqueries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from collections import OrderedDict

import ibis.expr.operations as ops
import ibis.expr.types as ir


class ExtractSubqueries:
def __init__(self, query, greedy=False):
self.query = query
self.greedy = greedy
self.expr_counts = OrderedDict()
self.node_to_expr = {}

@classmethod
def extract(cls, select_stmt):
helper = cls(select_stmt)
return helper.get_result()

def get_result(self):
if self.query.table_set is not None:
self.visit(self.query.table_set)

for clause in self.query.filters:
self.visit(clause)

expr_counts = self.expr_counts

if self.greedy:
to_extract = list(expr_counts.keys())
else:
to_extract = [op for op, count in expr_counts.items() if count > 1]

node_to_expr = self.node_to_expr
return [node_to_expr[op] for op in to_extract]

def observe(self, expr):
key = expr.op()

if key not in self.node_to_expr:
self.node_to_expr[key] = expr

assert self.node_to_expr[key].equals(expr)
self.expr_counts[key] = self.expr_counts.setdefault(key, 0) + 1

def seen(self, expr):
return expr.op() in self.expr_counts

def visit(self, expr):
node = expr.op()
method = f'visit_{type(node).__name__}'

if hasattr(self, method):
f = getattr(self, method)
f(expr)
elif isinstance(node, ops.Join):
self.visit_join(expr)
elif isinstance(node, ops.PhysicalTable):
self.visit_physical_table(expr)
elif isinstance(node, ops.ValueOp):
for arg in node.flat_args():
if not isinstance(arg, ir.Expr):
continue
self.visit(arg)
else:
raise NotImplementedError(type(node))

def visit_join(self, expr):
node = expr.op()
self.visit(node.left)
self.visit(node.right)

def visit_physical_table(self, _):
return

def visit_Exists(self, expr):
node = expr.op()
self.visit(node.foreign_table)
for pred in node.predicates:
self.visit(pred)

visit_NotExistsSubquery = visit_ExistsSubquery = visit_Exists

def visit_Aggregation(self, expr):
self.visit(expr.op().table)
self.observe(expr)

def visit_Distinct(self, expr):
self.observe(expr)

def visit_Limit(self, expr):
self.visit(expr.op().table)
self.observe(expr)

def visit_Union(self, expr):
op = expr.op()
self.visit(op.left)
self.visit(op.right)
self.observe(expr)

def visit_Intersection(self, expr):
op = expr.op()
self.visit(op.left)
self.visit(op.right)
self.observe(expr)

def visit_Difference(self, expr):
op = expr.op()
self.visit(op.left)
self.visit(op.right)
self.observe(expr)

def visit_MaterializedJoin(self, expr):
self.visit(expr.op().join)
self.observe(expr)

def visit_Selection(self, expr):
self.visit(expr.op().table)
self.observe(expr)

def visit_SQLQueryResult(self, expr):
self.observe(expr)

def visit_TableColumn(self, expr):
table = expr.op().table
if not self.seen(table):
self.visit(table)

def visit_SelfReference(self, expr):
self.visit(expr.op().table)
649 changes: 649 additions & 0 deletions ibis/backends/base/sql/compiler/query_builder.py

Large diffs are not rendered by default.

854 changes: 854 additions & 0 deletions ibis/backends/base/sql/compiler/select_builder.py

Large diffs are not rendered by default.

391 changes: 391 additions & 0 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
import operator
from typing import Callable, Dict

import ibis
import ibis.common.exceptions as com
import ibis.expr.analytics as analytics
import ibis.expr.datatypes as dt
import ibis.expr.format as fmt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.registry import (
operation_registry,
quote_identifier,
)


class QueryContext:

"""Records bits of information used during ibis AST to SQL translation.
Notably, table aliases (for subquery construction) and scalar query
parameters are tracked here.
"""

def __init__(
self, compiler, indent=2, parent=None, memo=None, params=None
):
self.compiler = compiler
self._table_refs = {}
self.extracted_subexprs = set()
self.subquery_memo = {}
self.indent = indent
self.parent = parent

self.always_alias = False

self.query = None

self._table_key_memo = {}
self.memo = memo or fmt.FormatMemo()
self.params = params if params is not None else {}

def _compile_subquery(self, expr):
sub_ctx = self.subcontext()
return self._to_sql(expr, sub_ctx)

def _to_sql(self, expr, ctx):
return self.compiler.to_sql(expr, ctx)

def collapse(self, queries):
"""Turn a sequence of queries into something executable.
Parameters
----------
queries : List[str]
Returns
-------
query : str
"""
return '\n\n'.join(queries)

@property
def top_context(self):
if self.parent is None:
return self
else:
return self.parent.top_context

def set_always_alias(self):
self.always_alias = True

def get_compiled_expr(self, expr):
this = self.top_context

key = self._get_table_key(expr)
if key in this.subquery_memo:
return this.subquery_memo[key]

op = expr.op()
if isinstance(op, ops.SQLQueryResult):
result = op.query
else:
result = self._compile_subquery(expr)

this.subquery_memo[key] = result
return result

def make_alias(self, expr):
i = len(self._table_refs)

key = self._get_table_key(expr)

# Get total number of aliases up and down the tree at this point; if we
# find the table prior-aliased along the way, however, we reuse that
# alias
ctx = self
while ctx.parent is not None:
ctx = ctx.parent

if key in ctx._table_refs:
alias = ctx._table_refs[key]
self.set_ref(expr, alias)
return

i += len(ctx._table_refs)

alias = f't{i:d}'
self.set_ref(expr, alias)

def need_aliases(self, expr=None):
return self.always_alias or len(self._table_refs) > 1

def has_ref(self, expr, parent_contexts=False):
key = self._get_table_key(expr)
return self._key_in(
key, '_table_refs', parent_contexts=parent_contexts
)

def set_ref(self, expr, alias):
key = self._get_table_key(expr)
self._table_refs[key] = alias

def get_ref(self, expr):
"""
Get the alias being used throughout a query to refer to a particular
table or inline view
"""
return self._get_table_item('_table_refs', expr)

def is_extracted(self, expr):
key = self._get_table_key(expr)
return key in self.top_context.extracted_subexprs

def set_extracted(self, expr):
key = self._get_table_key(expr)
self.extracted_subexprs.add(key)
self.make_alias(expr)

def subcontext(self):
return type(self)(
compiler=self.compiler,
indent=self.indent,
parent=self,
params=self.params,
)

# Maybe temporary hacks for correlated / uncorrelated subqueries

def set_query(self, query):
self.query = query

def is_foreign_expr(self, expr):
from ibis.expr.analysis import ExprValidator

# The expression isn't foreign to us. For example, the parent table set
# in a correlated WHERE subquery
if self.has_ref(expr, parent_contexts=True):
return False

exprs = [self.query.table_set] + self.query.select_set
validator = ExprValidator(exprs)
return not validator.validate(expr)

def _get_table_item(self, item, expr):
key = self._get_table_key(expr)
top = self.top_context

if self.is_extracted(expr):
return getattr(top, item).get(key)

return getattr(self, item).get(key)

def _get_table_key(self, table):
if isinstance(table, ir.TableExpr):
table = table.op()

try:
return self._table_key_memo[table]
except KeyError:
val = table._repr()
self._table_key_memo[table] = val
return val

def _key_in(self, key, memo_attr, parent_contexts=False):
if key in getattr(self, memo_attr):
return True

ctx = self
while parent_contexts and ctx.parent is not None:
ctx = ctx.parent
if key in getattr(ctx, memo_attr):
return True

return False


class ExprTranslator:

"""Class that performs translation of ibis expressions into executable
SQL.
"""

_registry = operation_registry
_rewrites: Dict[ops.Node, Callable] = {}

def __init__(self, expr, context, named=False, permit_subquery=False):
self.expr = expr
self.permit_subquery = permit_subquery

assert context is not None, 'context is None in {}'.format(
type(self).__name__
)
self.context = context

# For now, governing whether the result will have a name
self.named = named

def get_result(self):
"""
Build compiled SQL expression from the bottom up and return as a string
"""
translated = self.translate(self.expr)
if self._needs_name(self.expr):
# TODO: this could fail in various ways
name = self.expr.get_name()
translated = self.name(translated, name)
return translated

@classmethod
def add_operation(cls, operation, translate_function):
"""
Adds an operation to the operation registry. In general, operations
should be defined directly in the registry, in `registry.py`. But
there are couple of exceptions why this is needed. Operations defined
by Ibis users (not Ibis or backend developers). and UDF, which are
added dynamically.
"""
cls._registry[operation] = translate_function

def _needs_name(self, expr):
if not self.named:
return False

op = expr.op()
if isinstance(op, ops.TableColumn):
# This column has been given an explicitly different name
if expr.get_name() != op.name:
return True
return False

if expr.get_name() is ir.unnamed:
return False

return True

def name(self, translated, name, force=True):
return '{} AS {}'.format(
translated, quote_identifier(name, force=force)
)

def translate(self, expr):
# The operation node type the typed expression wraps
op = expr.op()

if type(op) in self._rewrites: # even if type(op) is in self._registry
expr = self._rewrites[type(op)](expr)
op = expr.op()

# TODO: use op MRO for subclasses instead of this isinstance spaghetti
if isinstance(op, ops.ScalarParameter):
return self._trans_param(expr)
elif isinstance(op, ops.TableNode):
# HACK/TODO: revisit for more complex cases
return '*'
elif type(op) in self._registry:
formatter = self._registry[type(op)]
return formatter(self, expr)
else:
raise com.OperationNotDefinedError(
f'No translation rule for {type(op)}'
)

def _trans_param(self, expr):
raw_value = self.context.params[expr.op()]
literal = ibis.literal(raw_value, type=expr.type())
return self.translate(literal)

@classmethod
def rewrites(cls, klass):
def decorator(f):
cls._rewrites[klass] = f
return f

return decorator


rewrites = ExprTranslator.rewrites


@rewrites(analytics.Bucket)
def _bucket(expr):
op = expr.op()
stmt = ibis.case()

if op.closed == 'left':
l_cmp = operator.le
r_cmp = operator.lt
else:
l_cmp = operator.lt
r_cmp = operator.le

user_num_buckets = len(op.buckets) - 1

bucket_id = 0
if op.include_under:
if user_num_buckets > 0:
cmp = operator.lt if op.close_extreme else r_cmp
else:
cmp = operator.le if op.closed == 'right' else operator.lt
stmt = stmt.when(cmp(op.arg, op.buckets[0]), bucket_id)
bucket_id += 1

for j, (lower, upper) in enumerate(zip(op.buckets, op.buckets[1:])):
if op.close_extreme and (
(op.closed == 'right' and j == 0)
or (op.closed == 'left' and j == (user_num_buckets - 1))
):
stmt = stmt.when((lower <= op.arg) & (op.arg <= upper), bucket_id)
else:
stmt = stmt.when(
l_cmp(lower, op.arg) & r_cmp(op.arg, upper), bucket_id
)
bucket_id += 1

if op.include_over:
if user_num_buckets > 0:
cmp = operator.lt if op.close_extreme else l_cmp
else:
cmp = operator.lt if op.closed == 'right' else operator.le

stmt = stmt.when(cmp(op.buckets[-1], op.arg), bucket_id)
bucket_id += 1

return stmt.end().name(expr._name)


@rewrites(analytics.CategoryLabel)
def _category_label(expr):
op = expr.op()

stmt = op.args[0].case()
for i, label in enumerate(op.labels):
stmt = stmt.when(i, label)

if op.nulls is not None:
stmt = stmt.else_(op.nulls)

return stmt.end().name(expr._name)


@rewrites(ops.Any)
def _any_expand(expr):
arg = expr.op().args[0]
return arg.max()


@rewrites(ops.NotAny)
def _notany_expand(expr):
arg = expr.op().args[0]
return arg.max() == 0


@rewrites(ops.All)
def _all_expand(expr):
arg = expr.op().args[0]
return arg.min()


@rewrites(ops.NotAll)
def _notall_expand(expr):
arg = expr.op().args[0]
return arg.min() == 0


@rewrites(ops.Cast)
def _rewrite_cast(expr):
arg, to = expr.op().args
if isinstance(to, dt.Interval) and isinstance(arg.type(), dt.Integer):
return arg.to_interval(unit=to.unit)
return expr
103 changes: 64 additions & 39 deletions ibis/backends/base_sql/ddl.py → ibis/backends/base/sql/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.backends.base_sql import quote_identifier, type_to_sql_string
from ibis.backends.base_sqlalchemy.compiler import DDL, DML
from ibis.backends.base.sql.compiler import DDL, DML
from ibis.backends.base.sql.registry import (
quote_identifier,
type_to_sql_string,
)

fully_qualified_re = re.compile(r"(.*)\.(?:`(.*)`|(.*))")
_format_aliases = {'TEXT': 'TEXTFILE'}
Expand All @@ -15,7 +18,7 @@ def _sanitize_format(format):
format = format.upper()
format = _format_aliases.get(format, format)
if format not in ('PARQUET', 'AVRO', 'TEXTFILE'):
raise ValueError('Invalid format: {!r}'.format(format))
raise ValueError(f'Invalid format: {format!r}')

return format

Expand All @@ -40,17 +43,18 @@ def format_schema(schema):

def _format_schema_element(name, t):
return '{} {}'.format(
quote_identifier(name, force=True), type_to_sql_string(t),
quote_identifier(name, force=True),
type_to_sql_string(t),
)


def _format_partition_kv(k, v, type):
if type == dt.string:
value_formatted = '"{}"'.format(v)
value_formatted = f'"{v}"'
else:
value_formatted = str(v)

return '{}={}'.format(k, value_formatted)
return f'{k}={value_formatted}'


def format_partition(partition, partition_schema):
Expand All @@ -73,53 +77,53 @@ def format_partition(partition, partition_schema):
return 'PARTITION ({})'.format(', '.join(tokens))


def format_properties(props):
def _format_properties(props):
tokens = []
for k, v in sorted(props.items()):
tokens.append(" '{}'='{}'".format(k, v))
tokens.append(f" '{k}'='{v}'")

return '(\n{}\n)'.format(',\n'.join(tokens))


def format_tblproperties(props):
formatted_props = format_properties(props)
return 'TBLPROPERTIES {}'.format(formatted_props)
formatted_props = _format_properties(props)
return f'TBLPROPERTIES {formatted_props}'


def _serdeproperties(props):
formatted_props = format_properties(props)
return 'SERDEPROPERTIES {}'.format(formatted_props)
formatted_props = _format_properties(props)
return f'SERDEPROPERTIES {formatted_props}'


class BaseQualifiedSQLStatement:
class _BaseQualifiedSQLStatement:
def _get_scoped_name(self, obj_name, database):
if database:
scoped_name = '{}.`{}`'.format(database, obj_name)
scoped_name = f'{database}.`{obj_name}`'
else:
if not is_fully_qualified(obj_name):
if _is_quoted(obj_name):
return obj_name
else:
return '`{}`'.format(obj_name)
return f'`{obj_name}`'
else:
return obj_name
return scoped_name


class BaseDDL(DDL, BaseQualifiedSQLStatement):
class BaseDDL(DDL, _BaseQualifiedSQLStatement):
pass


class BaseDML(DML, BaseQualifiedSQLStatement):
class _BaseDML(DML, _BaseQualifiedSQLStatement):
pass


class CreateDDL(BaseDDL):
class _CreateDDL(BaseDDL):
def _if_exists(self):
return 'IF NOT EXISTS ' if self.can_exist else ''


class CreateTable(CreateDDL):
class CreateTable(_CreateDDL):

"""
Expand Down Expand Up @@ -158,20 +162,19 @@ def _prefix(self):

def _create_line(self):
scoped_name = self._get_scoped_name(self.table_name, self.database)
return '{} {}{}'.format(self._prefix, self._if_exists(), scoped_name)
return f'{self._prefix} {self._if_exists()}{scoped_name}'

def _location(self):
return "LOCATION '{}'".format(self.path) if self.path else None
return f"LOCATION '{self.path}'" if self.path else None

def _storage(self):
# By the time we're here, we have a valid format
return 'STORED AS {}'.format(self.format)
return f'STORED AS {self.format}'

@property
def pieces(self):
yield self._create_line()
for piece in filter(None, self._pieces):
yield piece
yield from filter(None, self._pieces)

def compile(self):
return '\n'.join(self.pieces)
Expand Down Expand Up @@ -267,7 +270,7 @@ def _pieces(self):
main_schema = main_schema.delete(to_delete)

yield format_schema(main_schema)
yield 'PARTITIONED BY {}'.format(format_schema(part_schema))
yield f'PARTITIONED BY {format_schema(part_schema)}'
else:
yield format_schema(self.schema)

Expand All @@ -279,7 +282,7 @@ def _pieces(self):
yield self._location()


class CreateDatabase(CreateDDL):
class CreateDatabase(_CreateDDL):
def __init__(self, name, path=None, can_exist=False):
self.name = name
self.path = path
Expand All @@ -289,9 +292,9 @@ def compile(self):
name = quote_identifier(self.name)

create_decl = 'CREATE DATABASE'
create_line = '{} {}{}'.format(create_decl, self._if_exists(), name)
create_line = f'{create_decl} {self._if_exists()}{name}'
if self.path is not None:
create_line += "\nLOCATION '{}'".format(self.path)
create_line += f"\nLOCATION '{self.path}'"

return create_line

Expand All @@ -303,7 +306,7 @@ def __init__(self, must_exist=True):
def compile(self):
if_exists = '' if self.must_exist else 'IF EXISTS '
object_name = self._object_name()
return 'DROP {} {}{}'.format(self._object_type, if_exists, object_name)
return f'DROP {self._object_type} {if_exists}{object_name}'


class DropDatabase(DropObject):
Expand Down Expand Up @@ -346,10 +349,10 @@ def __init__(self, table_name, database=None):

def compile(self):
name = self._get_scoped_name(self.table_name, self.database)
return 'TRUNCATE TABLE {}'.format(name)
return f'TRUNCATE TABLE {name}'


class InsertSelect(BaseDML):
class InsertSelect(_BaseDML):
def __init__(
self,
table_name,
Expand All @@ -376,15 +379,13 @@ def compile(self):

if self.partition is not None:
part = format_partition(self.partition, self.partition_schema)
partition = ' {} '.format(part)
partition = f' {part} '
else:
partition = ''

select_query = self.select.compile()
scoped_name = self._get_scoped_name(self.table_name, self.database)
return '{0} {1}{2}\n{3}'.format(
cmd, scoped_name, partition, select_query
)
return f'{cmd} {scoped_name}{partition}\n{select_query}'


class AlterTable(BaseDDL):
Expand All @@ -403,16 +404,16 @@ def __init__(
self.serde_properties = serde_properties

def _wrap_command(self, cmd):
return 'ALTER TABLE {}'.format(cmd)
return f'ALTER TABLE {cmd}'

def _format_properties(self, prefix=''):
tokens = []

if self.location is not None:
tokens.append("LOCATION '{}'".format(self.location))
tokens.append(f"LOCATION '{self.location}'")

if self.format is not None:
tokens.append("FILEFORMAT {}".format(self.format))
tokens.append(f"FILEFORMAT {self.format}")

if self.tbl_properties is not None:
tokens.append(format_tblproperties(self.tbl_properties))
Expand All @@ -427,7 +428,7 @@ def _format_properties(self, prefix=''):

def compile(self):
props = self._format_properties()
action = '{} SET {}'.format(self.table, props)
action = f'{self.table} SET {props}'
return self._wrap_command(action)


Expand Down Expand Up @@ -483,3 +484,27 @@ def compile(self):
self.old_qualified_name, self.new_qualified_name
)
return self._wrap_command(cmd)


__all__ = (
'fully_qualified_re',
'is_fully_qualified',
'format_schema',
'format_partition',
'format_tblproperties',
'BaseDDL',
'CreateTable',
'CTAS',
'CreateView',
'CreateTableWithSchema',
'CreateDatabase',
'DropObject',
'DropDatabase',
'DropTable',
'DropView',
'TruncateTable',
'InsertSelect',
'AlterTable',
'DropFunction',
'RenameTable',
)
25 changes: 25 additions & 0 deletions ibis/backends/base/sql/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .aggregate import reduction
from .helpers import quote_identifier, sql_type_names, type_to_sql_string
from .literal import literal, literal_formatters
from .main import binary_infix_ops, fixed_arity, operation_registry, unary
from .window import (
cumulative_to_window,
format_window,
time_range_to_range_window,
)

__all__ = (
'quote_identifier',
'operation_registry',
'binary_infix_ops',
'fixed_arity',
'literal',
'literal_formatters',
'sql_type_names',
'type_to_sql_string',
'reduction',
'unary',
'cumulative_to_window',
'format_window',
'time_range_to_range_window',
)
45 changes: 45 additions & 0 deletions ibis/backends/base/sql/registry/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import itertools

import ibis


def _reduction_format(translator, func_name, where, arg, *args):
if where is not None:
arg = where.ifelse(arg, ibis.NA)

return '{}({})'.format(
func_name,
', '.join(map(translator.translate, itertools.chain([arg], args))),
)


def reduction(func_name):
def formatter(translator, expr):
op = expr.op()
*args, where = op.args
return _reduction_format(translator, func_name, where, *args)

return formatter


def variance_like(func_name):
func_names = {
'sample': f'{func_name}_samp',
'pop': f'{func_name}_pop',
}

def formatter(translator, expr):
arg, how, where = expr.op().args
return _reduction_format(translator, func_names[how], where, arg)

return formatter


def count_distinct(translator, expr):
arg, where = expr.op().args

if where is not None:
arg_formatted = translator.translate(where.ifelse(arg, None))
else:
arg_formatted = translator.translate(arg)
return f'count(DISTINCT {arg_formatted})'
52 changes: 52 additions & 0 deletions ibis/backends/base/sql/registry/binary_infix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from . import helpers


def binary_infix_op(infix_sym):
def formatter(translator, expr):
op = expr.op()

left, right = op.args

left_arg = translator.translate(left)
right_arg = translator.translate(right)
if helpers.needs_parens(left):
left_arg = helpers.parenthesize(left_arg)

if helpers.needs_parens(right):
right_arg = helpers.parenthesize(right_arg)

return f'{left_arg} {infix_sym} {right_arg}'

return formatter


def identical_to(translator, expr):
op = expr.op()
if op.args[0].equals(op.args[1]):
return 'TRUE'

left_expr = op.left
right_expr = op.right
left = translator.translate(left_expr)
right = translator.translate(right_expr)

if helpers.needs_parens(left_expr):
left = helpers.parenthesize(left)
if helpers.needs_parens(right_expr):
right = helpers.parenthesize(right)
return f'{left} IS NOT DISTINCT FROM {right}'


def xor(translator, expr):
op = expr.op()

left_arg = translator.translate(op.left)
right_arg = translator.translate(op.right)

if helpers.needs_parens(op.left):
left_arg = helpers.parenthesize(left_arg)

if helpers.needs_parens(op.right):
right_arg = helpers.parenthesize(right_arg)

return '({0} OR {1}) AND NOT ({0} AND {1})'.format(left_arg, right_arg)
66 changes: 66 additions & 0 deletions ibis/backends/base/sql/registry/case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from io import StringIO


class _CaseFormatter:
def __init__(self, translator, base, cases, results, default):
self.translator = translator
self.base = base
self.cases = cases
self.results = results
self.default = default

# HACK
self.indent = 2
self.multiline = len(cases) > 1
self.buf = StringIO()

def _trans(self, expr):
return self.translator.translate(expr)

def get_result(self):
self.buf.seek(0)

self.buf.write('CASE')
if self.base is not None:
base_str = self._trans(self.base)
self.buf.write(f' {base_str}')

for case, result in zip(self.cases, self.results):
self._next_case()
case_str = self._trans(case)
result_str = self._trans(result)
self.buf.write(f'WHEN {case_str} THEN {result_str}')

if self.default is not None:
self._next_case()
default_str = self._trans(self.default)
self.buf.write(f'ELSE {default_str}')

if self.multiline:
self.buf.write('\nEND')
else:
self.buf.write(' END')

return self.buf.getvalue()

def _next_case(self):
if self.multiline:
self.buf.write('\n{}'.format(' ' * self.indent))
else:
self.buf.write(' ')


def simple_case(translator, expr):
op = expr.op()
formatter = _CaseFormatter(
translator, op.base, op.cases, op.results, op.default
)
return formatter.get_result()


def searched_case(translator, expr):
op = expr.op()
formatter = _CaseFormatter(
translator, None, op.cases, op.results, op.default
)
return formatter.get_result()
80 changes: 80 additions & 0 deletions ibis/backends/base/sql/registry/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir

from . import identifiers


def format_call(translator, func, *args):
formatted_args = []
for arg in args:
fmt_arg = translator.translate(arg)
formatted_args.append(fmt_arg)

return '{}({})'.format(func, ', '.join(formatted_args))


def quote_identifier(name, quotechar='`', force=False):
"""Add quotes to the `name` identifier if needed."""
if force or name.count(' ') or name in identifiers.base_identifiers:
return '{0}{1}{0}'.format(quotechar, name)
else:
return name


def needs_parens(op):
if isinstance(op, ir.Expr):
op = op.op()
op_klass = type(op)
# function calls don't need parens
return op_klass in {
ops.Negate,
ops.IsNull,
ops.NotNull,
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Power,
ops.Modulus,
ops.Equals,
ops.NotEquals,
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.IdenticalTo,
ops.And,
ops.Or,
ops.Xor,
}


parenthesize = '({})'.format


sql_type_names = {
'int8': 'tinyint',
'int16': 'smallint',
'int32': 'int',
'int64': 'bigint',
'float': 'float',
'float32': 'float',
'double': 'double',
'float64': 'double',
'string': 'string',
'boolean': 'boolean',
'timestamp': 'timestamp',
'decimal': 'decimal',
}


def type_to_sql_string(tval):
if isinstance(tval, dt.Decimal):
return f'decimal({tval.precision}, {tval.scale})'
name = tval.name.lower()
try:
return sql_type_names[name]
except KeyError:
raise com.UnsupportedBackendType(name)
File renamed without changes.
102 changes: 102 additions & 0 deletions ibis/backends/base/sql/registry/literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import datetime
import math

import ibis.expr.types as ir


def _set_literal_format(translator, expr):
value_type = expr.type().value_type

formatted = [
translator.translate(ir.literal(x, type=value_type))
for x in expr.op().value
]

return '(' + ', '.join(formatted) + ')'


def _boolean_literal_format(translator, expr):
value = expr.op().value
return 'TRUE' if value else 'FALSE'


def _string_literal_format(translator, expr):
value = expr.op().value
return "'{}'".format(value.replace("'", "\\'"))


def _number_literal_format(translator, expr):
value = expr.op().value

if math.isfinite(value):
formatted = repr(value)
else:
if math.isnan(value):
formatted_val = 'NaN'
elif math.isinf(value):
if value > 0:
formatted_val = 'Infinity'
else:
formatted_val = '-Infinity'
formatted = f"CAST({formatted_val!r} AS DOUBLE)"

return formatted


def _interval_literal_format(translator, expr):
return 'INTERVAL {} {}'.format(
expr.op().value, expr.type().resolution.upper()
)


def _date_literal_format(translator, expr):
value = expr.op().value
if isinstance(value, datetime.date):
value = value.strftime('%Y-%m-%d')

return repr(value)


def _timestamp_literal_format(translator, expr):
value = expr.op().value
if isinstance(value, datetime.datetime):
value = value.strftime('%Y-%m-%d %H:%M:%S')

return repr(value)


literal_formatters = {
'boolean': _boolean_literal_format,
'number': _number_literal_format,
'string': _string_literal_format,
'interval': _interval_literal_format,
'timestamp': _timestamp_literal_format,
'date': _date_literal_format,
'set': _set_literal_format,
}


def literal(translator, expr):
"""Return the expression as its literal value."""
if isinstance(expr, ir.BooleanValue):
typeclass = 'boolean'
elif isinstance(expr, ir.StringValue):
typeclass = 'string'
elif isinstance(expr, ir.NumericValue):
typeclass = 'number'
elif isinstance(expr, ir.DateValue):
typeclass = 'date'
elif isinstance(expr, ir.TimestampValue):
typeclass = 'timestamp'
elif isinstance(expr, ir.IntervalValue):
typeclass = 'interval'
elif isinstance(expr, ir.SetValue):
typeclass = 'set'
else:
raise NotImplementedError

return literal_formatters[typeclass](translator, expr)


def null_literal(translator, expr):
return 'NULL'
362 changes: 362 additions & 0 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.util as util

from . import aggregate, binary_infix, case, helpers, string, timestamp, window
from .literal import literal, null_literal


def fixed_arity(func_name, arity):
def formatter(translator, expr):
op = expr.op()
if arity != len(op.args):
raise com.IbisError('incorrect number of args')
return helpers.format_call(translator, func_name, *op.args)

return formatter


def unary(func_name):
return fixed_arity(func_name, 1)


def not_null(translator, expr):
formatted_arg = translator.translate(expr.op().args[0])
return f'{formatted_arg} IS NOT NULL'


def is_null(translator, expr):
formatted_arg = translator.translate(expr.op().args[0])
return f'{formatted_arg} IS NULL'


def not_(translator, expr):
(arg,) = expr.op().args
formatted_arg = translator.translate(arg)
if helpers.needs_parens(arg):
formatted_arg = helpers.parenthesize(formatted_arg)
return f'NOT {formatted_arg}'


def negate(translator, expr):
arg = expr.op().args[0]
formatted_arg = translator.translate(arg)
if isinstance(expr, ir.BooleanValue):
return not_(translator, expr)
else:
if helpers.needs_parens(arg):
formatted_arg = helpers.parenthesize(formatted_arg)
return f'-{formatted_arg}'


def ifnull_workaround(translator, expr):
op = expr.op()
a, b = op.args

# work around per #345, #360
if isinstance(a, ir.DecimalValue) and isinstance(b, ir.IntegerValue):
b = b.cast(a.type())

return helpers.format_call(translator, 'isnull', a, b)


def sign(translator, expr):
(arg,) = expr.op().args
translated_arg = translator.translate(arg)
translated_type = helpers.type_to_sql_string(expr.type())
if expr.type() != dt.float:
return f'CAST(sign({translated_arg}) AS {translated_type})'
return f'sign({translated_arg})'


def hashbytes(translator, expr):
op = expr.op()
arg, how = op.args

arg_formatted = translator.translate(arg)

if how == 'md5':
return f'md5({arg_formatted})'
elif how == 'sha1':
return f'sha1({arg_formatted})'
elif how == 'sha256':
return f'sha256({arg_formatted})'
elif how == 'sha512':
return f'sha512({arg_formatted})'
else:
raise NotImplementedError(how)


def log(translator, expr):
op = expr.op()
arg, base = op.args
arg_formatted = translator.translate(arg)

if base is None:
return f'ln({arg_formatted})'

base_formatted = translator.translate(base)
return f'log({base_formatted}, {arg_formatted})'


def value_list(translator, expr):
op = expr.op()
formatted = [translator.translate(x) for x in op.values]
return helpers.parenthesize(', '.join(formatted))


def cast(translator, expr):
op = expr.op()
arg, target_type = op.args
arg_formatted = translator.translate(arg)

if isinstance(arg, ir.CategoryValue) and target_type == dt.int32:
return arg_formatted
if isinstance(arg, ir.TemporalValue) and target_type == dt.int64:
return f'1000000 * unix_timestamp({arg_formatted})'
else:
sql_type = helpers.type_to_sql_string(target_type)
return f'CAST({arg_formatted} AS {sql_type})'


def varargs(func_name):
def varargs_formatter(translator, expr):
op = expr.op()
return helpers.format_call(translator, func_name, *op.arg)

return varargs_formatter


def between(translator, expr):
op = expr.op()
comp, lower, upper = (translator.translate(x) for x in op.args)
return f'{comp} BETWEEN {lower} AND {upper}'


def table_array_view(translator, expr):
ctx = translator.context
table = expr.op().table
query = ctx.get_compiled_expr(table)
return f'(\n{util.indent(query, ctx.indent)}\n)'


def table_column(translator, expr):
op = expr.op()
field_name = op.name
quoted_name = helpers.quote_identifier(field_name, force=True)

table = op.table
ctx = translator.context

# If the column does not originate from the table set in the current SELECT
# context, we should format as a subquery
if translator.permit_subquery and ctx.is_foreign_expr(table):
proj_expr = table.projection([field_name]).to_array()
return table_array_view(translator, proj_expr)

if ctx.need_aliases():
alias = ctx.get_ref(table)
if alias is not None:
quoted_name = f'{alias}.{quoted_name}'

return quoted_name


def exists_subquery(translator, expr):
op = expr.op()
ctx = translator.context

dummy = ir.literal(1).name(ir.unnamed)

filtered = op.foreign_table.filter(op.predicates)
expr = filtered.projection([dummy])

subquery = ctx.get_compiled_expr(expr)

if isinstance(op, ops.ExistsSubquery):
key = 'EXISTS'
elif isinstance(op, ops.NotExistsSubquery):
key = 'NOT EXISTS'
else:
raise NotImplementedError

return f'{key} (\n{util.indent(subquery, ctx.indent)}\n)'


# XXX this is not added to operation_registry, but looks like impala is
# using it in the tests, and it works, even if it's not imported anywhere
def round(translator, expr):
op = expr.op()
arg, digits = op.args

arg_formatted = translator.translate(arg)

if digits is not None:
digits_formatted = translator.translate(digits)
return f'round({arg_formatted}, {digits_formatted})'
return f'round({arg_formatted})'


# XXX this is not added to operation_registry, but looks like impala is
# using it in the tests, and it works, even if it's not imported anywhere
def hash(translator, expr):
op = expr.op()
arg, how = op.args

arg_formatted = translator.translate(arg)

if how == 'fnv':
return f'fnv_hash({arg_formatted})'
else:
raise NotImplementedError(how)


binary_infix_ops = {
# Binary operations
ops.Add: binary_infix.binary_infix_op('+'),
ops.Subtract: binary_infix.binary_infix_op('-'),
ops.Multiply: binary_infix.binary_infix_op('*'),
ops.Divide: binary_infix.binary_infix_op('/'),
ops.Power: fixed_arity('pow', 2),
ops.Modulus: binary_infix.binary_infix_op('%'),
# Comparisons
ops.Equals: binary_infix.binary_infix_op('='),
ops.NotEquals: binary_infix.binary_infix_op('!='),
ops.GreaterEqual: binary_infix.binary_infix_op('>='),
ops.Greater: binary_infix.binary_infix_op('>'),
ops.LessEqual: binary_infix.binary_infix_op('<='),
ops.Less: binary_infix.binary_infix_op('<'),
ops.IdenticalTo: binary_infix.identical_to,
# Boolean comparisons
ops.And: binary_infix.binary_infix_op('AND'),
ops.Or: binary_infix.binary_infix_op('OR'),
ops.Xor: binary_infix.xor,
}


operation_registry = {
# Unary operations
ops.NotNull: not_null,
ops.IsNull: is_null,
ops.Negate: negate,
ops.Not: not_,
ops.IsNan: unary('is_nan'),
ops.IsInf: unary('is_inf'),
ops.IfNull: ifnull_workaround,
ops.NullIf: fixed_arity('nullif', 2),
ops.ZeroIfNull: unary('zeroifnull'),
ops.NullIfZero: unary('nullifzero'),
ops.Abs: unary('abs'),
ops.BaseConvert: fixed_arity('conv', 3),
ops.Ceil: unary('ceil'),
ops.Floor: unary('floor'),
ops.Exp: unary('exp'),
ops.Round: round,
ops.Sign: sign,
ops.Sqrt: unary('sqrt'),
ops.Hash: hash,
ops.HashBytes: hashbytes,
ops.Log: log,
ops.Ln: unary('ln'),
ops.Log2: unary('log2'),
ops.Log10: unary('log10'),
ops.DecimalPrecision: unary('precision'),
ops.DecimalScale: unary('scale'),
# Unary aggregates
ops.CMSMedian: aggregate.reduction('appx_median'),
ops.HLLCardinality: aggregate.reduction('ndv'),
ops.Mean: aggregate.reduction('avg'),
ops.Sum: aggregate.reduction('sum'),
ops.Max: aggregate.reduction('max'),
ops.Min: aggregate.reduction('min'),
ops.StandardDev: aggregate.variance_like('stddev'),
ops.Variance: aggregate.variance_like('var'),
ops.GroupConcat: aggregate.reduction('group_concat'),
ops.Count: aggregate.reduction('count'),
ops.CountDistinct: aggregate.count_distinct,
# string operations
ops.StringLength: unary('length'),
ops.StringAscii: unary('ascii'),
ops.Lowercase: unary('lower'),
ops.Uppercase: unary('upper'),
ops.Reverse: unary('reverse'),
ops.Strip: unary('trim'),
ops.LStrip: unary('ltrim'),
ops.RStrip: unary('rtrim'),
ops.Capitalize: unary('initcap'),
ops.Substring: string.substring,
ops.StrRight: fixed_arity('strright', 2),
ops.Repeat: fixed_arity('repeat', 2),
ops.StringFind: string.string_find,
ops.Translate: fixed_arity('translate', 3),
ops.FindInSet: string.find_in_set,
ops.LPad: fixed_arity('lpad', 3),
ops.RPad: fixed_arity('rpad', 3),
ops.StringJoin: string.string_join,
ops.StringSQLLike: string.string_like,
ops.RegexSearch: fixed_arity('regexp_like', 2),
ops.RegexExtract: fixed_arity('regexp_extract', 3),
ops.RegexReplace: fixed_arity('regexp_replace', 3),
ops.ParseURL: string.parse_url,
ops.StartsWith: string.startswith,
ops.EndsWith: string.endswith,
# Timestamp operations
ops.Date: unary('to_date'),
ops.TimestampNow: lambda *args: 'now()',
ops.ExtractYear: timestamp.extract_field('year'),
ops.ExtractMonth: timestamp.extract_field('month'),
ops.ExtractDay: timestamp.extract_field('day'),
ops.ExtractQuarter: timestamp.extract_field('quarter'),
ops.ExtractEpochSeconds: timestamp.extract_epoch_seconds,
ops.ExtractWeekOfYear: fixed_arity('weekofyear', 1),
ops.ExtractHour: timestamp.extract_field('hour'),
ops.ExtractMinute: timestamp.extract_field('minute'),
ops.ExtractSecond: timestamp.extract_field('second'),
ops.ExtractMillisecond: timestamp.extract_field('millisecond'),
ops.TimestampTruncate: timestamp.truncate,
ops.DateTruncate: timestamp.truncate,
ops.IntervalFromInteger: timestamp.interval_from_integer,
# Other operations
ops.E: lambda *args: 'e()',
ops.Literal: literal,
ops.NullLiteral: null_literal,
ops.ValueList: value_list,
ops.Cast: cast,
ops.Coalesce: varargs('coalesce'),
ops.Greatest: varargs('greatest'),
ops.Least: varargs('least'),
ops.Where: fixed_arity('if', 3),
ops.Between: between,
ops.Contains: binary_infix.binary_infix_op('IN'),
ops.NotContains: binary_infix.binary_infix_op('NOT IN'),
ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.TableColumn: table_column,
ops.TableArrayView: table_array_view,
ops.DateAdd: timestamp.timestamp_op('date_add'),
ops.DateSub: timestamp.timestamp_op('date_sub'),
ops.DateDiff: timestamp.timestamp_op('datediff'),
ops.TimestampAdd: timestamp.timestamp_op('date_add'),
ops.TimestampSub: timestamp.timestamp_op('date_sub'),
ops.TimestampDiff: timestamp.timestamp_diff,
ops.TimestampFromUNIX: timestamp.timestamp_from_unix,
ops.ExistsSubquery: exists_subquery,
ops.NotExistsSubquery: exists_subquery,
# RowNumber, and rank functions starts with 0 in Ibis-land
ops.RowNumber: lambda *args: 'row_number()',
ops.DenseRank: lambda *args: 'dense_rank()',
ops.MinRank: lambda *args: 'rank()',
ops.PercentRank: lambda *args: 'percent_rank()',
ops.FirstValue: unary('first_value'),
ops.LastValue: unary('last_value'),
ops.NthValue: window.nth_value,
ops.Lag: window.shift_like('lag'),
ops.Lead: window.shift_like('lead'),
ops.WindowOp: window.window,
ops.NTile: window.ntile,
ops.DayOfWeekIndex: timestamp.day_of_week_index,
ops.DayOfWeekName: timestamp.day_of_week_name,
**binary_infix_ops,
}
100 changes: 100 additions & 0 deletions ibis/backends/base/sql/registry/string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import ibis.expr.operations as ops

from . import helpers


def substring(translator, expr):
op = expr.op()
arg, start, length = op.args
arg_formatted = translator.translate(arg)
start_formatted = translator.translate(start)

# Impala is 1-indexed
if length is None or isinstance(length.op(), ops.Literal):
lvalue = length.op().value if length is not None else None
if lvalue:
return 'substr({}, {} + 1, {})'.format(
arg_formatted, start_formatted, lvalue
)
else:
return f'substr({arg_formatted}, {start_formatted} + 1)'
else:
length_formatted = translator.translate(length)
return 'substr({}, {} + 1, {})'.format(
arg_formatted, start_formatted, length_formatted
)


def string_find(translator, expr):
op = expr.op()
arg, substr, start, _ = op.args
arg_formatted = translator.translate(arg)
substr_formatted = translator.translate(substr)

if start is not None and not isinstance(start.op(), ops.Literal):
start_fmt = translator.translate(start)
return 'locate({}, {}, {} + 1) - 1'.format(
substr_formatted, arg_formatted, start_fmt
)
elif start is not None and start.op().value:
sval = start.op().value
return 'locate({}, {}, {}) - 1'.format(
substr_formatted, arg_formatted, sval + 1
)
else:
return f'locate({substr_formatted}, {arg_formatted}) - 1'


def find_in_set(translator, expr):
op = expr.op()

arg, str_list = op.args
arg_formatted = translator.translate(arg)
str_formatted = ','.join([x._arg.value for x in str_list])
return f"find_in_set({arg_formatted}, '{str_formatted}') - 1"


def string_join(translator, expr):
op = expr.op()
arg, strings = op.args
return helpers.format_call(translator, 'concat_ws', arg, *strings)


def string_like(translator, expr):
arg, pattern, _ = expr.op().args
return '{} LIKE {}'.format(
translator.translate(arg), translator.translate(pattern)
)


def parse_url(translator, expr):
op = expr.op()

arg, extract, key = op.args
arg_formatted = translator.translate(arg)

if key is None:
return f"parse_url({arg_formatted}, '{extract}')"
else:
key_fmt = translator.translate(key)
return "parse_url({}, '{}', {})".format(
arg_formatted, extract, key_fmt
)


def startswith(translator, expr):
arg, start = expr.op().args

arg_formatted = translator.translate(arg)
start_formatted = translator.translate(start)

return f"{arg_formatted} like concat({start_formatted}, '%')"


def endswith(translator, expr):
arg, start = expr.op().args

arg_formatted = translator.translate(arg)
end_formatted = translator.translate(start)

return f"{arg_formatted} like concat('%', {end_formatted})"
107 changes: 107 additions & 0 deletions ibis/backends/base/sql/registry/timestamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import ibis.common.exceptions as com
import ibis.expr.types as ir
import ibis.util as util


def extract_field(sql_attr):
def extract_field_formatter(translator, expr):
op = expr.op()
arg = translator.translate(op.args[0])

# This is pre-2.0 Impala-style, which did not used to support the
# SQL-99 format extract($FIELD from expr)
return f"extract({arg}, '{sql_attr}')"

return extract_field_formatter


def extract_epoch_seconds(t, expr):
(arg,) = expr.op().args
return f'unix_timestamp({t.translate(arg)})'


def truncate(translator, expr):
base_unit_names = {
'Y': 'Y',
'Q': 'Q',
'M': 'MONTH',
'W': 'W',
'D': 'J',
'h': 'HH',
'm': 'MI',
}
op = expr.op()
arg, unit = op.args

arg_formatted = translator.translate(arg)
try:
unit = base_unit_names[unit]
except KeyError:
raise com.UnsupportedOperationError(
f'{unit!r} unit is not supported in timestamp truncate'
)

return f"trunc({arg_formatted}, '{unit}')"


def interval_from_integer(translator, expr):
# interval cannot be selected from impala
op = expr.op()
arg, unit = op.args
arg_formatted = translator.translate(arg)

return 'INTERVAL {} {}'.format(
arg_formatted, expr.type().resolution.upper()
)


def timestamp_op(func):
def _formatter(translator, expr):
op = expr.op()
left, right = op.args
formatted_left = translator.translate(left)
formatted_right = translator.translate(right)

if isinstance(left, (ir.TimestampScalar, ir.DateValue)):
formatted_left = f'cast({formatted_left} as timestamp)'

if isinstance(right, (ir.TimestampScalar, ir.DateValue)):
formatted_right = f'cast({formatted_right} as timestamp)'

return f'{func}({formatted_left}, {formatted_right})'

return _formatter


def timestamp_diff(translator, expr):
op = expr.op()
left, right = op.args

return 'unix_timestamp({}) - unix_timestamp({})'.format(
translator.translate(left), translator.translate(right)
)


def _from_unixtime(translator, expr):
arg = translator.translate(expr)
return f'from_unixtime({arg}, "yyyy-MM-dd HH:mm:ss")'


def timestamp_from_unix(translator, expr):
op = expr.op()

val, unit = op.args
val = util.convert_unit(val, unit, 's').cast('int32')

arg = _from_unixtime(translator, val)
return f'CAST({arg} AS timestamp)'


def day_of_week_index(t, expr):
(arg,) = expr.op().args
return f'pmod(dayofweek({t.translate(arg)}) - 2, 7)'


def day_of_week_name(t, expr):
(arg,) = expr.op().args
return f'dayname({t.translate(arg)})'
335 changes: 335 additions & 0 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
from operator import add, mul, sub
from typing import Optional, Union

import ibis
import ibis.common.exceptions as com
import ibis.expr.analysis as L
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.expr.signature import Argument

_map_interval_to_microseconds = {
'W': 604800000000,
'D': 86400000000,
'h': 3600000000,
'm': 60000000,
's': 1000000,
'ms': 1000,
'us': 1,
'ns': 0.001,
}


_map_interval_op_to_op = {
# Literal Intervals have two args, i.e.
# Literal(1, Interval(value_type=int8, unit='D', nullable=True))
# Parse both args and multipy 1 * _map_interval_to_microseconds['D']
ops.Literal: mul,
ops.IntervalMultiply: mul,
ops.IntervalAdd: add,
ops.IntervalSubtract: sub,
}


_cumulative_to_reduction = {
ops.CumulativeSum: ops.Sum,
ops.CumulativeMin: ops.Min,
ops.CumulativeMax: ops.Max,
ops.CumulativeMean: ops.Mean,
ops.CumulativeAny: ops.Any,
ops.CumulativeAll: ops.All,
}


def _replace_interval_with_scalar(
expr: Union[ir.Expr, dt.Interval, float]
) -> Union[ir.Expr, float, Argument]:
"""
Good old Depth-First Search to identify the Interval and IntervalValue
components of the expression and return a comparable scalar expression.
Parameters
----------
expr : float or expression of intervals
For example, ``ibis.interval(days=1) + ibis.interval(hours=5)``
Returns
-------
preceding : float or ir.FloatingScalar, depending upon the expr
"""
if isinstance(expr, ir.Expr):
expr_op = expr.op()
else:
expr_op = None

if not isinstance(expr, (dt.Interval, ir.IntervalValue)):
# Literal expressions have op method but native types do not.
if isinstance(expr_op, ops.Literal):
return expr_op.value
else:
return expr
elif isinstance(expr, dt.Interval):
try:
microseconds = _map_interval_to_microseconds[expr.unit]
return microseconds
except KeyError:
raise ValueError(
"Expected preceding values of week(), "
+ "day(), hour(), minute(), second(), millisecond(), "
+ f"microseconds(), nanoseconds(); got {expr}"
)
elif expr_op.args and isinstance(expr, ir.IntervalValue):
if len(expr_op.args) > 2:
raise NotImplementedError("'preceding' argument cannot be parsed.")
left_arg = _replace_interval_with_scalar(expr_op.args[0])
right_arg = _replace_interval_with_scalar(expr_op.args[1])
method = _map_interval_op_to_op[type(expr_op)]
return method(left_arg, right_arg)
else:
raise TypeError(f'expr has unknown type {type(expr).__name__}')


def cumulative_to_window(translator, expr, window):
win = ibis.cumulative_window()
win = win.group_by(window._group_by).order_by(window._order_by)

op = expr.op()

klass = _cumulative_to_reduction[type(op)]
new_op = klass(*op.args)
new_expr = expr._factory(new_op, name=expr._name)

if type(new_op) in translator._rewrites:
new_expr = translator._rewrites[type(new_op)](new_expr)

new_expr = L.windowize_function(new_expr, win)
return new_expr


def time_range_to_range_window(translator, window):
# Check that ORDER BY column is a single time column:
order_by_vars = [x.op().args[0] for x in window._order_by]
if len(order_by_vars) > 1:
raise com.IbisInputError(
f"Expected 1 order-by variable, got {len(order_by_vars)}"
)

order_var = window._order_by[0].op().args[0]
timestamp_order_var = order_var.cast('int64')
window = window._replace(order_by=timestamp_order_var, how='range')

# Need to change preceding interval expression to scalars
preceding = window.preceding
if isinstance(preceding, ir.IntervalScalar):
new_preceding = _replace_interval_with_scalar(preceding)
window = window._replace(preceding=new_preceding)

return window


def format_window(translator, op, window):
components = []

if window.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented '
'for Impala-based backends.'
)

if len(window._group_by) > 0:
partition_args = [translator.translate(x) for x in window._group_by]
components.append('PARTITION BY {}'.format(', '.join(partition_args)))

if len(window._order_by) > 0:
order_args = []
for expr in window._order_by:
key = expr.op()
translated = translator.translate(key.expr)
if not key.ascending:
translated += ' DESC'
order_args.append(translated)

components.append('ORDER BY {}'.format(', '.join(order_args)))

p, f = window.preceding, window.following

def _prec(p: Optional[int]) -> str:
assert p is None or p >= 0

if p is None:
prefix = 'UNBOUNDED'
else:
if not p:
return 'CURRENT ROW'
prefix = str(p)
return f'{prefix} PRECEDING'

def _foll(f: Optional[int]) -> str:
assert f is None or f >= 0

if f is None:
prefix = 'UNBOUNDED'
else:
if not f:
return 'CURRENT ROW'
prefix = str(f)

return f'{prefix} FOLLOWING'

frame_clause_not_allowed = (
ops.Lag,
ops.Lead,
ops.DenseRank,
ops.MinRank,
ops.NTile,
ops.PercentRank,
ops.RowNumber,
)

if isinstance(op.expr.op(), frame_clause_not_allowed):
frame = None
elif p is not None and f is not None:
frame = '{} BETWEEN {} AND {}'.format(
window.how.upper(), _prec(p), _foll(f)
)

elif p is not None:
if isinstance(p, tuple):
start, end = p
frame = '{} BETWEEN {} AND {}'.format(
window.how.upper(), _prec(start), _prec(end)
)
else:
kind = 'ROWS' if p > 0 else 'RANGE'
frame = '{} BETWEEN {} AND UNBOUNDED FOLLOWING'.format(
kind, _prec(p)
)
elif f is not None:
if isinstance(f, tuple):
start, end = f
frame = '{} BETWEEN {} AND {}'.format(
window.how.upper(), _foll(start), _foll(end)
)
else:
kind = 'ROWS' if f > 0 else 'RANGE'
frame = '{} BETWEEN UNBOUNDED PRECEDING AND {}'.format(
kind, _foll(f)
)
else:
# no-op, default is full sample
frame = None

if frame is not None:
components.append(frame)

return 'OVER ({})'.format(' '.join(components))


_subtract_one = '({} - 1)'.format


_expr_transforms = {
ops.RowNumber: _subtract_one,
ops.DenseRank: _subtract_one,
ops.MinRank: _subtract_one,
ops.NTile: _subtract_one,
}


def window(translator, expr):
op = expr.op()

arg, window = op.args
window_op = arg.op()

_require_order_by = (
ops.Lag,
ops.Lead,
ops.DenseRank,
ops.MinRank,
ops.FirstValue,
ops.LastValue,
ops.PercentRank,
ops.NTile,
)

_unsupported_reductions = (
ops.CMSMedian,
ops.GroupConcat,
ops.HLLCardinality,
)

if isinstance(window_op, _unsupported_reductions):
raise com.UnsupportedOperationError(
f'{type(window_op)} is not supported in window functions'
)

if isinstance(window_op, ops.CumulativeOp):
arg = cumulative_to_window(translator, arg, window)
return translator.translate(arg)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
if isinstance(window_op, _require_order_by) and len(window._order_by) == 0:
window = window.order_by(window_op.args[0])

# Time ranges need to be converted to microseconds.
if window.how == 'range':
order_by_types = [type(x.op().args[0]) for x in window._order_by]
time_range_types = (ir.TimeColumn, ir.DateColumn, ir.TimestampColumn)
if any(col_type in time_range_types for col_type in order_by_types):
window = time_range_to_range_window(translator, window)

window_formatted = format_window(translator, op, window)

arg_formatted = translator.translate(arg)
result = f'{arg_formatted} {window_formatted}'

if type(window_op) in _expr_transforms:
return _expr_transforms[type(window_op)](result)
else:
return result


def nth_value(translator, expr):
op = expr.op()
arg, rank = op.args

arg_formatted = translator.translate(arg)
rank_formatted = translator.translate(rank - 1)

return f'first_value(lag({arg_formatted}, {rank_formatted}))'


def shift_like(name):
def formatter(translator, expr):
op = expr.op()
arg, offset, default = op.args

arg_formatted = translator.translate(arg)

if default is not None:
if offset is None:
offset_formatted = '1'
else:
offset_formatted = translator.translate(offset)

default_formatted = translator.translate(default)

return '{}({}, {}, {})'.format(
name, arg_formatted, offset_formatted, default_formatted
)
elif offset is not None:
offset_formatted = translator.translate(offset)
return f'{name}({arg_formatted}, {offset_formatted})'
else:
return f'{name}({arg_formatted})'

return formatter


def ntile(translator, expr):
op = expr.op()
arg, buckets = map(translator.translate, op.args)
return f'ntile({buckets})'
127 changes: 0 additions & 127 deletions ibis/backends/base_file/__init__.py

This file was deleted.

Loading