39 changes: 14 additions & 25 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ def _add_table(self, name: str, obj: pl.LazyFrame | pl.DataFrame) -> None:
self._tables[name] = obj
self._context.register(name, obj)

def _remove_table(self, name: str) -> None:
del self._tables[name]
self._context.unregister(name)

def sql(
self, query: str, schema: sch.Schema | None = None, dialect: str | None = None
) -> ir.Table:
Expand Down Expand Up @@ -362,7 +358,7 @@ def create_table(
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
schema: sch.SchemaLike | None = None,
database: str | None = None,
temp: bool | None = None,
overwrite: bool = False,
Expand Down Expand Up @@ -407,6 +403,7 @@ def create_view(
def drop_table(self, name: str, *, force: bool = False) -> None:
if name in self._tables:
del self._tables[name]
self._context.unregister(name)
elif not force:
raise com.IbisError(f"Table {name!r} does not exist")

Expand Down Expand Up @@ -449,17 +446,12 @@ def compile(

return translate(node, ctx=self._context)

def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
import sqlglot as sg

cte = sg.parse_one(str(ibis.to_sql(table, dialect="postgres")), read="postgres")
parsed = sg.parse_one(query, read=self.dialect)
parsed.args["with"] = cte.args.pop("with", [])
parsed = parsed.with_(
sg.to_identifier(name, quoted=True), as_=cte, dialect=self.dialect
)
def _get_sql_string_view_schema(
self, *, name: str, table: ir.Table, query: str
) -> sch.Schema:
from ibis.backends.sql.compilers.postgres import compiler

sql = parsed.sql(self.dialect)
sql = compiler.add_query_to_expr(name=name, table=table, query=query)
return self._get_schema_using_query(sql)

def _get_schema_using_query(self, query: str) -> sch.Schema:
Expand Down Expand Up @@ -527,15 +519,12 @@ def _to_pyarrow_table(
streaming: bool = False,
**kwargs: Any,
):
from ibis.formats.pyarrow import PyArrowData

df = self._to_dataframe(
expr, params=params, limit=limit, streaming=streaming, **kwargs
)
table = df.to_arrow()
if isinstance(expr, (ir.Table, ir.Value)):
schema = expr.as_table().schema().to_pyarrow()
return table.rename_columns(schema.names).cast(schema)
else:
raise com.IbisError(f"Cannot execute expression of type: {type(expr)}")
return PyArrowData.convert_table(df.to_arrow(), expr.as_table().schema())

def to_pyarrow(
self,
Expand All @@ -560,11 +549,11 @@ def to_pyarrow_batches(
table = self._to_pyarrow_table(expr, params=params, limit=limit, **kwargs)
return table.to_reader(chunk_size)

def _load_into_cache(self, name, expr):
self.create_table(name, self.compile(expr).cache())
def _create_cached_table(self, name, expr):
return self.create_table(name, self.compile(expr).cache())

def _clean_up_cached_table(self, name):
self._remove_table(name)
def _drop_cached_table(self, name):
self.drop_table(name, force=True)


@lazy_singledispatch
Expand Down
53 changes: 43 additions & 10 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def table(op, **_):

@translate.register(ops.DummyTable)
def dummy_table(op, **kw):
selections = [translate(arg, **kw) for name, arg in op.values.items()]
selections = [translate(arg, **kw).alias(name) for name, arg in op.values.items()]
return pl.DataFrame().lazy().select(selections)


Expand All @@ -68,12 +68,6 @@ def in_memory_table(op, **_):
return op.data.to_polars(op.schema).lazy()


@translate.register(ops.Alias)
def alias(op, **kw):
arg = translate(op.arg, **kw)
return arg.alias(op.name)


def _make_duration(value, dtype):
kwargs = {f"{dtype.resolution}s": value}
return pl.duration(**kwargs)
Expand Down Expand Up @@ -708,6 +702,7 @@ def struct_column(op, **kw):
ops.All: "all",
ops.Any: "any",
ops.ApproxMedian: "median",
ops.ApproxCountDistinct: "approx_n_unique",
ops.Count: "count",
ops.CountDistinct: "n_unique",
ops.Max: "max",
Expand Down Expand Up @@ -784,6 +779,7 @@ def execute_mode(op, **kw):


@translate.register(ops.Quantile)
@translate.register(ops.ApproxQuantile)
def execute_quantile(op, **kw):
arg = translate(op.arg, **kw)
quantile = translate(op.quantile, **kw)
Expand Down Expand Up @@ -955,6 +951,12 @@ def timestamp_diff(op, **kw):
return left.dt.truncate("1s") - right.dt.truncate("1s")


@translate.register(ops.ArraySort)
def array_sort(op, **kw):
arg = translate(op.arg, **kw)
return arg.list.sort()


@translate.register(ops.ArrayLength)
def array_length(op, **kw):
arg = translate(op.arg, **kw)
Expand Down Expand Up @@ -1006,7 +1008,14 @@ def array_collect(op, in_group_by=False, **kw):

@translate.register(ops.ArrayFlatten)
def array_flatten(op, **kw):
return pl.concat_list(translate(op.arg, **kw))
result = translate(op.arg, **kw)
return (
pl.when(result.is_null())
.then(None)
.when(result.list.len() == 0)
.then([])
.otherwise(result.flatten())
)


_date_methods = {
Expand Down Expand Up @@ -1107,9 +1116,24 @@ def comparison(op, **kw):
@translate.register(ops.Between)
def between(op, **kw):
op_arg = op.arg
arg_dtype = op_arg.dtype

arg = translate(op_arg, **kw)
lower = translate(op.lower_bound, **kw)
upper = translate(op.upper_bound, **kw)

dtype = PolarsType.from_ibis(arg_dtype)

lower_bound = op.lower_bound
lower = translate(lower_bound, **kw)

if lower_bound.dtype != arg_dtype:
lower = lower.cast(dtype)

upper_bound = op.upper_bound
upper = translate(upper_bound, **kw)

if upper_bound.dtype != arg_dtype:
upper = upper.cast(dtype)

return arg.is_between(lower, upper, closed="both")


Expand Down Expand Up @@ -1419,3 +1443,12 @@ def execute_group_concat(op, **kw):
arg = arg.sort_by(keys, descending=descending)

return pl.when(arg.count() > 0).then(arg.str.join(sep)).otherwise(None)


@translate.register(ops.DateDelta)
def execute_date_delta(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)
delta = left - right
method_name = f"total_{_literal_value(op.part)}s"
return getattr(delta.dt, method_name)()
5 changes: 0 additions & 5 deletions ibis/backends/polars/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,3 @@ def con(data_dir, tmp_path_factory, worker_id):
@pytest.fixture(scope="session")
def alltypes(con):
return con.table("functional_alltypes")


@pytest.fixture(scope="session")
def alltypes_df(alltypes):
return alltypes.execute()
39 changes: 39 additions & 0 deletions ibis/backends/polars/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import pytest

import ibis
from ibis.backends.tests.errors import PolarsSQLInterfaceError
from ibis.util import gen_name

pd = pytest.importorskip("pandas")
tm = pytest.importorskip("pandas.testing")


def test_cannot_run_sql_after_drop(con):
t = con.table("functional_alltypes")
n = t.count().execute()

name = gen_name("polars_dot_sql")
con.create_table(name, t)

sql = f"SELECT COUNT(*) FROM {name}"

expr = con.sql(sql)
result = expr.execute()
assert result.iat[0, 0] == n

con.drop_table(name)
with pytest.raises(PolarsSQLInterfaceError):
con.sql(sql)


def test_array_flatten(con):
data = {"id": range(3), "happy": [[["abc"]], [["bcd"]], [["def"]]]}
t = ibis.memtable(data)
expr = t.select("id", flat=t.happy.flatten()).order_by("id")
result = con.to_pyarrow(expr)
expected = pd.DataFrame(
{"id": data["id"], "flat": [row[0] for row in data["happy"]]}
)
tm.assert_frame_equal(result.to_pandas(), expected)
37 changes: 18 additions & 19 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,28 +231,25 @@ def do_connect(
>>> user = os.environ.get("IBIS_TEST_POSTGRES_USER", getpass.getuser())
>>> password = os.environ.get("IBIS_TEST_POSTGRES_PASSWORD")
>>> database = os.environ.get("IBIS_TEST_POSTGRES_DATABASE", "ibis_testing")
>>> con = connect(database=database, host=host, user=user, password=password)
>>> con = ibis.postgres.connect(database=database, host=host, user=user, password=password)
>>> con.list_tables() # doctest: +ELLIPSIS
[...]
>>> t = con.table("functional_alltypes")
>>> t
PostgreSQLTable[table]
name: functional_alltypes
schema:
id : int32
bool_col : boolean
tinyint_col : int16
smallint_col : int16
int_col : int32
bigint_col : int64
float_col : float32
double_col : float64
date_string_col : string
string_col : string
timestamp_col : timestamp
year : int32
month : int32
DatabaseTable: functional_alltypes
id int32
bool_col boolean
tinyint_col int16
smallint_col int16
int_col int32
bigint_col int64
float_col float32
double_col float64
date_string_col string
string_col string
timestamp_col timestamp(6)
year int32
month int32
"""
import psycopg2
import psycopg2.extras
Expand Down Expand Up @@ -626,7 +623,7 @@ def create_table(
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
schema: sch.SchemaLike | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
Expand Down Expand Up @@ -655,6 +652,8 @@ def create_table(
"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")
if schema is not None:
schema = ibis.schema(schema)

properties = []

Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/postgres/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,3 @@ def gdf(geotable):
@pytest.fixture(scope="module")
def intervals(con):
return con.table("intervals")


@pytest.fixture
def translate():
from ibis.backends.postgres import Backend

context = Backend.compiler.make_context()
return lambda expr: Backend.compiler.translator_class(expr, context).get_result()
25 changes: 0 additions & 25 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,31 +959,6 @@ def test_array_concat_mixed_types(array_types):
array_types.y + array_types.x.cast("array<double>")


@pytest.fixture
def t(con, temp_table):
with con.begin() as c:
c.execute(f"CREATE TABLE {temp_table} (id SERIAL PRIMARY KEY, name TEXT)")
return con.table(temp_table)


@pytest.fixture
def s(con, t, temp_table2):
temp_table = t.op().name
assert temp_table != temp_table2

with con.begin() as c:
c.execute(
f"""
CREATE TABLE {temp_table2} (
id SERIAL PRIMARY KEY,
left_t_id INTEGER REFERENCES {temp_table},
cost DOUBLE PRECISION
)
"""
)
return con.table(temp_table2)


@pytest.fixture
def trunc(con, temp_table):
quoted = temp_table
Expand Down
31 changes: 22 additions & 9 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ibis.backends.pyspark.converter import PySparkPandasData
from ibis.backends.pyspark.datatypes import PySparkSchema, PySparkType
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import AlterTable
from ibis.expr.operations.udf import InputType
from ibis.legacy.udf.vectorized import _coerce_to_series
from ibis.util import deprecated
Expand Down Expand Up @@ -180,7 +181,19 @@ def do_connect(
# local time to UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
self._session.conf.set("spark.sql.session.timeZone", "UTC")
self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")

# Databricks Serverless compute only supports limited properties
# and any attempt to set unsupported properties will result in an error.
# https://docs.databricks.com/en/spark/conf.html
try:
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
except ImportError:
# Use a dummy class for when spark connect is not available
class SparkConnectGrpcException(Exception):
pass

with contextlib.suppress(SparkConnectGrpcException):
self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")

for key, value in kwargs.items():
self._session.conf.set(key, value)
Expand Down Expand Up @@ -534,7 +547,7 @@ def create_table(
ir.Table | pd.DataFrame | pa.Table | pl.DataFrame | pl.LazyFrame | None
) = None,
*,
schema: sch.Schema | None = None,
schema: sch.SchemaLike | None = None,
database: str | None = None,
temp: bool | None = None,
overwrite: bool = False,
Expand Down Expand Up @@ -595,6 +608,7 @@ def create_table(
df = self._session.sql(query)
df.write.saveAsTable(name, format=format, mode=mode)
elif schema is not None:
schema = ibis.schema(schema)
schema = PySparkSchema.from_ibis(schema)
with self._active_catalog_database(catalog, db):
self._session.catalog.createTable(name, schema=schema, format=format)
Expand Down Expand Up @@ -659,10 +673,8 @@ def rename_table(self, old_name: str, new_name: str) -> None:
"""
old = sg.table(old_name, quoted=True)
new = sg.table(new_name, quoted=True)
query = sge.AlterTable(
this=old,
exists=False,
actions=[sge.RenameTable(this=new, exists=True)],
query = AlterTable(
this=old, exists=False, actions=[sge.RenameTable(this=new, exists=True)]
)
with self._safe_raw_sql(query):
pass
Expand Down Expand Up @@ -692,16 +704,17 @@ def compute_stats(
)
return self.raw_sql(f"ANALYZE TABLE {table} COMPUTE STATISTICS{maybe_noscan}")

def _load_into_cache(self, name, expr):
def _create_cached_table(self, name, expr):
query = self.compile(expr)
t = self._session.sql(query).cache()
assert t.is_cached
t.createOrReplaceTempView(name)
# store the underlying spark dataframe so we can release memory when
# asked to, instead of when the session ends
self._cached_dataframes[name] = t
return self.table(name)

def _clean_up_cached_table(self, name):
def _drop_cached_table(self, name):
self._session.catalog.dropTempView(name)
t = self._cached_dataframes.pop(name)
assert t.is_cached
Expand Down Expand Up @@ -1286,7 +1299,7 @@ def _to_filesystem_output(
df = df.write.format(format)
for k, v in (options or {}).items():
df = df.option(k, v)
df.save(path)
df.save(os.fspath(path))
return None
sq = df.writeStream.format(format)
sq = sq.option("path", os.fspath(path))
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/pyspark/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,6 @@ def test_drop_view(con, created_view):
assert created_view not in con.list_tables()


@pytest.fixture
def table(con, temp_database):
table_name = f"table_{util.guid()}"
schema = ibis.schema([("foo", "string"), ("bar", "int64")])
yield con.create_table(
table_name, database=temp_database, schema=schema, format="parquet"
)
con.drop_table(table_name, database=temp_database)


@pytest.fixture
def keyword_t(con):
yield "distinct"
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends.postgres import Backend as PostgresBackend
Expand Down Expand Up @@ -130,7 +131,7 @@ def create_table(
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
schema: sch.SchemaLike | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
Expand Down Expand Up @@ -177,6 +178,8 @@ def create_table(
"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")
if schema is not None:
schema = ibis.schema(schema)

if connector_properties is not None and (
encode_format is None or data_format is None
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/risingwave/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,3 @@ def alltypes(con):
@pytest.fixture(scope="module")
def df(alltypes):
return alltypes.execute()


@pytest.fixture(scope="module")
def intervals(con):
return con.table("intervals")
8 changes: 4 additions & 4 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def get_schema(
# snowflake puts temp tables in the same catalog and database as
# non-temp tables and differentiates between them using a different
# mechanism than other database that often put temp tables in a hidden
# or intentionall-difficult-to-access catalog/database
# or intentionally-difficult-to-access catalog/database
table = sg.table(
table_name, db=database, catalog=catalog, quoted=self.compiler.quoted
)
Expand Down Expand Up @@ -765,7 +765,7 @@ def create_table(
| pl.LazyFrame
| None = None,
*,
schema: sch.Schema | None = None,
schema: sch.SchemaLike | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
Expand Down Expand Up @@ -797,6 +797,8 @@ def create_table(
"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")
if schema is not None:
schema = ibis.schema(schema)

quoted = self.compiler.quoted

Expand Down Expand Up @@ -1200,8 +1202,6 @@ def insert(
The name of the table to which data needs will be inserted
obj
The source data or expression to insert
schema
The name of the schema that the table is located in
schema
[deprecated] The name of the schema that the table is located in
database
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_timestamp_tz_column(simple_con):
ibis.util.gen_name("snowflake_timestamp_tz_column"),
schema=ibis.schema({"ts": "string"}),
temp=True,
).mutate(ts=lambda t: t.ts.to_timestamp("YYYY-MM-DD HH24-MI-SS"))
).mutate(ts=lambda t: t.ts.as_timestamp("YYYY-MM-DD HH24-MI-SS"))
expr = t.ts
assert expr.execute().empty

Expand Down
33 changes: 9 additions & 24 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,18 +191,10 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
The schema inferred from `query`
"""

def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
compiler = self.compiler
dialect = compiler.dialect

cte = compiler.to_sqlglot(table)
parsed = sg.parse_one(query, read=dialect)
parsed.args["with"] = cte.args.pop("with", [])
parsed = parsed.with_(
sg.to_identifier(name, quoted=compiler.quoted), as_=cte, dialect=dialect
)

sql = parsed.sql(dialect)
def _get_sql_string_view_schema(
self, *, name: str, table: ir.Table, query: str
) -> sch.Schema:
sql = self.compiler.add_query_to_expr(name=name, table=table, query=query)
return self._get_schema_using_query(sql)

def _register_udfs(self, expr: ir.Expr) -> None:
Expand Down Expand Up @@ -262,12 +254,6 @@ def drop_view(
with self._safe_raw_sql(src):
pass

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, name):
self.drop_table(name, force=True)

def execute(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -431,14 +417,13 @@ def _build_insert_from_table(
compiler = self.compiler
quoted = compiler.quoted
# Compare the columns between the target table and the object to be inserted
# If they don't match, assume auto-generated column names and use positional
# ordering.
source_cols = source.columns
# If source is a subset of target, use source columns for insert list
# Otherwise, assume auto-generated column names and use positional ordering.
target_cols = self.get_schema(target).keys()

columns = (
source_cols
if not set(target_cols := self.get_schema(target).names).difference(
source_cols
)
if (source_cols := source.schema().keys()) <= target_cols
else target_cols
)

Expand Down
94 changes: 69 additions & 25 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import lower_stringslice

try:
from sqlglot.expressions import Alter
except ImportError:
from sqlglot.expressions import AlterTable
else:

def AlterTable(*args, kind="TABLE", **kwargs):
return Alter(*args, kind=kind, **kwargs)


if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Mapping

Expand Down Expand Up @@ -256,7 +266,10 @@ class SQLGlotCompiler(abc.ABC):
one_to_zero_index,
add_one_to_nth_value_input,
)
"""A sequence of rewrites to apply to the expression tree before compilation."""
"""A sequence of rewrites to apply to the expression tree before SQL-specific transforms."""

post_rewrites: tuple[type[pats.Replace], ...] = ()
"""A sequence of rewrites to apply to the expression tree after SQL-specific transforms."""

no_limit_value: sge.Null | None = None
"""The value to use to indicate no limit."""
Expand Down Expand Up @@ -524,7 +537,9 @@ def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If:
)

def cast(self, arg, to: dt.DataType) -> sge.Cast:
return sg.cast(sge.convert(arg), to=self.type_mapper.from_ibis(to), copy=False)
return sge.Cast(
this=sge.convert(arg), to=self.type_mapper.from_ibis(to), copy=False
)

def _prepare_params(self, params):
result = {}
Expand Down Expand Up @@ -594,6 +609,7 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
op,
params=params,
rewrites=self.rewrites,
post_rewrites=self.post_rewrites,
fuse_selects=options.sql.fuse_selects,
)

Expand Down Expand Up @@ -659,6 +675,11 @@ def visit_Field(self, op, *, rel, name):
)

def visit_Cast(self, op, *, arg, to):
from_ = op.arg.dtype

if from_.is_integer() and to.is_interval():
return self._make_interval(arg, to.unit)

return self.cast(arg, to)

def visit_ScalarSubquery(self, op, *, rel):
Expand Down Expand Up @@ -941,10 +962,11 @@ def visit_DayOfWeekName(self, op, *, arg):
ifs=list(itertools.starmap(self.if_, enumerate(calendar.day_name))),
)

def _make_interval(self, arg, unit):
return sge.Interval(this=arg, unit=self.v[unit.singular])

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(
this=sge.convert(arg), unit=sge.Var(this=unit.singular.upper())
)
return self._make_interval(arg, unit)

### String Instruments
def visit_Strip(self, op, *, arg):
Expand Down Expand Up @@ -1038,19 +1060,6 @@ def visit_Max(self, op, *, arg, where):

### Stats

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
funcname = f"percentile_{suffix}"
expr = sge.WithinGroup(
this=self.f[funcname](quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)
if where is not None:
expr = sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

visit_MultiQuantile = visit_Quantile

def visit_VarianceStandardDevCovariance(self, op, *, how, where, **kw):
hows = {"sample": "samp", "pop": "pop"}
funcs = {
Expand Down Expand Up @@ -1252,9 +1261,11 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
else:
yield value.as_(name, quoted=self.quoted, copy=False)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -1281,6 +1292,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_DummyTable(self, op, *, values):
Expand Down Expand Up @@ -1465,11 +1479,6 @@ def visit_Limit(self, op, *, parent, n, offset):
return result.subquery(alias, copy=False)
return result

def visit_Distinct(self, op, *, parent):
return (
sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False)
)

def visit_CTE(self, op, *, parent):
return sg.table(parent.alias_or_name, quoted=self.quoted)

Expand Down Expand Up @@ -1607,6 +1616,41 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
)
return sg.select(*columns_to_keep).from_(parent)

def add_query_to_expr(self, *, name: str, table: ir.Table, query: str) -> str:
dialect = self.dialect

compiled_ibis_expr = self.to_sqlglot(table)

# pull existing CTEs from the compiled Ibis expression and combine them
# with the new query
parsed = reduce(
lambda parsed, cte: parsed.with_(cte.args["alias"], as_=cte.args["this"]),
compiled_ibis_expr.ctes,
sg.parse_one(query, read=dialect),
)

# remove all ctes from the compiled expression, since they're now in
# our larger expression
compiled_ibis_expr.args.pop("with", None)

# add the new str query as a CTE
parsed = parsed.with_(
sg.to_identifier(name, quoted=self.quoted), as_=compiled_ibis_expr
)

# generate the SQL string
return parsed.sql(dialect)

def _make_sample_backwards_compatible(self, *, sample, parent):
# sample was changed to be owned by the table being sampled in 25.17.0
#
# this is a small workaround for backwards compatibility
if "this" in sample.__class__.arg_types:
sample.args["this"] = parent
else:
parent.args["sample"] = sample
return sg.select(STAR).from_(parent)


# `__init_subclass__` is uncalled for subclasses - we manually call it here to
# autogenerate the base class implementations as well.
Expand Down
63 changes: 58 additions & 5 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import decimal
import math
import re
from typing import TYPE_CHECKING, Any

Expand All @@ -20,6 +22,7 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
split_select_distinct_with_order_by,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit

Expand Down Expand Up @@ -111,6 +114,7 @@ class BigQueryCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_rank,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)

supports_qualify = True

Expand All @@ -120,8 +124,6 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.ExtractUserInfo,
ops.FindInSet,
ops.Median,
ops.Quantile,
ops.MultiQuantile,
ops.RegexSplit,
ops.RowID,
ops.TimestampDiff,
Expand Down Expand Up @@ -394,6 +396,41 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):

return sge.GroupConcat(this=arg, separator=sep)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not isinstance(op.quantile, ops.Literal):
raise com.UnsupportedOperationError(
"quantile must be a literal in BigQuery"
)

# BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return
# `resolution + 1` quantiles array. To handle this, we compute the
# resolution ourselves then restructure the output array as needed.
# To avoid excessive resolution we arbitrarily cap it at 100,000 -
# since these are approximate quantiles anyway this seems fine.
quantiles = util.promote_list(op.quantile.value)
fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles]
resolution = min(math.lcm(*(den for _, den in fracs)), 100_000)
indices = [(num * resolution) // den for num, den in fracs]

if where is not None:
arg = self.if_(where, arg, NULL)

if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)

array = self.f.approx_quantiles(
arg, sge.IgnoreNulls(this=sge.convert(resolution))
)
if isinstance(op, ops.ApproxQuantile):
return array[indices[0]]

if indices == list(range(resolution + 1)):
return array
else:
return sge.Array(expressions=[array[i] for i in indices])

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_FloorDivide(self, op, *, left, right):
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)

Expand Down Expand Up @@ -584,8 +621,6 @@ def visit_Cast(self, op, *, arg, to):
f"BigQuery does not allow extracting date part `{from_.unit}` from intervals"
)
return self.f.extract(self.v[to.resolution.upper()], arg)
elif from_.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.singular])
elif from_.is_floating() and to.is_integer():
return self.cast(self.f.trunc(arg), dt.int64)
return super().visit_Cast(op, arg=arg, to=to)
Expand Down Expand Up @@ -891,7 +926,25 @@ def visit_HashBytes(self, op, *, arg, how):

@staticmethod
def _gen_valid_name(name: str) -> str:
return "_".join(map(str.strip, _NAME_REGEX.findall(name))) or "tmp"
candidate = "_".join(map(str.strip, _NAME_REGEX.findall(name))) or "tmp"
# column names cannot be longer than 300 characters
#
# https://cloud.google.com/bigquery/docs/schemas#column_names
#
# it's easy to rename columns, so raise an exception telling the user
# to do so
#
# we could potentially relax this and support arbitrary-length columns
# by compressing the information using hashing, but there's no reason
# to solve that problem until someone encounters this error and cannot
# rename their columns
limit = 300
if len(candidate) > limit:
raise com.IbisError(
f"BigQuery does not allow column names longer than {limit:d} characters. "
"Please rename your columns to have fewer characters."
)
return candidate

def visit_CountStar(self, op, *, arg, where):
if where is not None:
Expand Down
51 changes: 28 additions & 23 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,24 @@ def visit_CountStar(self, op, *, where, arg):
return self.f.countIf(where)
return sge.Count(this=STAR)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is None:
return self.agg.quantile(arg, quantile, where=where)

func = "quantile" + "s" * isinstance(op, ops.MultiQuantile)
def _visit_quantile(self, func, arg, quantile, where):
return sge.ParameterizedAgg(
this=f"{func}If",
this=f"{func}If" if where is not None else func,
expressions=util.promote_list(quantile),
params=[arg, where],
params=[arg, where] if where is not None else [arg],
)

visit_MultiQuantile = visit_Quantile
def visit_Quantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantile", arg, quantile, where)

def visit_MultiQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantiles", arg, quantile, where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantileTDigest", arg, quantile, where)

def visit_ApproxMultiQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantilesTDigest", arg, quantile, where)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "pop":
Expand Down Expand Up @@ -372,24 +378,23 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
return self.f.toDateTime(arg)

def visit_TimestampTruncate(self, op, *, arg, unit):
converters = {
"Y": "toStartOfYear",
"Q": "toStartOfQuarter",
"M": "toStartOfMonth",
"W": "toMonday",
"D": "toDate",
"h": "toStartOfHour",
"m": "toStartOfMinute",
"s": "toDateTime",
}
if (short := unit.short) == "W":
func = "toMonday"
else:
func = f"toStartOf{unit.singular.capitalize()}"

unit = unit.short
if (converter := converters.get(unit)) is None:
raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}")
if short in ("s", "ms", "us", "ns"):
arg = self.f.toDateTime64(arg, op.arg.dtype.scale or 0)
return self.f[func](arg)

return self.f[converter](arg)
visit_TimeTruncate = visit_TimestampTruncate

visit_TimeTruncate = visit_DateTruncate = visit_TimestampTruncate
def visit_DateTruncate(self, op, *, arg, unit):
if unit.short == "W":
func = "toMonday"
else:
func = f"toStartOf{unit.singular.capitalize()}"
return self.f[func](arg)

def visit_TimestampBucket(self, op, *, arg, interval, offset):
if offset is not None:
Expand Down
9 changes: 7 additions & 2 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType

Expand All @@ -26,6 +27,8 @@ class DataFusionCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True, supports_order_by=True)

post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
Expand All @@ -40,8 +43,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.Greatest,
ops.IntervalFromInteger,
ops.Least,
ops.MultiQuantile,
ops.Quantile,
ops.RowID,
ops.Strftime,
ops.TimeDelta,
Expand All @@ -53,6 +54,7 @@ class DataFusionCompiler(SQLGlotCompiler):
)

SIMPLE_OPS = {
ops.ApproxQuantile: "approx_percentile_cont",
ops.ApproxMedian: "approx_median",
ops.ArrayRemove: "array_remove_all",
ops.BitAnd: "bit_and",
Expand Down Expand Up @@ -500,5 +502,8 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
op, arg=arg, sep=sep, where=where, order_by=order_by
)

def visit_ArrayFlatten(self, op, *, arg):
return self.if_(arg.is_(NULL), NULL, self.f.flatten(arg))


compiler = DataFusionCompiler()
2 changes: 0 additions & 2 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ class DruidCompiler(SQLGlotCompiler):
ops.IsInf,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.RegexReplace,
ops.RegexSplit,
Expand Down
11 changes: 9 additions & 2 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
this=parent,
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)

return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_ArraySlice(self, op, *, arg, start, stop):
arg_length = self.f.len(arg)
Expand Down Expand Up @@ -532,6 +532,13 @@ def visit_Quantile(self, op, *, arg, quantile, where):
def visit_MultiQuantile(self, op, *, arg, quantile, where):
return self.visit_Quantile(op, arg=arg, quantile=quantile, where=where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
return self.agg.approx_quantile(arg, quantile, where=where)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_HexDigest(self, op, *, arg, how):
if how in ("md5", "sha256"):
return getattr(self.f, how)(arg)
Expand Down
55 changes: 46 additions & 9 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compilers.base import NULL, SQLGlotCompiler
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.datatypes import ExasolType
from ibis.backends.sql.dialects import Exasol
from ibis.backends.sql.rewrites import (
Expand All @@ -32,7 +32,6 @@ class ExasolCompiler(SQLGlotCompiler):

UNSUPPORTED_OPS = (
ops.AnalyticVectorizedUDF,
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
Expand All @@ -45,7 +44,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.ArrayUnion,
ops.ArrayZip,
ops.BitwiseNot,
ops.Covariance,
ops.CumeDist,
ops.DateAdd,
ops.DateSub,
Expand All @@ -56,17 +54,14 @@ class ExasolCompiler(SQLGlotCompiler):
ops.IsInf,
ops.IsNan,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.ReductionVectorizedUDF,
ops.RegexExtract,
ops.RegexReplace,
ops.RegexSearch,
ops.RegexSplit,
ops.RowID,
ops.StandardDev,
ops.Strftime,
ops.StringJoin,
ops.StringSplit,
Expand All @@ -80,7 +75,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.TimestampSub,
ops.TypeOf,
ops.Unnest,
ops.Variance,
)

SIMPLE_OPS = {
Expand Down Expand Up @@ -125,6 +119,20 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
def visit_Date(self, op, *, arg):
return self.cast(arg, dt.date)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
raise com.UnsupportedOperationError(
"Exasol only implements `pop` correlation coefficient"
)

if (left_type := op.left.dtype).is_boolean():
left = self.cast(left, dt.Int32(nullable=left_type.nullable))

if (right_type := op.right.dtype).is_boolean():
right = self.cast(right, dt.Int32(nullable=right_type.nullable))

return self.agg.corr(left, right, where=where)

def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)
Expand Down Expand Up @@ -183,11 +191,40 @@ def visit_StringConcat(self, op, *, arg):
any_args_null = (a.is_(NULL) for a in arg)
return self.if_(sg.or_(*any_args_null), NULL, self.f.concat(*arg))

def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_CountStar(self, op, *, arg, where):
if where is not None:
return self.f.sum(self.cast(where, op.dtype))
return self.f.count(STAR)

def visit_CountDistinctStar(self, op, *, arg, where):
raise com.UnsupportedOperationError(
"COUNT(DISTINCT *) is not supported in Exasol"
cols = [sg.column(k, quoted=self.quoted) for k in op.arg.schema.keys()]
if where is not None:
cols = [self.if_(where, c, NULL) for c in cols]
row = sge.Tuple(expressions=cols)
return self.f.count(sge.Distinct(expressions=[row]))

def visit_Median(self, op, *, arg, where):
return self.visit_Quantile(op, arg=arg, quantile=sge.convert(0.5), where=where)

visit_ApproxMedian = visit_Median

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
funcname = f"percentile_{suffix}"
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.WithinGroup(
this=self.f[funcname](quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

visit_ApproxQuantile = visit_Quantile

def visit_TimestampTruncate(self, op, *, arg, unit):
short_name = unit.short
unit_mapping = {"W": "IW"}
Expand Down
20 changes: 17 additions & 3 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArgMax,
ops.ArgMin,
ops.ArrayFlatten,
ops.ArraySort,
ops.ArrayStringJoin,
ops.Correlation,
ops.CountDistinctStar,
Expand All @@ -84,9 +83,7 @@ class FlinkCompiler(SQLGlotCompiler):
ops.IsNan,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.ReductionVectorizedUDF,
ops.RegexSplit,
ops.RowID,
Expand All @@ -102,6 +99,7 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArrayLength: "cardinality",
ops.ArrayPosition: "array_position",
ops.ArrayRemove: "array_remove",
ops.ArraySort: "array_sort",
ops.ArrayUnion: "array_union",
ops.ExtractDayOfYear: "dayofyear",
ops.MapKeys: "map_keys",
Expand Down Expand Up @@ -575,5 +573,21 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# the only way to get filtering *and* respecting nulls is to use
# `FILTER` syntax, but it's broken in various ways for other aggregates
out = self.f.array_agg(arg)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if where is not None:
out = sge.Filter(this=out, expression=sge.Where(this=where))
return out


compiler = FlinkCompiler()
18 changes: 12 additions & 6 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.ArrayPosition,
ops.Array,
ops.Covariance,
ops.DateDelta,
ops.ExtractDayOfYear,
ops.Levenshtein,
ops.Map,
ops.Median,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.RegexSplit,
ops.RowID,
ops.StringSplit,
Expand Down Expand Up @@ -191,9 +188,7 @@ def visit_NonNullLiteral(self, op, *, value, dtype):

def visit_Cast(self, op, *, arg, to):
from_ = op.arg.dtype
if from_.is_integer() and to.is_interval():
return sge.Interval(this=sge.convert(arg), unit=to.unit.singular.upper())
elif from_.is_temporal() and to.is_integer():
if from_.is_temporal() and to.is_integer():
return 1_000_000 * self.f.unix_timestamp(arg)
return super().visit_Cast(op, arg=arg, to=to)

Expand Down Expand Up @@ -318,5 +313,16 @@ def visit_Sign(self, op, *, arg):
return self.cast(sign, dtype)
return sign

def visit_DateDelta(self, op, *, left, right, part):
if not isinstance(part, sge.Literal):
raise com.UnsupportedOperationError(
"Only literal `part` values are supported for date delta"
)
if part.this != "day":
raise com.UnsupportedOperationError(
f"Only 'day' part is supported for date delta in the {self.dialect} backend"
)
return self.f.datediff(left, right)


compiler = ImpalaCompiler()
42 changes: 32 additions & 10 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
split_select_distinct_with_order_by,
)
from ibis.common.deferred import var

Expand Down Expand Up @@ -69,6 +70,7 @@ class MSSQLCompiler(SQLGlotCompiler):
rewrite_rows_range_order_by_window,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)
copy_func_args = True

UNSUPPORTED_OPS = (
Expand All @@ -86,24 +88,20 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.BitXor,
ops.Covariance,
ops.CountDistinctStar,
ops.DateAdd,
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.IsInf,
ops.IsNan,
ops.LPad,
ops.Levenshtein,
ops.Map,
ops.Median,
ops.Mode,
ops.MultiQuantile,
ops.NthValue,
ops.Quantile,
ops.RegexExtract,
ops.RegexReplace,
ops.RegexSearch,
Expand All @@ -115,9 +113,7 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimestampAdd,
ops.TimestampDiff,
ops.TimestampSub,
ops.Unnest,
)

Expand Down Expand Up @@ -225,6 +221,14 @@ def visit_CountDistinct(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.WithinGroup(
this=self.f.approx_percentile_cont(quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg, nulls_first=True)]),
)

def visit_DayOfWeekIndex(self, op, *, arg):
return self.f.datepart(self.v.weekday, arg) - 1

Expand Down Expand Up @@ -477,9 +481,11 @@ def visit_All(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -498,7 +504,23 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_TimestampAdd(self, op, *, left, right):
return self.f.dateadd(
right.unit, self.cast(right.this, dt.int64), left, dialect=self.dialect
)

def visit_TimestampSub(self, op, *, left, right):
return self.f.dateadd(
right.unit, -self.cast(right.this, dt.int64), left, dialect=self.dialect
)

visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub


compiler = MSSQLCompiler()
27 changes: 6 additions & 21 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def POS_INF(self):
ops.Levenshtein,
ops.Median,
ops.Mode,
ops.MultiQuantile,
ops.Quantile,
ops.RegexReplace,
ops.RegexSplit,
ops.RowID,
Expand Down Expand Up @@ -119,23 +117,15 @@ def visit_Cast(self, op, *, arg, to):
# MariaDB does not support casting to JSON because it's an alias
# for TEXT (except when casting of course!)
return arg
elif from_.is_integer() and to.is_interval():
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit=to.unit), arg=arg, unit=to.unit
)
elif from_.is_integer() and to.is_timestamp():
return self.f.from_unixtime(arg)
return super().visit_Cast(op, arg=arg, to=to)

def visit_TimestampDiff(self, op, *, left, right):
return self.f.timestampdiff(
sge.Var(this="SECOND"), right, left, dialect=self.dialect
)
return self.f.timestampdiff(self.v.SECOND, right, left, dialect=self.dialect)

def visit_DateDiff(self, op, *, left, right):
return self.f.timestampdiff(
sge.Var(this="DAY"), right, left, dialect=self.dialect
)
return self.f.timestampdiff(self.v.DAY, right, left, dialect=self.dialect)

def visit_ApproxCountDistinct(self, op, *, arg, where):
if where is not None:
Expand Down Expand Up @@ -317,16 +307,16 @@ def visit_DateTimestampTruncate(self, op, *, arg, unit):

def visit_DateTimeDelta(self, op, *, left, right, part):
return self.f.timestampdiff(
sge.Var(this=part.this), right, left, dialect=self.dialect
self.v[part.this], right, left, dialect=self.dialect
)

visit_TimeDelta = visit_DateDelta = visit_DateTimeDelta

def visit_ExtractMillisecond(self, op, *, arg):
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg) / 1_000)
return self.f.floor(self.f.extract(self.v.microsecond, arg) / 1_000)

def visit_ExtractMicrosecond(self, op, *, arg):
return self.f.floor(self.f.extract(sge.Var(this="microsecond"), arg))
return self.f.floor(self.f.extract(self.v.microsecond, arg))

def visit_Strip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="BOTH")
Expand All @@ -337,14 +327,9 @@ def visit_LStrip(self, op, *, arg):
def visit_RStrip(self, op, *, arg):
return self.visit_LRStrip(op, arg=arg, position="TRAILING")

def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=sge.Var(this=op.resolution.upper()))

def visit_TimestampAdd(self, op, *, left, right):
if op.right.dtype.unit.short == "ms":
right = sge.Interval(
this=right.this * 1_000, unit=sge.Var(this="MICROSECOND")
)
right = sge.Interval(this=right.this * 1_000, unit=self.v.MICROSECOND)
return self.f.date_add(left, right, dialect=self.dialect)

def visit_UnwrapJSONString(self, op, *, arg):
Expand Down
85 changes: 63 additions & 22 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import string

import sqlglot as sg
import sqlglot.expressions as sge
import toolz
Expand Down Expand Up @@ -55,19 +57,16 @@ class OracleCompiler(SQLGlotCompiler):
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArrayStringJoin,
ops.Mode,
ops.MultiQuantile,
ops.RegexSplit,
ops.StringSplit,
ops.TimeTruncate,
ops.Bucket,
ops.TimestampBucket,
ops.TimeDelta,
ops.DateDelta,
ops.TimestampDelta,
ops.TimestampFromYMDHMS,
ops.TimeFromHMS,
ops.IntervalFromInteger,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.DateDiff,
Expand All @@ -84,12 +83,12 @@ class OracleCompiler(SQLGlotCompiler):
ops.BitOr: "bit_or_agg",
ops.BitXor: "bit_xor_agg",
ops.BitwiseAnd: "bitand",
ops.Hash: "hash",
ops.Hash: "ora_hash",
ops.LPad: "lpad",
ops.RPad: "rpad",
ops.StringAscii: "ascii",
ops.Strip: "trim",
ops.Hash: "ora_hash",
ops.Mode: "stats_mode",
}

@staticmethod
Expand Down Expand Up @@ -138,30 +137,37 @@ def visit_Literal(self, op, *, value, dtype):
elif dtype.is_uuid():
return sge.convert(str(value))
elif dtype.is_interval():
if dtype.unit.short in ("Y", "M"):
return self.f.numtoyminterval(value, dtype.unit.name)
elif dtype.unit.short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(value, dtype.unit.name)
else:
raise com.UnsupportedOperationError(
f"Intervals with precision {dtype.unit.name} not supported in Oracle."
)
return self._value_to_interval(value, dtype.unit)

return super().visit_Literal(op, value=value, dtype=dtype)

def _value_to_interval(self, arg, unit):
short = unit.short

if short in ("Y", "M"):
return self.f.numtoyminterval(arg, unit.singular)
elif short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(arg, unit.singular)
elif short == "ms":
return self.f.numtodsinterval(arg / 1e3, "second")
elif short in "us":
return self.f.numtodsinterval(arg / 1e6, "second")
elif short in "ns":
return self.f.numtodsinterval(arg / 1e9, "second")
else:
raise com.UnsupportedArgumentError(
f"Interval {unit.name} not supported by Oracle"
)

def visit_Cast(self, op, *, arg, to):
if to.is_interval():
from_ = op.arg.dtype
if from_.is_numeric() and to.is_interval():
# CASTing to an INTERVAL in Oracle requires specifying digits of
# precision that are a pain. There are two helper functions that
# should be used instead.
if to.unit.short in ("D", "h", "m", "s"):
return self.f.numtodsinterval(arg, to.unit.name)
elif to.unit.short in ("Y", "M"):
return self.f.numtoyminterval(arg, to.unit.name)
else:
raise com.UnsupportedArgumentError(
f"Interval {to.unit.name} not supported by Oracle"
)
return self._value_to_interval(arg, to.unit)
elif from_.is_string() and to.is_date():
return self.f.to_date(arg, "FXYYYY-MM-DD")
return self.cast(arg, to)

def visit_Limit(self, op, *, parent, n, offset):
Expand Down Expand Up @@ -302,6 +308,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):
)
return expr

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg)

return sge.WithinGroup(
this=self.f.approx_percentile(quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg)
Expand Down Expand Up @@ -457,5 +472,31 @@ def visit_GroupConcat(self, op, *, arg, where, sep, order_by):

return out

def visit_IntervalFromInteger(self, op, *, arg, unit):
return self._value_to_interval(arg, unit)

def visit_DateFromYMD(self, op, *, year, month, day):
year = self.f.lpad(year, 4, "0")
month = self.f.lpad(month, 2, "0")
day = self.f.lpad(day, 2, "0")
return self.f.to_date(self.f.concat(year, month, day), "FXYYYYMMDD")

def visit_DateDelta(self, op, *, left, right, part):
if not isinstance(part, sge.Literal):
raise com.UnsupportedOperationError(
"Only literal `part` values are supported for date delta"
)
if part.this != "day":
raise com.UnsupportedOperationError(
f"Only 'day' part is supported for date delta in the {self.dialect} backend"
)
return left - right

def visit_RStrip(self, op, *, arg):
return self.f.anon.rtrim(arg, string.whitespace)

def visit_LStrip(self, op, *, arg):
return self.f.anon.ltrim(arg, string.whitespace)


compiler = OracleCompiler()
26 changes: 19 additions & 7 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand All @@ -41,6 +42,7 @@ class PostgresCompiler(SQLGlotCompiler):

dialect = Postgres
type_mapper = PostgresType
post_rewrites = (split_select_distinct_with_order_by,)

agg = AggGen(supports_filter=True, supports_order_by=True)

Expand Down Expand Up @@ -226,6 +228,21 @@ def visit_CountDistinctStar(self, op, *, where, arg):
)
return self.agg.count(sge.Distinct(expressions=[row]), where=where)

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
funcname = f"percentile_{suffix}"
expr = sge.WithinGroup(
this=self.f[funcname](quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)
if where is not None:
expr = sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

visit_MultiQuantile = visit_Quantile
visit_ApproxQuantile = visit_Quantile
visit_ApproxMultiQuantile = visit_Quantile

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
raise com.UnsupportedOperationError(
Expand Down Expand Up @@ -632,7 +649,7 @@ def visit_ArraySlice(self, op, *, arg, start, stop):
slice_expr = sge.Slice(this=start + 1, expression=stop)
return sge.paren(arg, copy=False)[slice_expr]

def visit_IntervalFromInteger(self, op, *, arg, unit):
def _make_interval(self, arg, unit):
plural = unit.plural
if plural == "minutes":
plural = "mins"
Expand Down Expand Up @@ -666,19 +683,14 @@ def visit_Cast(self, op, *, arg, to):
if (timezone := to.timezone) is not None:
arg = self.f.timezone(timezone, arg)
return arg
elif from_.is_integer() and to.is_interval():
unit = to.unit
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit), arg=arg, unit=unit
)
elif from_.is_string() and to.is_binary():
# Postgres and Python use the words "decode" and "encode" in
# opposite ways, sweet!
return self.f.decode(arg, "escape")
elif from_.is_binary() and to.is_string():
return self.f.encode(arg, "escape")

return self.cast(arg, op.to)
return super().visit_Cast(op, arg=arg, to=to)

visit_TryCast = visit_Cast

Expand Down
31 changes: 25 additions & 6 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.datatypes import PySparkType
from ibis.backends.sql.dialects import PySpark
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
p,
split_select_distinct_with_order_by,
)
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType
Expand Down Expand Up @@ -51,6 +56,7 @@ class PySparkCompiler(SQLGlotCompiler):
dialect = PySpark
type_mapper = PySparkType
rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites)
post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.RowID,
Expand Down Expand Up @@ -283,6 +289,22 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
collected = self.if_(self.f.size(collected).eq(0), NULL, collected)
return self.f.array_join(collected, sep)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.percentile(arg, quantile)

visit_MultiQuantile = visit_Quantile

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.approx_percentile(arg, quantile)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_Correlation(self, op, *, left, right, how, where):
if (left_type := op.left.dtype).is_boolean():
left = self.cast(left, dt.Int32(nullable=left_type.nullable))
Expand Down Expand Up @@ -318,11 +340,8 @@ def visit_Sample(
raise com.UnsupportedOperationError(
"PySpark backend does not support sampling with seed."
)
sample = sge.TableSample(
this=parent,
percent=sge.convert(fraction * 100.0),
)
return sg.select(STAR).from_(sample)
sample = sge.TableSample(percent=sge.convert(int(fraction * 100.0)))
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_WindowBoundary(self, op, *, value, preceding):
if isinstance(op.value, ops.Literal) and op.value.value == 0:
Expand Down
26 changes: 16 additions & 10 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import sqlglot.expressions as sge

import ibis.common.exceptions as com
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compilers import PostgresCompiler
from ibis.backends.sql.compilers.base import ALL_OPERATIONS
from ibis.backends.sql.compilers.base import ALL_OPERATIONS, NULL
from ibis.backends.sql.datatypes import RisingWaveType
from ibis.backends.sql.dialects import RisingWave

Expand All @@ -20,9 +19,10 @@ class RisingWaveCompiler(PostgresCompiler):

UNSUPPORTED_OPS = (
ops.Arbitrary,
ops.DateFromYMD,
ops.Mode,
ops.RandomUUID,
ops.MultiQuantile,
ops.ApproxMultiQuantile,
*(
op
for op in ALL_OPERATIONS
Expand Down Expand Up @@ -66,6 +66,17 @@ def visit_Correlation(self, op, *, left, right, how, where):
op, left=left, right=right, how=how, where=where
)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
return sge.WithinGroup(
this=self.f[f"percentile_{suffix}"](quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

visit_ApproxQuantile = visit_Quantile

def visit_TimestampTruncate(self, op, *, arg, unit):
unit_mapping = {
"Y": "year",
Expand All @@ -87,13 +98,8 @@ def visit_TimestampTruncate(self, op, *, arg, unit):

visit_TimeTruncate = visit_DateTruncate = visit_TimestampTruncate

def visit_IntervalFromInteger(self, op, *, arg, unit):
if op.arg.shape == ds.scalar:
return sge.Interval(this=arg, unit=self.v[unit.name])
elif op.arg.shape == ds.columnar:
return arg * sge.Interval(this=sge.convert(1), unit=self.v[unit.name])
else:
raise ValueError("Invalid shape for converting to interval")
def _make_interval(self, arg, unit):
return arg * sge.Interval(this=sge.convert(1), unit=self.v[unit.name])

def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_binary():
Expand Down
29 changes: 18 additions & 11 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ class SnowflakeCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.RowID,
ops.MultiQuantile,
ops.IntervalFromInteger,
ops.IntervalAdd,
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.TimestampDiff,
)

Expand Down Expand Up @@ -266,7 +268,7 @@ def visit_Cast(self, op, *, arg, to):
return self.if_(self.f.is_object(arg), arg, NULL)
elif to.is_array():
return self.if_(self.f.is_array(arg), arg, NULL)
return self.cast(arg, to)
return super().visit_Cast(op, arg=arg, to=to)

def visit_ToJSONMap(self, op, *, arg):
return self.if_(self.f.is_object(arg), arg, NULL)
Expand Down Expand Up @@ -365,14 +367,14 @@ def visit_DateDelta(self, op, *, part, left, right):
def visit_TimestampDelta(self, op, *, part, left, right):
return self.f.timestampdiff(part, right, left, dialect=self.dialect)

def visit_TimestampDateAdd(self, op, *, left, right):
if not isinstance(op.right, ops.Literal):
raise com.OperationNotDefinedError(
f"right side of {type(op).__name__} operation must be an interval literal"
)
return sg.exp.Add(this=left, expression=right)
def visit_TimestampAdd(self, op, *, left, right):
return self.f.timestampadd(right.unit, right.this, left, dialect=self.dialect)

visit_DateAdd = visit_TimestampAdd = visit_TimestampDateAdd
def visit_TimestampSub(self, op, *, left, right):
return self.f.timestampadd(right.unit, -right.this, left, dialect=self.dialect)

visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub

def visit_IntegerRange(self, op, *, start, stop, step):
return self.if_(
Expand Down Expand Up @@ -606,6 +608,12 @@ def visit_Quantile(self, op, *, arg, quantile, where):
quantile = self.f.percentile_cont(quantile)
return sge.WithinGroup(this=quantile, expression=order_by)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)

return self.f.approx_percentile(arg, quantile)

def visit_CountStar(self, op, *, arg, where):
if where is None:
return super().visit_CountStar(op, arg=arg, where=where)
Expand Down Expand Up @@ -753,12 +761,11 @@ def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
sample = sge.TableSample(
this=parent,
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))
Expand Down
83 changes: 69 additions & 14 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
import sqlite3

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -19,6 +20,8 @@ class SQLiteCompiler(SQLGlotCompiler):

dialect = SQLite
type_mapper = SQLiteType
supports_time_shift_modifiers = sqlite3.sqlite_version_info >= (3, 46, 0)
supports_subsec = sqlite3.sqlite_version_info >= (3, 42, 0)

# We could set `supports_order_by=True` for SQLite >= 3.44.0 (2023-11-01).
agg = AggGen(supports_filter=True)
Expand All @@ -35,8 +38,6 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.IsInf,
ops.Covariance,
ops.Correlation,
ops.Quantile,
ops.MultiQuantile,
ops.Median,
ops.ApproxMedian,
ops.Array,
Expand All @@ -53,15 +54,11 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.IntervalFromInteger,
ops.TimestampBucket,
ops.TimestampAdd,
ops.TimestampSub,
ops.TimestampDiff,
ops.StringToDate,
ops.StringToTimestamp,
ops.TimeDelta,
ops.DateDelta,
ops.TimestampDelta,
ops.TryCast,
)
Expand Down Expand Up @@ -333,18 +330,65 @@ def visit_TimestampTruncate(self, op, *, arg, unit):
return self._temporal_truncate(self.f.anon.datetime, arg, unit)

def visit_DateArithmetic(self, op, *, left, right):
unit = op.right.dtype.unit
sign = "+" if isinstance(op, ops.DateAdd) else "-"
if unit not in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY):
right = right.this

if (unit := op.right.dtype.unit) in (
IntervalUnit.QUARTER,
IntervalUnit.MICROSECOND,
IntervalUnit.NANOSECOND,
):
raise com.UnsupportedOperationError(
"SQLite does not allow binary op {sign!r} with INTERVAL offset {unit}"
f"SQLite does not support `{unit}` units in temporal arithmetic"
)
if isinstance(op.right, ops.Literal):
return self.f.date(left, f"{sign}{op.right.value} {unit.plural}")
elif unit == IntervalUnit.WEEK:
unit = IntervalUnit.DAY
right *= 7
elif unit == IntervalUnit.MILLISECOND:
# sqlite doesn't allow milliseconds, so divide milliseconds by 1e3 to
# get seconds, and change the unit to seconds
unit = IntervalUnit.SECOND
right /= 1e3

# compute whether we're adding or subtracting an interval
sign = "+" if isinstance(op, (ops.DateAdd, ops.TimestampAdd)) else "-"

modifiers = []

# floor the result if the unit is a year, month, or day to match other
# backend behavior
if unit in (IntervalUnit.YEAR, IntervalUnit.MONTH, IntervalUnit.DAY):
if not self.supports_time_shift_modifiers:
raise com.UnsupportedOperationError(
"SQLite does not support time shift modifiers until version 3.46; "
f"found version {sqlite3.sqlite_version}"
)
modifiers.append("floor")

if isinstance(op, (ops.TimestampAdd, ops.TimestampSub)):
# if the left operand is a timestamp, return as much precision as
# possible
if not self.supports_subsec:
raise com.UnsupportedOperationError(
"SQLite does not support subsecond resolution until version 3.42; "
f"found version {sqlite3.sqlite_version}"
)
func = self.f.datetime
modifiers.append("subsec")
else:
return self.f.date(left, self.f.concat(sign, right, f" {unit.plural}"))
func = self.f.date

return func(
left,
self.f.concat(
sign, self.cast(right, dt.string), " ", unit.singular.lower()
),
*modifiers,
dialect=self.dialect,
)

visit_DateAdd = visit_DateSub = visit_DateArithmetic
visit_TimestampAdd = visit_TimestampSub = visit_DateAdd = visit_DateSub = (
visit_DateArithmetic
)

def visit_DateDiff(self, op, *, left, right):
return self.f.julianday(left) - self.f.julianday(right)
Expand Down Expand Up @@ -486,5 +530,16 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
raise com.UnsupportedBackendType(f"Unsupported type: {dtype!r}")
return super().visit_NonNullLiteral(op, value=value, dtype=dtype)

def visit_DateDelta(self, op, *, left, right, part):
if not isinstance(part, sge.Literal):
raise com.UnsupportedOperationError(
"Only literal `part` values are supported for date delta"
)
if part.this != "day":
raise com.UnsupportedOperationError(
f"Only 'day' part is supported for date delta in the {self.dialect} backend"
)
return self.f._ibis_date_delta(left, right)


compiler = SQLiteCompiler()
53 changes: 31 additions & 22 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
split_select_distinct_with_order_by,
)
from ibis.util import gen_name

Expand All @@ -39,15 +40,14 @@ class TrinoCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)
quoted = True

NAN = sg.func("nan")
POS_INF = sg.func("infinity")
NEG_INF = -POS_INF

UNSUPPORTED_OPS = (
ops.Quantile,
ops.MultiQuantile,
ops.Median,
ops.RowID,
ops.TimestampBucket,
Expand Down Expand Up @@ -110,17 +110,16 @@ def _minimize_spec(start, end, spec):
def visit_Sample(
self, op, *, parent, fraction: float, method: str, seed: int | None, **_
):
if op.seed is not None:
if seed is not None:
raise com.UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)
sample = sge.TableSample(
this=parent,
method="bernoulli" if method == "row" else "system",
percent=sge.convert(fraction * 100.0),
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)
return self._make_sample_backwards_compatible(sample=sample, parent=parent)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
Expand All @@ -135,6 +134,13 @@ def visit_Correlation(self, op, *, left, right, how, where):

return self.agg.corr(left, right, where=where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
return self.agg.approx_quantile(arg, quantile, where=where)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_BitXor(self, op, *, arg, where):
a, b = map(sg.to_identifier, "ab")
input_fn = combine_fn = sge.Lambda(
Expand Down Expand Up @@ -327,9 +333,7 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
elif dtype.is_time():
return self.cast(value.isoformat(), dtype)
elif dtype.is_interval():
return sge.Interval(
this=sge.convert(str(value)), unit=self.v[dtype.resolution.upper()]
)
return self._make_interval(sge.convert(str(value)), dtype.unit)
elif dtype.is_binary():
return self.f.from_hex(value.hex())
else:
Expand Down Expand Up @@ -442,15 +446,26 @@ def visit_TemporalDelta(self, op, *, part, left, right):

visit_TimeDelta = visit_DateDelta = visit_TimestampDelta = visit_TemporalDelta

def visit_IntervalFromInteger(self, op, *, arg, unit):
unit = op.unit.short
if unit in ("Y", "Q", "M", "W"):
def _make_interval(self, arg, unit):
short = unit.short
if short in ("Q", "W"):
raise com.UnsupportedOperationError(f"Interval unit {unit!r} not supported")
return self.f.parse_duration(
self.f.concat(
self.cast(arg, dt.String(nullable=op.arg.dtype.nullable)), unit.lower()

if isinstance(arg, sge.Literal):
# force strings in interval literals because trino requires it
arg.args["is_string"] = True
return super()._make_interval(arg, unit)

elif short in ("Y", "M"):
return arg * super()._make_interval(sge.convert("1"), unit)
elif short in ("D", "h", "m", "s", "ms", "us"):
return self.f.parse_duration(
self.f.concat(self.cast(arg, dt.string), short.lower())
)
else:
raise com.UnsupportedOperationError(
f"Interval unit {unit.name!r} not supported"
)
)

def visit_Range(self, op, *, start, stop, step):
def zero_value(dtype):
Expand Down Expand Up @@ -492,13 +507,7 @@ def visit_ArrayIndex(self, op, *, arg, index):

def visit_Cast(self, op, *, arg, to):
from_ = op.arg.dtype
if from_.is_integer() and to.is_interval():
return self.visit_IntervalFromInteger(
ops.IntervalFromInteger(op.arg, unit=to.unit),
arg=arg,
unit=to.unit,
)
elif from_.is_integer() and to.is_timestamp():
if from_.is_integer() and to.is_timestamp():
return self.f.from_unixtime(arg, to.timezone or "UTC")
return super().visit_Cast(op, arg=arg, to=to)

Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
typecode.BOOLEAN: dt.Boolean,
typecode.CHAR: dt.String,
typecode.DATE: dt.Date,
typecode.DATETIME: dt.Timestamp,
typecode.DATE32: dt.Date,
typecode.DOUBLE: dt.Float64,
typecode.ENUM: dt.String,
Expand Down Expand Up @@ -169,8 +170,10 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType

if method := getattr(cls, f"_from_sqlglot_{typecode.name}", None):
dtype = method(*typ.expressions)
elif (known_typ := _from_sqlglot_types.get(typecode)) is not None:
dtype = known_typ(nullable=cls.default_nullable)
else:
dtype = _from_sqlglot_types[typecode](nullable=cls.default_nullable)
dtype = dt.unknown

if nullable is not None:
return dtype.copy(nullable=nullable)
Expand Down Expand Up @@ -1055,7 +1058,12 @@ class ClickHouseType(SqlglotType):
@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
typ = super().from_ibis(dtype)

if typ.this == typecode.NULLABLE:
return typ

# nested types cannot be nullable in clickhouse
typ.args["nullable"] = False
if dtype.nullable and not (
dtype.is_map() or dtype.is_array() or dtype.is_struct()
):
Expand Down
37 changes: 21 additions & 16 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sqlglot import transforms
from sqlglot.dialects import (
TSQL,
ClickHouse,
Hive,
MySQL,
Oracle,
Expand All @@ -19,15 +18,27 @@
SQLite,
Trino,
)
from sqlglot.dialects import ClickHouse as _ClickHouse
from sqlglot.dialects.dialect import rename_func
from sqlglot.helper import find_new_name, seq_get

ClickHouse.Generator.TRANSFORMS |= {
sge.ArraySize: rename_func("length"),
sge.ArraySort: rename_func("arraySort"),
sge.LogicalAnd: rename_func("min"),
sge.LogicalOr: rename_func("max"),
}

class ClickHouse(_ClickHouse):
class Generator(_ClickHouse.Generator):
_ClickHouse.Generator.TRANSFORMS |= {
sge.ArraySize: rename_func("length"),
sge.ArraySort: rename_func("arraySort"),
sge.LogicalAnd: rename_func("min"),
sge.LogicalOr: rename_func("max"),
}

def except_op(self, expression: sge.Except) -> str:
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"

def intersect_op(self, expression: sge.Intersect) -> str:
return (
f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
)


class DataFusion(Postgres):
Expand Down Expand Up @@ -76,6 +87,7 @@ class Generator(Postgres.Generator):
TRANSFORMS = Postgres.Generator.TRANSFORMS.copy() | {
sge.Interval: _interval,
sge.GroupConcat: _group_concat,
sge.ApproxDistinct: rename_func("approximate_count_distinct"),
}
TYPE_MAPPING = Postgres.Generator.TYPE_MAPPING.copy() | {
sge.DataType.Type.TIMESTAMPTZ: "TIMESTAMP WITH LOCAL TIME ZONE",
Expand Down Expand Up @@ -211,6 +223,8 @@ class Generator(Hive.Generator):
sge.VariancePop: rename_func("var_pop"),
sge.ArrayConcat: rename_func("array_concat"),
sge.ArraySize: rename_func("cardinality"),
sge.ArrayAgg: rename_func("array_agg"),
sge.ArraySort: rename_func("array_sort"),
sge.Length: rename_func("char_length"),
sge.TryCast: lambda self,
e: f"TRY_CAST({e.this.sql(self.dialect)} AS {e.to.sql(self.dialect)})",
Expand Down Expand Up @@ -443,18 +457,9 @@ class Generator(Postgres.Generator):
SQLite.Generator.TYPE_MAPPING |= {sge.DataType.Type.BOOLEAN: "BOOLEAN"}


# TODO(cpcloud): remove this hack once
# https://github.com/tobymao/sqlglot/issues/2735 is resolved
def make_cross_joins_explicit(node):
if not (node.kind or node.side):
node.args["kind"] = "CROSS"
return node


Trino.Generator.TRANSFORMS |= {
sge.BitwiseLeftShift: rename_func("bitwise_left_shift"),
sge.BitwiseRightShift: rename_func("bitwise_right_shift"),
sge.FirstValue: rename_func("first_value"),
sge.Join: transforms.preprocess([make_cross_joins_explicit]),
sge.LastValue: rename_func("last_value"),
}
99 changes: 98 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Select(ops.Relation):
predicates: VarTuple[ops.Value[dt.Boolean]] = ()
qualified: VarTuple[ops.Value[dt.Boolean]] = ()
sort_keys: VarTuple[ops.SortKey] = ()
distinct: bool = False

def is_star_selection(self):
return tuple(self.values.items()) == tuple(self.parent.fields.items())
Expand Down Expand Up @@ -128,6 +129,12 @@ def sort_to_select(_, **kwargs):
return Select(_.parent, selections=_.values, sort_keys=_.keys)


@replace(p.Distinct)
def distinct_to_select(_, **kwargs):
"""Convert a Distinct node to a Select node."""
return Select(_.parent, selections=_.values, distinct=True)


@replace(p.DropColumns)
def drop_columns_to_select(_, **kwargs):
"""Convert a DropColumns node to a Select node."""
Expand Down Expand Up @@ -244,6 +251,48 @@ def merge_select_select(_, **kwargs):
if _.parent.find_below(blocking, filter=ops.Value):
return _

if _.parent.distinct:
# The inner query is distinct.
#
# If the outer query is distinct, it's only safe to merge if it's a simple subselection:
# - Fusing in the presence of non-deterministic calls in the select would lead to
# incorrect results
# - Fusing in the presence of expensive calls in the select would lead to potential
# performance pitfalls
if _.distinct and not all(
isinstance(v, ops.Field) for v in _.selections.values()
):
return _

# If the outer query isn't distinct, it's only safe to merge if the outer is a SELECT *:
# - If new columns are added, they might be non-distinct, changing the distinctness
# - If previous columns are removed, that would also change the distinctness
if not _.distinct and not _.is_star_selection():
return _

distinct = True
elif _.distinct:
# The outer query is distinct and the inner isn't. It's only safe to merge if either
# - The inner query isn't ordered
# - The outer query is a SELECT *
#
# Otherwise we run the risk that the outer query drops columns needed for the ordering of
# the inner query - many backends don't allow select distinc queries to order by columns
# that aren't present in their selection, like
#
# SELECT DISTINCT a, b FROM t ORDER BY c --- some backends will explode at this
#
# An alternate solution would be to drop the inner ORDER BY clause, since the backend will
# ignore it anyway since it's a subquery. That feels potentially risky though, better
# to generate the SQL as written.
if _.parent.sort_keys and not _.is_star_selection():
return _

distinct = True
else:
# Neither query is distinct, safe to merge
distinct = False

subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}

Expand All @@ -266,6 +315,7 @@ def merge_select_select(_, **kwargs):
predicates=unique_predicates,
qualified=unique_qualified,
sort_keys=unique_sort_keys,
distinct=distinct,
)
return result if complexity(result) <= complexity(_) else _

Expand All @@ -289,6 +339,7 @@ def sqlize(
node: ops.Node,
params: Mapping[ops.ScalarParameter, Any],
rewrites: Sequence[Pattern] = (),
post_rewrites: Sequence[Pattern] = (),
fuse_selects: bool = True,
) -> tuple[ops.Node, list[ops.Node]]:
"""Lower the ibis expression graph to a SQL-like relational algebra.
Expand All @@ -300,7 +351,9 @@ def sqlize(
params
A mapping of scalar parameters to their values.
rewrites
Supplementary rewrites to apply to the expression graph.
Supplementary rewrites to apply before SQL-specific transforms.
post_rewrites
Supplementary rewrites to apply after SQL-specific transforms.
fuse_selects
Whether to merge subsequent Select nodes into one where possible.
Expand All @@ -322,6 +375,7 @@ def sqlize(
| project_to_select
| filter_to_select
| sort_to_select
| distinct_to_select
| fill_null_to_select
| drop_null_to_select
| drop_columns_to_select
Expand All @@ -335,6 +389,9 @@ def sqlize(
else:
simplified = sqlized

if post_rewrites:
simplified = simplified.replace(reduce(operator.or_, post_rewrites))

# extract common table expressions while wrapping them in a CTE node
ctes = extract_ctes(simplified)

Expand All @@ -351,6 +408,46 @@ def wrap(node, _, **kwargs):
# supplemental rewrites selectively used on a per-backend basis


@replace(Select)
def split_select_distinct_with_order_by(_):
"""Split a `SELECT DISTINCT ... ORDER BY` query when needed.
Some databases (postgres, pyspark, ...) have issues with two types of
ordered select distinct statements:
```
--- ORDER BY with an expression instead of a name in the select list
SELECT DISTINCT a, b FROM t ORDER BY a + 1
--- ORDER BY using a qualified column name, rather than the alias in the select list
SELECT DISTINCT a, b as x FROM t ORDER BY b --- or t.b
```
We solve both these cases by splitting everything except the `ORDER BY`
into a subquery.
```
SELECT DISTINCT a, b FROM t WHERE a > 10 ORDER BY a + 1
--- is rewritten as ->
SELECT * FROM (SELECT DISTINCT a, b FROM t WHERE a > 10) ORDER BY a + 1
```
"""
# risingwave and pyspark also don't allow qualified names as sort keys, like
# SELECT DISTINCT t.a FROM t ORDER BY t.a
# To avoid having specific rewrite rules for these backends to use only
# local names, we always split SELECT DISTINCT from ORDER BY here. Otherwise we
# could also avoid splitting if all sort keys appear in the select list.
if _.distinct and _.sort_keys:
inner = _.copy(sort_keys=())
subs = {v: ops.Field(inner, k) for k, v in inner.values.items()}
sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
selections = {
k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()
}
return Select(inner, selections=selections, sort_keys=sort_keys)
return _


@replace(p.WindowFunction(func=p.NTile(y), order_by=()))
def add_order_by_to_empty_ranking_window_functions(_, **kwargs):
"""Add an ORDER BY clause to rank window functions that don't have one."""
Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/sql/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import sqlglot as sg

import ibis
from ibis import _
from ibis.backends.sql.dialects import Trino


def test_window_with_row_number_compiles():
Expand All @@ -16,3 +19,10 @@ def test_window_with_row_number_compiles():
.filter(~_.is_test)
)
assert ibis.to_sql(expr)


def test_transpile_join():
(result,) = sg.transpile(
"SELECT * FROM t1 JOIN t2 ON x = y", read="duckdb", write=Trino
)
assert "CROSS JOIN" not in result
24 changes: 23 additions & 1 deletion ibis/backends/sql/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.tests.strategies as its
from ibis.backends.sql.datatypes import DuckDBType, PostgresType, SqlglotType
from ibis.backends.sql.datatypes import (
ClickHouseType,
DuckDBType,
PostgresType,
SqlglotType,
)


def assert_dtype_roundtrip(ibis_type, sqlglot_expected=None):
Expand Down Expand Up @@ -63,3 +68,20 @@ def test_interval_without_unit():
SqlglotType.from_string("INTERVAL")
assert PostgresType.from_string("INTERVAL") == dt.Interval("s")
assert DuckDBType.from_string("INTERVAL") == dt.Interval("us")


@pytest.mark.parametrize(
"typ",
[
sge.DataType.Type.UINT256,
sge.DataType.Type.UINT128,
sge.DataType.Type.BIGSERIAL,
sge.DataType.Type.HLLSKETCH,
],
)
@pytest.mark.parametrize(
"typengine",
[ClickHouseType, PostgresType, DuckDBType],
)
def test_unsupported_dtypes_are_unknown(typengine, typ):
assert typengine.to_ibis(sge.DataType(this=typ)) == dt.unknown
2 changes: 1 addition & 1 deletion ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def create_table(
| pl.LazyFrame
| None = None,
*,
schema: ibis.Schema | None = None,
schema: sch.SchemaLike | None = None,
database: str | None = None,
temp: bool = False,
overwrite: bool = False,
Expand Down
25 changes: 0 additions & 25 deletions ibis/backends/sqlite/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import csv
import io
import sqlite3
from typing import Any

import pytest
Expand Down Expand Up @@ -55,30 +54,6 @@ def functional_alltypes(self) -> ir.Table:
return t.mutate(timestamp_col=t.timestamp_col.cast("timestamp"))


@pytest.fixture
def dbpath(tmp_path):
path = tmp_path / "test.db"
con = sqlite3.connect(path)
con.execute("CREATE TABLE t AS SELECT 1 a UNION SELECT 2 UNION SELECT 3")
con.execute("CREATE TABLE s AS SELECT 1 b UNION SELECT 2")
return path


@pytest.fixture(scope="session")
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope="session")
def translate(dialect):
return lambda expr: ibis.to_sql(expr, dialect="sqlite")


@pytest.fixture(scope="session")
def alltypes(con):
return con.table("functional_alltypes")


@pytest.fixture(scope="session")
def df(alltypes):
return alltypes.execute()
7 changes: 7 additions & 0 deletions ibis/backends/sqlite/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import operator
from collections import defaultdict
from datetime import date
from typing import TYPE_CHECKING, Any, NamedTuple
from urllib.parse import parse_qs, urlsplit
from uuid import uuid4
Expand Down Expand Up @@ -357,6 +358,12 @@ def _ibis_extract_user_info(url):
return f"{username}:{password}"


@udf
def _ibis_date_delta(left, right):
delta = date.fromisoformat(left) - date.fromisoformat(right)
return delta.days


class _ibis_var:
def __init__(self, offset):
self.mean = 0.0
Expand Down
5 changes: 1 addition & 4 deletions ibis/backends/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from filelock import FileLock

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Iterable, Iterator

import ibis.expr.types as ir

Expand Down Expand Up @@ -307,9 +307,6 @@ def win(self) -> ir.Table | None:
def api(self):
return self.connection

def make_context(self, params: Mapping[ir.Value, Any] | None = None):
return self.api.compiler.make_context(params=params)

def _tpc_table(self, name: str, benchmark: Literal["h", "ds"]):
if not getattr(self, f"supports_tpc{benchmark}"):
pytest.skip(
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@
from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError
from polars.exceptions import PanicException as PolarsPanicException
from polars.exceptions import SchemaError as PolarsSchemaError
from polars.exceptions import SQLInterfaceError as PolarsSQLInterfaceError
except ImportError:
PolarsComputeError = PolarsPanicException = PolarsInvalidOperationError = (
PolarsSchemaError
) = PolarsColumnNotFoundError = None
) = PolarsColumnNotFoundError = PolarsSQLInterfaceError = None

try:
from pyarrow import ArrowInvalid, ArrowNotImplementedError
Expand Down Expand Up @@ -133,8 +134,9 @@

try:
from oracledb.exceptions import DatabaseError as OracleDatabaseError
from oracledb.exceptions import InterfaceError as OracleInterfaceError
except ImportError:
OracleDatabaseError = None
OracleDatabaseError = OracleInterfaceError = None

try:
from pyodbc import DataError as PyODBCDataError
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
ntile(2) OVER (ORDER BY RAND() ASC) - 1 AS `new_col`
FROM `test` AS `t0`
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
ntile(2) OVER (ORDER BY randCanonical() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RANDOM() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SELECT
NTILE(2) OVER (ORDER BY RANDOM() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM (
SELECT
"test"
FROM (VALUES
(1),
(2),
(3),
(4),
(5)) AS "test"("test")
) AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RANDOM() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RANDOM() ASC) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SELECT
NTILE(2) OVER (ORDER BY RAND() ASC NULLS LAST) - 1 AS `new_col`
FROM (
SELECT
`test`
FROM (VALUES
(1),
(2),
(3),
(4),
(5)) AS `test`(`test`)
) AS `t0`
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RAND(UTC_TO_UNIX_MICROS(UTC_TIMESTAMP())) ASC) - 1 AS `new_col`
FROM `test` AS `t0`
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
TOP 10
NTILE(2) OVER (ORDER BY RAND() ASC) - 1 AS [new_col]
FROM [test] AS [t0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RAND() ASC) - 1 AS `new_col`
FROM `test` AS `t0`
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY DBMS_RANDOM.VALUE() ASC) - 1 AS "new_col"
FROM "test" "t0"
FETCH FIRST 10 ROWS ONLY
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RANDOM() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RAND() ASC NULLS LAST) - 1 AS `new_col`
FROM `test` AS `t0`
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RANDOM() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY UNIFORM(TO_DOUBLE(0.0), TO_DOUBLE(1.0), RANDOM()) ASC) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT
NTILE(2) OVER (ORDER BY 0.5 + (
CAST(RANDOM() AS REAL) / -1.8446744073709552e+19
) ASC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SELECT
NTILE(2) OVER (ORDER BY RAND() ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "new_col"
FROM "test" AS "t0"
LIMIT 10
5 changes: 0 additions & 5 deletions ibis/backends/tests/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,6 @@ def t2(con):
return con.table("t2")


@pytest.fixture(scope="module")
def where_uncorrelated_subquery(foo, bar):
return foo[foo.job.isin(bar.job)]


@pytest.fixture(scope="module")
def not_exists(foo_t, bar_t):
return foo_t[-(foo_t.key1 == bar_t.key1).any()]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

int_col_table = ibis.table(name="int_col_table", schema={"int_col": "int32"})

result = (int_col_table.int_col + 4).name("foo")
result = ((int_col_table.int_col + 4)).name("foo")
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
"month": "int32",
},
)
f = functional_alltypes.filter(functional_alltypes.bigint_col > 0)
f = functional_alltypes.filter((functional_alltypes.bigint_col > 0))

result = f.aggregate([f.int_col.nunique().name("nunique")], by=[f.string_col])
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
difference = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).difference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


t = ibis.table(name="t", schema={"a": "int64", "b": "string"})
f = t.filter(t.b == "m")
f = t.filter((t.b == "m"))
agg = f.aggregate([f.a.sum().name("sum"), f.a.max()], by=[f.b])
f1 = agg.filter(agg["Max(a)"] == 2)
f1 = agg.filter((agg["Max(a)"] == 2))

result = f1.select(f1.b, f1.sum)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
intersection = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).intersect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
)
param = ibis.param("timestamp")
f = alltypes.filter(alltypes.timestamp_col < param.name("my_param"))
f = alltypes.filter((alltypes.timestamp_col < param.name("my_param")))
agg = f.aggregate([f.float_col.sum().name("foo")], by=[f.string_col])

result = agg.foo.count()
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
lit = ibis.timestamp("2018-01-01 00:00:00")
s = ibis.table(name="s", schema={"b": "string"})
t = ibis.table(name="t", schema={"a": "int64", "b": "string", "c": "timestamp"})
f = t.filter(t.c == lit)
f = t.filter((t.c == lit))
dropcolumns = f.select(f.a, f.b, f.c.name("C")).drop("C")
joinchain = (
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date"))
.inner_join(
s,
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).b == s.b,
(
dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).b
== s.b
),
)
.select(dropcolumns.select(dropcolumns.a, dropcolumns.b, lit.name("the_date")).a)
)

result = joinchain.filter(joinchain.a < 1.0)
result = joinchain.filter((joinchain.a < 1.0))
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))

result = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
},
)
lit = ibis.literal(0)
f = functional_alltypes.filter(functional_alltypes.int_col > lit)
f1 = functional_alltypes.filter(functional_alltypes.int_col <= lit)
f = functional_alltypes.filter((functional_alltypes.int_col > lit))
f1 = functional_alltypes.filter((functional_alltypes.int_col <= lit))
union = f.select(
f.string_col.name("key"), f.float_col.cast("float64").name("value")
).union(f1.select(f1.string_col.name("key"), f1.double_col.name("value")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

result = (
tpch_region.inner_join(
tpch_nation, tpch_region.r_regionkey == tpch_nation.n_regionkey
tpch_nation, (tpch_region.r_regionkey == tpch_nation.n_regionkey)
)
.select(
tpch_nation.n_nationkey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
name="star2", schema={"foo_id": "string", "value1": "float64", "value3": "float64"}
)

result = star1.anti_join(star2, star1.foo_id == star2.foo_id).select(
result = star1.anti_join(star2, (star1.foo_id == star2.foo_id)).select(
star1.c, star1.f, star1.foo_id, star1.bar_id
)
Loading