260 changes: 232 additions & 28 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import sys
import urllib.parse
from collections import Counter
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -20,6 +21,8 @@
MutableMapping,
)

from bidict import MutableBidirectionalMapping, bidict

import ibis
import ibis.common.exceptions as exc
import ibis.config
Expand Down Expand Up @@ -214,12 +217,11 @@ def _ipython_key_completions_(self) -> list[str]:
return self._backend.list_tables()


# should have a better name
class ResultHandler:
class _FileIOHandler:
@staticmethod
def _import_pyarrow():
try:
import pyarrow
import pyarrow # noqa: ICN001
except ImportError:
raise ModuleNotFoundError(
"Exporting to arrow formats requires `pyarrow` but it is not installed"
Expand Down Expand Up @@ -322,8 +324,128 @@ def to_pyarrow_batches(
"""
raise NotImplementedError

def read_parquet(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a parquet file as a table in the current backend.
Parameters
----------
path
The data source.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to the backend loading function.
Returns
-------
ir.Table
The just-registered table
"""
raise NotImplementedError(
f"{self.name} does not support direct registration of parquet data."
)

def read_csv(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a CSV file as a table in the current backend.
Parameters
----------
path
The data source. A string or Path to the CSV file.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to the backend loading function.
Returns
-------
ir.Table
The just-registered table
"""
raise NotImplementedError(
f"{self.name} does not support direct registration of CSV data."
)

@util.experimental
def to_parquet(
self,
expr: ir.Table,
path: str | Path,
*,
params: Mapping[ir.Scalar, Any] | None = None,
**kwargs: Any,
) -> None:
"""Write the results of executing the given expression to a parquet file.
This method is eager and will execute the associated expression
immediately.
Parameters
----------
expr
The ibis expression to execute and persist to parquet.
path
The data source. A string or Path to the parquet file.
params
Mapping of scalar parameter expressions to value.
**kwargs
Additional keyword arguments passed to pyarrow.parquet.ParquetWriter
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html
"""
self._import_pyarrow()
import pyarrow.parquet as pq

batch_reader = expr.to_pyarrow_batches(params=params)

with pq.ParquetWriter(path, batch_reader.schema) as writer:
for batch in batch_reader:
writer.write_batch(batch)

@util.experimental
def to_csv(
self,
expr: ir.Table,
path: str | Path,
*,
params: Mapping[ir.Scalar, Any] | None = None,
**kwargs: Any,
) -> None:
"""Write the results of executing the given expression to a CSV file.
This method is eager and will execute the associated expression
immediately.
class BaseBackend(abc.ABC, ResultHandler):
Parameters
----------
expr
The ibis expression to execute and persist to CSV.
path
The data source. A string or Path to the CSV file.
params
Mapping of scalar parameter expressions to value.
kwargs
Additional keyword arguments passed to pyarrow.csv.CSVWriter
https://arrow.apache.org/docs/python/generated/pyarrow.csv.CSVWriter.html
"""
self._import_pyarrow()
import pyarrow.csv as pcsv

batch_reader = expr.to_pyarrow_batches(params=params)

with pcsv.CSVWriter(path, batch_reader.schema) as writer:
for batch in batch_reader:
writer.write_batch(batch)


class BaseBackend(abc.ABC, _FileIOHandler):
"""Base backend class.
All Ibis backends must subclass this class and implement all the
Expand All @@ -337,6 +459,12 @@ class BaseBackend(abc.ABC, ResultHandler):
def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
self._con_kwargs: dict[str, Any] = kwargs
# expression cache
self._query_cache: MutableBidirectionalMapping[
ops.TableNode, ops.PhysicalTable
] = bidict()

self._refs = Counter()

def __getstate__(self):
return dict(
Expand All @@ -346,6 +474,9 @@ def __getstate__(self):
_con_kwargs=self._con_kwargs,
)

def __rich_repr__(self):
yield "name", self.name

def __hash__(self):
return hash(self.db_identity)

Expand Down Expand Up @@ -377,10 +508,10 @@ def connect(self, *args, **kwargs) -> BaseBackend:
Parameters
----------
args
*args
Mandatory connection parameters, see the docstring of `do_connect`
for details.
kwargs
**kwargs
Extra connection parameters, see the docstring of `do_connect` for
details.
Expand Down Expand Up @@ -586,6 +717,14 @@ def compile(
"""Compile an expression."""
return self.compiler.to_sql(expr, params=params)

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
"""Convert an expression to a SQL string.
Called by `ibis.to_sql`/`ibis.show_sql`, gives the backend an
opportunity to generate nicer SQL for human consumption.
"""
raise NotImplementedError(f"Backend '{self.name}' backend doesn't support SQL")

def execute(self, expr: ir.Expr) -> Any:
"""Execute an expression."""

Expand Down Expand Up @@ -630,17 +769,19 @@ def create_database(self, name: str, force: bool = False) -> None:
f'Backend "{self.name}" does not implement "create_database"'
)

@abc.abstractmethod
def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
) -> None:
temp: bool = False,
overwrite: bool = False,
) -> ir.Table:
"""Create a new table.
Not all backends implement this method.
Parameters
----------
name
Expand All @@ -655,14 +796,22 @@ def create_table(
database
Name of the database where the table will be created, if not the
default.
temp
Whether a table is temporary or not
overwrite
Whether to clobber existing data
Returns
-------
Table
The table that was created.
"""
raise NotImplementedError(
f'Backend "{self.name}" does not implement "create_table"'
)

@abc.abstractmethod
def drop_table(
self,
name: str,
*,
database: str | None = None,
force: bool = False,
) -> None:
Expand All @@ -681,31 +830,38 @@ def drop_table(
f'Backend "{self.name}" does not implement "drop_table"'
)

@abc.abstractmethod
def create_view(
self,
name: str,
expr: ir.Table,
obj: ir.Table,
*,
database: str | None = None,
) -> None:
"""Create a view.
overwrite: bool = False,
) -> ir.Table:
"""Create a new view from an expression.
Parameters
----------
name
Name for the new view.
expr
An Ibis table expression that will be used to extract the query
of the view.
Name of the new view.
obj
An Ibis table expression that will be used to create the view.
database
Name of the database where the view will be created, if not the
default.
Name of the database where the view will be created, if not
provided the database's default is used.
overwrite
Whether to clobber an existing view with the same name
Returns
-------
Table
The view that was created.
"""
raise NotImplementedError(
f'Backend "{self.name}" does not implement "create_view"'
)

@abc.abstractmethod
def drop_view(
self, name: str, database: str | None = None, force: bool = False
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
"""Drop a view.
Expand All @@ -718,9 +874,6 @@ def drop_view(
force
If `False`, an exception is raised if the view does not exist.
"""
raise NotImplementedError(
f'Backend "{self.name}" does not implement "drop_view"'
)

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
Expand Down Expand Up @@ -749,6 +902,57 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
f"{cls.name} backend has not implemented `has_operation` API"
)

def _cached(self, expr):
"""Cache the provided expression.
All subsequent operations on the returned expression will be performed on the cached data.
Parameters
----------
expr
Table expression to cache
Returns
-------
Expr
Cached table
"""
op = expr.op()
if (result := self._query_cache.get(op)) is None:
name = util.generate_unique_table_name("cache")
self._load_into_cache(name, expr)
self._query_cache[op] = result = self.table(name).op()
self._refs[op] += 1
return ir.CachedTable(result)

def _release_cached(self, expr):
"""Releases the provided cached expression.
Parameters
----------
expr
Cached expression to release
"""
op = expr.op()
# we need to remove the expression representing the temp table as well
# as the expression that was used to create the temp table
#
# bidict automatically handles this for us; without it we'd have to
# do to the bookkeeping ourselves with two dicts
if (key := self._query_cache.inverse.get(op)) is None:
raise exc.IbisError(
"This expression has already been released. Did you call "
"`.release()` twice on the same expression?"
)

self._refs[key] -= 1

if not self._refs[key]:
del self._query_cache[key]
del self._refs[key]
self._clean_up_cached_table(op)


@functools.lru_cache(maxsize=None)
def _get_backend_names() -> frozenset[str]:
Expand Down
Empty file.
30 changes: 15 additions & 15 deletions ibis/expr/scope.py → ibis/backends/base/df/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
from __future__ import annotations

from collections import namedtuple
from typing import TYPE_CHECKING, Any, Iterable
from typing import Any, Iterable, Tuple

import pandas as pd

from ibis.backends.base.df.timecontext import TimeContextRelation, compare_timecontext
from ibis.expr.operations import Node
from ibis.expr.timecontext import TimeContextRelation, compare_timecontext

if TYPE_CHECKING:
from ibis.expr.typing import TimeContext
TimeContext = Tuple[pd.Timestamp, pd.Timestamp]

ScopeItem = namedtuple('ScopeItem', ['timecontext', 'value'])

Expand Down Expand Up @@ -153,20 +154,19 @@ def get_value(self, op: Node, timecontext: TimeContext | None = None) -> Any:
return None

def merge_scope(self, other_scope: Scope, overwrite=False) -> Scope:
"""merge items in other_scope into this scope.
"""Merge items in `other_scope` into this scope.
Parameters
----------
other_scope: Scope
other_scope
Scope to be merged with
overwrite: bool
if set to be True, force overwrite `value` if `op` already
exists.
overwrite
if `True`, force overwrite `value` if node already exists.
Returns
-------
Scope
a new Scope instance with items in two scope merged.
a new Scope instance with items in two scopes merged.
"""
result = Scope()

Expand All @@ -184,19 +184,19 @@ def merge_scope(self, other_scope: Scope, overwrite=False) -> Scope:
return result

def merge_scopes(self, other_scopes: Iterable[Scope], overwrite=False) -> Scope:
"""merge items in other_scopes into this scope.
"""Merge items in `other_scopes` into this scope.
Parameters
----------
other_scopes: Iterable[Scope]
other_scopes
scopes to be merged with
overwrite: Bool
if set to be True, force overwrite value if op already exists.
overwrite
if `True`, force overwrite value if node already exists.
Returns
-------
Scope
a new Scope instance with items in two scope merged.
a new Scope instance with items in input scopes merged.
"""
result = Scope()
for op in self:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,20 @@

import enum
import functools
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Tuple

import numpy as np
import pandas as pd

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis import config

TimeContext = Tuple[pd.Timestamp, pd.Timestamp]


if TYPE_CHECKING:
import pandas as pd
from ibis.backends.base.df.scope import Scope

from ibis.expr.scope import Scope
from ibis.expr.typing import TimeContext

# In order to use time context feature, there must be a column of Timestamp
# type, and named as 'time' in Table. This TIME_COL constant will be
Expand Down Expand Up @@ -111,7 +112,6 @@ def canonicalize_context(
timecontext: TimeContext | None,
) -> TimeContext | None:
"""Canonicalize a timecontext with type pandas.Timestamp for its begin and end time."""
import pandas as pd

SUPPORTS_TIMESTAMP_TYPE = pd.Timestamp
if not isinstance(timecontext, tuple) or len(timecontext) != 2:
Expand Down Expand Up @@ -162,7 +162,7 @@ def construct_time_context_aware_series(
Examples
--------
>>> import pandas as pd
>>> from ibis.expr.timecontext import construct_time_context_aware_series
>>> from ibis.backends.base.df.timecontext import construct_time_context_aware_series
>>> df = pd.DataFrame(
... {
... 'time': pd.Series(
Expand Down Expand Up @@ -211,8 +211,6 @@ def construct_time_context_aware_series(
Name: value, dtype: float64
The result is unchanged for a series already has 'time' as its index.
"""
import pandas as pd

time_col = get_time_col()
if time_col == frame.index.name:
time_index = frame.index
Expand Down Expand Up @@ -282,34 +280,25 @@ def adjust_context_asof_join(
return timecontext


@adjust_context.register(ops.Window)
@adjust_context.register(ops.WindowFunction)
def adjust_context_window(
op: ops.Window, scope: Scope, timecontext: TimeContext
op: ops.WindowFunction, scope: Scope, timecontext: TimeContext
) -> TimeContext:
import ibis.expr.types as ir
# TODO(kszucs): this file should be really moved to the pandas
# backend instead of the current central placement
from ibis.backends.pandas.execution import execute

# adjust time context by preceding and following
begin, end = timecontext

# TODO(kszucs): rewrite op.window.preceding to be an ops.Node
preceding = op.window.preceding
if preceding is not None:
if isinstance(preceding, ir.IntervalScalar):
# TODO(kszucs): this file should be really moved to the pandas
# backend instead of the current central placement
from ibis.backends.pandas.execution import execute

preceding = execute(preceding.op())
if preceding and not isinstance(preceding, (int, np.integer)):
begin = begin - preceding

following = op.window.following
if following is not None:
if isinstance(following, ir.IntervalScalar):
from ibis.backends.pandas.execution import execute

following = execute(following.op())
if following and not isinstance(following, (int, np.integer)):
end = end + following
if op.frame.start is not None:
value = execute(op.frame.start.value)
if value:
begin = begin - value

if op.frame.end is not None:
value = execute(op.frame.end.value)
if value:
end = end + value

return (begin, end)
24 changes: 16 additions & 8 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Iterable, Mapping

import sqlalchemy as sa

import ibis.common.exceptions as exc
import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.schema as sch
Expand All @@ -17,10 +16,9 @@
from ibis.backends.base.sql.compiler import Compiler

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

from ibis.expr.typing import TimeContext

__all__ = [
'BaseSQLBackend',
]
Expand All @@ -46,6 +44,8 @@ def _from_url(self, url: str) -> BaseBackend:
BaseBackend
A backend instance
"""
import sqlalchemy as sa

url = sa.engine.make_url(url)

kwargs = {}
Expand Down Expand Up @@ -78,6 +78,10 @@ def table(self, name: str, database: str | None = None) -> ir.Table:
Table
Table expression
"""
if database is not None and not isinstance(database, str):
raise exc.IbisTypeError(
f"`database` must be a string; got {type(database)}"
)
qualified_name = self._fully_qualified_name(name, database)
schema = self.get_schema(qualified_name)
node = self.table_class(qualified_name, schema, self)
Expand Down Expand Up @@ -139,6 +143,7 @@ def _cursor_batches(
limit: int | str | None = None,
chunk_size: int = 1_000_000,
) -> Iterable[list]:
self._register_in_memory_tables(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()

Expand Down Expand Up @@ -248,10 +253,10 @@ def execute(

return result

def _register_in_memory_table(self, table_op):
raise NotImplementedError
def _register_in_memory_table(self, _: ops.InMemoryTable) -> None:
raise NotImplementedError(self.name)

def _register_in_memory_tables(self, expr):
def _register_in_memory_tables(self, expr: ir.Expr) -> None:
if self.compiler.cheap_in_memory_tables:
for memtable in an.find_memtables(expr.op()):
self._register_in_memory_table(memtable)
Expand Down Expand Up @@ -306,7 +311,7 @@ def compile(
expr: ir.Expr,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
timecontext: TimeContext | None = None,
timecontext: tuple[pd.Timestamp, pd.Timestamp] | None = None,
) -> Any:
"""Compile an Ibis expression.
Expand All @@ -330,6 +335,9 @@ def compile(
"""
return self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
return str(self.compile(expr, **kwargs))

def explain(
self,
expr: ir.Expr | str,
Expand Down
264 changes: 175 additions & 89 deletions ibis/backends/base/sql/alchemy/__init__.py

Large diffs are not rendered by default.

303 changes: 80 additions & 223 deletions ibis/backends/base/sql/alchemy/datatypes.py

Large diffs are not rendered by default.

89 changes: 64 additions & 25 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import functools

import sqlalchemy as sa
import toolz
from sqlalchemy import sql

import ibis.expr.analysis as an
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.database import AlchemyTable
from ibis.backends.base.sql.alchemy.translator import (
Expand Down Expand Up @@ -113,15 +115,7 @@ def _format_table(self, op):
backend = child_expr._find_backend()
backend._create_temp_view(view=result, definition=definition)
elif isinstance(ref_op, ops.InMemoryTable):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)

if self.context.compiler.cheap_in_memory_tables:
result = sa.table(ref_op.name, *columns)
else:
# this has horrendous performance for medium to large tables
# should we warn?
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
result = self._format_in_memory_table(op, ref_op, translator)
else:
# A subquery
if ctx.is_extracted(ref_op):
Expand All @@ -144,6 +138,26 @@ def _format_table(self, op):
ctx.set_ref(op, result)
return result

def _format_in_memory_table(self, op, ref_op, translator):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
if self.context.compiler.cheap_in_memory_tables:
result = sa.table(ref_op.name, *columns)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
else:
raw_rows = (
sa.select(
*(
translator.translate(ops.Literal(val, dtype=type_))
for val, type_ in zip(row, op.schema.types)
)
)
for row in op.data.to_frame().itertuples(index=False)
)
result = sa.union_all(*raw_rows).alias(ref_op.name)
return result


class AlchemySelect(Select):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -193,22 +207,22 @@ def _compile_table_set(self):
def _add_select(self, table_set):
to_select = []

context = self.context
select_set = self.select_set

has_select_star = False
for op in self.select_set:
for op in select_set:
if isinstance(op, ops.Value):
arg = self._translate(op, named=True)
elif isinstance(op, ops.TableNode):
arg = context.get_ref(op)
if op.equals(self.table_set):
cached_table = self.context.get_ref(op)
if cached_table is None:
has_select_star = True
if has_select_star := arg is None:
continue
else:
arg = table_set
else:
arg = self.context.get_ref(op)
if arg is None:
raise ValueError(op)
elif arg is None:
raise ValueError(op)
else:
raise TypeError(op)

Expand All @@ -228,38 +242,61 @@ def _add_select(self, table_set):
if self.distinct:
result = result.distinct()

# if we're SELECT *-ing or there's no table_set (e.g., SELECT 1) then
# we can return early
if has_select_star or table_set is None:
# only process unnest if the backend doesn't support SELECT UNNEST(...)
unnest_children = []
if not self.translator_class.supports_unnest_in_select:
unnest_children.extend(
map(
context.get_ref,
toolz.unique(an.find_toplevel_unnest_children(select_set)),
)
)

# if we're SELECT *-ing or there's no table_set (e.g., SELECT 1) *and*
# there are no unnest operations then we can return early
if (has_select_star or table_set is None) and not unnest_children:
return result

if unnest_children:
# get all the unnests plus the current froms of the result selection
# and build up the cross join
table_set = functools.reduce(
functools.partial(sa.sql.FromClause.join, onclause=sa.true()),
toolz.unique(toolz.concatv(unnest_children, result.get_final_froms())),
)

return result.select_from(table_set)

def _add_group_by(self, fragment):
# GROUP BY and HAVING
if not len(self.group_by):
nkeys = len(self.group_by)
if not nkeys:
return fragment

group_keys = [self._translate(arg) for arg in self.group_by]
if self.context.compiler.supports_indexed_grouping_keys:
group_keys = map(sa.literal_column, map(str, range(1, nkeys + 1)))
else:
group_keys = map(self._translate, self.group_by)

fragment = fragment.group_by(*group_keys)

if len(self.having) > 0:
if self.having:
having_args = [self._translate(arg) for arg in self.having]
having_clause = functools.reduce(sql.and_, having_args)
fragment = fragment.having(having_clause)

return fragment

def _add_where(self, fragment):
if not len(self.where):
if not self.where:
return fragment

args = [self._translate(pred, permit_subquery=True) for pred in self.where]
clause = functools.reduce(sql.and_, args)
return fragment.where(clause)

def _add_order_by(self, fragment):
if not len(self.order_by):
if not self.order_by:
return fragment

clauses = []
Expand Down Expand Up @@ -344,6 +381,8 @@ class AlchemyCompiler(Compiler):
intersect_class = AlchemyIntersection
difference_class = AlchemyDifference

supports_indexed_grouping_keys = True

@classmethod
def to_sql(cls, expr, context=None, params=None, exists=False):
if context is None:
Expand Down
188 changes: 102 additions & 86 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,23 @@
from typing import Any

import sqlalchemy as sa
from sqlalchemy.sql.functions import FunctionElement, GenericFunction

import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.expr.window as W
from ibis.backends.base.sql.alchemy.database import AlchemyTable


class substr(GenericFunction):
"""A generic substr function, so dialects can customize compilation."""

type = sa.types.String()
inherit_cache = True


def variance_reduction(func_name):
suffix = {'sample': 'samp', 'pop': 'pop'}

Expand Down Expand Up @@ -72,11 +79,10 @@ def get_sqla_table(ctx, table):
while sa_table is None and ctx_level.parent is not ctx_level:
ctx_level = ctx_level.parent
sa_table = ctx_level.get_ref(table)
elif isinstance(table, AlchemyTable):
sa_table = table.sqla_table
else:
if isinstance(table, AlchemyTable):
sa_table = table.sqla_table
else:
sa_table = ctx.get_compiled_expr(table)
sa_table = ctx.get_compiled_expr(table)

return sa_table

Expand Down Expand Up @@ -167,9 +173,6 @@ def _cast(t, op):

sa_arg = t.translate(arg)

if arg_dtype.is_category() and typ.is_int32():
return sa_arg

# specialize going from an integer type to a timestamp
if arg_dtype.is_integer() and typ.is_timestamp():
return t.integer_to_timestamp(sa_arg)
Expand All @@ -192,14 +195,15 @@ def _contains(func):
def translate(t, op):
left = t.translate(op.value)

if isinstance(op.options, tuple):
options = op.options
if isinstance(options, tuple):
right = [t.translate(x) for x in op.options]
elif op.options.output_shape.is_columnar():
right = t.translate(op.options)
elif options.output_shape.is_columnar():
right = t.translate(ops.TableArrayView(options.to_expr().as_table()))
if not isinstance(right, sa.sql.Selectable):
right = sa.select(right)
else:
right = t.translate(op.options)
right = t.translate(options)

return func(left, right)

Expand All @@ -216,6 +220,9 @@ def _literal(_, op):
dtype = op.output_dtype
value = op.value

if value is None:
return sa.null()

if dtype.is_set():
return list(map(sa.literal, value))
elif dtype.is_array():
Expand Down Expand Up @@ -253,22 +260,19 @@ def _floor_divide(t, op):


def _simple_case(t, op):
cases = [ops.Equals(op.base, case) for case in op.cases]
return _translate_case(t, cases, op.results, op.default)
return _translate_case(t, op, value=t.translate(op.base))


def _searched_case(t, op):
return _translate_case(t, op.cases, op.results, op.default)


def _translate_case(t, cases, results, default):
case_args = [t.translate(arg) for arg in cases]
result_args = [t.translate(arg) for arg in results]
return _translate_case(t, op, value=None)

whens = zip(case_args, result_args)
default = t.translate(default)

return sa.case(*whens, else_=default)
def _translate_case(t, op, *, value):
return sa.case(
*zip(map(t.translate, op.cases), map(t.translate, op.results)),
value=value,
else_=t.translate(op.default),
)


def _negate(t, op):
Expand Down Expand Up @@ -303,69 +307,73 @@ def _endswith(t, op):
}


def _cumulative_to_window(translator, op, window):
win = W.cumulative_window()
win = win.group_by(window._group_by).order_by(window._order_by)

def _cumulative_to_window(translator, op, frame):
klass = _cumulative_to_reduction[type(op)]
new_op = klass(*op.args)
new_expr = new_op.to_expr().name(op.name)
new_frame = frame.copy(start=None, end=0)

if type(new_op) in translator._rewrites:
new_expr = translator._rewrites[type(new_op)](new_expr)

# TODO(kszucs): rewrite to receive and return an ops.Node
return an.windowize_function(new_expr, win)
return an.windowize_function(new_expr, frame=new_frame)


def _window(t, op):
arg, window = op.args
reduction = t.translate(arg)
def _translate_window_boundary(boundary):
if boundary is None:
return None

window_op = arg
if isinstance(boundary.value, ops.Literal):
if boundary.preceding:
return -boundary.value.value
else:
return boundary.value.value

if isinstance(window_op, ops.CumulativeOp):
arg = _cumulative_to_window(t, arg, window).op()
return t.translate(arg)
raise com.TranslationError("Window boundaries must be literal values")

if window.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented '
'for SQLAlchemy-based backends.'
)

# Checks for invalid user input e.g. passing in tuple for preceding and
# non-None value for following are caught and raised in expr/window.py
# if we're here, then the input is valid, we just need to interpret it
# correctly
if isinstance(window.preceding, tuple):
start, end = (-1 * x if x is not None else None for x in window.preceding)
elif isinstance(window.following, tuple):
start, end = window.following
else:
start = -window.preceding if window.preceding is not None else window.preceding
end = window.following
def _window_function(t, window):
if isinstance(window.func, ops.CumulativeOp):
func = _cumulative_to_window(t, window.func, window.frame).op()
return t.translate(func)

reduction = t.translate(window.func)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
if isinstance(window_op, t._require_order_by) and not window._order_by:
order_by = t.translate(window_op.args[0])
if isinstance(window.func, t._require_order_by) and not window.frame.order_by:
order_by = t.translate(window.func.arg) # .args[0])
else:
order_by = [t.translate(arg) for arg in window._order_by]
order_by = [t.translate(arg) for arg in window.frame.order_by]

partition_by = [t.translate(arg) for arg in window._group_by]
partition_by = [t.translate(arg) for arg in window.frame.group_by]

if isinstance(window.frame, ops.RowsWindowFrame):
if window.frame.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented for SQLAlchemy-based '
'backends.'
)
how = 'rows'
elif isinstance(window.frame, ops.RangeWindowFrame):
how = 'range_'
else:
raise NotImplementedError(type(window.frame))

if t._forbids_frame_clause and isinstance(window.func, t._forbids_frame_clause):
# some functions on some backends don't support frame clauses
additional_params = {}
else:
start = _translate_window_boundary(window.frame.start)
end = _translate_window_boundary(window.frame.end)
additional_params = {how: (start, end)}

how = {'range': 'range_'}.get(window.how, window.how)
additional_params = (
{}
if t._forbids_frame_clause and isinstance(window_op, t._forbids_frame_clause)
else {how: (start, end)}
)
result = reduction.over(
partition_by=partition_by, order_by=order_by, **additional_params
)

if isinstance(window_op, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
if isinstance(window.func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
return result - 1
else:
return result
Expand Down Expand Up @@ -417,12 +425,21 @@ def _zero_if_null(t, op):


def _substring(t, op):
args = t.translate(op.arg), t.translate(op.start) + 1

if (length := op.length) is not None:
args += (t.translate(length),)

return sa.func.substr(*args)
sa_arg = t.translate(op.arg)
sa_start = t.translate(op.start) + 1
# Start is an expression, need a runtime branch
sa_arg_length = t.translate(ops.StringLength(op.arg))
if op.length is None:
return sa.case(
((sa_start >= 1), sa.func.substr(sa_arg, sa_start)),
else_=sa.func.substr(sa_arg, sa_start + sa_arg_length),
)
else:
sa_length = t.translate(op.length)
return sa.case(
((sa_start >= 1), sa.func.substr(sa_arg, sa_start, sa_length)),
else_=sa.func.substr(sa_arg, sa_start + sa_arg_length, sa_length),
)


def _gen_string_find(func):
Expand All @@ -442,21 +459,6 @@ def _nth_value(t, op):
return sa.func.nth_value(t.translate(op.arg), t.translate(op.nth) + 1)


def _clip(*, min_func, max_func):
def translate(t, op):
arg = t.translate(op.arg)

if (upper := op.upper) is not None:
arg = min_func(t.translate(upper), arg)

if (lower := op.lower) is not None:
arg = max_func(t.translate(lower), arg)

return arg

return translate


def _bitwise_op(operator):
def translate(t, op):
left = t.translate(op.left)
Expand Down Expand Up @@ -491,6 +493,14 @@ def translator(t, op: ops.Node):
return translator


class array_map(FunctionElement):
pass


class array_filter(FunctionElement):
pass


sqlalchemy_operation_registry: dict[Any, Any] = {
ops.Alias: _alias,
ops.And: fixed_arity(operator.and_, 2),
Expand Down Expand Up @@ -534,6 +544,11 @@ def translator(t, op: ops.Node):
ops.Least: varargs(sa.func.least),
ops.Greatest: varargs(sa.func.greatest),
# string
ops.Capitalize: unary(
lambda arg: sa.func.concat(
sa.func.upper(sa.func.substr(arg, 1, 1)), sa.func.substr(arg, 2)
)
),
ops.LPad: fixed_arity(sa.func.lpad, 3),
ops.RPad: fixed_arity(sa.func.rpad, 3),
ops.Strip: unary(sa.func.trim),
Expand Down Expand Up @@ -604,7 +619,6 @@ def translator(t, op: ops.Node):
ops.IdenticalTo: fixed_arity(
sa.sql.expression.ColumnElement.is_not_distinct_from, 2
),
ops.Clip: _clip(min_func=sa.func.least, max_func=sa.func.greatest),
ops.Where: fixed_arity(
lambda predicate, value_if_true, value_if_false: sa.case(
(predicate, value_if_true),
Expand All @@ -626,6 +640,7 @@ def translator(t, op: ops.Node):
ops.ExtractHour: _extract('hour'),
ops.ExtractMinute: _extract('minute'),
ops.ExtractSecond: _extract('second'),
ops.Time: fixed_arity(lambda arg: sa.cast(arg, sa.TIME), 1),
}


Expand All @@ -641,12 +656,13 @@ def translator(t, op: ops.Node):
ops.PercentRank: unary(lambda _: sa.func.percent_rank()),
ops.CumeDist: unary(lambda _: sa.func.cume_dist()),
ops.NthValue: _nth_value,
ops.Window: _window,
ops.CumulativeOp: _window,
ops.WindowFunction: _window_function,
ops.CumulativeMax: unary(sa.func.max),
ops.CumulativeMin: unary(sa.func.min),
ops.CumulativeSum: unary(sa.func.sum),
ops.CumulativeMean: unary(sa.func.avg),
ops.CumulativeAny: unary(sa.func.bool_or),
ops.CumulativeAll: unary(sa.func.bool_and),
}

geospatial_functions = {
Expand Down
16 changes: 9 additions & 7 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import operator

import sqlalchemy as sa

Expand Down Expand Up @@ -46,7 +47,7 @@ class AlchemyExprTranslator(ExprTranslator):

integer_to_timestamp = sa.func.to_timestamp
native_json_type = True
_always_quote_columns = False
_always_quote_columns = None # let the dialect decide how to quote

_require_order_by = (
ops.DenseRank,
Expand All @@ -58,6 +59,8 @@ class AlchemyExprTranslator(ExprTranslator):

_dialect_name = "default"

supports_unnest_in_select = True

@functools.cached_property
def dialect(self) -> sa.engine.interfaces.Dialect:
if (name := self._dialect_name) == "default":
Expand Down Expand Up @@ -110,10 +113,9 @@ def _reduction(self, sa_func, op):

@rewrites(ops.NullIfZero)
def _nullifzero(op):
# TODO(kszucs): avoid rountripping to expr then back to op
expr = op.arg.to_expr()
new_expr = (expr == 0).ifelse(ibis.NA, expr)
return new_expr.op()
arg = op.arg
condition = ops.Equals(arg, ops.Literal(0, dtype=op.arg.output_dtype))
return ops.Where(condition, ibis.NA, arg)


# TODO This was previously implemented with the legacy `@compiles` decorator.
Expand All @@ -124,10 +126,10 @@ def _true_divide(t, op):
if all(arg.output_dtype.is_integer() for arg in op.args):
# TODO(kszucs): this should be done in the rewrite phase
right, left = op.right.to_expr(), op.left.to_expr()
new_expr = left.div(right.cast('double'))
new_expr = left.div(right.cast(dt.double))
return t.translate(new_expr.op())

return fixed_arity(lambda x, y: x / y, 2)(t, op)
return fixed_arity(operator.truediv, 2)(t, op)


AlchemyExprTranslator._registry[ops.Divide] = _true_divide
17 changes: 6 additions & 11 deletions ibis/backends/base/sql/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def _get_keyword_list(self):
return map(self.keyword, self.distincts)

def _extract_subqueries(self):
self.subqueries = _extract_common_table_expressions(
[self.table_set, *self.filters]
# extract any subquery to avoid generating incorrect sql when at least
# one of the set operands is invalid outside of being a subquery
#
# for example: SELECT * FROM t ORDER BY x UNION ...
self.subqueries = an.find_subqueries(
[self.table_set, *self.filters], min_dependents=1
)
for subquery in self.subqueries:
self.context.set_extracted(subquery)
Expand Down Expand Up @@ -111,12 +115,3 @@ def compile(self):
)
)
return '\n'.join(buf)


def _extract_common_table_expressions(nodes):
# filter out None values
nodes = list(filter(None, nodes))
counts = an.find_subqueries(nodes)
duplicates = [op for op, count in counts.items() if count > 1]
duplicates.reverse()
return duplicates
28 changes: 18 additions & 10 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,22 @@ def _quote_identifier(self, name):

def _format_in_memory_table(self, op):
names = op.schema.names
raw_rows = (
", ".join(
f"{val!r} AS {self._quote_identifier(name)}"
for val, name in zip(row, names)
)
for row in op.data.to_frame().itertuples(index=False)
)
rows = ", ".join(f"({raw_row})" for raw_row in raw_rows)
return f"(VALUES {rows})"
raw_rows = []
for row in op.data.to_frame().itertuples(index=False):
raw_row = []
for val, name in zip(row, names):
lit = ops.Literal(val, dtype=op.schema[name])
raw_row.append(
f"{self._translate(lit)} AS {self._quote_identifier(name)}"
)
raw_rows.append(", ".join(raw_row))

if self.context.compiler.support_values_syntax_in_select:
rows = ", ".join(f"({raw_row})" for raw_row in raw_rows)
return f"(VALUES {rows})"
else:
rows = " UNION ALL ".join(f"(SELECT {raw_row})" for raw_row in raw_rows)
return f"({rows})"

def _format_table(self, op):
# TODO: This could probably go in a class and be significantly nicer
Expand Down Expand Up @@ -296,7 +303,7 @@ def format_select_set(self):
formatted = []
for node in self.select_set:
if isinstance(node, ops.Value):
expr_str = self._translate(node, named=True)
expr_str = self._translate(node, named=True, permit_subquery=True)
elif isinstance(node, ops.TableNode):
alias = context.get_ref(node)
expr_str = f'{alias}.*' if alias else '*'
Expand Down Expand Up @@ -488,6 +495,7 @@ class Compiler:
difference_class = Difference

cheap_in_memory_tables = False
support_values_syntax_in_select = True

@classmethod
def make_context(cls, params=None):
Expand Down
46 changes: 10 additions & 36 deletions ibis/backends/base/sql/compiler/select_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import ibis.expr.analysis as an
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.compiler.base import _extract_common_table_expressions


class _LimitSpec(NamedTuple):
Expand Down Expand Up @@ -110,13 +109,12 @@ def _adapt_operation(node):

def _build_result_query(self):
self._collect_elements()
self._analyze_select_exprs()
self._analyze_subqueries()
self._populate_context()

return self.select_class(
self.table_set,
self.select_set,
list(self.select_set),
translator_class=self.translator_class,
table_set_formatter_class=self.table_set_formatter_class,
context=self.context,
Expand Down Expand Up @@ -154,35 +152,6 @@ def _make_table_aliases(self, node):
# down to child contexts so that they aren't missing any refs.
ctx.set_ref(node, ctx.top_context.get_ref(node))

# ---------------------------------------------------------------------
# Expr analysis / rewrites

def _analyze_select_exprs(self):
new_select_set = []

for op in self.select_set:
new_op = self._visit_select_expr(op)
new_select_set.append(new_op)

self.select_set = new_select_set

# TODO(kszucs): this should be rewritten using analysis.substitute()
def _visit_select_expr(self, op):
method = f'_visit_select_{type(op).__name__}'
if hasattr(self, method):
f = getattr(self, method)
return f(op)
elif isinstance(op, ops.Value):
new_args = []
for arg in op.args:
if isinstance(arg, ops.Node):
arg = self._visit_select_expr(arg)
new_args.append(arg)

return type(op)(*new_args)
else:
return op

# ---------------------------------------------------------------------
# Analysis of table set

Expand Down Expand Up @@ -280,15 +249,18 @@ def _collect_Limit(self, op, toplevel=False):

def _collect_Union(self, op, toplevel=False):
if toplevel:
raise NotImplementedError()
self.table_set = op
self.select_set = [op]

def _collect_Difference(self, op, toplevel=False):
if toplevel:
raise NotImplementedError()
self.table_set = op
self.select_set = [op]

def _collect_Intersection(self, op, toplevel=False):
if toplevel:
raise NotImplementedError()
self.table_set = op
self.select_set = [op]

def _collect_Aggregation(self, op, toplevel=False):
# The select set includes the grouping keys (if any), and these are
Expand Down Expand Up @@ -377,7 +349,9 @@ def _analyze_subqueries(self):
# want.

# Find the subqueries, and record them in the passed query context.
subqueries = _extract_common_table_expressions([self.table_set, *self.filters])
subqueries = an.find_subqueries(
[self.table_set, *self.filters], min_dependents=2
)

self.subqueries = []
for node in subqueries:
Expand Down
54 changes: 13 additions & 41 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, compiler, indent=2, parent=None, params=None):
self.always_alias = True
self.query = None
self.params = params if params is not None else {}
self._alias_counter = getattr(parent, "_alias_counter", 0)

def _compile_subquery(self, op):
sub_ctx = self.subcontext()
Expand Down Expand Up @@ -73,11 +74,9 @@ def get_compiled_expr(self, node):
return result

def make_alias(self, node):
i = len(self.table_refs)
i = self._alias_counter

# Get total number of aliases up and down the tree at this point; if we
# find the table prior-aliased along the way, however, we reuse that
# alias
# check for existing tables that we're referencing from a parent context
for ctx in itertools.islice(self._contexts(), 1, None):
try:
alias = ctx.table_refs[node]
Expand All @@ -87,8 +86,7 @@ def make_alias(self, node):
self.set_ref(node, alias)
return

i += len(ctx.table_refs)

self._alias_counter += 1
alias = f't{i:d}'
self.set_ref(node, alias)

Expand Down Expand Up @@ -326,24 +324,6 @@ def _bucket(op):
return result.op()


@rewrites(ops.CategoryLabel)
def _category_label(op):
# TODO(kszucs): avoid the expression roundtrip
expr = op.to_expr()
stmt = op.args[0].to_expr().case()
for i, label in enumerate(op.labels):
stmt = stmt.when(i, label)

if op.nulls is not None:
stmt = stmt.else_(op.nulls)

result = stmt.end()
if expr.has_name():
result = result.name(expr.get_name())

return result.op()


@rewrites(ops.Any)
def _any_expand(op):
return ops.Max(op.arg)
Expand Down Expand Up @@ -377,22 +357,14 @@ def _rewrite_string_contains(op):
return ops.GreaterEqual(ops.StringFind(op.haystack, op.needle), 0)


NEW_EXTRACT_URL_OPERATION = {
"PROTOCOL": ops.ExtractProtocol,
"AUTHORITY": ops.ExtractAuthority,
"USERINFO": ops.ExtractUserInfo,
"HOST": ops.ExtractHost,
"FILE": ops.ExtractFile,
"PATH": ops.ExtractPath,
"REF": ops.ExtractFragment,
}
@rewrites(ops.Clip)
def _rewrite_clip(op):
arg = ops.Cast(op.arg, op.output_dtype)

if (upper := op.upper) is not None:
arg = ops.Least((arg, ops.Cast(upper, op.output_dtype)))

@rewrites(ops.ParseURL)
def _rewrite_string_contains(op):
extract = op.extract
if extract == 'QUERY':
return ops.ExtractQuery(op.arg, op.key)
if (new_op := NEW_EXTRACT_URL_OPERATION.get(extract)) is not None:
return new_op(op.arg)
raise ValueError(f"{extract!r} is not supported")
if (lower := op.lower) is not None:
arg = ops.Greatest((arg, ops.Cast(lower, op.output_dtype)))

return arg
16 changes: 7 additions & 9 deletions ibis/backends/base/sql/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,14 @@ def _serdeproperties(props):
class _BaseQualifiedSQLStatement:
def _get_scoped_name(self, obj_name, database):
if database:
scoped_name = f'{database}.`{obj_name}`'
else:
if not is_fully_qualified(obj_name):
if _is_quoted(obj_name):
return obj_name
else:
return f'`{obj_name}`'
else:
return f'{database}.`{obj_name}`'
elif not is_fully_qualified(obj_name):
if _is_quoted(obj_name):
return obj_name
return scoped_name
else:
return f'`{obj_name}`'
else:
return obj_name


class BaseDDL(DDL, _BaseQualifiedSQLStatement):
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from ibis.backends.base.sql.registry.window import (
cumulative_to_window,
format_window,
format_window_frame,
time_range_to_range_window,
)

Expand All @@ -31,6 +31,6 @@
'reduction',
'unary',
'cumulative_to_window',
'format_window',
'format_window_frame',
'time_range_to_range_window',
)
File renamed without changes.
3 changes: 3 additions & 0 deletions ibis/backends/base/sql/registry/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def literal(translator, op):

dtype = op.output_dtype

if op.value is None:
return "NULL"

if dtype.is_boolean():
typeclass = 'boolean'
elif dtype.is_string():
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ def log(translator, op):
def cast(translator, op):
arg_formatted = translator.translate(op.arg)

if op.arg.output_dtype.is_category() and op.to.is_int32():
return arg_formatted
if op.arg.output_dtype.is_temporal() and op.to.is_int64():
return f'1000000 * unix_timestamp({arg_formatted})'
else:
Expand Down Expand Up @@ -276,6 +274,7 @@ def count_star(translator, op):
ops.Sqrt: unary('sqrt'),
ops.Hash: hash,
ops.HashBytes: hashbytes,
ops.RandomScalar: lambda *_: 'rand(utc_to_unix_micros(utc_timestamp()))',
ops.Log: log,
ops.Ln: unary('ln'),
ops.Log2: unary('log2'),
Expand All @@ -289,7 +288,9 @@ def count_star(translator, op):
ops.Sin: unary("sin"),
ops.Tan: unary("tan"),
ops.Pi: fixed_arity("pi", 0),
ops.E: fixed_arity("exp(1)", 0),
ops.E: fixed_arity("e", 0),
ops.Degrees: lambda t, op: f"(180 * {t.translate(op.arg)} / {t.translate(ops.Pi())})",
ops.Radians: lambda t, op: f"({t.translate(ops.Pi())} * {t.translate(op.arg)} / 180)",
# Unary aggregates
ops.ApproxMedian: aggregate.reduction('appx_median'),
ops.ApproxCountDistinct: aggregate.reduction('ndv'),
Expand Down Expand Up @@ -356,7 +357,6 @@ def count_star(translator, op):
ops.DateTruncate: timestamp.truncate,
ops.IntervalFromInteger: timestamp.interval_from_integer,
# Other operations
ops.E: lambda *args: 'e()',
ops.Literal: literal,
ops.NullLiteral: null_literal,
ops.Cast: cast,
Expand Down
243 changes: 83 additions & 160 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from __future__ import annotations

from operator import add, mul, sub

import ibis
import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir

_map_interval_to_microseconds = {
'W': 604800000000,
Expand All @@ -17,18 +13,6 @@
's': 1000000,
'ms': 1000,
'us': 1,
'ns': 0.001,
}


_map_interval_op_to_op = {
# Literal Intervals have two args, i.e.
# Literal(1, Interval(value_type=int8, unit='D', nullable=True))
# Parse both args and multipy 1 * _map_interval_to_microseconds['D']
ops.Literal: mul,
ops.IntervalMultiply: mul,
ops.IntervalAdd: add,
ops.IntervalSubtract: sub,
}


Expand All @@ -42,201 +26,140 @@
}


def _replace_interval_with_scalar(op: ops.Value) -> float | ir.FloatingScalar:
"""Replace an interval type or expression with its equivalent numeric scalar.
Parameters
----------
op
float or interval expression.
For example, `ibis.interval(days=1) + ibis.interval(hours=5)`
Returns
-------
preceding
`float` or `ir.FloatingScalar`, depending on the expr.
"""
if isinstance(op, ops.Literal):
unit = getattr(op.output_dtype, "unit", "us")
try:
micros = _map_interval_to_microseconds[unit]
return op.value * micros
except KeyError:
raise ValueError(f"Unsupported unit {unit!r}")
elif op.args and isinstance(op.output_dtype, dt.Interval):
if len(op.args) > 2:
raise NotImplementedError("'preceding' argument cannot be parsed.")
left_arg = _replace_interval_with_scalar(op.args[0])
right_arg = _replace_interval_with_scalar(op.args[1])
method = _map_interval_op_to_op[type(op)]
return method(left_arg, right_arg)
def cumulative_to_window(translator, func, frame):
klass = _cumulative_to_reduction[type(func)]
func = klass(*func.args)

try:
rule = translator._rewrites[type(func)]
except KeyError:
pass
else:
raise TypeError(f'input has unknown type {type(op)}')
func = rule(func)

frame = frame.copy(start=None, end=0)
expr = an.windowize_function(func.to_expr(), frame)
return expr.op()


def cumulative_to_window(translator, op, window):
klass = _cumulative_to_reduction[type(op)]
new_op = klass(*op.args)
def interval_boundary_to_integer(boundary):
if boundary is None:
return None
elif boundary.output_dtype.is_numeric():
return boundary

value = boundary.value
try:
rule = translator._rewrites[type(new_op)]
multiplier = _map_interval_to_microseconds[value.output_dtype.unit]
except KeyError:
pass
raise com.IbisInputError(
f"Unsupported interval unit: {value.output_dtype.unit}"
)

if isinstance(value, ops.Literal):
value = ops.Literal(value.value * multiplier, dt.int64)
else:
new_op = rule(new_op)
left = ops.Cast(value, to=dt.int64)
value = ops.Multiply(left, multiplier)

win = ibis.cumulative_window().group_by(window._group_by).order_by(window._order_by)
new_expr = an.windowize_function(new_op.to_expr(), win)
return new_expr.op()
return boundary.copy(value=value)


def time_range_to_range_window(_, window):
def time_range_to_range_window(frame):
# Check that ORDER BY column is a single time column:
order_by_vars = [x.args[0] for x in window._order_by]
if len(order_by_vars) > 1:
if len(frame.order_by) > 1:
raise com.IbisInputError(
f"Expected 1 order-by variable, got {len(order_by_vars)}"
f"Expected 1 order-by variable, got {len(frame.order_by)}"
)

order_var = order_by_vars[0]
timestamp_order_var = ops.Cast(order_var, dt.int64).to_expr()
window = window._replace(order_by=timestamp_order_var, how='range')
order_by = frame.order_by[0]
order_by = order_by.copy(expr=ops.Cast(order_by.expr, dt.int64))
start = interval_boundary_to_integer(frame.start)
end = interval_boundary_to_integer(frame.end)

# Need to change preceding interval expression to scalars
preceding = window.preceding
if isinstance(preceding, ir.IntervalScalar):
new_preceding = _replace_interval_with_scalar(preceding.op())
window = window._replace(preceding=new_preceding)
return frame.copy(order_by=(order_by,), start=start, end=end)

return window

def format_window_boundary(translator, boundary):
if isinstance(boundary.value, ops.Literal) and boundary.value.value == 0:
return "CURRENT ROW"

def format_window(translator, op, window):
components = []
value = translator.translate(boundary.value)
direction = "PRECEDING" if boundary.preceding else "FOLLOWING"

if window.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented for Impala-based backends.'
)
return f'{value} {direction}'

if window._group_by:
partition_args = ', '.join(map(translator.translate, window._group_by))
components.append(f'PARTITION BY {partition_args}')

if window._order_by:
order_args = ', '.join(map(translator.translate, window._order_by))
components.append(f'ORDER BY {order_args}')
def format_window_frame(translator, func, frame):
components = []

p, f = window.preceding, window.following
if frame.group_by:
partition_args = ', '.join(map(translator.translate, frame.group_by))
components.append(f'PARTITION BY {partition_args}')

def _prec(p: int | None) -> str:
assert p is None or p >= 0
if frame.order_by:
order_args = ', '.join(map(translator.translate, frame.order_by))
components.append(f'ORDER BY {order_args}')

if p is None:
prefix = 'UNBOUNDED'
if frame.start is None and frame.end is None:
# no-op, default is full sample
pass
elif not isinstance(func, translator._forbids_frame_clause):
if frame.start is None:
start = 'UNBOUNDED PRECEDING'
else:
if not p:
return 'CURRENT ROW'
prefix = str(p)
return f'{prefix} PRECEDING'
start = format_window_boundary(translator, frame.start)

def _foll(f: int | None) -> str:
assert f is None or f >= 0

if f is None:
prefix = 'UNBOUNDED'
if frame.end is None:
end = 'UNBOUNDED FOLLOWING'
else:
if not f:
return 'CURRENT ROW'
prefix = str(f)

return f'{prefix} FOLLOWING'

if translator._forbids_frame_clause and isinstance(
op.expr, translator._forbids_frame_clause
):
frame = None
elif p is not None and f is not None:
frame = f'{window.how.upper()} BETWEEN {_prec(p)} AND {_foll(f)}'
elif p is not None:
if isinstance(p, tuple):
start, end = p
frame = '{} BETWEEN {} AND {}'.format(
window.how.upper(), _prec(start), _prec(end)
)
else:
kind = 'ROWS' if p > 0 else 'RANGE'
frame = f'{kind} BETWEEN {_prec(p)} AND UNBOUNDED FOLLOWING'
elif f is not None:
if isinstance(f, tuple):
start, end = f
frame = '{} BETWEEN {} AND {}'.format(
window.how.upper(), _foll(start), _foll(end)
)
else:
kind = 'ROWS' if f > 0 else 'RANGE'
frame = f'{kind} BETWEEN UNBOUNDED PRECEDING AND {_foll(f)}'
else:
frame = None
end = format_window_boundary(translator, frame.end)

if frame is not None:
frame = f'{frame.how.upper()} BETWEEN {start} AND {end}'
components.append(frame)

return 'OVER ({})'.format(' '.join(components))


_subtract_one = '({} - 1)'.format


_expr_transforms = {
ops.RowNumber: _subtract_one,
ops.DenseRank: _subtract_one,
ops.MinRank: _subtract_one,
ops.NTile: _subtract_one,
}


def window(translator, op):
arg, window = op.args

_unsupported_reductions = (
ops.ApproxMedian,
ops.GroupConcat,
ops.ApproxCountDistinct,
)

if isinstance(arg, _unsupported_reductions):
if isinstance(op.func, _unsupported_reductions):
raise com.UnsupportedOperationError(
f'{type(arg)} is not supported in window functions'
f'{type(op.func)} is not supported in window functions'
)

if isinstance(arg, ops.CumulativeOp):
arg = cumulative_to_window(translator, arg, window)
if isinstance(op.func, ops.CumulativeOp):
arg = cumulative_to_window(translator, op.func, op.frame)
return translator.translate(arg)

# Some analytic functions need to have the expression of interest in
# the ORDER BY part of the window clause
if isinstance(arg, translator._require_order_by) and not window._order_by:
window = window.order_by(arg.args[0])
frame = op.frame
if isinstance(op.func, translator._require_order_by) and not frame.order_by:
frame = frame.copy(order_by=(op.func.arg,))

# Time ranges need to be converted to microseconds.
# FIXME(kszucs): avoid the expression roundtrip
if window.how == 'range':
time_range_types = (dt.Time, dt.Date, dt.Timestamp)
if any(
isinstance(c.output_dtype, time_range_types)
and c.output_shape.is_columnar()
for c in window._order_by
):
window = time_range_to_range_window(translator, window)

window_formatted = format_window(translator, op, window)

arg_formatted = translator.translate(arg)
if isinstance(frame, ops.RangeWindowFrame):
if any(c.output_dtype.is_temporal() for c in frame.order_by):
frame = time_range_to_range_window(frame)
elif isinstance(frame, ops.RowsWindowFrame):
if frame.max_lookback is not None:
raise NotImplementedError(
'Rows with max lookback is not implemented for SQL-based backends.'
)

window_formatted = format_window_frame(translator, op.func, frame)

arg_formatted = translator.translate(op.func)
result = f'{arg_formatted} {window_formatted}'

if type(arg) in _expr_transforms:
return _expr_transforms[type(arg)](result)
if isinstance(op.func, ops.RankBase):
return f'({result} - 1)'
else:
return result

Expand Down
66 changes: 65 additions & 1 deletion ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import google.auth.credentials
import google.cloud.bigquery as bq
import pandas as pd
import pydata_google_auth
from google.api_core.exceptions import NotFound
from pydata_google_auth import cache

import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
Expand All @@ -24,6 +26,7 @@
BigQueryTable,
bigquery_field_to_ibis_dtype,
bigquery_param,
ibis_schema_to_bigquery_schema,
parse_project_and_dataset,
rename_partitioned_column,
)
Expand Down Expand Up @@ -435,6 +438,68 @@ def set_database(self, name):
def version(self):
return bq.__version__

def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
temp: bool | None = None,
overwrite: bool = False,
) -> ir.Table:
if obj is None and schema is None:
raise com.IbisError("The schema or obj parameter is required")
if temp is True:
raise NotImplementedError(
"BigQuery backend does not yet support temporary tables"
)
if overwrite is not False:
raise NotImplementedError(
"BigQuery backend does not yet support overwriting tables"
)
if schema is not None:
table_id = self._fully_qualified_name(name, database)
bigquery_schema = ibis_schema_to_bigquery_schema(schema)
table = bq.Table(table_id, schema=bigquery_schema)
self.client.create_table(table)
else:
project_id, dataset = self._parse_project_and_dataset(database)
if isinstance(obj, pd.DataFrame):
table = ibis.memtable(obj)
else:
table = obj
sql_select = self.compile(table)
table_ref = f"`{project_id}`.`{dataset}`.`{name}`"
self.raw_sql(f'CREATE TABLE {table_ref} AS ({sql_select})')
return self.table(name, database=database)

def drop_table(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
table_id = self._fully_qualified_name(name, database)
self.client.delete_table(table_id, not_found_ok=not force)

def create_view(
self,
name: str,
obj: ir.Table,
*,
database: str | None = None,
overwrite: bool = False,
) -> ir.Table:
or_replace = "OR REPLACE " * overwrite
sql_select = self.compile(obj)
table_id = self._fully_qualified_name(name, database)
code = f"CREATE {or_replace}VIEW {table_id} AS {sql_select}"
self.raw_sql(code)
return self.table(name, database=database)

def drop_view(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
self.drop_table(name=name, database=database, force=force)


def compile(expr, params=None, **kwargs):
"""Compile an expression for BigQuery."""
Expand Down Expand Up @@ -514,7 +579,6 @@ def connect(


__all__ = [
"__version__",
"Backend",
"compile",
"connect",
Expand Down
21 changes: 18 additions & 3 deletions ibis/backends/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def bigquery_schema(table):
return sch.schema(fields)


def ibis_schema_to_bigquery_schema(schema: sch.Schema):
return [
(
bq.SchemaField(
name,
ibis_type_to_bigquery_type(type_),
mode='NULLABLE' if type_.nullable else 'REQUIRED',
)
if not type_.is_array()
else bq.SchemaField(
name, ibis_type_to_bigquery_type(type_.value_type), mode='REPEATED'
)
)
for name, type_ in schema.items()
]


class BigQueryCursor:
"""BigQuery cursor.
Expand Down Expand Up @@ -243,13 +260,11 @@ def parse_project_and_dataset(project: str, dataset: str = "") -> tuple[str, str
'ibis-gbq'
>>> dataset
'my_dataset'
>>> data_project, billing_project, dataset = parse_project_and_dataset(
>>> data_project, billing_project, _ = parse_project_and_dataset(
... 'ibis-gbq'
... )
>>> data_project
'ibis-gbq'
>>> print(dataset)
None
"""
if dataset.count(".") > 1:
raise ValueError(
Expand Down
13 changes: 11 additions & 2 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import ibis.common.graph as lin
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql import compiler as sql_compiler
from ibis.backends.bigquery import operations, registry, rewrites

Expand All @@ -22,7 +23,8 @@ def __init__(self, expr, context):

def compile(self):
"""Generate UDF string from definition."""
return self.expr.op().sql
op = expr.op() if isinstance(expr := self.expr, ir.Expr) else expr
return op.sql


class BigQueryUnion(sql_compiler.Union):
Expand Down Expand Up @@ -111,6 +113,8 @@ class BigQueryCompiler(sql_compiler.Compiler):
intersect_class = BigQueryIntersection
difference_class = BigQueryDifference

support_values_syntax_in_select = False

@staticmethod
def _generate_setup_queries(expr, context):
"""Generate DDL for temporary resources."""
Expand All @@ -119,7 +123,12 @@ def _generate_setup_queries(expr, context):

# UDFs are uniquely identified by the name of the Node subclass we
# generate.
return list(toolz.unique(queries, key=lambda x: type(x.expr.op()).__name__))
def key(x):
expr = x.expr
op = expr.op() if isinstance(expr, ir.Expr) else expr
return op.__class__.__name__

return list(toolz.unique(queries, key=key))


# Register custom UDFs
Expand Down
15 changes: 9 additions & 6 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ def trans_type(t):

@ibis_type_to_bigquery_type.register(dt.Decimal)
def trans_numeric(t):
if (t.precision, t.scale) != (38, 9):
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
"scale of 9"
)
return "NUMERIC"
if (t.precision, t.scale) == (76, 38):
return 'BIGNUMERIC'
if (t.precision, t.scale) in [(38, 9), (None, None)]:
return "NUMERIC"
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). "
f"Current precision: {t.precision}. Current scale: {t.scale}"
)


@ibis_type_to_bigquery_type.register(dt.JSON)
Expand Down
78 changes: 54 additions & 24 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def bigquery_cast_integer_to_timestamp(compiled_arg, from_, to):
return f"TIMESTAMP_SECONDS({compiled_arg})"


@bigquery_cast.register(str, dt.Interval, dt.Integer)
def bigquery_cast_interval_to_integer(compiled_arg, from_, to):
return f"EXTRACT({from_.resolution.upper()} from {compiled_arg})"


@bigquery_cast.register(str, dt.DataType, dt.DataType)
def bigquery_cast_generate(compiled_arg, from_, to):
"""Cast to desired type."""
Expand Down Expand Up @@ -191,12 +196,22 @@ def _string_right(translator, op):


def _string_substring(translator, op):
_, _, length = op.args
length = op.length
if (length := getattr(length, "value", None)) is not None and length < 0:
raise ValueError("Length parameter must be a non-negative value.")

base_substring = operation_registry[ops.Substring]
return base_substring(translator, op)
arg = translator.translate(op.arg)
start = translator.translate(op.start)

arg_length = f"LENGTH({arg})"
if op.length is not None:
suffix = f", {translator.translate(op.length)}"
else:
suffix = ""

if_pos = f"SUBSTR({arg}, {start} + 1{suffix})"
if_neg = f"SUBSTR({arg}, {arg_length} + {start} + 1{suffix})"
return f"IF({start} >= 0, {if_pos}, {if_neg})"


def _array_literal_format(op):
Expand All @@ -216,14 +231,28 @@ def _log(translator, op):

def _literal(translator, op):
dtype = op.output_dtype
value = op.value

if value is None:
return "NULL"

if dtype.is_decimal():
if value.is_nan():
return "CAST('NaN' AS FLOAT64)"
if value.is_infinite():
prefix = "-" * value.is_signed()
return f"CAST('{prefix}inf' AS FLOAT64)"
else:
return f"{ibis_type_to_bigquery_type(dtype)} '{value}'"
elif dtype.is_uuid():
return translator.translate(ops.Literal(str(value), dtype=dt.str))

if isinstance(dtype, dt.Numeric):
value = op.value
if not np.isfinite(value):
return f"CAST({str(value)!r} AS FLOAT64)"

# special case literal timestamp, date, and time scalars
if isinstance(op, ops.Literal):
value = op.value
if isinstance(dtype, dt.Date):
if isinstance(value, datetime.datetime):
raw_value = value.date()
Expand All @@ -241,7 +270,7 @@ def _literal(translator, op):
)
elif dtype.is_struct():
cols = (
f'{translator.translate(ops.Literal(op.value[name], dtype=type_))} AS {name}'
f'{translator.translate(ops.Literal(value[name], dtype=type_))} AS {name}'
for name, type_ in zip(dtype.names, dtype.types)
)
return "STRUCT({})".format(", ".join(cols))
Expand All @@ -251,7 +280,7 @@ def _literal(translator, op):
except NotImplementedError:
if isinstance(dtype, dt.Array):
return _array_literal_format(op)
raise NotImplementedError(type(op).__name__)
raise NotImplementedError(f'Unsupported type: {dtype!r}')


def _arbitrary(translator, op):
Expand Down Expand Up @@ -526,8 +555,8 @@ def _neg_idx_to_pos(array, idx):
def _array_slice(t, op):
arg = t.translate(op.arg)
cond = [f"index >= {_neg_idx_to_pos(arg, t.translate(op.start))}"]
if op.stop:
cond.append(f"index < {_neg_idx_to_pos(arg, t.translate(op.stop))}")
if stop := op.stop:
cond.append(f"index < {_neg_idx_to_pos(arg, t.translate(stop))}")
return (
f"ARRAY("
f"SELECT el "
Expand All @@ -538,19 +567,8 @@ def _array_slice(t, op):


def _capitalize(t, op):
return f"CONCAT(UPPER(SUBSTR({t.translate(op.arg)}, 1, 1)), SUBSTR({t.translate(op.arg)}, 2))"


def _clip(t, op):
arg = t.translate(op.arg)

if (upper := op.upper) is not None:
arg = f"LEAST({t.translate(upper)}, {arg})"

if (lower := op.lower) is not None:
arg = f"GREATEST({t.translate(lower)}, {arg})"

return arg
return f"CONCAT(UPPER(SUBSTR({arg}, 1, 1)), SUBSTR({arg}, 2))"


def _nth_value(t, op):
Expand All @@ -562,6 +580,17 @@ def _nth_value(t, op):
return f'NTH_VALUE({arg}, {nth_op.value + 1})'


def _interval_multiply(t, op):
if isinstance(op.left, ops.Literal) and isinstance(op.right, ops.Literal):
value = op.left.value * op.right.value
literal = ops.Literal(value, op.left.output_dtype)
return t.translate(literal)

left, right = t.translate(op.left), t.translate(op.right)
unit = op.left.output_dtype.resolution.upper()
return f"INTERVAL EXTRACT({unit} from {left}) * {right} {unit}"


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand All @@ -584,9 +613,6 @@ def _nth_value(t, op):
ops.Floor: compiles_floor,
ops.Modulus: fixed_arity("MOD", 2),
ops.Sign: unary("SIGN"),
ops.Clip: _clip,
ops.Degrees: lambda t, op: f"(180 * {t.translate(op.arg)} / ACOS(-1))",
ops.Radians: lambda t, op: f"(ACOS(-1) * {t.translate(op.arg)} / 180)",
ops.BitwiseNot: lambda t, op: f"~ {t.translate(op.arg)}",
ops.BitwiseXor: lambda t, op: f"{t.translate(op.left)} ^ {t.translate(op.right)}",
ops.BitwiseOr: lambda t, op: f"{t.translate(op.left)} | {t.translate(op.right)}",
Expand Down Expand Up @@ -623,6 +649,7 @@ def _nth_value(t, op):
ops.TimestampNow: fixed_arity("CURRENT_TIMESTAMP", 0),
ops.TimestampSub: _timestamp_op("TIMESTAMP_SUB", {"h", "m", "s", "ms", "us"}),
ops.TimestampTruncate: _truncate("TIMESTAMP", _timestamp_units),
ops.IntervalMultiply: _interval_multiply,
ops.Hash: _hash,
ops.StringReplace: fixed_arity("REPLACE", 3),
ops.StringSplit: fixed_arity("SPLIT", 2),
Expand Down Expand Up @@ -706,6 +733,9 @@ def _nth_value(t, op):
ops.RandomScalar: fixed_arity("RAND", 0),
ops.NthValue: _nth_value,
ops.JSONGetItem: lambda t, op: f"{t.translate(op.arg)}[{t.translate(op.index)}]",
ops.ArrayStringJoin: lambda t, op: f"ARRAY_TO_STRING({t.translate(op.arg)}, {t.translate(op.sep)})",
ops.StartsWith: fixed_arity("STARTS_WITH", 2),
ops.EndsWith: fixed_arity("ENDS_WITH", 2),
}

_invalid_operations = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT t0.`title`, t0.`tags`
FROM (
SELECT t1.*
FROM `bigquery-public-data.stackoverflow.posts_questions` t1
WHERE STRPOS(t1.`tags`, 'ibis') - 1 >= 0
) t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
WITH t0 AS (
SELECT t2.`float_col`, t2.`timestamp_col`, t2.`int_col`, t2.`string_col`
FROM `ibis-gbq.ibis_gbq_testing.functional_alltypes` t2
WHERE t2.`timestamp_col` < @param_0
)
SELECT count(t1.`foo`) AS `count`
FROM (
SELECT t0.`string_col`, sum(t0.`float_col`) AS `foo`
FROM t0
GROUP BY 1
) t1
32 changes: 8 additions & 24 deletions ibis/backends/bigquery/tests/system/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import datetime
import decimal
import re
import itertools

import pandas as pd
import pandas.testing as tm
Expand All @@ -10,6 +10,7 @@

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.bigquery.client import bigquery_param


Expand Down Expand Up @@ -112,18 +113,8 @@ def test_different_partition_col_name(monkeypatch, client):
assert col in parted_alltypes.columns


def test_subquery_scalar_params(alltypes, project_id, dataset_id):
expected = f"""\
SELECT count\\(t0\\.`foo`\\) AS `count`
FROM \\(
SELECT t1\\.`string_col`, sum\\(t1\\.`float_col`\\) AS `foo`
FROM \\(
SELECT t2\\.`float_col`, t2\\.`timestamp_col`, t2\\.`int_col`, t2\\.`string_col`
FROM `{project_id}\\.{dataset_id}\\.functional_alltypes` t2
WHERE t2\\.`timestamp_col` < @param_\\d+
\\) t1
GROUP BY 1
\\) t0"""
def test_subquery_scalar_params(alltypes, monkeypatch, snapshot):
monkeypatch.setattr(ops.ScalarParameter, "_counter", itertools.count())
t = alltypes
p = ibis.param("timestamp").name("my_param")
expr = (
Expand All @@ -136,7 +127,7 @@ def test_subquery_scalar_params(alltypes, project_id, dataset_id):
.name("count")
)
result = expr.compile(params={p: "20140101"})
assert re.match(expected, result) is not None
snapshot.assert_match(result, "out.sql")


def test_repr_struct_of_array_of_struct():
Expand Down Expand Up @@ -207,7 +198,7 @@ def test_scalar_param_partition_time(parted_alltypes):
assert "PARTITIONTIME" in parted_alltypes.columns
assert "PARTITIONTIME" in parted_alltypes.schema()
param = ibis.param("timestamp").name("time_param")
expr = parted_alltypes[parted_alltypes.PARTITIONTIME < param]
expr = parted_alltypes[param > parted_alltypes.PARTITIONTIME]
df = expr.execute(params={param: "2017-01-01"})
assert df.empty

Expand All @@ -220,18 +211,11 @@ def test_parted_column(client, kind):
assert t.columns == [expected_column, "string_col", "int_col"]


def test_cross_project_query(public):
def test_cross_project_query(public, snapshot):
table = public.table("posts_questions")
expr = table[table.tags.contains("ibis")][["title", "tags"]]
result = expr.compile()
expected = """\
SELECT t0.`title`, t0.`tags`
FROM (
SELECT t1.*
FROM `bigquery-public-data.stackoverflow.posts_questions` t1
WHERE STRPOS(t1.`tags`, 'ibis') - 1 >= 0
) t0"""
assert result == expected
snapshot.assert_match(result, "out.sql")
n = 5
df = expr.limit(n).execute()
assert len(df) == n
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def my_str_len(s):
add = expr.op()

# generated javascript is identical
assert add.left.op().sql == add.right.op().sql
assert add.left.sql == add.right.sql
assert client.execute(expr) == 8.0


Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
WITH t0 AS (
SELECT t4.*
FROM unbound_table t4
WHERE t4.`PARTITIONTIME` < DATE '2017-01-01'
SELECT t5.*
FROM unbound_table t5
WHERE t5.`PARTITIONTIME` < DATE '2017-01-01'
),
t1 AS (
SELECT CAST(t0.`file_date` AS DATE) AS `file_date`, t0.`PARTITIONTIME`,
t0.`val`
FROM t0
WHERE t0.`file_date` < DATE '2017-01-01'
),
t2 AS (
SELECT t1.*, t1.`val` * 2 AS `XYZ`
SELECT t1.*
FROM t1
WHERE t1.`file_date` < DATE '2017-01-01'
),
t3 AS (
SELECT t2.*, t2.`val` * 2 AS `XYZ`
FROM t2
)
SELECT t2.*
FROM t2
INNER JOIN t2 t3
SELECT t3.*
FROM t3
INNER JOIN t3 t4
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (PARTITION BY t0.`year` ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING) AS `two_month_avg`
avg(t0.`float_col`) OVER (PARTITION BY t0.`year` ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING) AS `two_month_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT substr(t0.`value`, 3 + 1, 1) AS `tmp`
SELECT IF(3 >= 0, SUBSTR(t0.`value`, 3 + 1, 1), SUBSTR(t0.`value`, LENGTH(t0.`value`) + 3 + 1, 1)) AS `tmp`
FROM t t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 86400000000 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 86400000000 PRECEDING AND EXTRACT(DAY from INTERVAL 0 DAY) * 86400000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 5 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 5 PRECEDING AND CURRENT ROW) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 3600000000 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 3600000000 PRECEDING AND EXTRACT(HOUR from INTERVAL 0 HOUR) * 3600000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 1 PRECEDING AND EXTRACT(MICROSECOND from INTERVAL 0 MICROSECOND) * 1 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 60000000 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 60000000 PRECEDING AND EXTRACT(MINUTE from INTERVAL 0 MINUTE) * 60000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 1000000 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 1000000 PRECEDING AND EXTRACT(SECOND from INTERVAL 0 SECOND) * 1000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 172800000000 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN EXTRACT(DAY from INTERVAL 2 DAY) * 86400000000 PRECEDING AND EXTRACT(DAY from INTERVAL 0 DAY) * 86400000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) RANGE BETWEEN 604800000000 PRECEDING AND CURRENT ROW) AS `win_avg`
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 604800000000 PRECEDING AND EXTRACT(WEEK from INTERVAL 0 WEEK) * 604800000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT t0.*,
avg(t0.`float_col`) OVER (ORDER BY UNIX_MICROS(t0.`timestamp_col`) ASC RANGE BETWEEN 31536000000000 PRECEDING AND EXTRACT(YEAR from INTERVAL 0 YEAR) * 31536000000000 FOLLOWING) AS `win_avg`
FROM functional_alltypes t0
11 changes: 1 addition & 10 deletions ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def test_range_window_function(alltypes, window, snapshot):
"preceding",
[
param(5, id="five"),
param(ibis.interval(nanoseconds=1), id="nanos"),
param(ibis.interval(nanoseconds=1), id="nanos", marks=pytest.mark.xfail),
param(ibis.interval(microseconds=1), id="micros"),
param(ibis.interval(seconds=1), id="seconds"),
param(ibis.interval(minutes=1), id="minutes"),
Expand All @@ -486,15 +486,6 @@ def test_trailing_range_window(alltypes, preceding, snapshot):
snapshot.assert_match(to_sql(expr), "out.sql")


def test_trailing_range_window_unsupported(alltypes):
t = alltypes
preceding = ibis.interval(years=1)
w = ibis.trailing_range_window(preceding=preceding, order_by=t.timestamp_col)
expr = t.mutate(win_avg=t.float_col.mean().over(w))
with pytest.raises(ValueError):
to_sql(expr)


@pytest.mark.parametrize("distinct1", [True, False])
@pytest.mark.parametrize("distinct2", [True, False])
def test_union_cte(alltypes, distinct1, distinct2, snapshot):
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/bigquery/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def test_no_ambiguities():
param(
"array<struct<a: string>>", "ARRAY<STRUCT<a STRING>>", id="array<struct>"
),
param(dt.Decimal(38, 9), "NUMERIC", id="decimal"),
param(dt.Decimal(38, 9), "NUMERIC", id="decimal-numeric"),
param(dt.Decimal(76, 38), "BIGNUMERIC", id="decimal-bignumeric"),
],
)
def test_simple(datatype, expected):
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/bigquery/udf/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def find_names(node: ast.AST) -> list[ast.Name]:
>>> import ast
>>> node = ast.parse('a + b')
>>> names = find_names(node)
>>> names # doctest: +ELLIPSIS
[<_ast.Name object at 0x...>, <_ast.Name object at 0x...>]
>>> names
[<....Name object at 0x...>, <....Name object at 0x...>]
>>> names[0].id
'a'
>>> names[1].id
Expand Down
149 changes: 123 additions & 26 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import toolz

import ibis
import ibis.common.exceptions as com
import ibis.config
import ibis.expr.analysis as an
import ibis.expr.operations as ops
Expand Down Expand Up @@ -86,14 +87,16 @@ class Options(ibis.config.Config):
"""

temp_db: str = "__ibis_tmp"
bool_type: str = "Boolean"
bool_type: Literal["Bool", "UInt8", "Int8"] = "Bool"

def __init__(self, *args, external_tables=None, **kwargs):
super().__init__(*args, **kwargs)
self._external_tables = external_tables or {}
self._external_tables = toolz.valmap(
lambda v: ibis.memtable(v).op(), external_tables or {}
)

def _register_in_memory_table(self, table_op):
self._external_tables[table_op.name] = table_op.data.to_frame()
def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self._external_tables[op.name] = op

def _log(self, sql: str) -> None:
"""Log the SQL, usually to the standard output.
Expand Down Expand Up @@ -204,7 +207,9 @@ def do_connect(
# This won't start a connection until `cursor` is called, so in
# the common case this is cheap.
self.con = clickhouse_driver.dbapi.connect(**options)
self._external_tables = external_tables or {}
self._external_tables = toolz.valmap(
lambda v: ibis.memtable(v).op(), external_tables or {}
)

@property
def version(self) -> str:
Expand Down Expand Up @@ -241,21 +246,18 @@ def list_tables(self, like=None, database=None):

def _normalize_external_tables(self, external_tables=None):
"""Merge registered external tables with any new external tables."""
import pandas as pd

external_tables_list = []
if external_tables is None:
external_tables = {}
for name, df in toolz.merge(self._external_tables, external_tables).items():
if not isinstance(df, pd.DataFrame):
raise TypeError('External table is not an instance of pandas dataframe')
schema = sch.infer(df)
for name, obj in toolz.merge(
self._external_tables,
toolz.valmap(lambda v: ibis.memtable(v).op(), external_tables or {}),
).items():
if not (schema := obj.schema):
raise TypeError(f'Schema is empty for external table {name}')

df = obj.data.to_frame()
structure = list(zip(schema.names, map(serialize, schema.types)))
external_tables_list.append(
{
"name": name,
"data": df.to_dict("records"),
"structure": list(zip(schema.names, map(serialize, schema.types))),
}
dict(name=name, data=df.to_dict("records"), structure=structure)
)
return external_tables_list

Expand All @@ -270,6 +272,10 @@ def _client_execute(self, query, external_tables=None):
external_tables=external_tables,
)

def _register_in_memory_tables(self, expr: ir.TableExpr):
for memtable in an.find_memtables(expr.op()):
self._register_in_memory_table(memtable)

def to_pyarrow_batches(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -321,6 +327,7 @@ def _cursor_batches(
limit: int | str | None = None,
chunk_size: int = 1_000_000,
) -> Iterable[list]:
self._register_in_memory_tables(expr)
sql = self.compile(expr, limit=limit, params=params)
cursor = self.raw_sql(sql)
try:
Expand All @@ -337,13 +344,10 @@ def execute(
**kwargs: Any,
) -> Any:
"""Execute an expression."""
self._register_in_memory_tables(expr)
table_expr = expr.as_table()
sql = self.compile(table_expr, limit=limit, **kwargs)
self._log(sql)

for memtable in an.find_memtables(expr.op()):
self._register_in_memory_table(memtable)

result = self.fetch_from_cursor(
self.raw_sql(sql, external_tables=external_tables),
table_expr.schema(),
Expand Down Expand Up @@ -375,6 +379,9 @@ def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
assert not isinstance(sql, sg.exp.Subquery)
return sql.sql(dialect="clickhouse", pretty=True)

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
return str(self.compile(expr, **kwargs))

def table(self, name: str, database: str | None = None) -> ir.Table:
"""Construct a table expression.
Expand Down Expand Up @@ -478,12 +485,102 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
]
return sch.Schema.from_tuples(fields)

def _table_command(self, cmd, name, database=None):
qualified_name = self._fully_qualified_name(name, database)
return f'{cmd} {qualified_name}'

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
from ibis.backends.clickhouse.compiler.values import translate_val

return operation in translate_val.registry

def create_database(
self, name: str, *, force: bool = False, engine: str = "Atomic"
) -> None:
self.raw_sql(
f"CREATE DATABASE {'IF NOT EXISTS ' * force}{name} ENGINE = {engine}"
)

def drop_database(self, name: str, *, force: bool = False) -> None:
self.raw_sql(f"DROP DATABASE {'IF EXISTS ' * force}{name}")

def truncate_table(self, name: str, database: str | None = None) -> None:
ident = ".".join(filter(None, (database, name)))
self.raw_sql(f"DELETE FROM {ident}")

def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
ident = ".".join(filter(None, (database, name)))
self.raw_sql(f"DROP TABLE {'IF EXISTS ' * force}{ident}")

def create_table(
self,
name: str,
obj: pd.DataFrame | ir.Table | None = None,
*,
schema: ibis.Schema | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
# backend specific arguments
engine: str | None,
order_by: Iterable[str] | None = None,
partition_by: Iterable[str] | None = None,
sample_by: str | None = None,
settings: Mapping[str, Any] | None = None,
) -> ir.Table:
tmp = "TEMPORARY " * temp
replace = "OR REPLACE " * overwrite
code = f"CREATE {replace}{tmp}TABLE {name}"

if obj is None and schema is None:
raise com.IbisError("The schema or obj parameter is required")

if schema is not None:
code += f" ({schema})"

if isinstance(obj, pd.DataFrame):
obj = ibis.memtable(obj, schema=schema)

if obj is not None:
self._register_in_memory_tables(obj)
query = self.compile(obj)
code += f" AS {query}"

code += f" ENGINE = {engine}"

if order_by is not None:
code += f" ORDER BY {', '.join(util.promote_list(order_by))}"

if partition_by is not None:
code += f" PARTITION BY {', '.join(util.promote_list(partition_by))}"

if sample_by is not None:
code += f" SAMPLE BY {sample_by}"

if settings:
kvs = ", ".join(f"{name}={value!r}" for name, value in settings.items())
code += f" SETTINGS {kvs}"

self.raw_sql(code)
return self.table(name, database=database)

def create_view(
self,
name: str,
obj: ir.Table,
*,
database: str | None = None,
overwrite: bool = False,
) -> ir.Table:
name = ".".join(filter(None, (database, name)))
replace = "OR REPLACE " * overwrite
query = self.compile(obj)
code = f"CREATE {replace}VIEW {name} AS {query}"
self.raw_sql(code)
return self.table(name, database=database)

def drop_view(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
name = ".".join(filter(None, (database, name)))
if_not_exists = "IF EXISTS " * force
self.raw_sql(f"DROP VIEW {if_not_exists}{name}")
31 changes: 20 additions & 11 deletions ibis/backends/clickhouse/compiler/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import sqlglot as sg

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.clickhouse.compiler.values import translate_val
Expand All @@ -13,7 +14,7 @@
@functools.singledispatch
def translate_rel(op: ops.TableNode, **_):
"""Translate a table node into sqlglot."""
raise NotImplementedError(type(op))
raise com.OperationNotDefinedError(f'No translation rule for {type(op)}')


@translate_rel.register(ops.DummyTable)
Expand Down Expand Up @@ -78,7 +79,7 @@ def _aggregation(op: ops.Aggregation, *, table, **kw):
sel = sg.select(*selections).from_(table)

if by:
sel = sel.group_by(*by, dialect="clickhouse")
sel = sel.group_by(*map(str, range(1, len(by) + 1)), dialect="clickhouse")

if predicates := op.predicates:
sel = sel.where(*map(tr_val, predicates), dialect="clickhouse")
Expand All @@ -102,6 +103,7 @@ def _aggregation(op: ops.Aggregation, *, table, **kw):
ops.CrossJoin: "CROSS",
ops.LeftSemiJoin: "LEFT SEMI",
ops.LeftAntiJoin: "LEFT ANTI",
ops.AsOfJoin: "LEFT ASOF",
}


Expand Down Expand Up @@ -149,21 +151,28 @@ def _query(op: ops.SQLQueryResult, *, aliases, **_):
return res.subquery(aliases.get(op, "_"))


_KEYWORD = {
ops.Union: "UNION",
ops.Intersection: "INTERSECT",
ops.Difference: "EXCEPT",
_SET_OP_FUNC = {
ops.Union: sg.union,
ops.Intersection: sg.intersect,
ops.Difference: sg.except_,
}


@translate_rel.register
def _set_op(op: ops.SetOp, *, left, right, **_):
dialect = "clickhouse"
left_query = left.args["this"].sql(dialect=dialect)
right_query = right.args["this"].sql(dialect=dialect)
distinct = "DISTINCT" if op.distinct else "ALL"
return sg.parse_one(
f"{left_query} {_KEYWORD[type(op)]} {distinct} {right_query}", read=dialect

if isinstance(left, sg.exp.Table):
left = sg.select("*", dialect=dialect).from_(left, dialect=dialect)

if isinstance(right, sg.exp.Table):
right = sg.select("*", dialect=dialect).from_(right, dialect=dialect)

return _SET_OP_FUNC[type(op)](
left.args.get("this", left),
right.args.get("this", right),
distinct=op.distinct,
dialect=dialect,
)


Expand Down
Loading