| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,336 @@ | ||
| import functools | ||
|
|
||
| import sqlalchemy as sa | ||
| import sqlalchemy.sql as sql | ||
|
|
||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
| from ibis.backends.base.sql.compiler import ( | ||
| Compiler, | ||
| Select, | ||
| SelectBuilder, | ||
| TableSetFormatter, | ||
| Union, | ||
| ) | ||
|
|
||
| from .database import AlchemyTable | ||
| from .datatypes import to_sqla_type | ||
| from .translator import AlchemyContext, AlchemyExprTranslator | ||
|
|
||
|
|
||
| class _AlchemyTableSetFormatter(TableSetFormatter): | ||
| def get_result(self): | ||
| # Got to unravel the join stack; the nesting order could be | ||
| # arbitrary, so we do a depth first search and push the join tokens | ||
| # and predicates onto a flat list, then format them | ||
| op = self.expr.op() | ||
|
|
||
| if isinstance(op, ops.Join): | ||
| self._walk_join_tree(op) | ||
| else: | ||
| self.join_tables.append(self._format_table(self.expr)) | ||
|
|
||
| result = self.join_tables[0] | ||
| for jtype, table, preds in zip( | ||
| self.join_types, self.join_tables[1:], self.join_predicates | ||
| ): | ||
| if len(preds): | ||
| sqla_preds = [self._translate(pred) for pred in preds] | ||
| onclause = functools.reduce(sql.and_, sqla_preds) | ||
| else: | ||
| onclause = None | ||
|
|
||
| if jtype in (ops.InnerJoin, ops.CrossJoin): | ||
| result = result.join(table, onclause) | ||
| elif jtype is ops.LeftJoin: | ||
| result = result.join(table, onclause, isouter=True) | ||
| elif jtype is ops.RightJoin: | ||
| result = table.join(result, onclause, isouter=True) | ||
| elif jtype is ops.OuterJoin: | ||
| result = result.outerjoin(table, onclause, full=True) | ||
| elif jtype is ops.LeftSemiJoin: | ||
| result = sa.select([result]).where( | ||
| sa.exists(sa.select([1]).where(onclause)) | ||
| ) | ||
| elif jtype is ops.LeftAntiJoin: | ||
| result = sa.select([result]).where( | ||
| ~(sa.exists(sa.select([1]).where(onclause))) | ||
| ) | ||
| else: | ||
| raise NotImplementedError(jtype) | ||
|
|
||
| return result | ||
|
|
||
| def _get_join_type(self, op): | ||
| return type(op) | ||
|
|
||
| def _format_table(self, expr): | ||
| ctx = self.context | ||
| ref_expr = expr | ||
| op = ref_op = expr.op() | ||
|
|
||
| if isinstance(op, ops.SelfReference): | ||
| ref_expr = op.table | ||
| ref_op = ref_expr.op() | ||
|
|
||
| alias = ctx.get_ref(expr) | ||
|
|
||
| if isinstance(ref_op, AlchemyTable): | ||
| result = ref_op.sqla_table | ||
| elif isinstance(ref_op, ops.UnboundTable): | ||
| # use SQLAlchemy's TableClause and ColumnClause for unbound tables | ||
| schema = ref_op.schema | ||
| result = sa.table( | ||
| ref_op.name if ref_op.name is not None else ctx.get_ref(expr), | ||
| *( | ||
| sa.column(n, to_sqla_type(t)) | ||
| for n, t in zip(schema.names, schema.types) | ||
| ), | ||
| ) | ||
| else: | ||
| # A subquery | ||
| if ctx.is_extracted(ref_expr): | ||
| # Was put elsewhere, e.g. WITH block, we just need to grab | ||
| # its alias | ||
| alias = ctx.get_ref(expr) | ||
|
|
||
| # hack | ||
| if isinstance(op, ops.SelfReference): | ||
| table = ctx.get_table(ref_expr) | ||
| self_ref = table.alias(alias) | ||
| ctx.set_table(expr, self_ref) | ||
| return self_ref | ||
| else: | ||
| return ctx.get_table(expr) | ||
|
|
||
| result = ctx.get_compiled_expr(expr) | ||
| alias = ctx.get_ref(expr) | ||
|
|
||
| result = result.alias(alias) | ||
| ctx.set_table(expr, result) | ||
| return result | ||
|
|
||
|
|
||
| def _can_lower_sort_column(table_set, expr): | ||
| # TODO(wesm): This code is pending removal through cleaner internal | ||
| # semantics | ||
|
|
||
| # we can currently sort by just-appeared aggregate metrics, but the way | ||
| # these are references in the expression DSL is as a SortBy (blocking | ||
| # table operation) on an aggregation. There's a hack in _collect_SortBy | ||
| # in the generic SQL compiler that "fuses" the sort with the | ||
| # aggregation so they appear in same query. It's generally for | ||
| # cosmetics and doesn't really affect query semantics. | ||
| bases = ops.find_all_base_tables(expr) | ||
| if len(bases) > 1: | ||
| return False | ||
|
|
||
| base = list(bases.values())[0] | ||
| base_op = base.op() | ||
|
|
||
| if isinstance(base_op, ops.Aggregation): | ||
| return base_op.table.equals(table_set) | ||
| elif isinstance(base_op, ops.Selection): | ||
| return base.equals(table_set) | ||
| else: | ||
| return False | ||
|
|
||
|
|
||
| class AlchemySelect(Select): | ||
| def __init__(self, *args, **kwargs): | ||
| self.exists = kwargs.pop('exists', False) | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def compile(self): | ||
| # Can't tell if this is a hack or not. Revisit later | ||
| self.context.set_query(self) | ||
|
|
||
| self._compile_subqueries() | ||
|
|
||
| frag = self._compile_table_set() | ||
| steps = [ | ||
| self._add_select, | ||
| self._add_groupby, | ||
| self._add_where, | ||
| self._add_order_by, | ||
| self._add_limit, | ||
| ] | ||
|
|
||
| for step in steps: | ||
| frag = step(frag) | ||
|
|
||
| return frag | ||
|
|
||
| def _compile_subqueries(self): | ||
| if not self.subqueries: | ||
| return | ||
|
|
||
| for expr in self.subqueries: | ||
| result = self.context.get_compiled_expr(expr) | ||
| alias = self.context.get_ref(expr) | ||
| result = result.cte(alias) | ||
| self.context.set_table(expr, result) | ||
|
|
||
| def _compile_table_set(self): | ||
| if self.table_set is not None: | ||
| helper = _AlchemyTableSetFormatter(self, self.table_set) | ||
| result = helper.get_result() | ||
| if isinstance(result, sql.selectable.Select) and hasattr( | ||
| result, 'subquery' | ||
| ): | ||
| return result.subquery() | ||
| return result | ||
| else: | ||
| return None | ||
|
|
||
| def _add_select(self, table_set): | ||
| to_select = [] | ||
|
|
||
| has_select_star = False | ||
| for expr in self.select_set: | ||
| if isinstance(expr, ir.ValueExpr): | ||
| arg = self._translate(expr, named=True) | ||
| elif isinstance(expr, ir.TableExpr): | ||
| if expr.equals(self.table_set): | ||
| cached_table = self.context.get_table(expr) | ||
| if cached_table is None: | ||
| # the select * case from materialized join | ||
| has_select_star = True | ||
| continue | ||
| else: | ||
| arg = table_set | ||
| else: | ||
| arg = self.context.get_table(expr) | ||
| if arg is None: | ||
| raise ValueError(expr) | ||
|
|
||
| to_select.append(arg) | ||
|
|
||
| if has_select_star: | ||
| if table_set is None: | ||
| raise ValueError('table_set cannot be None here') | ||
|
|
||
| clauses = [table_set] + to_select | ||
| else: | ||
| clauses = to_select | ||
|
|
||
| if self.exists: | ||
| result = sa.exists(clauses) | ||
| else: | ||
| result = sa.select(clauses) | ||
|
|
||
| if self.distinct: | ||
| result = result.distinct() | ||
|
|
||
| if not has_select_star: | ||
| if table_set is not None: | ||
| return result.select_from(table_set) | ||
| else: | ||
| return result | ||
| else: | ||
| return result | ||
|
|
||
| def _add_groupby(self, fragment): | ||
| # GROUP BY and HAVING | ||
| if not len(self.group_by): | ||
| return fragment | ||
|
|
||
| group_keys = [self._translate(arg) for arg in self.group_by] | ||
| fragment = fragment.group_by(*group_keys) | ||
|
|
||
| if len(self.having) > 0: | ||
| having_args = [self._translate(arg) for arg in self.having] | ||
| having_clause = functools.reduce(sql.and_, having_args) | ||
| fragment = fragment.having(having_clause) | ||
|
|
||
| return fragment | ||
|
|
||
| def _add_where(self, fragment): | ||
| if not len(self.where): | ||
| return fragment | ||
|
|
||
| args = [ | ||
| self._translate(pred, permit_subquery=True) for pred in self.where | ||
| ] | ||
| clause = functools.reduce(sql.and_, args) | ||
| return fragment.where(clause) | ||
|
|
||
| def _add_order_by(self, fragment): | ||
| if not len(self.order_by): | ||
| return fragment | ||
|
|
||
| clauses = [] | ||
| for expr in self.order_by: | ||
| key = expr.op() | ||
| sort_expr = key.expr | ||
|
|
||
| # here we have to determine if key.expr is in the select set (as it | ||
| # will be in the case of order_by fused with an aggregation | ||
| if _can_lower_sort_column(self.table_set, sort_expr): | ||
| arg = sort_expr.get_name() | ||
| else: | ||
| arg = self._translate(sort_expr) | ||
|
|
||
| if not key.ascending: | ||
| arg = sa.desc(arg) | ||
|
|
||
| clauses.append(arg) | ||
|
|
||
| return fragment.order_by(*clauses) | ||
|
|
||
| def _among_select_set(self, expr): | ||
| for other in self.select_set: | ||
| if expr.equals(other): | ||
| return True | ||
| return False | ||
|
|
||
| def _add_limit(self, fragment): | ||
| if self.limit is None: | ||
| return fragment | ||
|
|
||
| n, offset = self.limit['n'], self.limit['offset'] | ||
| fragment = fragment.limit(n) | ||
| if offset is not None and offset != 0: | ||
| fragment = fragment.offset(offset) | ||
|
|
||
| return fragment | ||
|
|
||
|
|
||
| class AlchemySelectBuilder(SelectBuilder): | ||
| def _convert_group_by(self, exprs): | ||
| return exprs | ||
|
|
||
|
|
||
| class AlchemyUnion(Union): | ||
| def compile(self): | ||
| def reduce_union(left, right, distincts=iter(self.distincts)): | ||
| distinct = next(distincts) | ||
| sa_func = sa.union if distinct else sa.union_all | ||
| return sa_func(left, right) | ||
|
|
||
| context = self.context | ||
| selects = [] | ||
|
|
||
| for table in self.tables: | ||
| table_set = context.get_compiled_expr(table) | ||
| selects.append(table_set.cte().select()) | ||
|
|
||
| return functools.reduce(reduce_union, selects) | ||
|
|
||
|
|
||
| class AlchemyCompiler(Compiler): | ||
| translator_class = AlchemyExprTranslator | ||
| context_class = AlchemyContext | ||
| table_set_formatter_class = _AlchemyTableSetFormatter | ||
| select_builder_class = AlchemySelectBuilder | ||
| select_class = AlchemySelect | ||
| union_class = AlchemyUnion | ||
|
|
||
| @classmethod | ||
| def to_sql(cls, expr, context=None, params=None, exists=False): | ||
| if context is None: | ||
| context = cls.make_context(params=params) | ||
| query = cls.to_ast(expr, context).queries[0] | ||
| if exists: | ||
| query.exists = True | ||
| return query.compile() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| import ibis | ||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
| from ibis import util | ||
| from ibis.backends.base.sql.compiler import ExprTranslator, QueryContext | ||
|
|
||
| from .datatypes import ibis_type_to_sqla, to_sqla_type | ||
| from .registry import fixed_arity, sqlalchemy_operation_registry | ||
|
|
||
|
|
||
| class AlchemyContext(QueryContext): | ||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| self._table_objects = {} | ||
|
|
||
| def collapse(self, queries): | ||
| if isinstance(queries, str): | ||
| return queries | ||
|
|
||
| if len(queries) > 1: | ||
| raise NotImplementedError( | ||
| 'Only a single query is supported for SQLAlchemy backends' | ||
| ) | ||
| return queries[0] | ||
|
|
||
| def subcontext(self): | ||
| return type(self)( | ||
| compiler=self.compiler, parent=self, params=self.params | ||
| ) | ||
|
|
||
| def _compile_subquery(self, expr): | ||
| sub_ctx = self.subcontext() | ||
| return self._to_sql(expr, sub_ctx) | ||
|
|
||
| def has_table(self, expr, parent_contexts=False): | ||
| key = self._get_table_key(expr) | ||
| return self._key_in( | ||
| key, '_table_objects', parent_contexts=parent_contexts | ||
| ) | ||
|
|
||
| def set_table(self, expr, obj): | ||
| key = self._get_table_key(expr) | ||
| self._table_objects[key] = obj | ||
|
|
||
| def get_table(self, expr): | ||
| """ | ||
| Get the memoized SQLAlchemy expression object | ||
| """ | ||
| return self._get_table_item('_table_objects', expr) | ||
|
|
||
|
|
||
| class AlchemyExprTranslator(ExprTranslator): | ||
|
|
||
| _registry = sqlalchemy_operation_registry | ||
| _rewrites = ExprTranslator._rewrites.copy() | ||
| _type_map = ibis_type_to_sqla | ||
|
|
||
| context_class = AlchemyContext | ||
|
|
||
| def name(self, translated, name, force=True): | ||
| if hasattr(translated, 'label'): | ||
| return translated.label(name) | ||
| return translated | ||
|
|
||
| def get_sqla_type(self, data_type): | ||
| return to_sqla_type(data_type, type_map=self._type_map) | ||
|
|
||
|
|
||
| rewrites = AlchemyExprTranslator.rewrites | ||
|
|
||
|
|
||
| @rewrites(ops.NullIfZero) | ||
| def _nullifzero(expr): | ||
| arg = expr.op().args[0] | ||
| return (arg == 0).ifelse(ibis.NA, arg) | ||
|
|
||
|
|
||
| # TODO This was previously implemented with the legacy `@compiles` decorator. | ||
| # This definition should now be in the registry, but there is some magic going | ||
| # on that things fail if it's not defined here (and in the registry | ||
| # `operator.truediv` is used. | ||
| def _true_divide(t, expr): | ||
| op = expr.op() | ||
| left, right = args = op.args | ||
|
|
||
| if util.all_of(args, ir.IntegerValue): | ||
| return t.translate(left.div(right.cast('double'))) | ||
|
|
||
| return fixed_arity(lambda x, y: x / y, 2)(t, expr) | ||
|
|
||
|
|
||
| AlchemyExprTranslator._registry[ops.Divide] = _true_divide |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| from .base import DDL, DML | ||
| from .query_builder import ( | ||
| Compiler, | ||
| Select, | ||
| SelectBuilder, | ||
| TableSetFormatter, | ||
| Union, | ||
| ) | ||
| from .translator import ExprTranslator, QueryContext | ||
|
|
||
| __all__ = ( | ||
| 'Compiler', | ||
| 'Select', | ||
| 'SelectBuilder', | ||
| 'Union', | ||
| 'TableSetFormatter', | ||
| 'ExprTranslator', | ||
| 'QueryContext', | ||
| 'DML', | ||
| 'DDL', | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| import abc | ||
| from itertools import chain | ||
|
|
||
| import toolz | ||
|
|
||
| import ibis.util as util | ||
|
|
||
| from .extract_subqueries import ExtractSubqueries | ||
|
|
||
|
|
||
| class DML(abc.ABC): | ||
| @abc.abstractmethod | ||
| def compile(self): | ||
| pass | ||
|
|
||
|
|
||
| class DDL(abc.ABC): | ||
| @abc.abstractmethod | ||
| def compile(self): | ||
| pass | ||
|
|
||
|
|
||
| class QueryAST: | ||
|
|
||
| __slots__ = 'context', 'dml', 'setup_queries', 'teardown_queries' | ||
|
|
||
| def __init__( | ||
| self, context, dml, setup_queries=None, teardown_queries=None | ||
| ): | ||
| self.context = context | ||
| self.dml = dml | ||
| self.setup_queries = setup_queries | ||
| self.teardown_queries = teardown_queries | ||
|
|
||
| @property | ||
| def queries(self): | ||
| return [self.dml] | ||
|
|
||
| def compile(self): | ||
| compiled_setup_queries = [q.compile() for q in self.setup_queries] | ||
| compiled_queries = [q.compile() for q in self.queries] | ||
| compiled_teardown_queries = [ | ||
| q.compile() for q in self.teardown_queries | ||
| ] | ||
| return self.context.collapse( | ||
| list( | ||
| chain( | ||
| compiled_setup_queries, | ||
| compiled_queries, | ||
| compiled_teardown_queries, | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| class SetOp(DML): | ||
| def __init__(self, tables, expr, context): | ||
| self.context = context | ||
| self.tables = tables | ||
| self.table_set = expr | ||
| self.filters = [] | ||
|
|
||
| def _extract_subqueries(self): | ||
| self.subqueries = ExtractSubqueries.extract(self) | ||
| for subquery in self.subqueries: | ||
| self.context.set_extracted(subquery) | ||
|
|
||
| def format_subqueries(self): | ||
| context = self.context | ||
| subqueries = self.subqueries | ||
|
|
||
| return ',\n'.join( | ||
| '{} AS (\n{}\n)'.format( | ||
| context.get_ref(expr), | ||
| util.indent(context.get_compiled_expr(expr), 2), | ||
| ) | ||
| for expr in subqueries | ||
| ) | ||
|
|
||
| def format_relation(self, expr): | ||
| ref = self.context.get_ref(expr) | ||
| if ref is not None: | ||
| return f'SELECT *\nFROM {ref}' | ||
| return self.context.get_compiled_expr(expr) | ||
|
|
||
| def _get_keyword_list(self): | ||
| raise NotImplementedError("Need objects to interleave") | ||
|
|
||
| def compile(self): | ||
| self._extract_subqueries() | ||
|
|
||
| extracted = self.format_subqueries() | ||
|
|
||
| buf = [] | ||
|
|
||
| if extracted: | ||
| buf.append(f'WITH {extracted}') | ||
|
|
||
| buf.extend( | ||
| toolz.interleave( | ||
| ( | ||
| map(self.format_relation, self.tables), | ||
| self._get_keyword_list(), | ||
| ) | ||
| ) | ||
| ) | ||
| return '\n'.join(buf) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| from collections import OrderedDict | ||
|
|
||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
|
|
||
|
|
||
| class ExtractSubqueries: | ||
| def __init__(self, query, greedy=False): | ||
| self.query = query | ||
| self.greedy = greedy | ||
| self.expr_counts = OrderedDict() | ||
| self.node_to_expr = {} | ||
|
|
||
| @classmethod | ||
| def extract(cls, select_stmt): | ||
| helper = cls(select_stmt) | ||
| return helper.get_result() | ||
|
|
||
| def get_result(self): | ||
| if self.query.table_set is not None: | ||
| self.visit(self.query.table_set) | ||
|
|
||
| for clause in self.query.filters: | ||
| self.visit(clause) | ||
|
|
||
| expr_counts = self.expr_counts | ||
|
|
||
| if self.greedy: | ||
| to_extract = list(expr_counts.keys()) | ||
| else: | ||
| to_extract = [op for op, count in expr_counts.items() if count > 1] | ||
|
|
||
| node_to_expr = self.node_to_expr | ||
| return [node_to_expr[op] for op in to_extract] | ||
|
|
||
| def observe(self, expr): | ||
| key = expr.op() | ||
|
|
||
| if key not in self.node_to_expr: | ||
| self.node_to_expr[key] = expr | ||
|
|
||
| assert self.node_to_expr[key].equals(expr) | ||
| self.expr_counts[key] = self.expr_counts.setdefault(key, 0) + 1 | ||
|
|
||
| def seen(self, expr): | ||
| return expr.op() in self.expr_counts | ||
|
|
||
| def visit(self, expr): | ||
| node = expr.op() | ||
| method = f'visit_{type(node).__name__}' | ||
|
|
||
| if hasattr(self, method): | ||
| f = getattr(self, method) | ||
| f(expr) | ||
| elif isinstance(node, ops.Join): | ||
| self.visit_join(expr) | ||
| elif isinstance(node, ops.PhysicalTable): | ||
| self.visit_physical_table(expr) | ||
| elif isinstance(node, ops.ValueOp): | ||
| for arg in node.flat_args(): | ||
| if not isinstance(arg, ir.Expr): | ||
| continue | ||
| self.visit(arg) | ||
| else: | ||
| raise NotImplementedError(type(node)) | ||
|
|
||
| def visit_join(self, expr): | ||
| node = expr.op() | ||
| self.visit(node.left) | ||
| self.visit(node.right) | ||
|
|
||
| def visit_physical_table(self, _): | ||
| return | ||
|
|
||
| def visit_Exists(self, expr): | ||
| node = expr.op() | ||
| self.visit(node.foreign_table) | ||
| for pred in node.predicates: | ||
| self.visit(pred) | ||
|
|
||
| visit_NotExistsSubquery = visit_ExistsSubquery = visit_Exists | ||
|
|
||
| def visit_Aggregation(self, expr): | ||
| self.visit(expr.op().table) | ||
| self.observe(expr) | ||
|
|
||
| def visit_Distinct(self, expr): | ||
| self.observe(expr) | ||
|
|
||
| def visit_Limit(self, expr): | ||
| self.visit(expr.op().table) | ||
| self.observe(expr) | ||
|
|
||
| def visit_Union(self, expr): | ||
| op = expr.op() | ||
| self.visit(op.left) | ||
| self.visit(op.right) | ||
| self.observe(expr) | ||
|
|
||
| def visit_Intersection(self, expr): | ||
| op = expr.op() | ||
| self.visit(op.left) | ||
| self.visit(op.right) | ||
| self.observe(expr) | ||
|
|
||
| def visit_Difference(self, expr): | ||
| op = expr.op() | ||
| self.visit(op.left) | ||
| self.visit(op.right) | ||
| self.observe(expr) | ||
|
|
||
| def visit_MaterializedJoin(self, expr): | ||
| self.visit(expr.op().join) | ||
| self.observe(expr) | ||
|
|
||
| def visit_Selection(self, expr): | ||
| self.visit(expr.op().table) | ||
| self.observe(expr) | ||
|
|
||
| def visit_SQLQueryResult(self, expr): | ||
| self.observe(expr) | ||
|
|
||
| def visit_TableColumn(self, expr): | ||
| table = expr.op().table | ||
| if not self.seen(table): | ||
| self.visit(table) | ||
|
|
||
| def visit_SelfReference(self, expr): | ||
| self.visit(expr.op().table) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,391 @@ | ||
| import operator | ||
| from typing import Callable, Dict | ||
|
|
||
| import ibis | ||
| import ibis.common.exceptions as com | ||
| import ibis.expr.analytics as analytics | ||
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.format as fmt | ||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
| from ibis.backends.base.sql.registry import ( | ||
| operation_registry, | ||
| quote_identifier, | ||
| ) | ||
|
|
||
|
|
||
| class QueryContext: | ||
|
|
||
| """Records bits of information used during ibis AST to SQL translation. | ||
| Notably, table aliases (for subquery construction) and scalar query | ||
| parameters are tracked here. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, compiler, indent=2, parent=None, memo=None, params=None | ||
| ): | ||
| self.compiler = compiler | ||
| self._table_refs = {} | ||
| self.extracted_subexprs = set() | ||
| self.subquery_memo = {} | ||
| self.indent = indent | ||
| self.parent = parent | ||
|
|
||
| self.always_alias = False | ||
|
|
||
| self.query = None | ||
|
|
||
| self._table_key_memo = {} | ||
| self.memo = memo or fmt.FormatMemo() | ||
| self.params = params if params is not None else {} | ||
|
|
||
| def _compile_subquery(self, expr): | ||
| sub_ctx = self.subcontext() | ||
| return self._to_sql(expr, sub_ctx) | ||
|
|
||
| def _to_sql(self, expr, ctx): | ||
| return self.compiler.to_sql(expr, ctx) | ||
|
|
||
| def collapse(self, queries): | ||
| """Turn a sequence of queries into something executable. | ||
| Parameters | ||
| ---------- | ||
| queries : List[str] | ||
| Returns | ||
| ------- | ||
| query : str | ||
| """ | ||
| return '\n\n'.join(queries) | ||
|
|
||
| @property | ||
| def top_context(self): | ||
| if self.parent is None: | ||
| return self | ||
| else: | ||
| return self.parent.top_context | ||
|
|
||
| def set_always_alias(self): | ||
| self.always_alias = True | ||
|
|
||
| def get_compiled_expr(self, expr): | ||
| this = self.top_context | ||
|
|
||
| key = self._get_table_key(expr) | ||
| if key in this.subquery_memo: | ||
| return this.subquery_memo[key] | ||
|
|
||
| op = expr.op() | ||
| if isinstance(op, ops.SQLQueryResult): | ||
| result = op.query | ||
| else: | ||
| result = self._compile_subquery(expr) | ||
|
|
||
| this.subquery_memo[key] = result | ||
| return result | ||
|
|
||
| def make_alias(self, expr): | ||
| i = len(self._table_refs) | ||
|
|
||
| key = self._get_table_key(expr) | ||
|
|
||
| # Get total number of aliases up and down the tree at this point; if we | ||
| # find the table prior-aliased along the way, however, we reuse that | ||
| # alias | ||
| ctx = self | ||
| while ctx.parent is not None: | ||
| ctx = ctx.parent | ||
|
|
||
| if key in ctx._table_refs: | ||
| alias = ctx._table_refs[key] | ||
| self.set_ref(expr, alias) | ||
| return | ||
|
|
||
| i += len(ctx._table_refs) | ||
|
|
||
| alias = f't{i:d}' | ||
| self.set_ref(expr, alias) | ||
|
|
||
| def need_aliases(self, expr=None): | ||
| return self.always_alias or len(self._table_refs) > 1 | ||
|
|
||
| def has_ref(self, expr, parent_contexts=False): | ||
| key = self._get_table_key(expr) | ||
| return self._key_in( | ||
| key, '_table_refs', parent_contexts=parent_contexts | ||
| ) | ||
|
|
||
| def set_ref(self, expr, alias): | ||
| key = self._get_table_key(expr) | ||
| self._table_refs[key] = alias | ||
|
|
||
| def get_ref(self, expr): | ||
| """ | ||
| Get the alias being used throughout a query to refer to a particular | ||
| table or inline view | ||
| """ | ||
| return self._get_table_item('_table_refs', expr) | ||
|
|
||
| def is_extracted(self, expr): | ||
| key = self._get_table_key(expr) | ||
| return key in self.top_context.extracted_subexprs | ||
|
|
||
| def set_extracted(self, expr): | ||
| key = self._get_table_key(expr) | ||
| self.extracted_subexprs.add(key) | ||
| self.make_alias(expr) | ||
|
|
||
| def subcontext(self): | ||
| return type(self)( | ||
| compiler=self.compiler, | ||
| indent=self.indent, | ||
| parent=self, | ||
| params=self.params, | ||
| ) | ||
|
|
||
| # Maybe temporary hacks for correlated / uncorrelated subqueries | ||
|
|
||
| def set_query(self, query): | ||
| self.query = query | ||
|
|
||
| def is_foreign_expr(self, expr): | ||
| from ibis.expr.analysis import ExprValidator | ||
|
|
||
| # The expression isn't foreign to us. For example, the parent table set | ||
| # in a correlated WHERE subquery | ||
| if self.has_ref(expr, parent_contexts=True): | ||
| return False | ||
|
|
||
| exprs = [self.query.table_set] + self.query.select_set | ||
| validator = ExprValidator(exprs) | ||
| return not validator.validate(expr) | ||
|
|
||
| def _get_table_item(self, item, expr): | ||
| key = self._get_table_key(expr) | ||
| top = self.top_context | ||
|
|
||
| if self.is_extracted(expr): | ||
| return getattr(top, item).get(key) | ||
|
|
||
| return getattr(self, item).get(key) | ||
|
|
||
| def _get_table_key(self, table): | ||
| if isinstance(table, ir.TableExpr): | ||
| table = table.op() | ||
|
|
||
| try: | ||
| return self._table_key_memo[table] | ||
| except KeyError: | ||
| val = table._repr() | ||
| self._table_key_memo[table] = val | ||
| return val | ||
|
|
||
| def _key_in(self, key, memo_attr, parent_contexts=False): | ||
| if key in getattr(self, memo_attr): | ||
| return True | ||
|
|
||
| ctx = self | ||
| while parent_contexts and ctx.parent is not None: | ||
| ctx = ctx.parent | ||
| if key in getattr(ctx, memo_attr): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| class ExprTranslator: | ||
|
|
||
| """Class that performs translation of ibis expressions into executable | ||
| SQL. | ||
| """ | ||
|
|
||
| _registry = operation_registry | ||
| _rewrites: Dict[ops.Node, Callable] = {} | ||
|
|
||
| def __init__(self, expr, context, named=False, permit_subquery=False): | ||
| self.expr = expr | ||
| self.permit_subquery = permit_subquery | ||
|
|
||
| assert context is not None, 'context is None in {}'.format( | ||
| type(self).__name__ | ||
| ) | ||
| self.context = context | ||
|
|
||
| # For now, governing whether the result will have a name | ||
| self.named = named | ||
|
|
||
| def get_result(self): | ||
| """ | ||
| Build compiled SQL expression from the bottom up and return as a string | ||
| """ | ||
| translated = self.translate(self.expr) | ||
| if self._needs_name(self.expr): | ||
| # TODO: this could fail in various ways | ||
| name = self.expr.get_name() | ||
| translated = self.name(translated, name) | ||
| return translated | ||
|
|
||
| @classmethod | ||
| def add_operation(cls, operation, translate_function): | ||
| """ | ||
| Adds an operation to the operation registry. In general, operations | ||
| should be defined directly in the registry, in `registry.py`. But | ||
| there are couple of exceptions why this is needed. Operations defined | ||
| by Ibis users (not Ibis or backend developers). and UDF, which are | ||
| added dynamically. | ||
| """ | ||
| cls._registry[operation] = translate_function | ||
|
|
||
| def _needs_name(self, expr): | ||
| if not self.named: | ||
| return False | ||
|
|
||
| op = expr.op() | ||
| if isinstance(op, ops.TableColumn): | ||
| # This column has been given an explicitly different name | ||
| if expr.get_name() != op.name: | ||
| return True | ||
| return False | ||
|
|
||
| if expr.get_name() is ir.unnamed: | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def name(self, translated, name, force=True): | ||
| return '{} AS {}'.format( | ||
| translated, quote_identifier(name, force=force) | ||
| ) | ||
|
|
||
| def translate(self, expr): | ||
| # The operation node type the typed expression wraps | ||
| op = expr.op() | ||
|
|
||
| if type(op) in self._rewrites: # even if type(op) is in self._registry | ||
| expr = self._rewrites[type(op)](expr) | ||
| op = expr.op() | ||
|
|
||
| # TODO: use op MRO for subclasses instead of this isinstance spaghetti | ||
| if isinstance(op, ops.ScalarParameter): | ||
| return self._trans_param(expr) | ||
| elif isinstance(op, ops.TableNode): | ||
| # HACK/TODO: revisit for more complex cases | ||
| return '*' | ||
| elif type(op) in self._registry: | ||
| formatter = self._registry[type(op)] | ||
| return formatter(self, expr) | ||
| else: | ||
| raise com.OperationNotDefinedError( | ||
| f'No translation rule for {type(op)}' | ||
| ) | ||
|
|
||
| def _trans_param(self, expr): | ||
| raw_value = self.context.params[expr.op()] | ||
| literal = ibis.literal(raw_value, type=expr.type()) | ||
| return self.translate(literal) | ||
|
|
||
| @classmethod | ||
| def rewrites(cls, klass): | ||
| def decorator(f): | ||
| cls._rewrites[klass] = f | ||
| return f | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| rewrites = ExprTranslator.rewrites | ||
|
|
||
|
|
||
| @rewrites(analytics.Bucket) | ||
| def _bucket(expr): | ||
| op = expr.op() | ||
| stmt = ibis.case() | ||
|
|
||
| if op.closed == 'left': | ||
| l_cmp = operator.le | ||
| r_cmp = operator.lt | ||
| else: | ||
| l_cmp = operator.lt | ||
| r_cmp = operator.le | ||
|
|
||
| user_num_buckets = len(op.buckets) - 1 | ||
|
|
||
| bucket_id = 0 | ||
| if op.include_under: | ||
| if user_num_buckets > 0: | ||
| cmp = operator.lt if op.close_extreme else r_cmp | ||
| else: | ||
| cmp = operator.le if op.closed == 'right' else operator.lt | ||
| stmt = stmt.when(cmp(op.arg, op.buckets[0]), bucket_id) | ||
| bucket_id += 1 | ||
|
|
||
| for j, (lower, upper) in enumerate(zip(op.buckets, op.buckets[1:])): | ||
| if op.close_extreme and ( | ||
| (op.closed == 'right' and j == 0) | ||
| or (op.closed == 'left' and j == (user_num_buckets - 1)) | ||
| ): | ||
| stmt = stmt.when((lower <= op.arg) & (op.arg <= upper), bucket_id) | ||
| else: | ||
| stmt = stmt.when( | ||
| l_cmp(lower, op.arg) & r_cmp(op.arg, upper), bucket_id | ||
| ) | ||
| bucket_id += 1 | ||
|
|
||
| if op.include_over: | ||
| if user_num_buckets > 0: | ||
| cmp = operator.lt if op.close_extreme else l_cmp | ||
| else: | ||
| cmp = operator.lt if op.closed == 'right' else operator.le | ||
|
|
||
| stmt = stmt.when(cmp(op.buckets[-1], op.arg), bucket_id) | ||
| bucket_id += 1 | ||
|
|
||
| return stmt.end().name(expr._name) | ||
|
|
||
|
|
||
| @rewrites(analytics.CategoryLabel) | ||
| def _category_label(expr): | ||
| op = expr.op() | ||
|
|
||
| stmt = op.args[0].case() | ||
| for i, label in enumerate(op.labels): | ||
| stmt = stmt.when(i, label) | ||
|
|
||
| if op.nulls is not None: | ||
| stmt = stmt.else_(op.nulls) | ||
|
|
||
| return stmt.end().name(expr._name) | ||
|
|
||
|
|
||
| @rewrites(ops.Any) | ||
| def _any_expand(expr): | ||
| arg = expr.op().args[0] | ||
| return arg.max() | ||
|
|
||
|
|
||
| @rewrites(ops.NotAny) | ||
| def _notany_expand(expr): | ||
| arg = expr.op().args[0] | ||
| return arg.max() == 0 | ||
|
|
||
|
|
||
| @rewrites(ops.All) | ||
| def _all_expand(expr): | ||
| arg = expr.op().args[0] | ||
| return arg.min() | ||
|
|
||
|
|
||
| @rewrites(ops.NotAll) | ||
| def _notall_expand(expr): | ||
| arg = expr.op().args[0] | ||
| return arg.min() == 0 | ||
|
|
||
|
|
||
| @rewrites(ops.Cast) | ||
| def _rewrite_cast(expr): | ||
| arg, to = expr.op().args | ||
| if isinstance(to, dt.Interval) and isinstance(arg.type(), dt.Integer): | ||
| return arg.to_interval(unit=to.unit) | ||
| return expr |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from .aggregate import reduction | ||
| from .helpers import quote_identifier, sql_type_names, type_to_sql_string | ||
| from .literal import literal, literal_formatters | ||
| from .main import binary_infix_ops, fixed_arity, operation_registry, unary | ||
| from .window import ( | ||
| cumulative_to_window, | ||
| format_window, | ||
| time_range_to_range_window, | ||
| ) | ||
|
|
||
| __all__ = ( | ||
| 'quote_identifier', | ||
| 'operation_registry', | ||
| 'binary_infix_ops', | ||
| 'fixed_arity', | ||
| 'literal', | ||
| 'literal_formatters', | ||
| 'sql_type_names', | ||
| 'type_to_sql_string', | ||
| 'reduction', | ||
| 'unary', | ||
| 'cumulative_to_window', | ||
| 'format_window', | ||
| 'time_range_to_range_window', | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import itertools | ||
|
|
||
| import ibis | ||
|
|
||
|
|
||
| def _reduction_format(translator, func_name, where, arg, *args): | ||
| if where is not None: | ||
| arg = where.ifelse(arg, ibis.NA) | ||
|
|
||
| return '{}({})'.format( | ||
| func_name, | ||
| ', '.join(map(translator.translate, itertools.chain([arg], args))), | ||
| ) | ||
|
|
||
|
|
||
| def reduction(func_name): | ||
| def formatter(translator, expr): | ||
| op = expr.op() | ||
| *args, where = op.args | ||
| return _reduction_format(translator, func_name, where, *args) | ||
|
|
||
| return formatter | ||
|
|
||
|
|
||
| def variance_like(func_name): | ||
| func_names = { | ||
| 'sample': f'{func_name}_samp', | ||
| 'pop': f'{func_name}_pop', | ||
| } | ||
|
|
||
| def formatter(translator, expr): | ||
| arg, how, where = expr.op().args | ||
| return _reduction_format(translator, func_names[how], where, arg) | ||
|
|
||
| return formatter | ||
|
|
||
|
|
||
| def count_distinct(translator, expr): | ||
| arg, where = expr.op().args | ||
|
|
||
| if where is not None: | ||
| arg_formatted = translator.translate(where.ifelse(arg, None)) | ||
| else: | ||
| arg_formatted = translator.translate(arg) | ||
| return f'count(DISTINCT {arg_formatted})' |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from . import helpers | ||
|
|
||
|
|
||
| def binary_infix_op(infix_sym): | ||
| def formatter(translator, expr): | ||
| op = expr.op() | ||
|
|
||
| left, right = op.args | ||
|
|
||
| left_arg = translator.translate(left) | ||
| right_arg = translator.translate(right) | ||
| if helpers.needs_parens(left): | ||
| left_arg = helpers.parenthesize(left_arg) | ||
|
|
||
| if helpers.needs_parens(right): | ||
| right_arg = helpers.parenthesize(right_arg) | ||
|
|
||
| return f'{left_arg} {infix_sym} {right_arg}' | ||
|
|
||
| return formatter | ||
|
|
||
|
|
||
| def identical_to(translator, expr): | ||
| op = expr.op() | ||
| if op.args[0].equals(op.args[1]): | ||
| return 'TRUE' | ||
|
|
||
| left_expr = op.left | ||
| right_expr = op.right | ||
| left = translator.translate(left_expr) | ||
| right = translator.translate(right_expr) | ||
|
|
||
| if helpers.needs_parens(left_expr): | ||
| left = helpers.parenthesize(left) | ||
| if helpers.needs_parens(right_expr): | ||
| right = helpers.parenthesize(right) | ||
| return f'{left} IS NOT DISTINCT FROM {right}' | ||
|
|
||
|
|
||
| def xor(translator, expr): | ||
| op = expr.op() | ||
|
|
||
| left_arg = translator.translate(op.left) | ||
| right_arg = translator.translate(op.right) | ||
|
|
||
| if helpers.needs_parens(op.left): | ||
| left_arg = helpers.parenthesize(left_arg) | ||
|
|
||
| if helpers.needs_parens(op.right): | ||
| right_arg = helpers.parenthesize(right_arg) | ||
|
|
||
| return '({0} OR {1}) AND NOT ({0} AND {1})'.format(left_arg, right_arg) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| from io import StringIO | ||
|
|
||
|
|
||
| class _CaseFormatter: | ||
| def __init__(self, translator, base, cases, results, default): | ||
| self.translator = translator | ||
| self.base = base | ||
| self.cases = cases | ||
| self.results = results | ||
| self.default = default | ||
|
|
||
| # HACK | ||
| self.indent = 2 | ||
| self.multiline = len(cases) > 1 | ||
| self.buf = StringIO() | ||
|
|
||
| def _trans(self, expr): | ||
| return self.translator.translate(expr) | ||
|
|
||
| def get_result(self): | ||
| self.buf.seek(0) | ||
|
|
||
| self.buf.write('CASE') | ||
| if self.base is not None: | ||
| base_str = self._trans(self.base) | ||
| self.buf.write(f' {base_str}') | ||
|
|
||
| for case, result in zip(self.cases, self.results): | ||
| self._next_case() | ||
| case_str = self._trans(case) | ||
| result_str = self._trans(result) | ||
| self.buf.write(f'WHEN {case_str} THEN {result_str}') | ||
|
|
||
| if self.default is not None: | ||
| self._next_case() | ||
| default_str = self._trans(self.default) | ||
| self.buf.write(f'ELSE {default_str}') | ||
|
|
||
| if self.multiline: | ||
| self.buf.write('\nEND') | ||
| else: | ||
| self.buf.write(' END') | ||
|
|
||
| return self.buf.getvalue() | ||
|
|
||
| def _next_case(self): | ||
| if self.multiline: | ||
| self.buf.write('\n{}'.format(' ' * self.indent)) | ||
| else: | ||
| self.buf.write(' ') | ||
|
|
||
|
|
||
| def simple_case(translator, expr): | ||
| op = expr.op() | ||
| formatter = _CaseFormatter( | ||
| translator, op.base, op.cases, op.results, op.default | ||
| ) | ||
| return formatter.get_result() | ||
|
|
||
|
|
||
| def searched_case(translator, expr): | ||
| op = expr.op() | ||
| formatter = _CaseFormatter( | ||
| translator, None, op.cases, op.results, op.default | ||
| ) | ||
| return formatter.get_result() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| import ibis.common.exceptions as com | ||
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
|
|
||
| from . import identifiers | ||
|
|
||
|
|
||
| def format_call(translator, func, *args): | ||
| formatted_args = [] | ||
| for arg in args: | ||
| fmt_arg = translator.translate(arg) | ||
| formatted_args.append(fmt_arg) | ||
|
|
||
| return '{}({})'.format(func, ', '.join(formatted_args)) | ||
|
|
||
|
|
||
| def quote_identifier(name, quotechar='`', force=False): | ||
| """Add quotes to the `name` identifier if needed.""" | ||
| if force or name.count(' ') or name in identifiers.base_identifiers: | ||
| return '{0}{1}{0}'.format(quotechar, name) | ||
| else: | ||
| return name | ||
|
|
||
|
|
||
| def needs_parens(op): | ||
| if isinstance(op, ir.Expr): | ||
| op = op.op() | ||
| op_klass = type(op) | ||
| # function calls don't need parens | ||
| return op_klass in { | ||
| ops.Negate, | ||
| ops.IsNull, | ||
| ops.NotNull, | ||
| ops.Add, | ||
| ops.Subtract, | ||
| ops.Multiply, | ||
| ops.Divide, | ||
| ops.Power, | ||
| ops.Modulus, | ||
| ops.Equals, | ||
| ops.NotEquals, | ||
| ops.GreaterEqual, | ||
| ops.Greater, | ||
| ops.LessEqual, | ||
| ops.Less, | ||
| ops.IdenticalTo, | ||
| ops.And, | ||
| ops.Or, | ||
| ops.Xor, | ||
| } | ||
|
|
||
|
|
||
| parenthesize = '({})'.format | ||
|
|
||
|
|
||
| sql_type_names = { | ||
| 'int8': 'tinyint', | ||
| 'int16': 'smallint', | ||
| 'int32': 'int', | ||
| 'int64': 'bigint', | ||
| 'float': 'float', | ||
| 'float32': 'float', | ||
| 'double': 'double', | ||
| 'float64': 'double', | ||
| 'string': 'string', | ||
| 'boolean': 'boolean', | ||
| 'timestamp': 'timestamp', | ||
| 'decimal': 'decimal', | ||
| } | ||
|
|
||
|
|
||
| def type_to_sql_string(tval): | ||
| if isinstance(tval, dt.Decimal): | ||
| return f'decimal({tval.precision}, {tval.scale})' | ||
| name = tval.name.lower() | ||
| try: | ||
| return sql_type_names[name] | ||
| except KeyError: | ||
| raise com.UnsupportedBackendType(name) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| import datetime | ||
| import math | ||
|
|
||
| import ibis.expr.types as ir | ||
|
|
||
|
|
||
| def _set_literal_format(translator, expr): | ||
| value_type = expr.type().value_type | ||
|
|
||
| formatted = [ | ||
| translator.translate(ir.literal(x, type=value_type)) | ||
| for x in expr.op().value | ||
| ] | ||
|
|
||
| return '(' + ', '.join(formatted) + ')' | ||
|
|
||
|
|
||
| def _boolean_literal_format(translator, expr): | ||
| value = expr.op().value | ||
| return 'TRUE' if value else 'FALSE' | ||
|
|
||
|
|
||
| def _string_literal_format(translator, expr): | ||
| value = expr.op().value | ||
| return "'{}'".format(value.replace("'", "\\'")) | ||
|
|
||
|
|
||
| def _number_literal_format(translator, expr): | ||
| value = expr.op().value | ||
|
|
||
| if math.isfinite(value): | ||
| formatted = repr(value) | ||
| else: | ||
| if math.isnan(value): | ||
| formatted_val = 'NaN' | ||
| elif math.isinf(value): | ||
| if value > 0: | ||
| formatted_val = 'Infinity' | ||
| else: | ||
| formatted_val = '-Infinity' | ||
| formatted = f"CAST({formatted_val!r} AS DOUBLE)" | ||
|
|
||
| return formatted | ||
|
|
||
|
|
||
| def _interval_literal_format(translator, expr): | ||
| return 'INTERVAL {} {}'.format( | ||
| expr.op().value, expr.type().resolution.upper() | ||
| ) | ||
|
|
||
|
|
||
| def _date_literal_format(translator, expr): | ||
| value = expr.op().value | ||
| if isinstance(value, datetime.date): | ||
| value = value.strftime('%Y-%m-%d') | ||
|
|
||
| return repr(value) | ||
|
|
||
|
|
||
| def _timestamp_literal_format(translator, expr): | ||
| value = expr.op().value | ||
| if isinstance(value, datetime.datetime): | ||
| value = value.strftime('%Y-%m-%d %H:%M:%S') | ||
|
|
||
| return repr(value) | ||
|
|
||
|
|
||
| literal_formatters = { | ||
| 'boolean': _boolean_literal_format, | ||
| 'number': _number_literal_format, | ||
| 'string': _string_literal_format, | ||
| 'interval': _interval_literal_format, | ||
| 'timestamp': _timestamp_literal_format, | ||
| 'date': _date_literal_format, | ||
| 'set': _set_literal_format, | ||
| } | ||
|
|
||
|
|
||
| def literal(translator, expr): | ||
| """Return the expression as its literal value.""" | ||
| if isinstance(expr, ir.BooleanValue): | ||
| typeclass = 'boolean' | ||
| elif isinstance(expr, ir.StringValue): | ||
| typeclass = 'string' | ||
| elif isinstance(expr, ir.NumericValue): | ||
| typeclass = 'number' | ||
| elif isinstance(expr, ir.DateValue): | ||
| typeclass = 'date' | ||
| elif isinstance(expr, ir.TimestampValue): | ||
| typeclass = 'timestamp' | ||
| elif isinstance(expr, ir.IntervalValue): | ||
| typeclass = 'interval' | ||
| elif isinstance(expr, ir.SetValue): | ||
| typeclass = 'set' | ||
| else: | ||
| raise NotImplementedError | ||
|
|
||
| return literal_formatters[typeclass](translator, expr) | ||
|
|
||
|
|
||
| def null_literal(translator, expr): | ||
| return 'NULL' |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,362 @@ | ||
| import ibis.common.exceptions as com | ||
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
| import ibis.util as util | ||
|
|
||
| from . import aggregate, binary_infix, case, helpers, string, timestamp, window | ||
| from .literal import literal, null_literal | ||
|
|
||
|
|
||
| def fixed_arity(func_name, arity): | ||
| def formatter(translator, expr): | ||
| op = expr.op() | ||
| if arity != len(op.args): | ||
| raise com.IbisError('incorrect number of args') | ||
| return helpers.format_call(translator, func_name, *op.args) | ||
|
|
||
| return formatter | ||
|
|
||
|
|
||
| def unary(func_name): | ||
| return fixed_arity(func_name, 1) | ||
|
|
||
|
|
||
| def not_null(translator, expr): | ||
| formatted_arg = translator.translate(expr.op().args[0]) | ||
| return f'{formatted_arg} IS NOT NULL' | ||
|
|
||
|
|
||
| def is_null(translator, expr): | ||
| formatted_arg = translator.translate(expr.op().args[0]) | ||
| return f'{formatted_arg} IS NULL' | ||
|
|
||
|
|
||
| def not_(translator, expr): | ||
| (arg,) = expr.op().args | ||
| formatted_arg = translator.translate(arg) | ||
| if helpers.needs_parens(arg): | ||
| formatted_arg = helpers.parenthesize(formatted_arg) | ||
| return f'NOT {formatted_arg}' | ||
|
|
||
|
|
||
| def negate(translator, expr): | ||
| arg = expr.op().args[0] | ||
| formatted_arg = translator.translate(arg) | ||
| if isinstance(expr, ir.BooleanValue): | ||
| return not_(translator, expr) | ||
| else: | ||
| if helpers.needs_parens(arg): | ||
| formatted_arg = helpers.parenthesize(formatted_arg) | ||
| return f'-{formatted_arg}' | ||
|
|
||
|
|
||
| def ifnull_workaround(translator, expr): | ||
| op = expr.op() | ||
| a, b = op.args | ||
|
|
||
| # work around per #345, #360 | ||
| if isinstance(a, ir.DecimalValue) and isinstance(b, ir.IntegerValue): | ||
| b = b.cast(a.type()) | ||
|
|
||
| return helpers.format_call(translator, 'isnull', a, b) | ||
|
|
||
|
|
||
| def sign(translator, expr): | ||
| (arg,) = expr.op().args | ||
| translated_arg = translator.translate(arg) | ||
| translated_type = helpers.type_to_sql_string(expr.type()) | ||
| if expr.type() != dt.float: | ||
| return f'CAST(sign({translated_arg}) AS {translated_type})' | ||
| return f'sign({translated_arg})' | ||
|
|
||
|
|
||
| def hashbytes(translator, expr): | ||
| op = expr.op() | ||
| arg, how = op.args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if how == 'md5': | ||
| return f'md5({arg_formatted})' | ||
| elif how == 'sha1': | ||
| return f'sha1({arg_formatted})' | ||
| elif how == 'sha256': | ||
| return f'sha256({arg_formatted})' | ||
| elif how == 'sha512': | ||
| return f'sha512({arg_formatted})' | ||
| else: | ||
| raise NotImplementedError(how) | ||
|
|
||
|
|
||
| def log(translator, expr): | ||
| op = expr.op() | ||
| arg, base = op.args | ||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if base is None: | ||
| return f'ln({arg_formatted})' | ||
|
|
||
| base_formatted = translator.translate(base) | ||
| return f'log({base_formatted}, {arg_formatted})' | ||
|
|
||
|
|
||
| def value_list(translator, expr): | ||
| op = expr.op() | ||
| formatted = [translator.translate(x) for x in op.values] | ||
| return helpers.parenthesize(', '.join(formatted)) | ||
|
|
||
|
|
||
| def cast(translator, expr): | ||
| op = expr.op() | ||
| arg, target_type = op.args | ||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if isinstance(arg, ir.CategoryValue) and target_type == dt.int32: | ||
| return arg_formatted | ||
| if isinstance(arg, ir.TemporalValue) and target_type == dt.int64: | ||
| return f'1000000 * unix_timestamp({arg_formatted})' | ||
| else: | ||
| sql_type = helpers.type_to_sql_string(target_type) | ||
| return f'CAST({arg_formatted} AS {sql_type})' | ||
|
|
||
|
|
||
| def varargs(func_name): | ||
| def varargs_formatter(translator, expr): | ||
| op = expr.op() | ||
| return helpers.format_call(translator, func_name, *op.arg) | ||
|
|
||
| return varargs_formatter | ||
|
|
||
|
|
||
| def between(translator, expr): | ||
| op = expr.op() | ||
| comp, lower, upper = (translator.translate(x) for x in op.args) | ||
| return f'{comp} BETWEEN {lower} AND {upper}' | ||
|
|
||
|
|
||
| def table_array_view(translator, expr): | ||
| ctx = translator.context | ||
| table = expr.op().table | ||
| query = ctx.get_compiled_expr(table) | ||
| return f'(\n{util.indent(query, ctx.indent)}\n)' | ||
|
|
||
|
|
||
| def table_column(translator, expr): | ||
| op = expr.op() | ||
| field_name = op.name | ||
| quoted_name = helpers.quote_identifier(field_name, force=True) | ||
|
|
||
| table = op.table | ||
| ctx = translator.context | ||
|
|
||
| # If the column does not originate from the table set in the current SELECT | ||
| # context, we should format as a subquery | ||
| if translator.permit_subquery and ctx.is_foreign_expr(table): | ||
| proj_expr = table.projection([field_name]).to_array() | ||
| return table_array_view(translator, proj_expr) | ||
|
|
||
| if ctx.need_aliases(): | ||
| alias = ctx.get_ref(table) | ||
| if alias is not None: | ||
| quoted_name = f'{alias}.{quoted_name}' | ||
|
|
||
| return quoted_name | ||
|
|
||
|
|
||
| def exists_subquery(translator, expr): | ||
| op = expr.op() | ||
| ctx = translator.context | ||
|
|
||
| dummy = ir.literal(1).name(ir.unnamed) | ||
|
|
||
| filtered = op.foreign_table.filter(op.predicates) | ||
| expr = filtered.projection([dummy]) | ||
|
|
||
| subquery = ctx.get_compiled_expr(expr) | ||
|
|
||
| if isinstance(op, ops.ExistsSubquery): | ||
| key = 'EXISTS' | ||
| elif isinstance(op, ops.NotExistsSubquery): | ||
| key = 'NOT EXISTS' | ||
| else: | ||
| raise NotImplementedError | ||
|
|
||
| return f'{key} (\n{util.indent(subquery, ctx.indent)}\n)' | ||
|
|
||
|
|
||
| # XXX this is not added to operation_registry, but looks like impala is | ||
| # using it in the tests, and it works, even if it's not imported anywhere | ||
| def round(translator, expr): | ||
| op = expr.op() | ||
| arg, digits = op.args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if digits is not None: | ||
| digits_formatted = translator.translate(digits) | ||
| return f'round({arg_formatted}, {digits_formatted})' | ||
| return f'round({arg_formatted})' | ||
|
|
||
|
|
||
| # XXX this is not added to operation_registry, but looks like impala is | ||
| # using it in the tests, and it works, even if it's not imported anywhere | ||
| def hash(translator, expr): | ||
| op = expr.op() | ||
| arg, how = op.args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if how == 'fnv': | ||
| return f'fnv_hash({arg_formatted})' | ||
| else: | ||
| raise NotImplementedError(how) | ||
|
|
||
|
|
||
| binary_infix_ops = { | ||
| # Binary operations | ||
| ops.Add: binary_infix.binary_infix_op('+'), | ||
| ops.Subtract: binary_infix.binary_infix_op('-'), | ||
| ops.Multiply: binary_infix.binary_infix_op('*'), | ||
| ops.Divide: binary_infix.binary_infix_op('/'), | ||
| ops.Power: fixed_arity('pow', 2), | ||
| ops.Modulus: binary_infix.binary_infix_op('%'), | ||
| # Comparisons | ||
| ops.Equals: binary_infix.binary_infix_op('='), | ||
| ops.NotEquals: binary_infix.binary_infix_op('!='), | ||
| ops.GreaterEqual: binary_infix.binary_infix_op('>='), | ||
| ops.Greater: binary_infix.binary_infix_op('>'), | ||
| ops.LessEqual: binary_infix.binary_infix_op('<='), | ||
| ops.Less: binary_infix.binary_infix_op('<'), | ||
| ops.IdenticalTo: binary_infix.identical_to, | ||
| # Boolean comparisons | ||
| ops.And: binary_infix.binary_infix_op('AND'), | ||
| ops.Or: binary_infix.binary_infix_op('OR'), | ||
| ops.Xor: binary_infix.xor, | ||
| } | ||
|
|
||
|
|
||
| operation_registry = { | ||
| # Unary operations | ||
| ops.NotNull: not_null, | ||
| ops.IsNull: is_null, | ||
| ops.Negate: negate, | ||
| ops.Not: not_, | ||
| ops.IsNan: unary('is_nan'), | ||
| ops.IsInf: unary('is_inf'), | ||
| ops.IfNull: ifnull_workaround, | ||
| ops.NullIf: fixed_arity('nullif', 2), | ||
| ops.ZeroIfNull: unary('zeroifnull'), | ||
| ops.NullIfZero: unary('nullifzero'), | ||
| ops.Abs: unary('abs'), | ||
| ops.BaseConvert: fixed_arity('conv', 3), | ||
| ops.Ceil: unary('ceil'), | ||
| ops.Floor: unary('floor'), | ||
| ops.Exp: unary('exp'), | ||
| ops.Round: round, | ||
| ops.Sign: sign, | ||
| ops.Sqrt: unary('sqrt'), | ||
| ops.Hash: hash, | ||
| ops.HashBytes: hashbytes, | ||
| ops.Log: log, | ||
| ops.Ln: unary('ln'), | ||
| ops.Log2: unary('log2'), | ||
| ops.Log10: unary('log10'), | ||
| ops.DecimalPrecision: unary('precision'), | ||
| ops.DecimalScale: unary('scale'), | ||
| # Unary aggregates | ||
| ops.CMSMedian: aggregate.reduction('appx_median'), | ||
| ops.HLLCardinality: aggregate.reduction('ndv'), | ||
| ops.Mean: aggregate.reduction('avg'), | ||
| ops.Sum: aggregate.reduction('sum'), | ||
| ops.Max: aggregate.reduction('max'), | ||
| ops.Min: aggregate.reduction('min'), | ||
| ops.StandardDev: aggregate.variance_like('stddev'), | ||
| ops.Variance: aggregate.variance_like('var'), | ||
| ops.GroupConcat: aggregate.reduction('group_concat'), | ||
| ops.Count: aggregate.reduction('count'), | ||
| ops.CountDistinct: aggregate.count_distinct, | ||
| # string operations | ||
| ops.StringLength: unary('length'), | ||
| ops.StringAscii: unary('ascii'), | ||
| ops.Lowercase: unary('lower'), | ||
| ops.Uppercase: unary('upper'), | ||
| ops.Reverse: unary('reverse'), | ||
| ops.Strip: unary('trim'), | ||
| ops.LStrip: unary('ltrim'), | ||
| ops.RStrip: unary('rtrim'), | ||
| ops.Capitalize: unary('initcap'), | ||
| ops.Substring: string.substring, | ||
| ops.StrRight: fixed_arity('strright', 2), | ||
| ops.Repeat: fixed_arity('repeat', 2), | ||
| ops.StringFind: string.string_find, | ||
| ops.Translate: fixed_arity('translate', 3), | ||
| ops.FindInSet: string.find_in_set, | ||
| ops.LPad: fixed_arity('lpad', 3), | ||
| ops.RPad: fixed_arity('rpad', 3), | ||
| ops.StringJoin: string.string_join, | ||
| ops.StringSQLLike: string.string_like, | ||
| ops.RegexSearch: fixed_arity('regexp_like', 2), | ||
| ops.RegexExtract: fixed_arity('regexp_extract', 3), | ||
| ops.RegexReplace: fixed_arity('regexp_replace', 3), | ||
| ops.ParseURL: string.parse_url, | ||
| ops.StartsWith: string.startswith, | ||
| ops.EndsWith: string.endswith, | ||
| # Timestamp operations | ||
| ops.Date: unary('to_date'), | ||
| ops.TimestampNow: lambda *args: 'now()', | ||
| ops.ExtractYear: timestamp.extract_field('year'), | ||
| ops.ExtractMonth: timestamp.extract_field('month'), | ||
| ops.ExtractDay: timestamp.extract_field('day'), | ||
| ops.ExtractQuarter: timestamp.extract_field('quarter'), | ||
| ops.ExtractEpochSeconds: timestamp.extract_epoch_seconds, | ||
| ops.ExtractWeekOfYear: fixed_arity('weekofyear', 1), | ||
| ops.ExtractHour: timestamp.extract_field('hour'), | ||
| ops.ExtractMinute: timestamp.extract_field('minute'), | ||
| ops.ExtractSecond: timestamp.extract_field('second'), | ||
| ops.ExtractMillisecond: timestamp.extract_field('millisecond'), | ||
| ops.TimestampTruncate: timestamp.truncate, | ||
| ops.DateTruncate: timestamp.truncate, | ||
| ops.IntervalFromInteger: timestamp.interval_from_integer, | ||
| # Other operations | ||
| ops.E: lambda *args: 'e()', | ||
| ops.Literal: literal, | ||
| ops.NullLiteral: null_literal, | ||
| ops.ValueList: value_list, | ||
| ops.Cast: cast, | ||
| ops.Coalesce: varargs('coalesce'), | ||
| ops.Greatest: varargs('greatest'), | ||
| ops.Least: varargs('least'), | ||
| ops.Where: fixed_arity('if', 3), | ||
| ops.Between: between, | ||
| ops.Contains: binary_infix.binary_infix_op('IN'), | ||
| ops.NotContains: binary_infix.binary_infix_op('NOT IN'), | ||
| ops.SimpleCase: case.simple_case, | ||
| ops.SearchedCase: case.searched_case, | ||
| ops.TableColumn: table_column, | ||
| ops.TableArrayView: table_array_view, | ||
| ops.DateAdd: timestamp.timestamp_op('date_add'), | ||
| ops.DateSub: timestamp.timestamp_op('date_sub'), | ||
| ops.DateDiff: timestamp.timestamp_op('datediff'), | ||
| ops.TimestampAdd: timestamp.timestamp_op('date_add'), | ||
| ops.TimestampSub: timestamp.timestamp_op('date_sub'), | ||
| ops.TimestampDiff: timestamp.timestamp_diff, | ||
| ops.TimestampFromUNIX: timestamp.timestamp_from_unix, | ||
| ops.ExistsSubquery: exists_subquery, | ||
| ops.NotExistsSubquery: exists_subquery, | ||
| # RowNumber, and rank functions starts with 0 in Ibis-land | ||
| ops.RowNumber: lambda *args: 'row_number()', | ||
| ops.DenseRank: lambda *args: 'dense_rank()', | ||
| ops.MinRank: lambda *args: 'rank()', | ||
| ops.PercentRank: lambda *args: 'percent_rank()', | ||
| ops.FirstValue: unary('first_value'), | ||
| ops.LastValue: unary('last_value'), | ||
| ops.NthValue: window.nth_value, | ||
| ops.Lag: window.shift_like('lag'), | ||
| ops.Lead: window.shift_like('lead'), | ||
| ops.WindowOp: window.window, | ||
| ops.NTile: window.ntile, | ||
| ops.DayOfWeekIndex: timestamp.day_of_week_index, | ||
| ops.DayOfWeekName: timestamp.day_of_week_name, | ||
| **binary_infix_ops, | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| import ibis.expr.operations as ops | ||
|
|
||
| from . import helpers | ||
|
|
||
|
|
||
| def substring(translator, expr): | ||
| op = expr.op() | ||
| arg, start, length = op.args | ||
| arg_formatted = translator.translate(arg) | ||
| start_formatted = translator.translate(start) | ||
|
|
||
| # Impala is 1-indexed | ||
| if length is None or isinstance(length.op(), ops.Literal): | ||
| lvalue = length.op().value if length is not None else None | ||
| if lvalue: | ||
| return 'substr({}, {} + 1, {})'.format( | ||
| arg_formatted, start_formatted, lvalue | ||
| ) | ||
| else: | ||
| return f'substr({arg_formatted}, {start_formatted} + 1)' | ||
| else: | ||
| length_formatted = translator.translate(length) | ||
| return 'substr({}, {} + 1, {})'.format( | ||
| arg_formatted, start_formatted, length_formatted | ||
| ) | ||
|
|
||
|
|
||
| def string_find(translator, expr): | ||
| op = expr.op() | ||
| arg, substr, start, _ = op.args | ||
| arg_formatted = translator.translate(arg) | ||
| substr_formatted = translator.translate(substr) | ||
|
|
||
| if start is not None and not isinstance(start.op(), ops.Literal): | ||
| start_fmt = translator.translate(start) | ||
| return 'locate({}, {}, {} + 1) - 1'.format( | ||
| substr_formatted, arg_formatted, start_fmt | ||
| ) | ||
| elif start is not None and start.op().value: | ||
| sval = start.op().value | ||
| return 'locate({}, {}, {}) - 1'.format( | ||
| substr_formatted, arg_formatted, sval + 1 | ||
| ) | ||
| else: | ||
| return f'locate({substr_formatted}, {arg_formatted}) - 1' | ||
|
|
||
|
|
||
| def find_in_set(translator, expr): | ||
| op = expr.op() | ||
|
|
||
| arg, str_list = op.args | ||
| arg_formatted = translator.translate(arg) | ||
| str_formatted = ','.join([x._arg.value for x in str_list]) | ||
| return f"find_in_set({arg_formatted}, '{str_formatted}') - 1" | ||
|
|
||
|
|
||
| def string_join(translator, expr): | ||
| op = expr.op() | ||
| arg, strings = op.args | ||
| return helpers.format_call(translator, 'concat_ws', arg, *strings) | ||
|
|
||
|
|
||
| def string_like(translator, expr): | ||
| arg, pattern, _ = expr.op().args | ||
| return '{} LIKE {}'.format( | ||
| translator.translate(arg), translator.translate(pattern) | ||
| ) | ||
|
|
||
|
|
||
| def parse_url(translator, expr): | ||
| op = expr.op() | ||
|
|
||
| arg, extract, key = op.args | ||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if key is None: | ||
| return f"parse_url({arg_formatted}, '{extract}')" | ||
| else: | ||
| key_fmt = translator.translate(key) | ||
| return "parse_url({}, '{}', {})".format( | ||
| arg_formatted, extract, key_fmt | ||
| ) | ||
|
|
||
|
|
||
| def startswith(translator, expr): | ||
| arg, start = expr.op().args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
| start_formatted = translator.translate(start) | ||
|
|
||
| return f"{arg_formatted} like concat({start_formatted}, '%')" | ||
|
|
||
|
|
||
| def endswith(translator, expr): | ||
| arg, start = expr.op().args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
| end_formatted = translator.translate(start) | ||
|
|
||
| return f"{arg_formatted} like concat('%', {end_formatted})" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| import ibis.common.exceptions as com | ||
| import ibis.expr.types as ir | ||
| import ibis.util as util | ||
|
|
||
|
|
||
| def extract_field(sql_attr): | ||
| def extract_field_formatter(translator, expr): | ||
| op = expr.op() | ||
| arg = translator.translate(op.args[0]) | ||
|
|
||
| # This is pre-2.0 Impala-style, which did not used to support the | ||
| # SQL-99 format extract($FIELD from expr) | ||
| return f"extract({arg}, '{sql_attr}')" | ||
|
|
||
| return extract_field_formatter | ||
|
|
||
|
|
||
| def extract_epoch_seconds(t, expr): | ||
| (arg,) = expr.op().args | ||
| return f'unix_timestamp({t.translate(arg)})' | ||
|
|
||
|
|
||
| def truncate(translator, expr): | ||
| base_unit_names = { | ||
| 'Y': 'Y', | ||
| 'Q': 'Q', | ||
| 'M': 'MONTH', | ||
| 'W': 'W', | ||
| 'D': 'J', | ||
| 'h': 'HH', | ||
| 'm': 'MI', | ||
| } | ||
| op = expr.op() | ||
| arg, unit = op.args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
| try: | ||
| unit = base_unit_names[unit] | ||
| except KeyError: | ||
| raise com.UnsupportedOperationError( | ||
| f'{unit!r} unit is not supported in timestamp truncate' | ||
| ) | ||
|
|
||
| return f"trunc({arg_formatted}, '{unit}')" | ||
|
|
||
|
|
||
| def interval_from_integer(translator, expr): | ||
| # interval cannot be selected from impala | ||
| op = expr.op() | ||
| arg, unit = op.args | ||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| return 'INTERVAL {} {}'.format( | ||
| arg_formatted, expr.type().resolution.upper() | ||
| ) | ||
|
|
||
|
|
||
| def timestamp_op(func): | ||
| def _formatter(translator, expr): | ||
| op = expr.op() | ||
| left, right = op.args | ||
| formatted_left = translator.translate(left) | ||
| formatted_right = translator.translate(right) | ||
|
|
||
| if isinstance(left, (ir.TimestampScalar, ir.DateValue)): | ||
| formatted_left = f'cast({formatted_left} as timestamp)' | ||
|
|
||
| if isinstance(right, (ir.TimestampScalar, ir.DateValue)): | ||
| formatted_right = f'cast({formatted_right} as timestamp)' | ||
|
|
||
| return f'{func}({formatted_left}, {formatted_right})' | ||
|
|
||
| return _formatter | ||
|
|
||
|
|
||
| def timestamp_diff(translator, expr): | ||
| op = expr.op() | ||
| left, right = op.args | ||
|
|
||
| return 'unix_timestamp({}) - unix_timestamp({})'.format( | ||
| translator.translate(left), translator.translate(right) | ||
| ) | ||
|
|
||
|
|
||
| def _from_unixtime(translator, expr): | ||
| arg = translator.translate(expr) | ||
| return f'from_unixtime({arg}, "yyyy-MM-dd HH:mm:ss")' | ||
|
|
||
|
|
||
| def timestamp_from_unix(translator, expr): | ||
| op = expr.op() | ||
|
|
||
| val, unit = op.args | ||
| val = util.convert_unit(val, unit, 's').cast('int32') | ||
|
|
||
| arg = _from_unixtime(translator, val) | ||
| return f'CAST({arg} AS timestamp)' | ||
|
|
||
|
|
||
| def day_of_week_index(t, expr): | ||
| (arg,) = expr.op().args | ||
| return f'pmod(dayofweek({t.translate(arg)}) - 2, 7)' | ||
|
|
||
|
|
||
| def day_of_week_name(t, expr): | ||
| (arg,) = expr.op().args | ||
| return f'dayname({t.translate(arg)})' |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,335 @@ | ||
| from operator import add, mul, sub | ||
| from typing import Optional, Union | ||
|
|
||
| import ibis | ||
| import ibis.common.exceptions as com | ||
| import ibis.expr.analysis as L | ||
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
| from ibis.expr.signature import Argument | ||
|
|
||
| _map_interval_to_microseconds = { | ||
| 'W': 604800000000, | ||
| 'D': 86400000000, | ||
| 'h': 3600000000, | ||
| 'm': 60000000, | ||
| 's': 1000000, | ||
| 'ms': 1000, | ||
| 'us': 1, | ||
| 'ns': 0.001, | ||
| } | ||
|
|
||
|
|
||
| _map_interval_op_to_op = { | ||
| # Literal Intervals have two args, i.e. | ||
| # Literal(1, Interval(value_type=int8, unit='D', nullable=True)) | ||
| # Parse both args and multipy 1 * _map_interval_to_microseconds['D'] | ||
| ops.Literal: mul, | ||
| ops.IntervalMultiply: mul, | ||
| ops.IntervalAdd: add, | ||
| ops.IntervalSubtract: sub, | ||
| } | ||
|
|
||
|
|
||
| _cumulative_to_reduction = { | ||
| ops.CumulativeSum: ops.Sum, | ||
| ops.CumulativeMin: ops.Min, | ||
| ops.CumulativeMax: ops.Max, | ||
| ops.CumulativeMean: ops.Mean, | ||
| ops.CumulativeAny: ops.Any, | ||
| ops.CumulativeAll: ops.All, | ||
| } | ||
|
|
||
|
|
||
| def _replace_interval_with_scalar( | ||
| expr: Union[ir.Expr, dt.Interval, float] | ||
| ) -> Union[ir.Expr, float, Argument]: | ||
| """ | ||
| Good old Depth-First Search to identify the Interval and IntervalValue | ||
| components of the expression and return a comparable scalar expression. | ||
| Parameters | ||
| ---------- | ||
| expr : float or expression of intervals | ||
| For example, ``ibis.interval(days=1) + ibis.interval(hours=5)`` | ||
| Returns | ||
| ------- | ||
| preceding : float or ir.FloatingScalar, depending upon the expr | ||
| """ | ||
| if isinstance(expr, ir.Expr): | ||
| expr_op = expr.op() | ||
| else: | ||
| expr_op = None | ||
|
|
||
| if not isinstance(expr, (dt.Interval, ir.IntervalValue)): | ||
| # Literal expressions have op method but native types do not. | ||
| if isinstance(expr_op, ops.Literal): | ||
| return expr_op.value | ||
| else: | ||
| return expr | ||
| elif isinstance(expr, dt.Interval): | ||
| try: | ||
| microseconds = _map_interval_to_microseconds[expr.unit] | ||
| return microseconds | ||
| except KeyError: | ||
| raise ValueError( | ||
| "Expected preceding values of week(), " | ||
| + "day(), hour(), minute(), second(), millisecond(), " | ||
| + f"microseconds(), nanoseconds(); got {expr}" | ||
| ) | ||
| elif expr_op.args and isinstance(expr, ir.IntervalValue): | ||
| if len(expr_op.args) > 2: | ||
| raise NotImplementedError("'preceding' argument cannot be parsed.") | ||
| left_arg = _replace_interval_with_scalar(expr_op.args[0]) | ||
| right_arg = _replace_interval_with_scalar(expr_op.args[1]) | ||
| method = _map_interval_op_to_op[type(expr_op)] | ||
| return method(left_arg, right_arg) | ||
| else: | ||
| raise TypeError(f'expr has unknown type {type(expr).__name__}') | ||
|
|
||
|
|
||
| def cumulative_to_window(translator, expr, window): | ||
| win = ibis.cumulative_window() | ||
| win = win.group_by(window._group_by).order_by(window._order_by) | ||
|
|
||
| op = expr.op() | ||
|
|
||
| klass = _cumulative_to_reduction[type(op)] | ||
| new_op = klass(*op.args) | ||
| new_expr = expr._factory(new_op, name=expr._name) | ||
|
|
||
| if type(new_op) in translator._rewrites: | ||
| new_expr = translator._rewrites[type(new_op)](new_expr) | ||
|
|
||
| new_expr = L.windowize_function(new_expr, win) | ||
| return new_expr | ||
|
|
||
|
|
||
| def time_range_to_range_window(translator, window): | ||
| # Check that ORDER BY column is a single time column: | ||
| order_by_vars = [x.op().args[0] for x in window._order_by] | ||
| if len(order_by_vars) > 1: | ||
| raise com.IbisInputError( | ||
| f"Expected 1 order-by variable, got {len(order_by_vars)}" | ||
| ) | ||
|
|
||
| order_var = window._order_by[0].op().args[0] | ||
| timestamp_order_var = order_var.cast('int64') | ||
| window = window._replace(order_by=timestamp_order_var, how='range') | ||
|
|
||
| # Need to change preceding interval expression to scalars | ||
| preceding = window.preceding | ||
| if isinstance(preceding, ir.IntervalScalar): | ||
| new_preceding = _replace_interval_with_scalar(preceding) | ||
| window = window._replace(preceding=new_preceding) | ||
|
|
||
| return window | ||
|
|
||
|
|
||
| def format_window(translator, op, window): | ||
| components = [] | ||
|
|
||
| if window.max_lookback is not None: | ||
| raise NotImplementedError( | ||
| 'Rows with max lookback is not implemented ' | ||
| 'for Impala-based backends.' | ||
| ) | ||
|
|
||
| if len(window._group_by) > 0: | ||
| partition_args = [translator.translate(x) for x in window._group_by] | ||
| components.append('PARTITION BY {}'.format(', '.join(partition_args))) | ||
|
|
||
| if len(window._order_by) > 0: | ||
| order_args = [] | ||
| for expr in window._order_by: | ||
| key = expr.op() | ||
| translated = translator.translate(key.expr) | ||
| if not key.ascending: | ||
| translated += ' DESC' | ||
| order_args.append(translated) | ||
|
|
||
| components.append('ORDER BY {}'.format(', '.join(order_args))) | ||
|
|
||
| p, f = window.preceding, window.following | ||
|
|
||
| def _prec(p: Optional[int]) -> str: | ||
| assert p is None or p >= 0 | ||
|
|
||
| if p is None: | ||
| prefix = 'UNBOUNDED' | ||
| else: | ||
| if not p: | ||
| return 'CURRENT ROW' | ||
| prefix = str(p) | ||
| return f'{prefix} PRECEDING' | ||
|
|
||
| def _foll(f: Optional[int]) -> str: | ||
| assert f is None or f >= 0 | ||
|
|
||
| if f is None: | ||
| prefix = 'UNBOUNDED' | ||
| else: | ||
| if not f: | ||
| return 'CURRENT ROW' | ||
| prefix = str(f) | ||
|
|
||
| return f'{prefix} FOLLOWING' | ||
|
|
||
| frame_clause_not_allowed = ( | ||
| ops.Lag, | ||
| ops.Lead, | ||
| ops.DenseRank, | ||
| ops.MinRank, | ||
| ops.NTile, | ||
| ops.PercentRank, | ||
| ops.RowNumber, | ||
| ) | ||
|
|
||
| if isinstance(op.expr.op(), frame_clause_not_allowed): | ||
| frame = None | ||
| elif p is not None and f is not None: | ||
| frame = '{} BETWEEN {} AND {}'.format( | ||
| window.how.upper(), _prec(p), _foll(f) | ||
| ) | ||
|
|
||
| elif p is not None: | ||
| if isinstance(p, tuple): | ||
| start, end = p | ||
| frame = '{} BETWEEN {} AND {}'.format( | ||
| window.how.upper(), _prec(start), _prec(end) | ||
| ) | ||
| else: | ||
| kind = 'ROWS' if p > 0 else 'RANGE' | ||
| frame = '{} BETWEEN {} AND UNBOUNDED FOLLOWING'.format( | ||
| kind, _prec(p) | ||
| ) | ||
| elif f is not None: | ||
| if isinstance(f, tuple): | ||
| start, end = f | ||
| frame = '{} BETWEEN {} AND {}'.format( | ||
| window.how.upper(), _foll(start), _foll(end) | ||
| ) | ||
| else: | ||
| kind = 'ROWS' if f > 0 else 'RANGE' | ||
| frame = '{} BETWEEN UNBOUNDED PRECEDING AND {}'.format( | ||
| kind, _foll(f) | ||
| ) | ||
| else: | ||
| # no-op, default is full sample | ||
| frame = None | ||
|
|
||
| if frame is not None: | ||
| components.append(frame) | ||
|
|
||
| return 'OVER ({})'.format(' '.join(components)) | ||
|
|
||
|
|
||
| _subtract_one = '({} - 1)'.format | ||
|
|
||
|
|
||
| _expr_transforms = { | ||
| ops.RowNumber: _subtract_one, | ||
| ops.DenseRank: _subtract_one, | ||
| ops.MinRank: _subtract_one, | ||
| ops.NTile: _subtract_one, | ||
| } | ||
|
|
||
|
|
||
| def window(translator, expr): | ||
| op = expr.op() | ||
|
|
||
| arg, window = op.args | ||
| window_op = arg.op() | ||
|
|
||
| _require_order_by = ( | ||
| ops.Lag, | ||
| ops.Lead, | ||
| ops.DenseRank, | ||
| ops.MinRank, | ||
| ops.FirstValue, | ||
| ops.LastValue, | ||
| ops.PercentRank, | ||
| ops.NTile, | ||
| ) | ||
|
|
||
| _unsupported_reductions = ( | ||
| ops.CMSMedian, | ||
| ops.GroupConcat, | ||
| ops.HLLCardinality, | ||
| ) | ||
|
|
||
| if isinstance(window_op, _unsupported_reductions): | ||
| raise com.UnsupportedOperationError( | ||
| f'{type(window_op)} is not supported in window functions' | ||
| ) | ||
|
|
||
| if isinstance(window_op, ops.CumulativeOp): | ||
| arg = cumulative_to_window(translator, arg, window) | ||
| return translator.translate(arg) | ||
|
|
||
| # Some analytic functions need to have the expression of interest in | ||
| # the ORDER BY part of the window clause | ||
| if isinstance(window_op, _require_order_by) and len(window._order_by) == 0: | ||
| window = window.order_by(window_op.args[0]) | ||
|
|
||
| # Time ranges need to be converted to microseconds. | ||
| if window.how == 'range': | ||
| order_by_types = [type(x.op().args[0]) for x in window._order_by] | ||
| time_range_types = (ir.TimeColumn, ir.DateColumn, ir.TimestampColumn) | ||
| if any(col_type in time_range_types for col_type in order_by_types): | ||
| window = time_range_to_range_window(translator, window) | ||
|
|
||
| window_formatted = format_window(translator, op, window) | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
| result = f'{arg_formatted} {window_formatted}' | ||
|
|
||
| if type(window_op) in _expr_transforms: | ||
| return _expr_transforms[type(window_op)](result) | ||
| else: | ||
| return result | ||
|
|
||
|
|
||
| def nth_value(translator, expr): | ||
| op = expr.op() | ||
| arg, rank = op.args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
| rank_formatted = translator.translate(rank - 1) | ||
|
|
||
| return f'first_value(lag({arg_formatted}, {rank_formatted}))' | ||
|
|
||
|
|
||
| def shift_like(name): | ||
| def formatter(translator, expr): | ||
| op = expr.op() | ||
| arg, offset, default = op.args | ||
|
|
||
| arg_formatted = translator.translate(arg) | ||
|
|
||
| if default is not None: | ||
| if offset is None: | ||
| offset_formatted = '1' | ||
| else: | ||
| offset_formatted = translator.translate(offset) | ||
|
|
||
| default_formatted = translator.translate(default) | ||
|
|
||
| return '{}({}, {}, {})'.format( | ||
| name, arg_formatted, offset_formatted, default_formatted | ||
| ) | ||
| elif offset is not None: | ||
| offset_formatted = translator.translate(offset) | ||
| return f'{name}({arg_formatted}, {offset_formatted})' | ||
| else: | ||
| return f'{name}({arg_formatted})' | ||
|
|
||
| return formatter | ||
|
|
||
|
|
||
| def ntile(translator, expr): | ||
| op = expr.op() | ||
| arg, buckets = map(translator.translate, op.args) | ||
| return f'ntile({buckets})' |