diff --git a/beanquery/compiler.py b/beanquery/compiler.py index 6e6c442e..ee0758ea 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -22,12 +22,14 @@ EvalColumn, EvalConstant, EvalCreateTable, + EvalUnion, EvalGetItem, EvalGetter, EvalInsert, EvalOr, EvalPivot, EvalQuery, + EvalSelect, EvalConstantSubquery1D, EvalRow, EvalTarget, @@ -96,6 +98,172 @@ def _compile(self, node: Optional[ast.Node]): return None raise NotImplementedError + @_compile.register + def _query(self, node: ast.Query): + set_operators = node.set_operators or [] + + if set_operators: + # SELECT ... UNION SELECT ... [LIMIT|ORDER BY|PIVOT BY] + return self._compile_union_query(node) + else: + # SELECT ... [LIMIT | ORDER BY | PIVOT BY] + return self._compile_single_select_query(node) + + def _compile_single_select_query(self, node: ast.Query): + # Compile a query that consists of a single SELECT only. + + assert not node.set_operators + + select = node.queries[0] + eval_select = self._select(select) + + # ORDER BY belongs to the enclosing Query, not the Select. + new_targets, order_spec = self._compile_order_by(node.order_by, eval_select.c_targets) + eval_select.c_targets.extend(new_targets) + + # DISTINCT with ORDER BY on columns not in SELECT produces non-deterministic + # results: when multiple rows have the same visible values but different + # ORDER BY values, which row survives DISTINCT is arbitrary. + # We allow ORDER BY f(x) if x is visible, since f(x) is computable from x. + if eval_select.distinct and new_targets: + visible_column_ids = set() + for t in eval_select.c_targets: + if t.name is not None: + visible_column_ids.update(id(c) for c in _collect_columns(t.c_expr)) + + for target in new_targets: + for col in _collect_columns(target.c_expr): + if id(col) not in visible_column_ids: + raise CompilationError( + f'When using DISTINCT, ORDER BY expressions must only ' + f'reference columns that appear in the SELECT list. ' + f'Offending ORDER BY expression: {node.order_by[0].column.text}') + + # If this is an aggregate query (it groups, see list of indexes), check that + # the set of non-aggregates match exactly the group indexes. This should + # always be the case at this point, because we have added all the necessary + # targets to the list of group-by expressions and should have resolved all + # the indexes. + if eval_select.group_indexes is not None: + non_aggregate_indexes = {i for i, t in enumerate(eval_select.c_targets) + if not t.is_aggregate} + if non_aggregate_indexes != set(eval_select.group_indexes): + missing_names = ['"{}"'.format(eval_select.c_targets[i].name) + for i in non_aggregate_indexes - set(eval_select.group_indexes)] + raise CompilationError( + 'all non-aggregates must be covered by GROUP-BY clause in aggregate query: ' + 'the following targets are missing: {}'.format(','.join(missing_names))) + + # Wrap in EvalQuery with ORDER BY and LIMIT. + eval_query = EvalQuery( + select=eval_select, + order_spec=order_spec, + limit=node.limit, + ) + + # PIVOT applies to the final sorted/paged result set. + pivots = self._compile_pivot_by(node.pivot_by, eval_select.c_targets, eval_select.group_indexes) + if pivots: + return EvalPivot(eval_query, pivots) + + return eval_query + + def _compile_union_query(self, node: ast.Query): + # UNION chain: compile each SELECT against the original table. + # Each operand is wrapped in EvalQuery for consistent interface of + # plain `SELECT ... UNION SELECT ...` vs. subqueries + # `(SELECT ...) UNION (SELECT ...)` + # + # Plain SELECTs get order_spec=[] and limit=None since the BQL + # grammar does not allow ORDER BY or LIMIT on plain UNION operands + # + # SELECT a FROM x UNION SELECT b FROM y ORDER BY 1 + # -- ORDER BY applies to the entire UNION result + # + # To apply ORDER BY to an individual operand, use a subquery: + # + # (SELECT a FROM x ORDER BY a LIMIT 10) UNION SELECT b FROM y + # -- ORDER BY and LIMIT apply only to the first operand + saved_table = self.table + compiled = [] + set_operators = node.set_operators or [] + + for select in node.queries: + self.table = saved_table + eq = self._compile(select) + if isinstance(eq, EvalSelect): + # We have a plain SELECT operand (no subquery in parentheses, then we would get + # type EvalQuery by a recursive call of this function with node.set_operators = []) + eq = EvalQuery(select=eq, order_spec=[], limit=None) + compiled.append(eq) + + # Validate UNION operands: all must have the same column count and + # compatible types. We do not auto-coerce except between int and + # Decimal + first_targets = compiled[0].c_targets + for i, eq in enumerate(compiled[1:], start=1): + if len(eq.c_targets) != len(first_targets): + raise CompilationError( + f'UNION operands must have the same number of columns: ' + f'operand 0 has {len(first_targets)} columns, ' + f'operand {i} has {len(eq.c_targets)} columns') + for j, (t1, t2) in enumerate(zip(first_targets, eq.c_targets)): + if t1.c_expr.dtype != t2.c_expr.dtype: + # Try to coerce numeric types (int ↔ Decimal) to a common type. + # We will not auto-coerce other types, as they might be hard to catch by + # the user when they debug their BQL statement. + type_pair = {t1.c_expr.dtype, t2.c_expr.dtype} + if type_pair == {int, Decimal}: + # Both are numeric, coerce to Decimal. + coerced1 = self._try_coerce_operand(t1.c_expr, Decimal) + coerced2 = self._try_coerce_operand(t2.c_expr, Decimal) + if coerced1 is not None: + first_targets[j] = EvalTarget(coerced1, t1.name, t1.is_aggregate) + if coerced2 is not None: + eq.c_targets[j] = EvalTarget(coerced2, t2.name, t2.is_aggregate) + else: + raise CompilationError( + f'UNION operands have type mismatch at position {j}: ' + f'{t1.c_expr.dtype.__name__} vs {t2.c_expr.dtype.__name__}') + + # EvalUnion has the same interface as EvalSelect, so we wrap it in + # EvalQuery which handles ORDER BY, LIMIT, and visible column extraction. + eval_union = EvalUnion( + queries=compiled, + set_operators=set_operators, + c_targets=first_targets, + ) + + # ORDER BY targets are resolved against the first operand's columns. + # Unlike single SELECT, UNION does not support invisible ORDER BY columns + # because each operand may have different source tables. ORDER BY must + # reference columns in the SELECT list (by name, alias, or position). + new_targets, order_spec = self._compile_order_by(node.order_by, first_targets) + if new_targets: + # Find the first offending ORDER BY expression (one that created a new target). + # new_targets are indexed starting at len(first_targets). + n_original = len(first_targets) + offending_indexes = {idx for idx, _ in order_spec if idx >= n_original} + offending_text = None + for i, spec in enumerate(node.order_by): + idx, _ = order_spec[i] + if idx >= n_original and hasattr(spec.column, 'text'): + offending_text = spec.column.text + break + raise CompilationError( + f'UNION queries only support ORDER BY on expressions that appear in the ' + f'SELECT list. Any column or expression in ORDER BY must be added as a ' + f'column to all SELECT clauses in the UNION. ' + f'Offending expression: {offending_text or "unknown"}') + + eval_query = EvalQuery( + select=eval_union, + order_spec=order_spec, + limit=node.limit, + ) + + return eval_query + @_compile.register def _select(self, node: ast.Select): self.stack.append(self.table) @@ -123,47 +291,26 @@ def _select(self, node: ast.Select): new_targets, group_indexes, having_index = self._compile_group_by(node.group_by, c_targets) c_targets.extend(new_targets) - # Process the ORDER-BY clause. - new_targets, order_spec = self._compile_order_by(node.order_by, c_targets) - c_targets.extend(new_targets) - - # If this is an aggregate query (it groups, see list of indexes), check that - # the set of non-aggregates match exactly the group indexes. This should - # always be the case at this point, because we have added all the necessary - # targets to the list of group-by expressions and should have resolved all - # the indexes. - if group_indexes is not None: - non_aggregate_indexes = {index for index, c_target in enumerate(c_targets) - if not c_target.is_aggregate} - if non_aggregate_indexes != set(group_indexes): - missing_names = ['"{}"'.format(c_targets[index].name) - for index in non_aggregate_indexes - set(group_indexes)] - raise CompilationError( - 'all non-aggregates must be covered by GROUP-BY clause in aggregate query: ' - 'the following targets are missing: {}'.format(','.join(missing_names))) - - query = EvalQuery(self.table, - c_targets, - c_where, - group_indexes, - having_index, - order_spec, - node.limit, - node.distinct) - - pivots = self._compile_pivot_by(node.pivot_by, c_targets, group_indexes) - if pivots: - return EvalPivot(query, pivots) + # ORDER BY and LIMIT are compiled by the enclosing _query handler, + # which also validates aggregate coverage after ORDER BY targets are added. + select = EvalSelect( + table=self.table, + c_targets=c_targets, + c_where=c_where, + group_indexes=group_indexes, + having_index=having_index, + distinct=node.distinct, + ) self.stack.pop() - return query + return select def _compile_from(self, node): if node is None: return None # Subquery. - if isinstance(node, ast.Select): + if isinstance(node, ast.Query): self.table = SubqueryTable(self._compile(node)) return None @@ -653,6 +800,34 @@ def _inop(self, node: Union[ast.In, ast.NotIn]): op = OPERATORS[type(node)][0] return op(left, right) + def _try_coerce_operand(self, operand, target_type): + """Attempt to coerce an operand to a target type. + + Args: + operand: An EvalNode to coerce. + target_type: The desired type to coerce to. + + Returns: + Coerced EvalNode if coercion is possible, None otherwise. + """ + if operand.dtype == target_type: + return operand + + # The Beancount parser does not emit int typed values, thus casting to int + # is only going to loose information. Promote to decimal. + if target_type is int: + target_type = Decimal + + name = types.MAP.get(target_type) + if name is None: + return None + + func = types.function_lookup(FUNCTIONS, name, [operand]) + if func is None: + return None + + return func(self.context, [operand]) + @_compile.register def _binaryop(self, node: ast.BinaryOp): left = self._compile(node.left) @@ -671,28 +846,16 @@ def _binaryop(self, node: ast.BinaryOp): # Implement type inference when one of the operands is not strongly typed. if left.dtype is object and right.dtype is not object: - target = right.dtype - if target is int: - # The Beancount parser does not emit int typed - # values, thus casting to int is only going to - # loose information. Promote to decimal. - target = Decimal - name = types.MAP.get(target) - if name is None: + coerced = self._try_coerce_operand(left, right.dtype) + if coerced is None: break - left = types.function_lookup(FUNCTIONS, name, [left])(self.context, [left]) + left = coerced continue if right.dtype is object and left.dtype is not object: - target = left.dtype - if target is int: - # The Beancount parser does not emit int typed - # values, thus casting to int is only going to - # loose information. Promote to decimal. - target = Decimal - name = types.MAP.get(target) - if name is None: + coerced = self._try_coerce_operand(right, left.dtype) + if coerced is None: break - right = types.function_lookup(FUNCTIONS, name, [right])(self.context, [right]) + right = coerced continue # Failure. @@ -735,7 +898,11 @@ def _print(self, node: ast.Print): self.table = self.context.tables.get('entries') expr = self._compile_from(node.from_clause) targets = [EvalTarget(EvalRow(), 'ROW(*)', False)] - return EvalQuery(self.table, targets, expr, None, None, None, None, False) + return EvalQuery( + select=EvalSelect(self.table, targets, expr, None, None, False), + order_spec=None, + limit=None, + ) @_compile.register def _create_table(self, node: ast.CreateTable): @@ -789,7 +956,7 @@ def transform_journal(journal): Returns: An instance of an uncompiled Select object. """ - cooked_select = parser.parse(""" + cooked = parser.parse(""" SELECT date, @@ -804,12 +971,15 @@ def transform_journal(journal): """.format(where=('WHERE account ~ "{}"'.format(journal.account) if journal.account else ''), - summary_func=journal.summary_func or '')) + summary_func=journal.summary_func or '')).queries[0] - return ast.Select(cooked_select.targets, - journal.from_clause, - cooked_select.where_clause, - None, None, None, None, None) + select = ast.Select( + cooked.targets, + journal.from_clause, + cooked.where_clause, + None, None) + + return ast.Query(queries=[select], set_operators=[], order_by=None, limit=None, pivot_by=None) def transform_balances(balances): @@ -826,20 +996,22 @@ def transform_balances(balances): ## the first or last sort-order value gets used, because it would simplify ## the input statement. - cooked_select = parser.parse(""" + cooked_query = parser.parse(""" SELECT account, SUM({}(position)) GROUP BY account, ACCOUNT_SORTKEY(account) ORDER BY ACCOUNT_SORTKEY(account) """.format(balances.summary_func or "")) + cooked = cooked_query.queries[0] - return ast.Select(cooked_select.targets, - balances.from_clause, - balances.where_clause, - cooked_select.group_by, - cooked_select.order_by, - None, None, None) + select = ast.Select( + cooked.targets, + balances.from_clause, + balances.where_clause, + cooked.group_by, + None) + return ast.Query(queries=[select], set_operators=[], order_by=cooked_query.order_by, limit=None, pivot_by=None) def get_target_name(target): @@ -909,5 +1081,13 @@ def is_aggregate(node): return bool(aggregates) +def _collect_columns(node): + """Recursively collect all EvalColumn nodes from an expression tree.""" + if isinstance(node, EvalColumn): + yield node + for child in node.childnodes(): + yield from _collect_columns(child) + + def compile(context, statement, parameters=None): return Compiler(context).compile(statement, parameters) diff --git a/beanquery/parser/ast.py b/beanquery/parser/ast.py index 0fc9e49e..78ffeeb8 100644 --- a/beanquery/parser/ast.py +++ b/beanquery/parser/ast.py @@ -81,11 +81,22 @@ def node(name, fields): # from_clause: An instance of 'From', or None if absent. # where_clause: A root expression node, or None if absent. # group_by: An instance of 'GroupBy', or None if absent. -# order_by: An instance of 'OrderBy', or None if absent. -# pivot_by: An instance of 'PivotBy', or None if absent. -# limit: An integer, or None is absent. # distinct: A boolean value (True), or None if absent. -Select = node('Select', 'targets from_clause where_clause group_by order_by pivot_by limit distinct') +Select = node('Select', 'targets from_clause where_clause group_by distinct') + +# The top-level query node wrapping one or more SELECT bodies. +# +# A single SELECT is the degenerate case (len(queries) == 1). +# In the future, UNION chain support will be added where len(queries) > 1. +# +# Attributes: +# queries: List of Select nodes. +# set_operators: List of set-operator names between adjacent queries, e.g. +# 'union' or 'union_all'. len == len(queries) - 1. +# order_by: Optional list of OrderBy applied to the combined result. +# limit: Optional integer limit applied to the combined result. +# pivot_by: Optional PivotBy applied to the combined result. +Query = node('Query', 'queries set_operators order_by limit pivot_by') # A select query that produces final balances for accounts. # This is equivalent to diff --git a/beanquery/parser/bql.ebnf b/beanquery/parser/bql.ebnf index 1c078b63..c6b93dda 100644 --- a/beanquery/parser/bql.ebnf +++ b/beanquery/parser/bql.ebnf @@ -15,7 +15,7 @@ bql statement = - | select + | query | balances | journal | print @@ -23,18 +23,34 @@ statement | insert ; +(* Wrapper for queries, which can be either a plain SELECT or a chain of + SELECTs combined by set operators (UNION, UNION ALL). ORDER BY, LIMIT, + and PIVOT BY after the last operand apply to the combined result set. + The list `operators` has one entry per inter-operand connector, always one + less than the number of queries. *) +query::Query + = queries+:( select | subquery ) + { ( 'UNION' 'ALL' set_operators+:`'union_all'` + | 'UNION' set_operators+:`'union'` + ) queries+:( select | subquery) + }* + ['ORDER' 'BY' order_by:','.{order}+] + ['LIMIT' limit:integer] + ['PIVOT' 'BY' pivot_by:pivotby] + ; + +(* SELECT body without ORDER BY / LIMIT / PIVOT BY so the enclosing query rule can claim + those tokens for the combined result set. *) select::Select = 'SELECT' ['DISTINCT' distinct:`True`] targets:(','.{ target }+ | asterisk) - ['FROM' from_clause:(_table | subselect | from)] + ['FROM' from_clause:(_table | subquery | from)] ['WHERE' where_clause:expression] ['GROUP' 'BY' group_by:groupby] - ['ORDER' 'BY' order_by:','.{order}+] - ['PIVOT' 'BY' pivot_by:pivotby] - ['LIMIT' limit:integer] ; -subselect - = '(' @:select ')' +(* Parenthesised sub-query; uses query so ORDER BY / LIMIT are allowed inside. *) +subquery + = '(' @:query ')' ; from::From @@ -132,12 +148,20 @@ comparison | sum ; +(* This operator is special in that it has parentheses. Avoid double parentheses +with subquerys ALL( (SELECT ...) ) by &(...) look-ahead *) any::Any - = left:sum op:op 'any' '(' right:expression ')' + = + | left:sum op:op 'any' &('(' 'SELECT') right:subquery + | left:sum op:op 'any' '(' right:expression ')' ; +(* This operator is special in that it has parentheses. Avoid double parentheses +with subquerys ALL( (SELECT ...) ) by &(...) look-ahead *) all::All - = left:sum op:op 'all' '(' right:expression ')' + = + | left:sum op:op 'all' &('(' 'SELECT') right:subquery + | left:sum op:op 'all' '(' right:expression ')' ; op @@ -282,7 +306,7 @@ subscript::Subscript atom = - | select + | subquery | function | constant | column @@ -390,7 +414,7 @@ create_table::CreateTable ( | '(' columns:','.{( identifier identifier )} ')' ['USING' using:string] | 'USING' using:string - | 'AS' query:select + | 'AS' query:query ) ; diff --git a/beanquery/parser/parser.py b/beanquery/parser/parser.py index 2ec9ee84..1a47e244 100644 --- a/beanquery/parser/parser.py +++ b/beanquery/parser/parser.py @@ -9,7 +9,7 @@ # Any changes you make to it will be overwritten the next time # the file is generated. -# ruff: noqa: C405, COM812, I001, F401, PLR1702, PLC2801, SIM117 +# ruff: noqa: RUF100, C405, COM812, I001, F401, PLR1702, PLC2801, SIM117 from __future__ import annotations @@ -25,34 +25,34 @@ KEYWORDS: set[str] = { - 'AND', + 'USING', + 'INSERT', 'AS', + 'WHERE', + 'AND', + 'TRUE', 'ASC', + 'NOT', + 'TABLE', + 'GROUP', + 'LIMIT', + 'IS', + 'SELECT', + 'IN', 'BY', + 'PRINT', + 'CREATE', 'DESC', - 'DISTINCT', + 'JOURNAL', 'FALSE', 'FROM', - 'GROUP', 'HAVING', - 'IN', - 'IS', - 'LIMIT', - 'NOT', - 'OR', 'ORDER', - 'PIVOT', - 'SELECT', - 'TRUE', - 'WHERE', - 'CREATE', - 'TABLE', - 'USING', - 'INSERT', 'INTO', + 'PIVOT', + 'OR', 'BALANCES', - 'JOURNAL', - 'PRINT', + 'DISTINCT', } @@ -60,7 +60,6 @@ class BQLBuffer(Buffer): def __init__(self, text, /, config: ParserConfig | None = None, **settings): config = ParserConfig.new( config, - owner=self, whitespace=None, nameguard=None, ignorecase=True, @@ -80,7 +79,6 @@ class BQLParser(Parser): def __init__(self, /, config: ParserConfig | None = None, **settings): config = ParserConfig.new( config, - owner=self, whitespace=None, nameguard=None, ignorecase=True, @@ -107,7 +105,7 @@ def _bql_(self): def _statement_(self): with self._choice(): with self._option(): - self._select_() + self._query_() with self._option(): self._balances_() with self._option(): @@ -121,11 +119,93 @@ def _statement_(self): self._error( 'expecting one of: ' "'BALANCES' 'CREATE' 'INSERT' 'JOURNAL'" - "'PRINT' 'SELECT' " - ' ' - ' ' ) + @tatsumasu('Query') + def _query_(self): + with self._group(): + with self._choice(): + with self._option(): + self._select_() + with self._option(): + self._subquery_() + self._error( + 'expecting one of: ' + ' ' + ) + self.add_last_node_to_name('queries') + self._define( + [], + ['queries', 'set_operators'], + ) + self._closure(block0) + with self._optional(): + self._token('ORDER') + self._token('BY') + + def sep1(): + self._token(',') + + def block2(): + self._order_() + self._positive_gather(block2, sep1) + self.name_last_node('order_by') + self._define(['order_by'], []) + with self._optional(): + self._token('LIMIT') + self._integer_() + self.name_last_node('limit') + self._define(['limit'], []) + with self._optional(): + self._token('PIVOT') + self._token('BY') + self._pivotby_() + self.name_last_node('pivot_by') + self._define(['pivot_by'], []) + self._define( + ['limit', 'order_by', 'pivot_by'], + ['queries', 'set_operators'], + ) + @tatsumasu('Select') def _select_(self): self._token('SELECT') @@ -158,12 +238,12 @@ def block1(): with self._option(): self.__table_() with self._option(): - self._subselect_() + self._subquery_() with self._option(): self._from_() self._error( 'expecting one of: ' - '<_table> ' + '<_table> ' ) self.name_last_node('from_clause') self._define(['from_clause'], []) @@ -178,35 +258,12 @@ def block1(): self._groupby_() self.name_last_node('group_by') self._define(['group_by'], []) - with self._optional(): - self._token('ORDER') - self._token('BY') - - def sep2(): - self._token(',') - - def block3(): - self._order_() - self._positive_gather(block3, sep2) - self.name_last_node('order_by') - self._define(['order_by'], []) - with self._optional(): - self._token('PIVOT') - self._token('BY') - self._pivotby_() - self.name_last_node('pivot_by') - self._define(['pivot_by'], []) - with self._optional(): - self._token('LIMIT') - self._integer_() - self.name_last_node('limit') - self._define(['limit'], []) - self._define(['distinct', 'from_clause', 'group_by', 'limit', 'order_by', 'pivot_by', 'targets', 'where_clause'], []) + self._define(['distinct', 'from_clause', 'group_by', 'targets', 'where_clause'], []) @tatsumasu() - def _subselect_(self): + def _subquery_(self): self._token('(') - self._select_() + self._query_() self.name_last_node('@') self._token(')') @@ -586,29 +643,67 @@ def _comparison_(self): @tatsumasu('Any') @nomemo def _any_(self): - self._sum_() - self.name_last_node('left') - self._op_() - self.name_last_node('op') - self._token('any') - self._token('(') - self._expression_() - self.name_last_node('right') - self._token(')') - self._define(['left', 'op', 'right'], []) + with self._choice(): + with self._option(): + self._sum_() + self.name_last_node('left') + self._op_() + self.name_last_node('op') + self._token('any') + with self._if(): + with self._group(): + self._token('(') + self._token('SELECT') + self._subquery_() + self.name_last_node('right') + self._define(['left', 'op', 'right'], []) + with self._option(): + self._sum_() + self.name_last_node('left') + self._op_() + self.name_last_node('op') + self._token('any') + self._token('(') + self._expression_() + self.name_last_node('right') + self._token(')') + self._define(['left', 'op', 'right'], []) + self._error( + 'expecting one of: ' + ' ' + ) @tatsumasu('All') def _all_(self): - self._sum_() - self.name_last_node('left') - self._op_() - self.name_last_node('op') - self._token('all') - self._token('(') - self._expression_() - self.name_last_node('right') - self._token(')') - self._define(['left', 'op', 'right'], []) + with self._choice(): + with self._option(): + self._sum_() + self.name_last_node('left') + self._op_() + self.name_last_node('op') + self._token('all') + with self._if(): + with self._group(): + self._token('(') + self._token('SELECT') + self._subquery_() + self.name_last_node('right') + self._define(['left', 'op', 'right'], []) + with self._option(): + self._sum_() + self.name_last_node('left') + self._op_() + self.name_last_node('op') + self._token('all') + self._token('(') + self._expression_() + self.name_last_node('right') + self._token(')') + self._define(['left', 'op', 'right'], []) + self._error( + 'expecting one of: ' + ' ' + ) @tatsumasu() def _op_(self): @@ -911,9 +1006,9 @@ def _primary_(self): self._atom_() self._error( 'expecting one of: ' - "'SELECT' " + "'(' " ' ' - ' ' + ' ' ) @tatsumasu('Placeholder') @@ -1227,7 +1322,7 @@ def block1(): self._define(['using'], []) with self._option(): self._token('AS') - self._select_() + self._query_() self.name_last_node('query') self._define(['query'], []) self._error( diff --git a/beanquery/parser_test.py b/beanquery/parser_test.py index 33f47fdb..96166df8 100644 --- a/beanquery/parser_test.py +++ b/beanquery/parser_test.py @@ -15,32 +15,43 @@ def Select(targets, from_clause=None, where_clause=None, **kwargs): from_clause=from_clause, where_clause=where_clause, group_by=None, - order_by=None, - pivot_by=None, - limit=None, distinct=None) defaults.update(kwargs) return ast.Select(**defaults) +def Query(select=None, queries=None, set_operators=None, order_by=None, limit=None, pivot_by=None): + """Build an ast.Query wrapping a single Select, for test assertions.""" + return ast.Query( + queries=[select] if queries is None else queries, + set_operators=set_operators or [], + order_by=order_by, + limit=limit, + pivot_by=pivot_by) + + class QueryParserTestBase(unittest.TestCase): def parse(self, query): return parser.parse(query.strip()) def assertParse(self, query, expected): + # Convenience: a bare ast.Select expected is auto-wrapped in ast.Query. + if isinstance(expected, ast.Select): + expected = ast.Query(queries=[expected], set_operators=[], order_by=None, limit=None, pivot_by=None) self.assertEqual(parser.parse(query), expected) def assertParseTarget(self, query, expected): expr = parser.parse(query) - self.assertIsInstance(expr, ast.Select) - self.assertEqual(len(expr.targets), 1) - self.assertEqual(expr.targets[0].expression, expected) + self.assertIsInstance(expr, ast.Query) + select = expr.queries[0] + self.assertEqual(len(select.targets), 1) + self.assertEqual(select.targets[0].expression, expected) def assertParseFrom(self, query, expected): expr = parser.parse(query) - self.assertIsInstance(expr, ast.Select) - self.assertEqual(expr.from_clause, expected) + self.assertIsInstance(expr, ast.Query) + self.assertEqual(expr.queries[0].from_clause, expected) class TestParseSelect(QueryParserTestBase): @@ -289,18 +300,19 @@ def test_from_select(self): SELECT a, b FROM ( SELECT * FROM date = 2014-05-02 ) WHERE c = 5 LIMIT 100;""", - Select([ - ast.Target(ast.Column('a'), None), - ast.Target(ast.Column('b'), None)], - Select( - ast.Asterisk(), - ast.From( - ast.Equal( - ast.Column('date'), - ast.Constant(datetime.date(2014, 5, 2))), - None, None, None)), - ast.Equal(ast.Column('c'), ast.Constant(5)), - limit=100)) + Query( + Select([ + ast.Target(ast.Column('a'), None), + ast.Target(ast.Column('b'), None)], + Query(Select( + ast.Asterisk(), + ast.From( + ast.Equal( + ast.Column('date'), + ast.Constant(datetime.date(2014, 5, 2))), + None, None, None))), + ast.Equal(ast.Column('c'), ast.Constant(5))), + limit=100)) class TestSelectGroupBy(QueryParserTestBase): @@ -362,41 +374,38 @@ class TestSelectOrderBy(QueryParserTestBase): def test_orderby_one(self): self.assertParse( "SELECT * ORDER BY a;", - Select(ast.Asterisk(), - order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC)])) + Query(Select(ast.Asterisk()), + order_by=[ast.OrderBy(ast.Column('a'), ast.Ordering.ASC)])) def test_orderby_many(self): self.assertParse( "SELECT * ORDER BY a, b, c;", - Select(ast.Asterisk(), - order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC), - ast.OrderBy(ast.Column('b'), ast.Ordering.ASC), - ast.OrderBy(ast.Column('c'), ast.Ordering.ASC)])) + Query(Select(ast.Asterisk()), + order_by=[ + ast.OrderBy(ast.Column('a'), ast.Ordering.ASC), + ast.OrderBy(ast.Column('b'), ast.Ordering.ASC), + ast.OrderBy(ast.Column('c'), ast.Ordering.ASC)])) def test_orderby_asc(self): self.assertParse( "SELECT * ORDER BY a ASC;", - Select(ast.Asterisk(), - order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC)])) + Query(Select(ast.Asterisk()), + order_by=[ast.OrderBy(ast.Column('a'), ast.Ordering.ASC)])) def test_orderby_desc(self): self.assertParse( "SELECT * ORDER BY a DESC;", - Select(ast.Asterisk(), - order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.DESC)])) + Query(Select(ast.Asterisk()), + order_by=[ast.OrderBy(ast.Column('a'), ast.Ordering.DESC)])) def test_orderby_many_asc_desc(self): self.assertParse( "SELECT * ORDER BY a ASC, b DESC, c;", - Select(ast.Asterisk(), - order_by=[ - ast.OrderBy(ast.Column('a'), ast.Ordering.ASC), - ast.OrderBy(ast.Column('b'), ast.Ordering.DESC), - ast.OrderBy(ast.Column('c'), ast.Ordering.ASC)])) + Query(Select(ast.Asterisk()), + order_by=[ + ast.OrderBy(ast.Column('a'), ast.Ordering.ASC), + ast.OrderBy(ast.Column('b'), ast.Ordering.DESC), + ast.OrderBy(ast.Column('c'), ast.Ordering.ASC)])) def test_orderby_empty(self): with self.assertRaises(parser.ParseError): @@ -417,11 +426,11 @@ def test_pivotby(self): self.assertParse( "SELECT * PIVOT BY a, b", - Select(ast.Asterisk(), pivot_by=ast.PivotBy([ast.Column('a'), ast.Column('b')]))) + Query(queries=[Select(ast.Asterisk())], pivot_by=ast.PivotBy([ast.Column('a'), ast.Column('b')]))) self.assertParse( "SELECT * PIVOT BY 1, 2", - Select(ast.Asterisk(), pivot_by=ast.PivotBy([1, 2]))) + Query(queries=[Select(ast.Asterisk())], pivot_by=ast.PivotBy([1, 2]))) class TestSelectOptions(QueryParserTestBase): @@ -432,7 +441,7 @@ def test_distinct(self): def test_limit_present(self): self.assertParse( - "SELECT * LIMIT 45;", Select(ast.Asterisk(), limit=45)) + "SELECT * LIMIT 45;", Query(Select(ast.Asterisk()), limit=45, pivot_by=None)) def test_limit_empty(self): with self.assertRaises(parser.ParseError): @@ -582,21 +591,25 @@ def test_ast_node(self): def test_tosexp(self): sexp = parser.parse('SELECT a + 1 FROM #test WHERE a > 42 ORDER BY b DESC').tosexp() self.assertEqual(sexp, textwrap.dedent('''\ - (select - targets: ( - (target - expression: (add + (query + queries: ( + (select + targets: ( + (target + expression: (add + left: (column + name: 'a') + right: (constant + value: 1)))) + from-clause: (table + name: 'test') + where-clause: (greater left: (column name: 'a') right: (constant - value: 1)))) - from-clause: (table - name: 'test') - where-clause: (greater - left: (column - name: 'a') - right: (constant - value: 42)) + value: 42)))) + set-operators: ( + ) order-by: ( (orderby column: (column @@ -612,8 +625,9 @@ def test_walk(self): class TestNodeText(unittest.TestCase): def test_text(self): - select = parser.parse('SELECT date + 1') - self.assertEqual(select.text, 'SELECT date + 1') + query = parser.parse('SELECT date + 1') + select = query.queries[0] + self.assertEqual(query.text, 'SELECT date + 1') self.assertEqual(select.targets[0].expression.text, 'date + 1') self.assertEqual(select.targets[0].expression.left.text, 'date') self.assertEqual(select.targets[0].expression.right.text, '1') @@ -689,3 +703,26 @@ def test_string(self): def test_date(self): self.assertEqualEx(self.parse('1972-05-28'), datetime.date(1972, 5, 28)) + + +class TestParseQuery(QueryParserTestBase): + """Parser always wraps SELECT in an ast.Query node.""" + + def test_single_select_is_wrapped_in_query(self): + """A plain SELECT is wrapped in an ast.Query containing one ast.Select.""" + result = self.parse("SELECT 1") + self.assertIsInstance(result, ast.Query) + self.assertEqual(len(result.queries), 1) + self.assertIsInstance(result.queries[0], ast.Select) + + def test_select_with_order_by_sets_query_order_by(self): + """ORDER BY on a plain SELECT is held by the enclosing ast.Query.""" + result = self.parse("SELECT 1 AS n ORDER BY 1") + self.assertIsInstance(result, ast.Query) + self.assertIsNotNone(result.order_by) + + def test_select_with_limit_sets_query_limit(self): + """LIMIT on a plain SELECT is held by the enclosing ast.Query.""" + result = self.parse("SELECT 1 LIMIT 1") + self.assertIsInstance(result, ast.Query) + self.assertIsNotNone(result.limit) diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index eda87ec7..f2de83cb 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -13,6 +13,7 @@ import collections import dataclasses import datetime +import functools import itertools import re import operator @@ -604,7 +605,79 @@ def __call__(self, context): EvalTarget = collections.namedtuple('EvalTarget', 'c_expr name is_aggregate') -# A compiled query, ready for execution. +@dataclasses.dataclass +class EvalUnion: + """Execute a chain of SELECTs combined by set operators (UNION, UNION ALL). + + This class has the same interface as EvalSelect: __call__ returns + (result_types, rows, visible_mask). It is wrapped by EvalQuery which + handles ORDER BY, LIMIT, and visible column extraction. + + set_operators[i] is the set operator between queries[i] and queries[i+1]. + Supported values: 'union' (deduplicate), 'union_all' (keep all rows). + """ + + queries: list + set_operators: list[str] + c_targets: list + + @property + def columns(self): + return [t for t in self.c_targets if t.name is not None] + + def __call__(self): + # Accumulate rows, applying deduplication at each UNION boundary. + # Each query returns (columns, rows) with visible columns only. + _, rows = self.queries[0]() + for op, query in zip(self.set_operators, self.queries[1:]): + _, next_rows = query() + if op == 'union_all': + rows = rows + next_rows + else: + # UNION: deduplicate the entire accumulated result, preserving + # first-seen order across all rows accumulated so far. + seen = set() + deduped = [] + for r in rows + next_rows: + if r not in seen: + seen.add(r) + deduped.append(r) + rows = deduped + + # Return same interface as EvalSelect: (result_types, rows, visible_mask). + # All columns are visible (visible column extraction already done by inner queries). + result_types = tuple(cursor.Column(t.name, t.c_expr.dtype) for t in self.c_targets) + visible_mask = [True] * len(self.c_targets) + return result_types, rows, visible_mask + + +# A compiled query wrapping a SELECT or a UNION of multiple SELECTs. +# +# This mirrors ast.Query which wraps ast.Select and owns ORDER BY, LIMIT. +# +# Attributes: +# select: The inner EvalSelect or EvalUnion. +# order_spec: A list of (integer indexes, sort order) tuples. +# limit: An optional integer used to cut off the number of result rows returned. +@dataclasses.dataclass +class EvalQuery: + select: EvalSelect | EvalUnion + order_spec: list[tuple[int, ast.Ordering]] + limit: int + + @property + def columns(self): + return self.select.columns + + @property + def c_targets(self): + return self.select.c_targets + + def __call__(self): + return query_execute.execute_query(self) + + +# A compiled SELECT, ready for execution. # # Attributes: # c_targets: A list of compiled targets (instancef of EvalTarget). @@ -614,19 +687,14 @@ def __call__(self, context): # this list of indexes should always cover all non-aggregates in 'c_targets'. # And this list may well include some invisible columns if only specified in # the GROUP BY clause. -# order_spec: A list of (integer indexes, sort order) tuples. -# This list may refer to either aggregates or non-aggregates. -# limit: An optional integer used to cut off the number of result rows returned. # distinct: An optional boolean that requests we should uniquify the result rows. @dataclasses.dataclass -class EvalQuery: +class EvalSelect: table: tables.Table c_targets: list c_where: EvalNode group_indexes: list[int] having_index: int - order_spec: list[tuple[int, ast.Ordering]] - limit: int distinct: bool @property diff --git a/beanquery/query_compile_test.py b/beanquery/query_compile_test.py index 4418f45a..f7786ee0 100644 --- a/beanquery/query_compile_test.py +++ b/beanquery/query_compile_test.py @@ -204,12 +204,12 @@ def assertSelectInvariants(self, query): AssertionError: if the check fails. """ # Check that the group references cover all the simple indexes. - if query.group_indexes is not None: + if query.select.group_indexes is not None: non_aggregate_indexes = [index for index, c_target in enumerate(query.c_targets) if not compiler.is_aggregate(c_target.c_expr)] - self.assertEqual(set(non_aggregate_indexes), set(query.group_indexes), + self.assertEqual(set(non_aggregate_indexes), set(query.select.group_indexes), "Invalid indexes: {}".format(query)) def assertIndexes(self, @@ -243,7 +243,7 @@ def assertIndexes(self, self.assertEqual( set(expected_group_indexes) if expected_group_indexes is not None else None, - set(query.group_indexes) if query.group_indexes is not None else None) + set(query.select.group_indexes) if query.select.group_indexes is not None else None) self.assertEqual( set(expected_order_spec) if expected_order_spec is not None else None, @@ -281,10 +281,10 @@ def test_compile_from(self): # Test the compilation of from. query = self.compile("SELECT account FROM CLOSE;") - self.assertEqual(query.table.close, True) + self.assertEqual(query.select.table.close, True) query = self.compile("SELECT account FROM length(payee) != 0;") - self.assertTrue(isinstance(query.c_where, qc.EvalNode)) + self.assertTrue(isinstance(query.select.c_where, qc.EvalNode)) with self.assertRaises(CompilationError): query = self.compile("SELECT account FROM sum(payee) != 0;") @@ -426,16 +426,16 @@ def test_compile_group_by_implicit(self): def test_compile_group_by_coverage(self): # Non-aggregates. query = self.compile("SELECT account, length(account);") - self.assertEqual(None, query.group_indexes) + self.assertEqual(None, query.select.group_indexes) self.assertEqual(None, query.order_spec) # Aggregates only. query = self.compile("SELECT first(account), last(account);") - self.assertEqual([], query.group_indexes) + self.assertEqual([], query.select.group_indexes) # Mixed with non-aggregates in group-by clause. query = self.compile("SELECT account, sum(number) GROUP BY account;") - self.assertEqual([0], query.group_indexes) + self.assertEqual([0], query.select.group_indexes) # Mixed with non-aggregates in group-by clause with non-aggregates a # strict subset of the group-by columns. 'account' is a subset of @@ -443,7 +443,7 @@ def test_compile_group_by_coverage(self): query = self.compile(""" SELECT account, sum(number) GROUP BY account, flag; """) - self.assertEqual([0, 2], query.group_indexes) + self.assertEqual([0, 2], query.select.group_indexes) # Non-aggregates not covered by group-by clause. with self.assertRaises(CompilationError): @@ -467,7 +467,7 @@ def test_compile_group_by_coverage(self): query = self.compile(""" SELECT date, flag, account GROUP BY date, flag, account; """) - self.assertEqual([0, 1, 2], query.group_indexes) + self.assertEqual([0, 1, 2], query.select.group_indexes) def test_compile_group_by_reconcile(self): # Check that no invisible column is created if redundant. @@ -475,7 +475,7 @@ def test_compile_group_by_reconcile(self): SELECT account, length(account), sum(number) GROUP BY account, length(account); """) - self.assertEqual([0, 1], query.group_indexes) + self.assertEqual([0, 1], query.select.group_indexes) class TestCompileSelectOrderBy(CompileSelectBase): @@ -484,20 +484,20 @@ def test_compile_order_by_simple(self): query = self.compile(""" SELECT account, sum(number) GROUP BY account ORDER BY account; """) - self.assertEqual([0], query.group_indexes) + self.assertEqual([0], query.select.group_indexes) self.assertEqual([(0, False)], query.order_spec) def test_compile_order_by_simple_2(self): query = self.compile(""" SELECT account, length(narration) GROUP BY account, 2 ORDER BY 1, 2; """) - self.assertEqual([0, 1], query.group_indexes) + self.assertEqual([0, 1], query.select.group_indexes) self.assertEqual([(0, False), (1, False)], query.order_spec) query = self.compile(""" SELECT account, length(narration) as l GROUP BY account, l ORDER BY l; """) - self.assertEqual([0, 1], query.group_indexes) + self.assertEqual([0, 1], query.select.group_indexes) self.assertEqual([(1, False)], query.order_spec) def test_compile_order_by_create_non_agg(self): @@ -514,7 +514,7 @@ def test_compile_order_by_create_non_agg(self): query = self.compile(""" SELECT account, year(date) GROUP BY 1, 2 ORDER BY 2; """) - self.assertEqual([0, 1], query.group_indexes) + self.assertEqual([0, 1], query.select.group_indexes) self.assertEqual([(1, False)], query.order_spec) # We detect similarity between order-by and targets yet. @@ -543,34 +543,48 @@ def test_compile_order_by_reference_invisible(self): GROUP BY length(account) ORDER BY length(account); """) - self.assertEqual([2], query.group_indexes) + self.assertEqual([2], query.select.group_indexes) self.assertEqual([(2, False)], query.order_spec) def test_compile_order_by_aggregate(self): query = self.compile(""" SELECT account, first(narration) GROUP BY account ORDER BY 2; """) - self.assertEqual([0], query.group_indexes) + self.assertEqual([0], query.select.group_indexes) self.assertEqual([(1, False)], query.order_spec) query = self.compile(""" SELECT account, first(narration) as f GROUP BY account ORDER BY f; """) - self.assertEqual([0], query.group_indexes) + self.assertEqual([0], query.select.group_indexes) self.assertEqual([(1, False)], query.order_spec) query = self.compile(""" SELECT account, first(narration) GROUP BY account ORDER BY sum(number); """) - self.assertEqual([0], query.group_indexes) + self.assertEqual([0], query.select.group_indexes) self.assertEqual([(2, False)], query.order_spec) query = self.compile(""" SELECT account GROUP BY account ORDER BY sum(number); """) - self.assertEqual([0], query.group_indexes) + self.assertEqual([0], query.select.group_indexes) self.assertEqual([(1, False)], query.order_spec) + def test_compile_distinct_order_by_invisible(self): + # DISTINCT with ORDER BY on a column not in SELECT is rejected. + # That would result in undefined ordering. + with self.assertRaises(CompilationError) as ctx: + self.compile("SELECT DISTINCT account ORDER BY date;") + self.assertIn('DISTINCT', str(ctx.exception)) + self.assertIn('SELECT list', str(ctx.exception)) + self.assertIn('date', str(ctx.exception)) + + # ORDER BY on a visible column is allowed. + self.compile("SELECT DISTINCT account, date ORDER BY date;") + self.compile("SELECT DISTINCT account ORDER BY account;") + self.compile("SELECT DISTINCT account ORDER BY 1;") + class TestTranslationJournal(CompileSelectBase): @@ -580,7 +594,7 @@ def test_journal(self): journal = parser.parse("JOURNAL;") select = compiler.transform_journal(journal) self.assertEqual(select, - ast.Select([ + ast.Query(queries=[ast.Select([ ast.Target(ast.Column('date'), None), ast.Target(ast.Column('flag'), None), ast.Target(ast.Function('maxwidth', [ @@ -590,13 +604,13 @@ def test_journal(self): ast.Target(ast.Column('account'), None), ast.Target(ast.Column('position'), None), ast.Target(ast.Column('balance'), None), - ], - None, None, None, None, None, None, None)) + ], None, None, None, None)], + set_operators=[], order_by=None, limit=None, pivot_by=None)) def test_journal_with_account(self): journal = parser.parse("JOURNAL 'liabilities';") select = compiler.transform_journal(journal) - self.assertEqual(select, ast.Select([ + self.assertEqual(select, ast.Query(queries=[ast.Select([ ast.Target(ast.Column('date'), None), ast.Target(ast.Column('flag'), None), ast.Target(ast.Function('maxwidth', [ @@ -608,15 +622,15 @@ def test_journal_with_account(self): ast.Target(ast.Column('account'), None), ast.Target(ast.Column('position'), None), ast.Target(ast.Column('balance'), None), - ], - None, + ], None, ast.Match(ast.Column('account'), ast.Constant('liabilities')), - None, None, None, None, None)) + None, None, None)], + set_operators=[], order_by=None, limit=None, pivot_by=None)) def test_journal_with_account_and_from(self): journal = parser.parse("JOURNAL 'liabilities' FROM year = 2014;") select = compiler.transform_journal(journal) - self.assertEqual(select, ast.Select([ + self.assertEqual(select, ast.Query(queries=[ast.Select([ ast.Target(ast.Column('date'), None), ast.Target(ast.Column('flag'), None), ast.Target(ast.Function('maxwidth', [ @@ -631,12 +645,13 @@ def test_journal_with_account_and_from(self): ], ast.From(ast.Equal(ast.Column('year'), ast.Constant(2014)), None, None, None), ast.Match(ast.Column('account'), ast.Constant('liabilities')), - None, None, None, None, None)) + None, None, None)], + set_operators=[], order_by=None, limit=None, pivot_by=None)) def test_journal_with_account_func_and_from(self): journal = parser.parse("JOURNAL 'liabilities' AT cost FROM year = 2014;") select = compiler.transform_journal(journal) - self.assertEqual(select, ast.Select([ + self.assertEqual(select, ast.Query(queries=[ast.Select([ ast.Target(ast.Column('date'), None), ast.Target(ast.Column('flag'), None), ast.Target(ast.Function('maxwidth', [ @@ -651,7 +666,8 @@ def test_journal_with_account_func_and_from(self): ], ast.From(ast.Equal(ast.Column('year'), ast.Constant(2014)), None, None, None), ast.Match(ast.Column('account'), ast.Constant('liabilities')), - None, None, None, None, None)) + None, None, None)], + set_operators=[], order_by=None, limit=None, pivot_by=None)) class TestTranslationBalance(CompileSelectBase): @@ -666,31 +682,31 @@ class TestTranslationBalance(CompileSelectBase): def test_balance(self): balance = parser.parse("BALANCES;") select = compiler.transform_balances(balance) - self.assertEqual(select, ast.Select([ + self.assertEqual(select, ast.Query(queries=[ast.Select([ ast.Target(ast.Column('account'), None), ast.Target(ast.Function('sum', [ ast.Column('position') ]), None), - ], - None, None, self.group_by, self.order_by, None, None, None)) + ], None, None, self.group_by, None, None)], + set_operators=[], order_by=self.order_by, limit=None, pivot_by=None)) def test_balance_with_units(self): balance = parser.parse("BALANCES AT cost;") select = compiler.transform_balances(balance) - self.assertEqual(select, ast.Select([ + self.assertEqual(select, ast.Query(queries=[ast.Select([ ast.Target(ast.Column('account'), None), ast.Target(ast.Function('sum', [ ast.Function('cost', [ ast.Column('position') ]) ]), None) - ], - None, None, self.group_by, self.order_by, None, None, None)) + ], None, None, self.group_by, None, None)], + set_operators=[], order_by=self.order_by, limit=None, pivot_by=None)) def test_balance_with_units_and_from(self): balance = parser.parse("BALANCES AT cost FROM year = 2014;") select = compiler.transform_balances(balance) - self.assertEqual(select, ast.Select([ + self.assertEqual(select, ast.Query(queries=[ast.Select([ ast.Target(ast.Column('account'), None), ast.Target(ast.Function('sum', [ ast.Function('cost', [ @@ -699,26 +715,33 @@ def test_balance_with_units_and_from(self): ]), None), ], ast.From(ast.Equal(ast.Column('year'), ast.Constant(2014)), None, None, None), - None, self.group_by, self.order_by, None, None, None)) + None, self.group_by, None, None)], + set_operators=[], order_by=self.order_by, limit=None, pivot_by=None)) def test_print(self): self.assertCompile( qc.EvalQuery( - Table('entries'), - [qc.EvalTarget(qc.EvalRow(), 'ROW(*)', False)], - None, None, None, None, None, False), + select=qc.EvalSelect( + Table('entries'), + [qc.EvalTarget(qc.EvalRow(), 'ROW(*)', False)], + None, None, None, False), + order_spec=None, + limit=None), "PRINT;", ) def test_print_from(self): self.assertCompile( qc.EvalQuery( - Table('entries'), - [qc.EvalTarget(qc.EvalRow(), 'ROW(*)', False)], - qc.Operator(ast.Equal, [ - Column('year', int), - qc.EvalConstant(2014) ]), - None, None, None, None, False), + select=qc.EvalSelect( + Table('entries'), + [qc.EvalTarget(qc.EvalRow(), 'ROW(*)', False)], + qc.Operator(ast.Equal, [ + Column('year', int), + qc.EvalConstant(2014) ]), + None, None, False), + order_spec=None, + limit=None), "PRINT FROM year = 2014;") @@ -736,18 +759,24 @@ def compile(self, query, params): def test_named_parameters(self): query = self.compile('''SELECT %(x)s + %(y)s''', {'x': 1, 'y': 2}) self.assertEqual(query, qc.EvalQuery( - Table(''), [ - # addition of constants is optimized away - qc.EvalTarget(qc.EvalConstant(3), '%(x)s + %(y)s', False) - ], None, None, None, None, None, None)) + select=qc.EvalSelect( + Table(''), [ + # addition of constants is optimized away + qc.EvalTarget(qc.EvalConstant(3), '%(x)s + %(y)s', False) + ], None, None, None, None), + order_spec=None, + limit=None)) def test_positional_parameters(self): query = self.compile('''SELECT %s + %s''', (1, 2, )) self.assertEqual(query, qc.EvalQuery( - Table(''), [ - # addition of constants is optimized away - qc.EvalTarget(qc.EvalConstant(3), '%s + %s', False) - ], None, None, None, None, None, None)) + select=qc.EvalSelect( + Table(''), [ + # addition of constants is optimized away + qc.EvalTarget(qc.EvalConstant(3), '%s + %s', False) + ], None, None, None, None), + order_spec=None, + limit=None)) def test_mixing_parameters(self): with self.assertRaises(ProgrammingError): @@ -780,7 +809,7 @@ def compile(self, query): def test_select_from(self): query = self.compile('''SELECT x FROM foo''') - self.assertEqual(query.table, self.conn.tables['foo']) + self.assertEqual(query.select.table, self.conn.tables['foo']) def test_select_from_invalid(self): with self.assertRaisesRegex(beanquery.ProgrammingError, 'column "qux" not found in table "postings"'): @@ -788,11 +817,11 @@ def test_select_from_invalid(self): def test_select_from_column(self): query = self.compile('''SELECT account FROM date''') - self.assertEqual(query.table, self.conn.tables['postings']) + self.assertEqual(query.select.table, self.conn.tables['postings']) def test_select_from_hash(self): query = self.compile('''SELECT x FROM #foo''') - self.assertEqual(query.table, self.conn.tables['foo']) + self.assertEqual(query.select.table, self.conn.tables['foo']) def test_select_from_hash_invalid(self): with self.assertRaisesRegex(beanquery.ProgrammingError, 'table "qux" does not exist'): @@ -808,7 +837,7 @@ def cleanup(): del self.conn.tables['date'] self.addCleanup(cleanup) query = self.compile('''SELECT year FROM #date''') - self.assertEqual(query.table, self.conn.tables['date']) + self.assertEqual(query.select.table, self.conn.tables['date']) class TestQuotedIdentifiers(unittest.TestCase): @@ -825,9 +854,9 @@ def compile(self, query): def test_from_quoted(self): query = self.compile('''SELECT * FROM postings''') - self.assertIs(query.table, self.conn.tables['postings']) + self.assertIs(query.select.table, self.conn.tables['postings']) query = self.compile('''SELECT * FROM "postings"''') - self.assertIs(query.table, self.conn.tables['postings']) + self.assertIs(query.select.table, self.conn.tables['postings']) def test_quoted_target(self): query = self.compile('''SELECT date FROM postings''') @@ -850,3 +879,218 @@ def test_quoted_string_in_expression(self): # if the double quoted string is not a table name, it is a string literal ideed query = self.compile('''SELECT "a" + "b" FROM postings''') self.assertEqual(query.c_targets[0].c_expr.value, 'ab') + + +class FakeQuery: + """Minimal query stub that returns fixed (columns, rows) without any SQL.""" + + def __init__(self, columns, rows): + self._columns = columns + self._rows = rows + + @property + def columns(self): + return self._columns + + def __call__(self): + return self._columns, list(self._rows) + + +class TestEvalUnion(unittest.TestCase): + """Unit tests for EvalUnion.__call__ in isolation from the parser/compiler. + + EvalUnion returns (result_types, rows, visible_mask) like EvalSelect. + ORDER BY and LIMIT are handled by wrapping EvalUnion in EvalQuery. + """ + + COL_N = qc.EvalTarget(qc.EvalConstant(None, int), 'n', False) + COL_A = qc.EvalTarget(qc.EvalConstant(None, int), 'a', False) + COL_B = qc.EvalTarget(qc.EvalConstant(None, str), 'b', False) + + def _union(self, queries, set_operators, c_targets=None): + """Create an EvalUnion. For ORDER BY/LIMIT tests, wrap in EvalQuery.""" + if c_targets is None: + c_targets = [self.COL_N] + return qc.EvalUnion( + queries=queries, + set_operators=set_operators, + c_targets=c_targets, + ) + + def _query(self, queries, set_operators, c_targets=None, order_spec=None, limit=None): + """Create an EvalQuery wrapping an EvalUnion for ORDER BY/LIMIT tests.""" + union = self._union(queries, set_operators, c_targets) + return qc.EvalQuery( + select=union, + order_spec=order_spec or [], + limit=limit, + ) + + # -- columns property -- + + def test_columns_returns_visible_targets(self): + """EvalUnion.columns should return targets with non-None names.""" + invisible = qc.EvalTarget(qc.EvalConstant(None, int), None, False) + visible = qc.EvalTarget(qc.EvalConstant(None, int), 'n', False) + q1 = FakeQuery([visible], [(1,)]) + u = self._union([q1], [], c_targets=[invisible, visible]) + self.assertEqual(u.columns, [visible]) + + # -- row concatenation and deduplication -- + + def test_union_all_concatenates_without_deduplication(self): + """Given two queries sharing a duplicate row, + UNION ALL should return all rows from both queries + including the duplicate.""" + q1 = FakeQuery([self.COL_N], [(1,), (2,)]) + q2 = FakeQuery([self.COL_N], [(2,), (3,)]) + _, rows, _ = self._union([q1, q2], ['union_all'])() + self.assertEqual(rows, [(1,), (2,), (2,), (3,)]) + + def test_union_removes_duplicates_preserving_first_seen_order(self): + """Given two queries sharing a duplicate row, + UNION should deduplicate while keeping + the order in which rows were first encountered.""" + q1 = FakeQuery([self.COL_N], [(1,), (2,)]) + q2 = FakeQuery([self.COL_N], [(2,), (3,)]) + _, rows, _ = self._union([q1, q2], ['union'])() + self.assertEqual(rows, [(1,), (2,), (3,)]) + + # -- mixed UNION / UNION ALL chains -- + + def test_union_then_union_all_deduplicates_first_pair_only(self): + """Given A UNION B UNION ALL C where A=B=C=(1,), + the UNION between A and B should collapse them to one row, + then UNION ALL should append C's row unchanged, + yielding two rows of (1,).""" + q1 = FakeQuery([self.COL_N], [(1,)]) + q2 = FakeQuery([self.COL_N], [(1,)]) + q3 = FakeQuery([self.COL_N], [(1,)]) + _, rows, _ = self._union([q1, q2, q3], ['union', 'union_all'])() + self.assertEqual(rows, [(1,), (1,)]) + + def test_union_all_then_union_deduplicates_all_accumulated_rows(self): + """Given A UNION ALL B UNION C where A=B=(1,) and C=(2,), + UNION ALL should keep the duplicate (1,) from A and B, + then the final UNION should deduplicate the entire accumulated set, + yielding one (1,) and one (2,).""" + q1 = FakeQuery([self.COL_N], [(1,)]) + q2 = FakeQuery([self.COL_N], [(1,)]) + q3 = FakeQuery([self.COL_N], [(2,)]) + _, rows, _ = self._union([q1, q2, q3], ['union_all', 'union'])() + self.assertEqual(rows, [(1,), (2,)]) + + # -- ORDER BY (via EvalQuery wrapper) -- + + def test_order_by_asc_sorts_ascending(self): + """Given unsorted rows and order_spec [(0, ASC)], + EvalQuery wrapping EvalUnion should return rows sorted ascending.""" + q1 = FakeQuery([self.COL_N], [(3,), (1,), (2,)]) + _, rows = self._query([q1], [], order_spec=[(0, ast.Ordering.ASC)])() + self.assertEqual(rows, [(1,), (2,), (3,)]) + + def test_order_by_desc_sorts_descending(self): + """Given unsorted rows and order_spec [(0, DESC)], + EvalQuery wrapping EvalUnion should return rows sorted descending.""" + q1 = FakeQuery([self.COL_N], [(1,), (3,), (2,)]) + _, rows = self._query([q1], [], order_spec=[(0, ast.Ordering.DESC)])() + self.assertEqual(rows, [(3,), (2,), (1,)]) + + def test_order_by_non_first_column(self): + """Given two-column rows and order_spec [(1, ASC)], + EvalQuery should sort by the second column.""" + q1 = FakeQuery([self.COL_A, self.COL_B], [(1, 'b'), (2, 'a')]) + _, rows = self._query([q1], [], c_targets=[self.COL_A, self.COL_B], + order_spec=[(1, ast.Ordering.ASC)])() + self.assertEqual(rows, [(2, 'a'), (1, 'b')]) + + # -- LIMIT (via EvalQuery wrapper) -- + + def test_limit_truncates_result(self): + """Given three rows and limit=2, + EvalQuery wrapping EvalUnion should return only the first two rows.""" + q1 = FakeQuery([self.COL_N], [(1,), (2,), (3,)]) + _, rows = self._query([q1], [], limit=2)() + self.assertEqual(rows, [(1,), (2,)]) + + def test_limit_none_returns_all_rows(self): + """Given limit=None, EvalQuery should not truncate the result.""" + q1 = FakeQuery([self.COL_N], [(1,), (2,), (3,)]) + _, rows = self._query([q1], [], limit=None)() + self.assertEqual(len(rows), 3) + + def test_order_by_desc_with_limit_returns_top_n(self): + """Given three queries yielding 1, 2, 3, ORDER BY 1 DESC LIMIT 2 + should return the two largest values in descending order.""" + q1 = FakeQuery([self.COL_N], [(1,)]) + q2 = FakeQuery([self.COL_N], [(2,)]) + q3 = FakeQuery([self.COL_N], [(3,)]) + _, rows = self._query( + [q1, q2, q3], ['union_all', 'union_all'], + order_spec=[(0, ast.Ordering.DESC)], + limit=2, + )() + self.assertEqual(rows, [(3,), (2,)]) + + # -- result_types passthrough -- + + def test_call_returns_result_types_from_c_targets(self): + """__call__ should return result_types derived from c_targets.""" + q1 = FakeQuery([self.COL_A, self.COL_B], [(1, 'x')]) + q2 = FakeQuery([self.COL_A, self.COL_B], [(2, 'y')]) + result_types, _, _ = self._union([q1, q2], ['union_all'], + c_targets=[self.COL_A, self.COL_B])() + self.assertEqual(len(result_types), 2) + self.assertEqual(result_types[0].name, 'a') + self.assertEqual(result_types[1].name, 'b') + + # -- edge cases -- + + def test_empty_subquery_contributes_no_rows(self): + """Given one empty and one non-empty sub-query joined by UNION ALL, + the empty sub-query should contribute nothing to the result.""" + q1 = FakeQuery([self.COL_N], []) + q2 = FakeQuery([self.COL_N], [(1,)]) + _, rows, _ = self._union([q1, q2], ['union_all'])() + self.assertEqual(rows, [(1,)]) + + +class TestTryCoerceOperand(unittest.TestCase): + """Tests for Compiler._try_coerce_operand helper method. + + Note: The following behaviors are tested via integration tests in + query_execute_test.py and do not have dedicated unit tests here: + - int to Decimal coercion (tested in test_operators: SELECT 2.0 * 2) + - Decimal to int coercion (tested in test_operators: SELECT 2 * 2.0) + - object to Decimal coercion (tested in test_operators_type_inference) + - Value preservation through coercion (tested in test_operators) + """ + + @classmethod + def setUpClass(cls): + cls.context = Connection() + cls.context.tables['test'] = test.Table(0) + cls.compiler = compiler.Compiler(cls.context) + cls.compiler.table = cls.context.tables['test'] + + def test_same_type_returns_operand(self): + """When operand type matches target type, return operand unchanged (fast path).""" + operand = qc.EvalConstant(D('42'), D) + result = self.compiler._try_coerce_operand(operand, D) + self.assertIs(result, operand) + + def test_int_target_promotes_to_decimal(self): + """When target is int, promote to Decimal to avoid information loss.""" + operand = qc.EvalConstant(D('42'), D) + result = self.compiler._try_coerce_operand(operand, int) + self.assertIsNotNone(result) + # Should be coerced to Decimal, not int + self.assertEqual(result.dtype, D) + self.assertEqual(result(None), D('42')) + + def test_unsupported_coercion_returns_none(self): + """When coercion is not possible, return None.""" + # Try to coerce to a type that's not in types.MAP + operand = qc.EvalConstant(42, int) + result = self.compiler._try_coerce_operand(operand, object) + self.assertIsNone(result) \ No newline at end of file diff --git a/beanquery/query_env.py b/beanquery/query_env.py index 7606b193..e5198e90 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -684,8 +684,11 @@ def date_part(field, x): @function([str], relativedelta) def interval(x): - """Construct a relative time interval.""" - m = re.fullmatch(r'([-+]?[0-9]+)\s+(day|month|year)s?', x) + """Construct a relative time interval. Example argument: '2 weeks'. + Further options are day, month, year, century, millenium (plural s can be + appended). Use to modify dates: `date + interval(...)`""" + x = x.lower() + m = re.fullmatch(r'([-+]?[0-9]+)\s+([a-z]+?)s?', x) if not m: return None number = int(m.group(1)) @@ -702,7 +705,7 @@ def interval(x): return relativedelta(years=number * 10) if unit == 'century': return relativedelta(years=number * 100) - if unit == 'millennium': + if unit in ['millennium', 'millenia']: return relativedelta(years=number * 1000) return None diff --git a/beanquery/query_execute.py b/beanquery/query_execute.py index e301cb5c..abd2b185 100644 --- a/beanquery/query_execute.py +++ b/beanquery/query_execute.py @@ -14,15 +14,28 @@ class Unique: - def __init__(self, columns): + """Generator that yields only the first occurrence of each unique row. + + Handles non-hashable column types (e.g., Inventory) by wrapping values + into hashable representations via hashable.make(). + + Args: + columns: Column types for hashable wrapping. + key: Optional function to extract the value to hash from each row. + """ + + def __init__(self, columns, key=None): self.wrap = hashable.make(columns) + self.key = key def __call__(self, iterable): wrap = self.wrap + key = self.key seen = set() add = seen.add for obj in iterable: - h = wrap(obj) + k = key(obj) if key else obj + h = wrap(k) if h not in seen: add(h) yield obj @@ -101,35 +114,58 @@ def func(obj): return func +def execute_query(query): + """Execute a compiled query with ORDER BY and LIMIT. + + Args: + query: An instance of EvalQuery wrapping an EvalSelect. + Returns: + A pair of (result_types, result_rows). + """ + result_types, rows, visible_mask = query.select() + + # ORDER BY requires materialization. + if query.order_spec: + rows = list(rows) + for reverse, spec in itertools.groupby(reversed(query.order_spec), key=operator.itemgetter(1)): + indexes = reversed([i[0] for i in spec]) + rows.sort(key=nullitemgetter(*indexes), reverse=reverse) + + # Extract visible columns. + visible_indexes = [i for i, v in enumerate(visible_mask) if v] + result_types = tuple(result_types[i] for i in visible_indexes) + rows = (tuple(row[i] for i in visible_indexes) for row in rows) + + # Apply LIMIT. + if query.limit is not None: + rows = itertools.islice(rows, query.limit) + + return result_types, list(rows) + + def execute_select(query): """Given a compiled select statement, execute the query. Args: - query: An instance of a query_compile.Query - entries: A list of directives. - options_map: A parser's option_map. + query: An instance of EvalSelect. Returns: - A pair of: - result_types: A list of (name, data-type) item pairs. - result_rows: A list of ResultRow tuples of length and types described by - 'result_types'. + A tuple of: + result_types: A list of Column(name, dtype) for ALL columns. + result_rows: A list of tuples with ALL columns (including invisible). + visible_mask: A list of bools, True if column is visible. """ - # Figure out the result types that describe what we return. + # Figure out the result types for ALL columns. result_types = tuple(cursor.Column(target.name, target.c_expr.dtype) - for target in query.c_targets - if target.name is not None) + for target in query.c_targets) + + # Track which columns are visible (have a name). + visible_mask = [target.name is not None for target in query.c_targets] # Pre-compute lists of the expressions to evaluate. group_indexes = (set(query.group_indexes) if query.group_indexes is not None else query.group_indexes) - # Indexes of the columns for result rows and order rows. - result_indexes = [index - for index, c_target in enumerate(query.c_targets) - if c_target.name] - order_spec = query.order_spec - # Dispatch between the non-aggregated queries and aggregated queries. c_where = query.c_where rows = [] @@ -143,7 +179,7 @@ def execute_select(query): # Iterate over all the postings once. for context in query.table: if c_where is None or c_where(context): - values = [c_expr(context) for c_expr in c_target_exprs] + values = tuple(c_expr(context) for c_expr in c_target_exprs) rows.append(values) else: @@ -213,28 +249,15 @@ def create(): if not values[query.having_index]: continue - rows.append(values) - - # Apply ORDER BY. - if order_spec is not None: - # Process the order-by clauses grouped by their ordering direction. - for reverse, spec in itertools.groupby(reversed(order_spec), key=operator.itemgetter(1)): - indexes = reversed([i[0] for i in spec]) - # The rows may contain None values: nullitemgetter() - # replaces these with a special value that compares - # smaller than anything else. - rows.sort(key=nullitemgetter(*indexes), reverse=reverse) + rows.append(tuple(values)) - # Extract results set and convert into tuples. - rows = (tuple(row[i] for i in result_indexes) for row in rows) - - # Apply DISTINCT. + # DISTINCT must operate on visible columns only, ignoring columns that were + # auto-added for GROUP BY or ORDER BY. Unique returns a generator to create + # a lazy pipeline so LIMIT can cut off early. if query.distinct: - unique = Unique(result_types) + visible_indexes = [i for i, v in enumerate(visible_mask) if v] + visible_types = tuple(result_types[i] for i in visible_indexes) + unique = Unique(visible_types, key=lambda row: tuple(row[i] for i in visible_indexes)) rows = unique(rows) - # Apply LIMIT. - if query.limit is not None: - rows = itertools.islice(rows, query.limit) - - return result_types, list(rows) + return result_types, rows, visible_mask diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 0358d632..3d8c230e 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -512,8 +512,8 @@ def compile(self, query): @staticmethod def filter_entries(query): entries = [] - expr = query.c_where - for entry in query.table: + expr = query.select.c_where + for entry in query.select.table: if expr is None or expr(entry): entries.append(entry) return entries @@ -1881,3 +1881,160 @@ def test_csv_source(self, filename): self.assertEqual(names, ['id', 'name', 'check', 'date', 'value']) types = [column.dtype for column in conn.tables['test'].columns.values()] self.assertEqual(types, [int, str, bool, datetime.date, Decimal]) + + +class TestUnion(QueryBase): + INPUT = """ + 2022-01-01 open Assets:Bank + 2022-01-01 open Expenses:Food + 2022-01-01 open Expenses:Transport + + 2022-01-15 * "Lunch" + Assets:Bank -10.00 USD + Expenses:Food 10.00 USD + + 2022-01-20 * "Dinner" + Assets:Bank -20.00 USD + Expenses:Food 20.00 USD + + 2022-02-01 * "Bus" + Assets:Bank -5.00 USD + Expenses:Transport 5.00 USD + """ + + def test_basic_union(self): + """Given three SELECTs returning the same constant value from the postings table, + UNION should deduplicate the combined result, returning only unique values.""" + curs = self.ctx.execute( + """SELECT 1 AS n + UNION + SELECT 1 AS n + UNION + SELECT 2 AS n""" + ) + self.assertEqual(curs.fetchall(), [(1,), (2,)]) + + def test_union_all(self): + """Given two SELECTs returning the same constant value from the postings table, + UNION ALL should concatenate all rows from both SELECTs without deduplication.""" + curs = self.ctx.execute( + """SELECT 1 AS n + UNION ALL + SELECT 1 AS n""" + ) + self.assertEqual(curs.fetchall(), [(1,)] * 12) + + def test_union_mixed(self): + """Given three SELECTs combined by UNION then UNION ALL, + the first UNION should deduplicate (yielding 1 row), then UNION ALL + should append all 6 rows from the third SELECT without deduplication, + resulting in 7 rows""" + curs = self.ctx.execute( + """SELECT 1 AS n + UNION + SELECT 1 AS n + UNION ALL + SELECT 1 AS n""" + ) + self.assertEqual(curs.fetchall(), [(1,)] * 7) + + def test_union_order_by(self): + """Given three SELECTs returning different constant values combined by UNION, + ORDER BY should sort the deduplicated combined result in ascending order.""" + curs = self.ctx.execute( + """SELECT 2 AS n + UNION + SELECT 1 AS n + UNION + SELECT 3 AS n + ORDER BY 1""" + ) + self.assertEqual(curs.fetchall(), [(1,), (2,), (3,)]) + + def test_union_limit(self): + """Given three SELECTs returning different constant values combined by UNION, + LIMIT should truncate the deduplicated combined result to the specified number of rows.""" + curs = self.ctx.execute( + """SELECT 1 AS n + UNION + SELECT 2 AS n + UNION SELECT 3 AS n + LIMIT 2""" + ) + self.assertEqual(curs.fetchall(), [(1,), (2,)]) + + def test_union_order_by_desc_limit(self): + """Given three SELECTs returning different constant values combined by UNION, + ORDER BY DESC should sort the deduplicated combined result in descending order, + and LIMIT should then truncate to the specified number of rows.""" + curs = self.ctx.execute( + """SELECT 1 AS n + UNION + SELECT 2 AS n + UNION + SELECT 3 AS n + ORDER BY 1 DESC + LIMIT 2""" + ) + self.assertEqual(curs.fetchall(), [(3,), (2,)]) + + def test_union_column_count_mismatch(self): + """Given two SELECTs with different column counts, + compilation should raise an error indicating column count mismatch.""" + with self.assertRaises(CompilationError) as cm: + self.ctx.execute("SELECT 1, 2 UNION SELECT 1") + self.assertIn('same number of columns', str(cm.exception)) + + def test_union_type_mismatch(self): + """Given two SELECTs with incompatible column types, + compilation should raise an error indicating type mismatch.""" + with self.assertRaises(CompilationError) as cm: + self.ctx.execute("SELECT 'a' UNION SELECT 2022-01-01") + self.assertIn('type mismatch', str(cm.exception)) + + def test_union_compatible_numeric_types(self): + """Given two SELECTs returning compatible numeric types (int and Decimal), + UNION should deduplicate the combined result, returning unique values.""" + curs = self.ctx.execute("SELECT 1 UNION SELECT 1.5") + rows = curs.fetchall() + self.assertEqual(len(rows), 2) + + def test_union_with_from(self): + """Given two SELECTs with explicit FROM clauses selecting different accounts, + UNION should combine the results, returning unique account names.""" + curs = self.ctx.execute(""" + SELECT account FROM OPEN ON 2022-01-01 WHERE account ~ 'Food' + UNION + SELECT account FROM OPEN ON 2022-01-01 WHERE account ~ 'Transport' + """) + rows = curs.fetchall() + self.assertEqual(len(rows), 2) + accounts = {row[0] for row in rows} + self.assertIn('Expenses:Food', accounts) + self.assertIn('Expenses:Transport', accounts) + + def test_union_subquery(self): + """UNION with parenthesized subqueries.""" + curs = self.ctx.execute(""" + (SELECT 3 AS n ORDER BY 1 LIMIT 1) + UNION + (SELECT 1 AS n) + """) + rows = curs.fetchall() + self.assertEqual(set(rows), {(3,), (1,)}) + + def test_union_column_names_from_first(self): + """Column names come from first query.""" + curs = self.ctx.execute("SELECT 1 AS first_name UNION SELECT 2 AS second_name") + self.assertEqual(curs.description[0].name, 'first_name') + + def test_union_order_by_invisible_column_rejected(self): + """UNION ORDER BY on expressions not in SELECT list should be rejected.""" + with self.assertRaises(CompilationError) as cm: + self.ctx.execute(""" + SELECT account FROM OPEN ON 2022-01-01 + UNION + SELECT account FROM OPEN ON 2022-01-01 + ORDER BY length(account) + """) + self.assertIn('SELECT list', str(cm.exception))