64 changes: 60 additions & 4 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
plan,
)
from ibis.common.dispatch import Dispatched
from ibis.common.exceptions import OperationNotDefinedError, UnboundExpressionError
from ibis.common.exceptions import (
OperationNotDefinedError,
UnboundExpressionError,
UnsupportedOperationError,
)
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import any_of, gen_name

Expand Down Expand Up @@ -185,7 +189,7 @@ def visit(cls, op: ops.TimestampTruncate | ops.DateTruncate, arg, unit):

unit = units.get(unit.short, unit.short)

if unit in "YMWD":
if unit in "YQMWD":
return arg.dt.to_period(unit).dt.to_timestamp()
try:
return arg.dt.floor(unit)
Expand Down Expand Up @@ -253,7 +257,12 @@ def visit(
############################# Reductions ##################################

@classmethod
def visit(cls, op: ops.Reduction, arg, where):
def visit(cls, op: ops.Reduction, arg, where, order_by=()):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
func = cls.kernels.reductions[type(op)]
return cls.agg(func, arg, where)

Expand Down Expand Up @@ -310,6 +319,47 @@ def visit(cls, op: ops.StandardDev, arg, where, how):
ddof = {"pop": 0, "sample": 1}[how]
return cls.agg(lambda x: x.std(ddof=ddof), arg, where)

@classmethod
def visit(cls, op: ops.ArrayCollect, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
return cls.agg(
(lambda x: x.tolist() if include_null else x.dropna().tolist()), arg, where
)

@classmethod
def visit(cls, op: ops.First, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def first(arg):
if not include_null:
arg = arg.dropna()
return arg.iat[0] if len(arg) else None

return cls.agg(first, arg, where)

@classmethod
def visit(cls, op: ops.Last, arg, where, order_by, include_null):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

def last(arg):
if not include_null:
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None

return cls.agg(last, arg, where)

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if where is None:
Expand Down Expand Up @@ -344,7 +394,13 @@ def agg(df):
return agg

@classmethod
def visit(cls, op: ops.GroupConcat, arg, sep, where):
def visit(cls, op: ops.GroupConcat, arg, sep, where, order_by):
if order_by:
raise UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

if where is None:

def agg(df):
Expand Down
23 changes: 10 additions & 13 deletions ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def array_position_rowwise(row):
return -1


def array_remove_rowwise(row):
if row["arg"] is None:
return None
return [x for x in row["arg"] if x != row["other"]]


def array_slice_rowwise(row):
arg, start, stop = row["arg"], row["start"], row["stop"]
if isnull(start) and isnull(stop):
Expand Down Expand Up @@ -254,18 +260,11 @@ def round_serieswise(arg, digits):
return np.round(arg, digits).astype("float64")


def first(arg):
# first excludes null values unless they're all null
def arbitrary(arg):
arg = arg.dropna()
return arg.iat[0] if len(arg) else None


def last(arg):
# last excludes null values unless they're all null
arg = arg.dropna()
return arg.iat[-1] if len(arg) else None


reductions = {
ops.Min: lambda x: x.min(),
ops.Max: lambda x: x.max(),
Expand All @@ -280,12 +279,9 @@ def last(arg):
ops.BitAnd: lambda x: np.bitwise_and.reduce(x.values),
ops.BitOr: lambda x: np.bitwise_or.reduce(x.values),
ops.BitXor: lambda x: np.bitwise_xor.reduce(x.values),
ops.Last: last,
ops.First: first,
ops.Arbitrary: first,
ops.Arbitrary: arbitrary,
ops.CountDistinct: lambda x: x.nunique(),
ops.ApproxCountDistinct: lambda x: x.nunique(),
ops.ArrayCollect: lambda x: x.dropna().tolist(),
}


Expand Down Expand Up @@ -380,11 +376,12 @@ def wrapper(*args, **kwargs):
ops.Repeat: lambda df: df["arg"] * df["times"],
}


rowwise = {
ops.ArrayContains: lambda row: row["other"] in row["arg"],
ops.ArrayIndex: array_index_rowwise,
ops.ArrayPosition: array_position_rowwise,
ops.ArrayRemove: lambda row: [x for x in row["arg"] if x != row["other"]],
ops.ArrayRemove: array_remove_rowwise,
ops.ArrayRepeat: lambda row: np.tile(row["arg"], max(0, row["times"])),
ops.ArraySlice: array_slice_rowwise,
ops.ArrayUnion: lambda row: toolz.unique(row["left"] + row["right"]),
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,15 @@ def execute(
return expr.__pandas_result__(df.to_pandas())
else:
assert isinstance(expr, ir.Column), type(expr)
if expr.type().is_temporal():

dtype = expr.type()
if dtype.is_temporal():
return expr.__pandas_result__(df.to_pandas())
else:
from ibis.formats.pandas import PandasData

# note: skip frame-construction overhead
return df.to_series().to_pandas()
return PandasData.convert_column(df.to_series().to_pandas(), dtype)

def to_polars(
self,
Expand Down
232 changes: 141 additions & 91 deletions ibis/backends/polars/compiler.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ibis/backends/polars/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_multiple_argument_udf(alltypes):
df = alltypes[["smallint_col", "int_col"]].execute()
expected = df.smallint_col + df.int_col

tm.assert_series_equal(result, expected.rename("tmp"))
tm.assert_series_equal(result, expected.astype("int64").rename("tmp"))


@pytest.mark.parametrize(
Expand Down
125 changes: 30 additions & 95 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@

import contextlib
import inspect
import textwrap
from functools import partial
from itertools import takewhile
from operator import itemgetter
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote_plus

import numpy as np
import pandas as pd
import sqlglot as sg
import sqlglot.expressions as sge
from pandas.api.types import is_float_dtype

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
Expand All @@ -27,9 +23,7 @@
from ibis import util
from ibis.backends import CanCreateDatabase, CanCreateSchema, CanListCatalog
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import PostgresCompiler
from ibis.backends.sql.compilers.base import TRUE, C, ColGen, F
from ibis.common.exceptions import InvalidDecoratorError

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -41,15 +35,9 @@
import pyarrow as pa


def _verify_source_line(func_name: str, line: str):
if line.startswith("@"):
raise InvalidDecoratorError(func_name, line)
return line


class Backend(SQLBackend, CanListCatalog, CanCreateDatabase, CanCreateSchema):
name = "postgres"
compiler = PostgresCompiler()
compiler = sc.postgres.compiler
supports_python_udfs = True

def _from_url(self, url: ParseResult, **kwargs):
Expand Down Expand Up @@ -149,7 +137,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
convert_df = df.convert_dtypes()
for col in convert_df.columns:
if not is_float_dtype(convert_df[col]):
df[col] = df[col].replace(np.nan, None)
df[col] = df[col].replace(float("nan"), None)

data = df.itertuples(index=False)
sql = self._build_insert_template(
Expand Down Expand Up @@ -303,6 +291,20 @@ def _post_connect(self) -> None:
with self.begin() as cur:
cur.execute("SET TIMEZONE = UTC")

@property
def _session_temp_db(self) -> str | None:
# Postgres doesn't assign the temporary table database until the first
# temp table is created in a given session.
# Before that temp table is created, this will return `None`
# After a temp table is created, it will return `pg_temp_N` where N is
# some integer
res = self.raw_sql(
"select nspname from pg_namespace where oid = pg_my_temp_schema()"
).fetchone()
if res is not None:
return res[0]
return res

def list_tables(
self,
like: str | None = None,
Expand Down Expand Up @@ -458,7 +460,7 @@ def function(self, name: str, *, database: str | None = None) -> Callable:
on=n.oid.eq(p.pronamespace),
join_type="LEFT",
)
.where(sg.and_(*predicates))
.where(*predicates)
)

def split_name_type(arg: str) -> tuple[str, dt.DataType]:
Expand Down Expand Up @@ -495,69 +497,6 @@ def fake_func(*args, **kwargs): ...
op = ops.udf.scalar.builtin(fake_func, database=database)
return op

def _get_udf_source(self, udf_node: ops.ScalarUDF):
config = udf_node.__config__
func = udf_node.__func__
func_name = func.__name__

lines, _ = inspect.getsourcelines(func)
iter_lines = iter(lines)

function_premable_lines = list(
takewhile(lambda line: not line.lstrip().startswith("def "), iter_lines)
)

if len(function_premable_lines) > 1:
raise InvalidDecoratorError(
name=func_name, lines="".join(function_premable_lines)
)

source = textwrap.dedent(
"".join(map(partial(_verify_source_line, func_name), iter_lines))
).strip()

type_mapper = self.compiler.type_mapper
argnames = udf_node.argnames
return dict(
name=type(udf_node).__name__,
ident=self.compiler.__sql_name__(udf_node),
signature=", ".join(
f"{argname} {type_mapper.to_string(arg.dtype)}"
for argname, arg in zip(argnames, udf_node.args)
),
return_type=type_mapper.to_string(udf_node.dtype),
language=config.get("language", "plpython3u"),
source=source,
args=", ".join(argnames),
)

def _define_udf_translation_rules(self, expr: ir.Expr) -> None:
"""No-op, these are defined in the compiler."""

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
return """\
CREATE OR REPLACE FUNCTION {ident}({signature})
RETURNS {return_type}
LANGUAGE {language}
AS $$
{source}
return {name}({args})
$$""".format(**self._get_udf_source(udf_node))

def _register_udfs(self, expr: ir.Expr) -> None:
udf_sources = []
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
udf_sources.append(sql)
if udf_sources:
# define every udf in one execution to avoid the overhead of
# database round trips per udf
with self._safe_raw_sql(";\n".join(udf_sources)):
pass

def get_schema(
self,
name: str,
Expand All @@ -571,6 +510,16 @@ def get_schema(

format_type = self.compiler.f["pg_catalog.format_type"]

# If no database is specified, assume the current database
db = database or self.current_database

dbs = [sge.convert(db)]

# If a database isn't specified, then include temp tables in the
# returned values
if database is None and (temp_table_db := self._session_temp_db) is not None:
dbs.append(sge.convert(temp_table_db))

type_info = (
sg.select(
a.attname.as_("column_name"),
Expand All @@ -591,7 +540,7 @@ def get_schema(
.where(
a.attnum > 0,
sg.not_(a.attisdropped),
n.nspname.eq(sge.convert(database)) if database is not None else TRUE,
n.nspname.isin(*dbs),
c.relname.eq(sge.convert(name)),
)
.order_by(a.attnum)
Expand Down Expand Up @@ -720,7 +669,7 @@ def create_table(

self._run_pre_execute_hooks(table)

query = self._to_sqlglot(table)
query = self.compiler.to_sqlglot(table)
else:
query = None

Expand Down Expand Up @@ -823,17 +772,3 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
else:
con.commit()
return cursor

def _to_sqlglot(
self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any
):
table_expr = expr.as_table()
conversions = {
name: table_expr[name].as_ewkb()
for name, typ in table_expr.schema().items()
if typ.is_geospatial()
}

if conversions:
table_expr = table_expr.mutate(**conversions)
return super()._to_sqlglot(table_expr, limit=limit, params=params, **kwargs)
17 changes: 16 additions & 1 deletion ibis/backends/postgres/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_infoschema_dtypes(con):


def test_password_with_bracket():
password = f"{IBIS_POSTGRES_PASS}["
password = f"{IBIS_POSTGRES_PASS}[]"
quoted_pass = quote_plus(password)
url = f"postgres://{IBIS_POSTGRES_USER}:{quoted_pass}@{IBIS_POSTGRES_HOST}:{IBIS_POSTGRES_PORT}/{POSTGRES_TEST_DB}"
with pytest.raises(
Expand Down Expand Up @@ -417,3 +417,18 @@ def test_create_geospatial_table_with_srid(con):
for column, dtype in zip(column_names, column_types)
}
)


@pytest.fixture(scope="module")
def enum_table(con):
name = gen_name("enum_table")
con.raw_sql("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')")
con.raw_sql(f"CREATE TEMP TABLE {name} (mood mood)")
yield name
con.raw_sql(f"DROP TABLE {name}")
con.raw_sql("DROP TYPE mood")


def test_enum_table(con, enum_table):
t = con.table(enum_table)
assert t.mood.type() == dt.unknown
36 changes: 22 additions & 14 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pyspark.sql import SparkSession
from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType

import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.config
import ibis.expr.operations as ops
Expand All @@ -24,7 +25,6 @@
from ibis.backends.pyspark.converter import PySparkPandasData
from ibis.backends.pyspark.datatypes import PySparkSchema, PySparkType
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import PySparkCompiler
from ibis.expr.operations.udf import InputType
from ibis.legacy.udf.vectorized import _coerce_to_series
from ibis.util import deprecated
Expand All @@ -46,7 +46,7 @@
from ibis.expr.api import Watermark

PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4")

PYSPARK_LT_35 = vparse(pyspark.__version__) < vparse("3.5")
ConnectionMode = Literal["streaming", "batch"]


Expand Down Expand Up @@ -104,7 +104,7 @@ def _interval_to_string(interval):

class Backend(SQLBackend, CanListCatalog, CanCreateDatabase):
name = "pyspark"
compiler = PySparkCompiler()
compiler = sc.pyspark.compiler

class Options(ibis.config.Config):
"""PySpark options.
Expand Down Expand Up @@ -359,18 +359,26 @@ def wrapper(*args):
def _register_udfs(self, expr: ir.Expr) -> None:
node = expr.op()
for udf in node.find(ops.ScalarUDF):
if udf.__input_type__ not in (InputType.PANDAS, InputType.BUILTIN):
raise NotImplementedError(
"Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend"
)
# register pandas UDFs
udf_name = self.compiler.__sql_name__(udf)
udf_return = PySparkType.from_ibis(udf.dtype)
if udf.__input_type__ == InputType.PANDAS:
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)

elif udf.__input_type__ == InputType.PYTHON:
udf_func = udf.__func__
spark_udf = F.udf(udf_func, udf_return)
elif udf.__input_type__ == InputType.PYARROW:
# raise not implemented error if running on pyspark < 3.5
if PYSPARK_LT_35:
raise NotImplementedError(
"pyarrow UDFs are only supported in pyspark >= 3.5"
)
udf_func = udf.__func__
spark_udf = F.udf(udf_func, udf_return, useArrow=True)
else:
# Builtin functions don't need to be registered
continue
self._session.udf.register(udf_name, spark_udf)
for udf in node.find(ops.ElementWiseVectorizedUDF):
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.func, udf.return_type)
Expand Down Expand Up @@ -439,7 +447,7 @@ def create_database(
name
Database name
catalog
Catalog to create database in (defaults to ``current_catalog``)
Catalog to create database in (defaults to `current_catalog`)
path
Path where to store the database data; otherwise uses Spark default
force
Expand Down Expand Up @@ -473,7 +481,7 @@ def drop_database(
name
Database name
catalog
Catalog containing database to drop (defaults to ``current_catalog``)
Catalog containing database to drop (defaults to `current_catalog`)
force
If False, Spark throws exception if database is not empty or
database does not exist
Expand Down
27 changes: 26 additions & 1 deletion ibis/backends/pyspark/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import ibis
from ibis.backends.pyspark import PYSPARK_LT_35

pytest.importorskip("pyspark")

Expand All @@ -22,12 +23,36 @@ def df(con):
def repeat(x, n) -> str: ...


@ibis.udf.scalar.python
def py_repeat(x: str, n: int) -> str:
return x * n


@ibis.udf.scalar.pyarrow
def pyarrow_repeat(x: str, n: int) -> str:
return x * n


def test_builtin_udf(t, df):
result = t.mutate(repeated=repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


def test_python_udf(t, df):
result = t.mutate(repeated=py_repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+")
def test_pyarrow_udf(t, df):
result = t.mutate(repeated=pyarrow_repeat(t.str_col, 2)).execute()
expected = df.assign(repeated=df.str_col * 2)
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(not PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+")
def test_illegal_udf_type(t):
@ibis.udf.scalar.pyarrow
def my_add_one(x) -> str:
Expand All @@ -39,6 +64,6 @@ def my_add_one(x) -> str:

with pytest.raises(
NotImplementedError,
match="Only Builtin UDFs and Pandas UDFs are supported in the PySpark backend",
match="pyarrow UDFs are only supported in pyspark >= 3.5",
):
expr.execute()
12 changes: 9 additions & 3 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from psycopg2 import extras

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import util
from ibis.backends.postgres import Backend as PostgresBackend
from ibis.backends.sql.compilers import RisingWaveCompiler
from ibis.util import experimental

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,7 +45,7 @@ def format_properties(props):

class Backend(PostgresBackend):
name = "risingwave"
compiler = RisingWaveCompiler()
compiler = sc.risingwave.compiler
supports_python_udfs = False

def do_connect(
Expand Down Expand Up @@ -202,7 +202,7 @@ def create_table(

self._run_pre_execute_hooks(table)

query = self._to_sqlglot(table)
query = self.compiler.to_sqlglot(table)
else:
query = None

Expand Down Expand Up @@ -586,3 +586,9 @@ def drop_sink(
)
with self._safe_raw_sql(src):
pass

@property
def _session_temp_db(self) -> str | None:
# Return `None`, because RisingWave does not implement temp tables like
# Postgres
return None
192 changes: 78 additions & 114 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@
import contextlib
import functools
import glob
import inspect
import itertools
import json
import os
import platform
import shutil
import sys
import tempfile
import textwrap
import warnings
from operator import itemgetter
from pathlib import Path
Expand All @@ -25,6 +21,7 @@
import sqlglot.expressions as sge

import ibis
import ibis.backends.sql.compilers as sc
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
Expand All @@ -34,9 +31,7 @@
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema
from ibis.backends.snowflake.converter import SnowflakePandasData
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import SnowflakeCompiler
from ibis.backends.sql.compilers.base import STAR
from ibis.backends.sql.datatypes import SnowflakeType

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
Expand Down Expand Up @@ -76,15 +71,78 @@
"returns": "ARRAY",
"source": """return Array(count).fill(value).flat();""",
},
"ibis_udfs.public.array_sum": {
"inputs": {"array": "ARRAY"},
"returns": "DOUBLE",
"source": """\
let total = 0.0;
let allNull = true;
for (val of array) {
if (val !== null) {
total += val;
allNull = false;
}
}
return !allNull ? total : null;""",
},
"ibis_udfs.public.array_avg": {
"inputs": {"array": "ARRAY"},
"returns": "DOUBLE",
"source": """\
let count = 0;
let total = 0.0;
for (val of array) {
if (val !== null) {
total += val;
++count;
}
}
return count !== 0 ? total / count : null;""",
},
"ibis_udfs.public.array_any": {
"inputs": {"array": "ARRAY"},
"returns": "BOOLEAN",
"source": """\
let count = 0;
for (val of array) {
if (val === true) {
return true;
} else if (val === false) {
++count;
}
}
return count !== 0 ? false : null;""",
},
"ibis_udfs.public.array_all": {
"inputs": {"array": "ARRAY"},
"returns": "BOOLEAN",
"source": """\
let count = 0;
for (val of array) {
if (val === false) {
return false;
} else if (val === true) {
++count;
}
}
return count !== 0 ? true : null;""",
},
}


class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema):
name = "snowflake"
compiler = SnowflakeCompiler()
compiler = sc.snowflake.compiler
supports_python_udfs = True

_latest_udf_python_version = (3, 10)
_top_level_methods = ("from_connection", "from_snowpark")

def __init__(self, *args, _from_snowpark: bool = False, **kwargs) -> None:
Expand Down Expand Up @@ -274,7 +332,7 @@ def _setup_session(self, *, session_parameters, create_object_udfs: bool):
f"Unable to create Ibis UDFs, some functionality will not work: {e}"
)

@util.experimental
@util.deprecated(as_of="10.0", instead="use from_connection instead")
@classmethod
def from_snowpark(
cls, session: snowflake.snowpark.Session, *, create_object_udfs: bool = True
Expand Down Expand Up @@ -394,107 +452,6 @@ def reconnect(self) -> None:
)
super().reconnect()

def _get_udf_source(self, udf_node: ops.ScalarUDF):
name = type(udf_node).__name__
signature = ", ".join(
f"{name} {self.compiler.type_mapper.to_string(arg.dtype)}"
for name, arg in zip(udf_node.argnames, udf_node.args)
)
return_type = SnowflakeType.to_string(udf_node.dtype)
lines, _ = inspect.getsourcelines(udf_node.__func__)
source = textwrap.dedent(
"".join(
itertools.dropwhile(
lambda line: not line.lstrip().startswith("def "), lines
)
)
).strip()

config = udf_node.__config__

preamble_lines = [*self._UDF_PREAMBLE_LINES]

if imports := config.get("imports"):
preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})")

packages = "({})".format(
", ".join(map(repr, ("pandas", *config.get("packages", ()))))
)
preamble_lines.append(f"PACKAGES = {packages}")

return dict(
source=source,
name=name,
func_name=udf_node.__func_name__,
preamble="\n".join(preamble_lines).format(
name=name,
signature=signature,
return_type=return_type,
comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}",
version=".".join(
map(str, min(sys.version_info[:2], self._latest_udf_python_version))
),
),
)

_UDF_PREAMBLE_LINES = (
"CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})",
"RETURNS {return_type}",
"LANGUAGE PYTHON",
"IMMUTABLE",
"RUNTIME_VERSION = '{version}'",
"COMMENT = '{comment}'",
)

def _define_udf_translation_rules(self, expr):
"""No-op, these are defined in the compiler."""

def _register_udfs(self, expr: ir.Expr) -> None:
udf_sources = []
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
udf_sources.append(sql)
if udf_sources:
# define every udf in one execution to avoid the overhead of db
# round trips per udf
with self._safe_raw_sql(";\n".join(udf_sources)):
pass

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str:
return """\
{preamble}
HANDLER = '{func_name}'
AS $$
from __future__ import annotations
from typing import *
{source}
$$""".format(**self._get_udf_source(udf_node))

def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
template = """\
{preamble}
HANDLER = 'wrapper'
AS $$
from __future__ import annotations
from typing import *
import _snowflake
import pandas as pd
{source}
@_snowflake.vectorized(input=pd.DataFrame)
def wrapper(df):
return {func_name}(*(col for _, col in df.items()))
$$"""
return template.format(**self._get_udf_source(udf_node))

def to_pyarrow(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -530,10 +487,10 @@ def to_pandas_batches(
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
**kwargs: Any,
chunk_size: int = 1_000_000,
) -> Iterator[pd.DataFrame | pd.Series | Any]:
self._run_pre_execute_hooks(expr)
sql = self.compile(expr, limit=limit, params=params, **kwargs)
sql = self.compile(expr, limit=limit, params=params)
target_schema = expr.as_table().schema()
converter = functools.partial(
SnowflakePandasData.convert_table, schema=target_schema
Expand Down Expand Up @@ -582,6 +539,13 @@ def get_schema(
catalog: str | None = None,
database: str | None = None,
) -> Iterable[tuple[str, dt.DataType]]:
# this will always show temp tables with the same name as a non-temp
# table first
#
# snowflake puts temp tables in the same catalog and database as
# non-temp tables and differentiates between them using a different
# mechanism than other database that often put temp tables in a hidden
# or intentionall-difficult-to-access catalog/database
table = sg.table(
table_name, db=database, catalog=catalog, quoted=self.compiler.quoted
)
Expand Down Expand Up @@ -670,7 +634,7 @@ def list_tables(
tables_query = "SHOW TABLES"
views_query = "SHOW VIEWS"

if table_loc is not None:
if table_loc.catalog or table_loc.db:
tables_query += f" IN {table_loc}"
views_query += f" IN {table_loc}"

Expand Down Expand Up @@ -879,7 +843,7 @@ def create_table(

self._run_pre_execute_hooks(table)

query = self._to_sqlglot(table)
query = self.compiler.to_sqlglot(table)
else:
query = None

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
"schema": os.environ["SNOWFLAKE_SCHEMA"],
}
)
return ibis.backends.snowflake.Backend.from_snowpark(builder.create())
return ibis.backends.snowflake.Backend.from_connection(builder.create())
else:
return ibis.connect(_get_url(), **kw)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/snowflake/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_ibis_inside_snowpark(snowpark_session, execute_as):
def ibis_sproc(session):
import ibis.backends.snowflake

con = ibis.backends.snowflake.Backend.from_snowpark(session)
con = ibis.backends.snowflake.Backend.from_connection(session)

expr = (
con.tables.functional_alltypes.group_by("string_col")
Expand Down
94 changes: 40 additions & 54 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends import BaseBackend
from ibis.backends.sql.compilers.base import STAR

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -132,10 +131,8 @@ def table(
"""
table_loc = self._warn_and_create_table_loc(database, schema)

catalog, database = None, None
if table_loc is not None:
catalog = table_loc.catalog or None
database = table_loc.db or None
catalog = table_loc.catalog or None
database = table_loc.db or None

table_schema = self.get_schema(name, catalog=catalog, database=database)
return ops.DatabaseTable(
Expand All @@ -145,39 +142,15 @@ def table(
namespace=ops.Namespace(catalog=catalog, database=database),
).to_expr()

def _to_sqlglot(
self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any
):
"""Compile an Ibis expression to a sqlglot object."""
table_expr = expr.as_table()

if limit == "default":
limit = ibis.options.sql.default_limit
if limit is not None:
table_expr = table_expr.limit(limit)

if params is None:
params = {}

sql = self.compiler.translate(table_expr.op(), params=params)
assert not isinstance(sql, sge.Subquery)

if isinstance(sql, sge.Table):
sql = sg.select(STAR, copy=False).from_(sql, copy=False)

assert not isinstance(sql, sge.Subquery)
return sql

def compile(
self,
expr: ir.Expr,
limit: str | None = None,
params=None,
params: Mapping[ir.Expr, Any] | None = None,
pretty: bool = False,
**kwargs: Any,
):
"""Compile an Ibis expression to a SQL string."""
query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs)
query = self.compiler.to_sqlglot(expr, limit=limit, params=params)
sql = query.sql(dialect=self.dialect, pretty=pretty, copy=False)
self._log(sql)
return sql
Expand Down Expand Up @@ -222,7 +195,7 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
compiler = self.compiler
dialect = compiler.dialect

cte = self._to_sqlglot(table)
cte = compiler.to_sqlglot(table)
parsed = sg.parse_one(query, read=dialect)
parsed.args["with"] = cte.args.pop("with", [])
parsed = parsed.with_(
Expand All @@ -232,6 +205,21 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
sql = parsed.sql(dialect)
return self._get_schema_using_query(sql)

def _register_udfs(self, expr: ir.Expr) -> None:
udf_sources = []
compiler = self.compiler
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
compiler, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
)
if sql := compile_func(udf_node):
udf_sources.append(sql)
if udf_sources:
# define every udf in one execution to avoid the overhead of db
# round trips per udf
with self._safe_raw_sql(";\n".join(udf_sources)):
pass

def create_view(
self,
name: str,
Expand Down Expand Up @@ -570,28 +558,7 @@ def disconnect(self):
# _most_ sqlglot backends
self.con.close()

def _compile_builtin_udf(self, udf_node: ops.ScalarUDF | ops.AggUDF) -> None:
"""Compile a built-in UDF. No-op by default."""

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None:
raise NotImplementedError(
f"Python UDFs are not supported in the {self.name} backend"
)

def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None:
raise NotImplementedError(
f"PyArrow UDFs are not supported in the {self.name} backend"
)

def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
raise NotImplementedError(
f"pandas UDFs are not supported in the {self.name} backend"
)

def _to_catalog_db_tuple(self, table_loc: sge.Table):
if table_loc is None or table_loc == (None, None):
return None, None

if (sg_cat := table_loc.args["catalog"]) is not None:
sg_cat.args["quoted"] = False
sg_cat = sg_cat.sql(self.name)
Expand All @@ -603,7 +570,8 @@ def _to_catalog_db_tuple(self, table_loc: sge.Table):

def _to_sqlglot_table(self, database):
if database is None:
return None
# Create "table" with empty catalog and db
database = sg.exp.Table(catalog=None, db=None)
elif isinstance(database, (list, tuple)):
if len(database) > 2:
raise ValueError(
Expand Down Expand Up @@ -647,3 +615,21 @@ def _to_sqlglot_table(self, database):
)

return database

def _register_builtin_udf(self, udf_node: ops.ScalarUDF) -> None:
"""No-op."""

def _register_python_udf(self, udf_node: ops.ScalarUDF) -> str:
raise NotImplementedError(
f"Python UDFs are not supported in the {self.dialect} backend"
)

def _register_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> str:
raise NotImplementedError(
f"PyArrow UDFs are not supported in the {self.dialect} backend"
)

def _register_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
raise NotImplementedError(
f"pandas UDFs are not supported in the {self.dialect} backend"
)
132 changes: 106 additions & 26 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class AggGen:
supports_filter
Whether the backend supports a FILTER clause in the aggregate.
Defaults to False.
supports_order_by
Whether the backend supports an ORDER BY clause in (relevant)
aggregates. Defaults to False.
"""

class _Accessor:
Expand All @@ -79,10 +82,13 @@ def __getattr__(self, name: str) -> Callable:

__getitem__ = __getattr__

__slots__ = ("supports_filter",)
__slots__ = ("supports_filter", "supports_order_by")

def __init__(self, *, supports_filter: bool = False):
def __init__(
self, *, supports_filter: bool = False, supports_order_by: bool = False
):
self.supports_filter = supports_filter
self.supports_order_by = supports_order_by

def __get__(self, instance, owner=None):
if instance is None:
Expand All @@ -96,6 +102,7 @@ def aggregate(
name: str,
*args: Any,
where: Any = None,
order_by: tuple = (),
):
"""Compile the specified aggregate.
Expand All @@ -109,21 +116,31 @@ def aggregate(
Any arguments to pass to the aggregate.
where
An optional column filter to apply before performing the aggregate.
order_by
Optional ordering keys to use to order the rows before performing
the aggregate.
"""
func = compiler.f[name]

if where is None:
return func(*args)

if self.supports_filter:
return sge.Filter(
this=func(*args),
expression=sge.Where(this=where),
if order_by and not self.supports_order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
f"not supported for the {compiler.dialect} backend"
)
else:

if where is not None and not self.supports_filter:
args = tuple(compiler.if_(where, arg, NULL) for arg in args)
return func(*args)

if order_by and self.supports_order_by:
*rest, last = args
out = func(*rest, sge.Order(this=last, expressions=order_by))
else:
out = func(*args)

if where is not None and self.supports_filter:
out = sge.Filter(this=out, expression=sge.Where(this=where))

return out


class VarGen:
Expand Down Expand Up @@ -250,6 +267,9 @@ class SQLGlotCompiler(abc.ABC):
copy_func_args: bool = False
"""Whether to copy function arguments when generating SQL."""

supports_qualify: bool = False
"""Whether the backend supports the QUALIFY clause."""

NAN: ClassVar[sge.Expression] = sge.Cast(
this=sge.convert("NaN"), to=sge.DataType(this=sge.DataType.Type.DOUBLE)
)
Expand Down Expand Up @@ -296,7 +316,6 @@ class SQLGlotCompiler(abc.ABC):
ops.ApproxCountDistinct: "approx_distinct",
ops.ArgMax: "max_by",
ops.ArgMin: "min_by",
ops.ArrayCollect: "array_agg",
ops.ArrayContains: "array_contains",
ops.ArrayFlatten: "flatten",
ops.ArrayLength: "array_size",
Expand All @@ -314,15 +333,13 @@ class SQLGlotCompiler(abc.ABC):
ops.Degrees: "degrees",
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
ops.First: "first",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
ops.LPad: "lpad",
ops.Last: "last",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
Expand Down Expand Up @@ -424,8 +441,10 @@ def make_impl(op, target_name):

if issubclass(op, ops.Reduction):

def impl(self, _, *, _name: str = target_name, where, **kw):
return self.agg[_name](*kw.values(), where=where)
def impl(
self, _, *, _name: str = target_name, where, order_by=(), **kw
):
return self.agg[_name](*kw.values(), where=where, order_by=order_by)

else:

Expand All @@ -434,15 +453,18 @@ def impl(self, _, *, _name: str = target_name, **kw):

return impl

for op, target_name in cls.SIMPLE_OPS.items():
setattr(cls, methodname(op), make_impl(op, target_name))

# unconditionally raise an exception for unsupported operations
#
# these *must* be defined after SIMPLE_OPS to handle compilers that
# subclass other compilers
for op in cls.UNSUPPORTED_OPS:
# change to visit_Unsupported in a follow up
# TODO: handle geoespatial ops as a separate case?
setattr(cls, methodname(op), cls.visit_Undefined)

for op, target_name in cls.SIMPLE_OPS.items():
setattr(cls, methodname(op), make_impl(op, target_name))

# raise on any remaining unsupported operations
for op in ALL_OPERATIONS:
name = methodname(op)
Expand Down Expand Up @@ -474,6 +496,24 @@ def dialect(self) -> str:
def type_mapper(self) -> type[SqlglotType]:
"""The type mapper for the backend."""

def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: # noqa: B027
"""No-op."""

def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None:
raise NotImplementedError(
f"Python UDFs are not supported in the {self.dialect} backend"
)

def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None:
raise NotImplementedError(
f"PyArrow UDFs are not supported in the {self.dialect} backend"
)

def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
raise NotImplementedError(
f"pandas UDFs are not supported in the {self.dialect} backend"
)

# Concrete API

def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If:
Expand All @@ -495,6 +535,34 @@ def _prepare_params(self, params):
result[node] = value
return result

def to_sqlglot(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
import ibis

table_expr = expr.as_table()

if limit == "default":
limit = ibis.options.sql.default_limit
if limit is not None:
table_expr = table_expr.limit(limit)

if params is None:
params = {}

sql = self.translate(table_expr.op(), params=params)
assert not isinstance(sql, sge.Subquery)

if isinstance(sql, sge.Table):
sql = sg.select(STAR, copy=False).from_(sql, copy=False)

assert not isinstance(sql, sge.Subquery)
return sql

def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
"""Translate an ibis operation to a sqlglot expression.
Expand Down Expand Up @@ -837,6 +905,7 @@ def visit_ExtractSecond(self, op, *, arg):
def visit_TimestampTruncate(self, op, *, arg, unit):
unit_mapping = {
"Y": "year",
"Q": "quarter",
"M": "month",
"W": "week",
"D": "day",
Expand All @@ -847,10 +916,12 @@ def visit_TimestampTruncate(self, op, *, arg, unit):
"us": "us",
}

if (unit := unit_mapping.get(unit.short)) is None:
raise com.UnsupportedOperationError(f"Unsupported truncate unit {unit}")
if (raw_unit := unit_mapping.get(unit.short)) is None:
raise com.UnsupportedOperationError(
f"Unsupported truncate unit {unit.short!r}"
)

return self.f.date_trunc(unit, arg)
return self.f.date_trunc(raw_unit, arg)

def visit_DateTruncate(self, op, *, arg, unit):
return self.visit_TimestampTruncate(op, arg=arg, unit=unit)
Expand Down Expand Up @@ -1181,15 +1252,21 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
else:
yield value.as_(name, quoted=self.quoted, copy=False)

def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
# if we've constructed a useless projection return the parent relation
if not selections and not predicates and not sort_keys:
if not (selections or predicates or qualified or sort_keys):
return parent

result = parent

if selections:
if op.is_star_selection():
# if there are `qualify` predicates then sqlglot adds a hidden
# column to implement the functionality if the dialect doesn't
# support it
#
# using STAR in that case would lead to an extra column, so in that
# case we have to spell out the columns
if op.is_star_selection() and (not qualified or self.supports_qualify):
fields = [STAR]
else:
fields = self._cleanup_names(selections)
Expand All @@ -1198,6 +1275,9 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
if predicates:
result = result.where(*predicates, copy=False)

if qualified:
result = result.qualify(*qualified, copy=False)

if sort_keys:
result = result.order_by(*sort_keys, copy=False)

Expand Down

Large diffs are not rendered by default.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from collections import ChainMap
from typing import TYPE_CHECKING

from ibis.backends.bigquery.udf.find import find_names
from ibis.backends.bigquery.udf.rewrite import rewrite
from ibis.backends.sql.compilers.bigquery.udf.find import find_names
from ibis.backends.sql.compilers.bigquery.udf.rewrite import rewrite

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down
File renamed without changes.
File renamed without changes.
112 changes: 97 additions & 15 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@


class ClickhouseAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
def aggregate(self, compiler, name, *args, where=None, order_by=()):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# Clickhouse aggregate functions all have filtering variants with a
# `If` suffix (e.g. `SumIf` instead of `Sum`).
if where is not None:
Expand All @@ -37,6 +42,8 @@ class ClickHouseCompiler(SQLGlotCompiler):

agg = ClickhouseAggGen()

supports_qualify = True

UNSUPPORTED_OPS = (
ops.RowID,
ops.CumeDist,
Expand All @@ -56,16 +63,13 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.Arbitrary: "any",
ops.ArgMax: "argMax",
ops.ArgMin: "argMin",
ops.ArrayCollect: "groupArray",
ops.ArrayContains: "has",
ops.ArrayFlatten: "arrayFlatten",
ops.ArrayIntersect: "arrayIntersect",
ops.ArrayPosition: "indexOf",
ops.BitwiseAnd: "bitAnd",
ops.BitwiseLeftShift: "bitShiftLeft",
ops.BitwiseNot: "bitNot",
ops.BitwiseOr: "bitOr",
ops.BitwiseRightShift: "bitShiftRight",
ops.BitwiseXor: "bitXor",
ops.Capitalize: "initcap",
ops.CountDistinct: "uniq",
Expand All @@ -88,16 +92,13 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.ExtractWeekOfYear: "toISOWeek",
ops.ExtractYear: "toYear",
ops.ExtractIsoYear: "toISOYear",
ops.First: "any",
ops.IntegerRange: "range",
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Last: "anyLast",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapContains: "mapContains",
ops.MapKeys: "mapKeys",
ops.MapLength: "length",
ops.MapMerge: "mapUpdate",
Expand Down Expand Up @@ -373,6 +374,7 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
def visit_TimestampTruncate(self, op, *, arg, unit):
converters = {
"Y": "toStartOfYear",
"Q": "toStartOfQuarter",
"M": "toStartOfMonth",
"W": "toMonday",
"D": "toDate",
Expand Down Expand Up @@ -435,17 +437,20 @@ def visit_StringSplit(self, op, *, arg, delimiter):
delimiter, self.cast(arg, dt.String(nullable=False))
)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
call = self.agg.groupArray(arg, where=where)
return self.if_(self.f.empty(call), NULL, self.f.arrayStringConcat(call, sep))

def visit_Cot(self, op, *, arg):
return 1.0 / self.f.tan(arg)

def visit_StructColumn(self, op, *, values, names):
# ClickHouse struct types cannot be nullable
# (non-nested fields can be nullable)
return self.cast(self.f.tuple(*values), op.dtype.copy(nullable=False))
def visit_StructColumn(self, op, *, values, **_):
return self.f.tuple(*values)

def visit_Clip(self, op, *, arg, lower, upper):
if upper is not None:
Expand Down Expand Up @@ -583,9 +588,10 @@ def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.arrayFilter(func, arg)

def visit_ArrayRemove(self, op, *, arg, other):
x = sg.to_identifier("x")
body = x.neq(other)
return self.f.arrayFilter(sge.Lambda(this=body, expressions=[x]), arg)
x = sg.to_identifier(util.gen_name("x"))
should_keep_null = sg.and_(x.is_(NULL), other.is_(sg.not_(NULL)))
cond = sg.or_(x.neq(other), should_keep_null)
return self.f.arrayFilter(sge.Lambda(this=cond, expressions=[x]), arg)

def visit_ArrayUnion(self, op, *, left, right):
arg = self.f.arrayConcat(left, right)
Expand All @@ -597,6 +603,27 @@ def visit_ArrayUnion(self, op, *, left, right):
def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
return self.f.arrayZip(*arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.any(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.anyLast(arg, where=where, order_by=order_by)

def visit_CountDistinctStar(
self, op: ops.CountDistinctStar, *, where, **_: Any
) -> str:
Expand Down Expand Up @@ -736,3 +763,58 @@ def _cleanup_names(
value.as_(self._gen_valid_name(name), quoted=quoted, copy=False)
for name, value in exprs.items()
)

def _array_reduction(self, arg):
x = sg.to_identifier("x", quoted=self.quoted)
not_null = sge.Lambda(this=x.is_(sg.not_(NULL)), expressions=[x])
return self.f.arrayFilter(not_null, arg)

def visit_ArrayMin(self, op, *, arg):
return self.f.arrayReduce("min", self._array_reduction(arg))

visit_ArrayAll = visit_ArrayMin

def visit_ArrayMax(self, op, *, arg):
return self.f.arrayReduce("max", self._array_reduction(arg))

visit_ArrayAny = visit_ArrayMax

def visit_ArraySum(self, op, *, arg):
return self.f.arrayReduce("sum", self._array_reduction(arg))

def visit_ArrayMean(self, op, *, arg):
return self.f.arrayReduce("avg", self._array_reduction(arg))

def _promote_bitshift_inputs(self, *, op, left, right):
# clickhouse is incredibly pedantic about types allowed in bit shifting
#
# e.g., a UInt8 cannot be bitshift by more than 8 bits, UInt16 by more
# than 16, and so on.
#
# This is why something like Ibis is necessary so that people have just
# _consistent_ things, let alone *nice* things.
left_dtype = op.left.dtype
right_dtype = op.right.dtype

if left_dtype != right_dtype:
promoted = dt.higher_precedence(left_dtype, right_dtype)
return self.cast(left, promoted), self.cast(right, promoted)
return left, right

def visit_BitwiseLeftShift(self, op, *, left, right):
return self.f.bitShiftLeft(
*self._promote_bitshift_inputs(op=op, left=left, right=right)
)

def visit_BitwiseRightShift(self, op, *, left, right):
return self.f.bitShiftRight(
*self._promote_bitshift_inputs(op=op, left=left, right=right)
)

def visit_MapContains(self, op, *, arg, key):
return self.if_(
sg.or_(arg.is_(NULL), key.is_(NULL)), NULL, self.f.mapContains(arg, key)
)


compiler = ClickHouseCompiler()
61 changes: 42 additions & 19 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType

Expand All @@ -25,27 +24,20 @@ class DataFusionCompiler(SQLGlotCompiler):
dialect = DataFusion
type_mapper = DataFusionType

rewrites = (
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)

agg = AggGen(supports_filter=True)
agg = AggGen(supports_filter=True, supports_order_by=True)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArrayZip,
ops.BitwiseNot,
ops.Clip,
ops.CountDistinctStar,
ops.DateDelta,
ops.Greatest,
ops.GroupConcat,
ops.IntervalFromInteger,
ops.Least,
ops.MultiQuantile,
Expand All @@ -55,9 +47,7 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.TimeDelta,
ops.TimestampBucket,
ops.TimestampDelta,
ops.TimestampNow,
ops.TypeOf,
ops.Unnest,
ops.StringToDate,
ops.StringToTimestamp,
)
Expand Down Expand Up @@ -123,6 +113,12 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
return sg.exp.HexString(this=value.hex())
elif dtype.is_uuid():
return sge.convert(str(value))
elif dtype.is_struct():
args = []
for name, field_value in value.items():
args.append(sge.convert(name))
args.append(field_value)
return self.f.named_struct(*args)
else:
return None

Expand Down Expand Up @@ -329,6 +325,12 @@ def visit_ArrayRepeat(self, op, *, arg, times):
def visit_ArrayPosition(self, op, *, arg, other):
return self.f.coalesce(self.f.array_position(arg, other), 0)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_Covariance(self, op, *, left, right, how, where):
x = op.left
if x.dtype.is_boolean():
Expand Down Expand Up @@ -423,15 +425,17 @@ def visit_StringConcat(self, op, *, arg):
sg.or_(*any_args_null), self.cast(NULL, dt.string), self.f.concat(*arg)
)

def visit_First(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where)
def visit_First(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where)
def visit_Last(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Aggregate(self, op, *, parent, groups, metrics):
"""Support `GROUP BY` expressions in `SELECT` since DataFusion does not."""
Expand Down Expand Up @@ -479,3 +483,22 @@ def visit_Aggregate(self, op, *, parent, groups, metrics):
sel = sel.group_by(*by_names_quoted)

return sel

def visit_StructColumn(self, op, *, names, values):
args = []
for name, value in zip(names, values):
args.append(sge.convert(name))
args.append(value)
return self.f.named_struct(*args)

def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"DataFusion does not support order-sensitive group_concat"
)
return super().visit_GroupConcat(
op, arg=arg, sep=sep, where=where, order_by=order_by
)


compiler = DataFusionCompiler()
21 changes: 16 additions & 5 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import sqlglot.expressions as sge
import toolz

import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compilers.base import NULL, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DruidType
from ibis.backends.sql.dialects import Druid
from ibis.common.temporal import TimestampUnit


class DruidCompiler(SQLGlotCompiler):
Expand All @@ -25,7 +27,6 @@ class DruidCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand All @@ -36,14 +37,14 @@ class DruidCompiler(SQLGlotCompiler):
ops.ArrayZip,
ops.CountDistinctStar,
ops.Covariance,
ops.Date,
ops.DateDelta,
ops.DateFromYMD,
ops.DayOfWeekIndex,
ops.DayOfWeekName,
ops.First,
ops.IntervalFromInteger,
ops.IsNan,
ops.IsInf,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.MultiQuantile,
Expand Down Expand Up @@ -117,8 +118,8 @@ def visit_Pi(self, op):
def visit_Sign(self, op, *, arg):
return self.if_(arg.eq(0), 0, self.if_(arg > 0, 1, -1))

def visit_GroupConcat(self, op, *, arg, sep, where):
return self.agg.string_agg(arg, sep, 1 << 20, where=where)
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
return self.agg.string_agg(arg, sep, 1 << 20, where=where, order_by=order_by)

def visit_StartsWith(self, op, *, arg, start):
return self.f.left(arg, self.f.length(start)).eq(start)
Expand Down Expand Up @@ -169,6 +170,13 @@ def visit_Cast(self, op, *, arg, to):
return self.f.time_parse(arg)
return super().visit_Cast(op, arg=arg, to=to)

def visit_TimestampFromUNIX(self, op, *, arg, unit):
if unit == TimestampUnit.SECOND:
return self.f.millis_to_timestamp(arg * 1_000)
elif unit == TimestampUnit.MILLISECOND:
return self.f.millis_to_timestamp(arg)
raise exc.UnsupportedArgumentError(f"Druid doesn't support {unit} units")

def visit_TimestampFromYMDHMS(
self, op, *, year, month, day, hours, minutes, seconds
):
Expand All @@ -188,3 +196,6 @@ def visit_TimestampFromYMDHMS(
"Z",
)
)


compiler = DruidCompiler()
113 changes: 85 additions & 28 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
from functools import partial, reduce
from typing import TYPE_CHECKING, Any

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -10,11 +11,17 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.util import gen_name

if TYPE_CHECKING:
from collections.abc import Mapping

import ibis.expr.types as ir


_INTERVAL_SUFFIXES = {
"ms": "milliseconds",
"us": "microseconds",
Expand All @@ -33,12 +40,9 @@ class DuckDBCompiler(SQLGlotCompiler):
dialect = DuckDB
type_mapper = DuckDBType

agg = AggGen(supports_filter=True)
agg = AggGen(supports_filter=True, supports_order_by=True)

rewrites = (
exclude_nulls_from_array_collect,
*SQLGlotCompiler.rewrites,
)
supports_qualify = True

LOWERED_OPS = {
ops.Sample: None,
Expand All @@ -48,6 +52,12 @@ class DuckDBCompiler(SQLGlotCompiler):
SIMPLE_OPS = {
ops.Arbitrary: "any_value",
ops.ArrayPosition: "list_indexof",
ops.ArrayMin: "list_min",
ops.ArrayMax: "list_max",
ops.ArrayAny: "list_bool_or",
ops.ArrayAll: "list_bool_and",
ops.ArraySum: "list_sum",
ops.ArrayMean: "list_avg",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
Expand Down Expand Up @@ -91,6 +101,33 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.GeoY: "st_y",
}

def to_sqlglot(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
sql = super().to_sqlglot(expr, limit=limit, params=params)

table_expr = expr.as_table()
geocols = table_expr.schema().geospatial

if not geocols:
return sql

quoted = self.quoted
return sg.select(
sge.Star(
replace=[
self.f.st_aswkb(sg.column(col, quoted=quoted)).as_(
col, quoted=quoted
)
for col in geocols
]
)
).from_(sql.subquery())

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[
Expand All @@ -113,6 +150,12 @@ def visit_ArrayDistinct(self, op, *, arg):
),
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_ArrayIndex(self, op, *, arg, index):
return self.f.list_extract(arg, index + self.cast(index >= 0, op.index.dtype))

Expand Down Expand Up @@ -162,10 +205,10 @@ def visit_ArrayIntersect(self, op, *, left, right):
return self.f.list_filter(left, lamduh)

def visit_ArrayRemove(self, op, *, arg, other):
param = sg.to_identifier("x")
body = param.neq(other)
lamduh = sge.Lambda(this=body, expressions=[param])
return self.f.list_filter(arg, lamduh)
x = sg.to_identifier(util.gen_name("x"))
should_keep_null = sg.and_(x.is_(NULL), other.is_(sg.not_(NULL)))
cond = sg.or_(x.neq(other), should_keep_null)
return self.f.list_filter(arg, sge.Lambda(this=cond, expressions=[x]))

def visit_ArrayUnion(self, op, *, left, right):
arg = self.f.list_concat(left, right)
Expand Down Expand Up @@ -201,10 +244,17 @@ def visit_ArrayZip(self, op, *, arg):
any_arg_null = sg.or_(*(arr.is_(NULL) for arr in arg))
return self.if_(any_arg_null, NULL, zipped_arrays)

def visit_Array(self, op, *, exprs):
return self.cast(self.f.array(*exprs), op.dtype)

def visit_Map(self, op, *, keys, values):
# workaround for https://github.com/ibis-project/ibis/issues/8632
return self.if_(
sg.or_(keys.is_(NULL), values.is_(NULL)), NULL, self.f.map(keys, values)
sg.or_(keys.is_(NULL), values.is_(NULL)),
NULL,
self.f.map(
self.cast(keys, op.keys.dtype), self.cast(values, op.values.dtype)
),
)

def visit_MapGet(self, op, *, arg, key, default):
Expand Down Expand Up @@ -378,6 +428,8 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
return self.cast(
str(value), to=dt.float32 if dtype.is_decimal() else dtype
)
if dtype.is_floating() or dtype.is_integer():
return sge.convert(value)
return self.cast(value, dtype)
elif dtype.is_time():
return self.f.make_time(
Expand All @@ -401,16 +453,16 @@ def visit_NonNullLiteral(self, op, *, value, dtype):

return self.f[funcname](*args)
elif dtype.is_struct():
return sge.Struct.from_arg_list(
[
sge.PropertyEQ(
this=sg.to_identifier(k, quoted=self.quoted),
expression=self.visit_Literal(
return self.cast(
sge.Struct.from_arg_list(
[
self.visit_Literal(
ops.Literal(v, field_dtype), value=v, dtype=field_dtype
),
)
for field_dtype, (k, v) in zip(dtype.types, value.items())
]
)
for field_dtype, v in zip(dtype.types, value.values())
]
),
op.dtype,
)
else:
return None
Expand Down Expand Up @@ -460,15 +512,17 @@ def visit_RegexReplace(self, op, *, arg, pattern, replacement):
arg, pattern, replacement, "g", dialect=self.dialect
)

def visit_First(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first(arg, where=where)
def visit_First(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.first(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where):
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where)
def visit_Last(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_Quantile(self, op, *, arg, quantile, where):
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
Expand Down Expand Up @@ -598,3 +652,6 @@ def visit_TableUnnest(
.from_(parent)
.join(unnest, join_type="CROSS" if not keep_empty else "LEFT")
)


compiler = DuckDBCompiler()
29 changes: 26 additions & 3 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayDistinct,
ops.ArrayFilter,
ops.ArrayFlatten,
Expand Down Expand Up @@ -88,8 +87,6 @@ class ExasolCompiler(SQLGlotCompiler):
ops.Log10: "log10",
ops.All: "min",
ops.Any: "max",
ops.First: "first_value",
ops.Last: "last_value",
}

@staticmethod
Expand Down Expand Up @@ -128,6 +125,29 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
def visit_Date(self, op, *, arg):
return self.cast(arg, dt.date)

def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

return sge.GroupConcat(this=arg, separator=sep)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the exasol backend"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the exasol backend"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_StartsWith(self, op, *, arg, start):
return self.f.left(arg, self.f.length(start)).eq(start)

Expand Down Expand Up @@ -241,3 +261,6 @@ def visit_BitwiseOr(self, op, *, left, right):

def visit_BitwiseXor(self, op, *, left, right):
return self.cast(self.f.bit_xor(left, right), op.dtype)


compiler = ExasolCompiler()
28 changes: 24 additions & 4 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@


class FlinkAggGen(AggGen):
def aggregate(self, compiler, name, *args, where=None):
def aggregate(self, compiler, name, *args, where=None, order_by=()):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)

func = compiler.f[name]
if where is not None:
# Flink does support FILTER, but it's broken for:
Expand Down Expand Up @@ -65,7 +71,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayFlatten,
ops.ArraySort,
ops.ArrayStringJoin,
Expand Down Expand Up @@ -99,8 +104,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArrayRemove: "array_remove",
ops.ArrayUnion: "array_union",
ops.ExtractDayOfYear: "dayofyear",
ops.First: "first_value",
ops.Last: "last_value",
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
Expand Down Expand Up @@ -302,6 +305,20 @@ def visit_ArraySlice(self, op, *, arg, start, stop):

return self.f.array_slice(*args)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the flink backend"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the flink backend"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Not(self, op, *, arg):
return sg.not_(self.cast(arg, dt.boolean))

Expand Down Expand Up @@ -557,3 +574,6 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right):

def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)


compiler = FlinkCompiler()
6 changes: 3 additions & 3 deletions ibis/backends/sql/compilers/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@ class ImpalaCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.ArrayPosition,
ops.Array,
ops.Covariance,
ops.DateDelta,
ops.ExtractDayOfYear,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Map,
ops.Median,
Expand Down Expand Up @@ -320,3 +317,6 @@ def visit_Sign(self, op, *, arg):
if not dtype.is_float32():
return self.cast(sign, dtype)
return sign


compiler = ImpalaCompiler()
69 changes: 50 additions & 19 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import calendar
from typing import TYPE_CHECKING, Any

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -13,6 +14,7 @@
NULL,
STAR,
TRUE,
AggGen,
SQLGlotCompiler,
)
from ibis.backends.sql.datatypes import MSSQLType
Expand All @@ -25,6 +27,11 @@
)
from ibis.common.deferred import var

if TYPE_CHECKING:
from collections.abc import Mapping

import ibis.expr.operations as ir

y = var("y")
start = var("start")
end = var("end")
Expand Down Expand Up @@ -52,6 +59,8 @@ def rewrite_rows_range_order_by_window(_, **kwargs):
class MSSQLCompiler(SQLGlotCompiler):
__slots__ = ()

agg = AggGen(supports_order_by=True)

dialect = MSSQL
type_mapper = MSSQLType
rewrites = (
Expand All @@ -66,7 +75,6 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayDistinct,
ops.ArrayFlatten,
Expand All @@ -82,14 +90,12 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.DateDiff,
ops.DateSub,
ops.EndsWith,
ops.First,
ops.IntervalAdd,
ops.IntervalFromInteger,
ops.IntervalMultiply,
ops.IntervalSubtract,
ops.IsInf,
ops.IsNan,
ops.Last,
ops.LPad,
ops.Levenshtein,
ops.Map,
Expand Down Expand Up @@ -130,17 +136,9 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.Max: "max",
}

@property
def NAN(self):
return self.f.double("NaN")

@property
def POS_INF(self):
return self.f.double("Infinity")

@property
def NEG_INF(self):
return self.f.double("-Infinity")
NAN = sg.func("double", sge.convert("NaN"))
POS_INF = sg.func("double", sge.convert("Infinity"))
NEG_INF = sg.func("double", sge.convert("-Infinity"))

@staticmethod
def _generate_groups(groups):
Expand All @@ -157,7 +155,28 @@ def _minimize_spec(start, end, spec):
return None
return spec

def visit_RandomUUID(self, op, **kwargs):
def to_sqlglot(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
"""Compile an Ibis expression to a sqlglot object."""
import ibis

table_expr = expr.as_table()
conversions = {
name: ibis.ifelse(table_expr[name], 1, 0).cast(dt.boolean)
for name, typ in table_expr.schema().items()
if typ.is_boolean()
}

if conversions:
table_expr = table_expr.mutate(**conversions)
return super().to_sqlglot(table_expr, limit=limit, params=params)

def visit_RandomUUID(self, op, **_):
return self.f.newid()

def visit_StringLength(self, op, *, arg):
Expand Down Expand Up @@ -185,10 +204,16 @@ def visit_Substring(self, op, *, arg, start, length):
length = self.f.length(arg)
return self.f.substring(arg, start, length)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.group_concat(arg, sep)

out = self.f.group_concat(arg, sep)

if order_by:
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))

return out

def visit_CountStar(self, op, *, arg, where):
if where is not None:
Expand Down Expand Up @@ -452,9 +477,9 @@ def visit_All(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)

def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
# if we've constructed a useless projection return the parent relation
if not selections and not predicates and not sort_keys:
if not (selections or predicates or qualified or sort_keys):
return parent

result = parent
Expand All @@ -467,7 +492,13 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys):
if predicates:
result = result.where(*predicates, copy=True)

if qualified:
result = result.qualify(*qualified, copy=True)

if sort_keys:
result = result.order_by(*sort_keys, copy=False)

return result


compiler = MSSQLCompiler()
30 changes: 24 additions & 6 deletions ibis/backends/sql/compilers/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,10 @@ def POS_INF(self):
ops.ApproxMedian,
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
ops.Covariance,
ops.First,
ops.Last,
ops.Levenshtein,
ops.Median,
ops.Mode,
Expand Down Expand Up @@ -165,14 +162,19 @@ def visit_CountDistinctStar(self, op, *, arg, where):
sge.Distinct(expressions=list(map(func, op.arg.schema.keys())))
)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if not isinstance(op.sep, ops.Literal):
raise com.UnsupportedOperationError(
"Only string literal separators are supported"
)

if where is not None:
arg = self.if_(where, arg)
return self.f.group_concat(arg, sep)
arg = self.if_(where, arg, NULL)

if order_by:
arg = sge.Order(this=arg, expressions=order_by)

return sge.GroupConcat(this=arg, separator=sep)

def visit_DayOfWeekIndex(self, op, *, arg):
return (self.f.dayofweek(arg) + 5) % 7
Expand Down Expand Up @@ -285,6 +287,19 @@ def visit_LRStrip(self, op, *, arg, position):
)

def visit_DateTimestampTruncate(self, op, *, arg, unit):
if unit.short == "Q":
# adapted from https://stackoverflow.com/a/11884743
return (
# January 1 of the year of the `arg`
self.f.makedate(self.f.year(arg), 1)
# add the current quarter's number of quarters minus one to Jan 1
# first quarter: add zero
# second quarter: add one
# third quarter: add two
# fourth quarter: add three
+ sge.Interval(this=self.f.quarter(arg) - 1, unit=self.v.QUARTER)
)

truncate_formats = {
"s": "%Y-%m-%d %H:%i:%s",
"m": "%Y-%m-%d %H:%i:00",
Expand Down Expand Up @@ -359,3 +374,6 @@ def visit_UnwrapJSONBoolean(self, op, *, arg):
self.if_(arg.eq(sge.convert("true")), 1, 0),
NULL,
)


compiler = MySQLCompiler()
22 changes: 18 additions & 4 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.sql.compilers.base import NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import OracleType
from ibis.backends.sql.dialects import Oracle
from ibis.backends.sql.rewrites import (
Expand All @@ -23,6 +23,8 @@
class OracleCompiler(SQLGlotCompiler):
__slots__ = ()

agg = AggGen(supports_order_by=True)

dialect = Oracle
type_mapper = OracleType
rewrites = (
Expand All @@ -49,13 +51,10 @@ class OracleCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
ops.ArrayCollect,
ops.Array,
ops.ArrayFlatten,
ops.ArrayMap,
ops.ArrayStringJoin,
ops.First,
ops.Last,
ops.Mode,
ops.MultiQuantile,
ops.RegexSplit,
Expand Down Expand Up @@ -329,6 +328,7 @@ def visit_Xor(self, op, *, left, right):
def visit_DateTruncate(self, op, *, arg, unit):
trunc_unit_mapping = {
"Y": "year",
"Q": "Q",
"M": "MONTH",
"W": "IW",
"D": "DDD",
Expand Down Expand Up @@ -445,3 +445,17 @@ def visit_StringConcat(self, op, *, arg):

def visit_ExtractIsoYear(self, op, *, arg):
return self.cast(self.f.to_char(arg, "IYYY"), op.dtype)

def visit_GroupConcat(self, op, *, arg, where, sep, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

out = self.f.listagg(arg, sep)

if order_by:
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))

return out


compiler = OracleCompiler()
133 changes: 128 additions & 5 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import inspect
import string
import textwrap
from functools import partial, reduce
from itertools import takewhile
from typing import TYPE_CHECKING, Any

import sqlglot as sg
import sqlglot.expressions as sge
Expand All @@ -13,9 +17,20 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

if TYPE_CHECKING:
from collections.abc import Mapping

import ibis.expr.types as ir


def _verify_source_line(func_name: str, line: str):
if line.startswith("@"):
raise InvalidDecoratorError(func_name, line)
return line


class PostgresUDFNode(ops.Value):
shape = rlz.shape_like("args")
Expand All @@ -27,9 +42,7 @@ class PostgresCompiler(SQLGlotCompiler):
dialect = Postgres
type_mapper = PostgresType

rewrites = (exclude_nulls_from_array_collect, *SQLGlotCompiler.rewrites)

agg = AggGen(supports_filter=True)
agg = AggGen(supports_filter=True, supports_order_by=True)

NAN = sge.Literal.number("'NaN'::double precision")
POS_INF = sge.Literal.number("'Inf'::double precision")
Expand All @@ -43,7 +56,6 @@ class PostgresCompiler(SQLGlotCompiler):

SIMPLE_OPS = {
ops.Arbitrary: "first", # could use any_value for postgres>=16
ops.ArrayCollect: "array_agg",
ops.ArrayRemove: "array_remove",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
Expand Down Expand Up @@ -100,6 +112,64 @@ class PostgresCompiler(SQLGlotCompiler):
ops.TimeFromHMS: "make_time",
}

def to_sqlglot(
self,
expr: ir.Expr,
*,
limit: str | None = None,
params: Mapping[ir.Expr, Any] | None = None,
):
table_expr = expr.as_table()
geocols = table_expr.schema().geospatial
conversions = {name: table_expr[name].as_ewkb() for name in geocols}

if conversions:
table_expr = table_expr.mutate(**conversions)
return super().to_sqlglot(table_expr, limit=limit, params=params)

def _compile_python_udf(self, udf_node: ops.ScalarUDF):
config = udf_node.__config__
func = udf_node.__func__
func_name = func.__name__

lines, _ = inspect.getsourcelines(func)
iter_lines = iter(lines)

function_premable_lines = list(
takewhile(lambda line: not line.lstrip().startswith("def "), iter_lines)
)

if len(function_premable_lines) > 1:
raise InvalidDecoratorError(
name=func_name, lines="".join(function_premable_lines)
)

source = textwrap.dedent(
"".join(map(partial(_verify_source_line, func_name), iter_lines))
).strip()

type_mapper = self.type_mapper
argnames = udf_node.argnames
return """\
CREATE OR REPLACE FUNCTION {ident}({signature})
RETURNS {return_type}
LANGUAGE {language}
AS $$
{source}
return {name}({args})
$$""".format(
name=type(udf_node).__name__,
ident=self.__sql_name__(udf_node),
signature=", ".join(
f"{argname} {type_mapper.to_string(arg.dtype)}"
for argname, arg in zip(argnames, udf_node.args)
),
return_type=type_mapper.to_string(udf_node.dtype),
language=config.get("language", "plpython3u"),
source=source,
args=", ".join(argnames),
)

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

Expand Down Expand Up @@ -285,6 +355,26 @@ def visit_ArrayIntersect(self, op, *, left, right):
)
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the postgres backend"
)
return self.agg.first(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the postgres backend"
)
return self.agg.last(arg, where=where, order_by=order_by)

def visit_Log2(self, op, *, arg):
return self.cast(
self.f.log(
Expand Down Expand Up @@ -670,3 +760,36 @@ def visit_TableUnnest(
join_type="CROSS" if not keep_empty else "LEFT",
)
)

def _unnest(self, expression, *, as_):
alias = sge.TableAlias(columns=[sg.to_identifier(as_)])
return sge.Unnest(expressions=[expression], alias=alias)

def _array_reduction(self, *, arg, reduction):
name = sg.to_identifier(gen_name(f"pg_arr_{reduction}"))
return (
sg.select(self.f[reduction](name))
.from_(self._unnest(arg, as_=name))
.subquery()
)

def visit_ArrayMin(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="min")

def visit_ArrayMax(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="max")

def visit_ArraySum(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="sum")

def visit_ArrayMean(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="avg")

def visit_ArrayAny(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="bool_or")

def visit_ArrayAll(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="bool_and")


compiler = PostgresCompiler()
89 changes: 80 additions & 9 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import calendar
import itertools
import operator
import re

import sqlglot as sg
Expand Down Expand Up @@ -68,6 +69,10 @@ class PySparkCompiler(SQLGlotCompiler):
ops.ArrayRemove: "array_remove",
ops.ArraySort: "array_sort",
ops.ArrayUnion: "array_union",
ops.ArrayMin: "array_min",
ops.ArrayMax: "array_max",
ops.ArrayAll: "array_min",
ops.ArrayAny: "array_max",
ops.EndsWith: "endswith",
ops.Hash: "hash",
ops.Log10: "log10",
Expand Down Expand Up @@ -235,15 +240,27 @@ def visit_FirstValue(self, op, *, arg):
def visit_LastValue(self, op, *, arg):
return sge.IgnoreNulls(this=self.f.last(arg))

def visit_First(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.IgnoreNulls(this=self.f.first(arg))
def visit_First(self, op, *, arg, where, order_by, include_null):
if where is not None and include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by pyspark"
)
out = self.agg.first(arg, where=where, order_by=order_by)
if not include_null:
out = sge.IgnoreNulls(this=out)
return out

def visit_Last(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.IgnoreNulls(this=self.f.last(arg))
def visit_Last(self, op, *, arg, where, order_by, include_null):
if where is not None and include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by pyspark"
)
out = self.agg.last(arg, where=where, order_by=order_by)
if not include_null:
out = sge.IgnoreNulls(this=out)
return out

def visit_Arbitrary(self, op, *, arg, where):
# For Spark>=3.4 we could use any_value here
Expand All @@ -254,7 +271,12 @@ def visit_Arbitrary(self, op, *, arg, where):
def visit_Median(self, op, *, arg, where):
return self.agg.percentile(arg, 0.5, where=where)

def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
if where is not None:
arg = self.if_(where, arg, NULL)
collected = self.f.collect_list(arg)
Expand Down Expand Up @@ -391,6 +413,13 @@ def visit_ArrayContains(self, op, *, arg, other):
def visit_ArrayStringJoin(self, op, *, arg, sep):
return self.f.concat_ws(sep, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the pyspark backend"
)
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_StringFind(self, op, *, arg, substr, start, end):
if end is not None:
raise com.UnsupportedOperationError(
Expand Down Expand Up @@ -589,3 +618,45 @@ def _format_window_interval(self, expression):
this = expression.this.this # avoid quoting the interval as a string literal

return f"{this}{unit}"

def _array_reduction(self, *, dtype, arg, output):
quoted = self.quoted
dot = lambda a, f: sge.Dot.build((a, sge.to_identifier(f, quoted=quoted)))
state_dtype = dt.Struct({"sum": dtype, "count": dt.int64})
initial_state = self.cast(
sge.Struct.from_arg_list([sge.convert(0), sge.convert(0)]), state_dtype
)

s = sg.to_identifier("s", quoted=quoted)
x = sg.to_identifier("x", quoted=quoted)

s_sum = dot(s, "sum")
s_count = dot(s, "count")

input_fn_body = self.cast(
sge.Struct.from_arg_list(
[
x + self.f.coalesce(s_sum, 0),
s_count + self.if_(x.is_(sg.not_(NULL)), 1, 0),
]
),
state_dtype,
)
input_fn = sge.Lambda(this=input_fn_body, expressions=[s, x])

output_fn_body = self.if_(s_count > 0, output(s_sum, s_count), NULL)
return self.f.aggregate(
arg,
initial_state,
input_fn,
sge.Lambda(this=output_fn_body, expressions=[s]),
)

def visit_ArraySum(self, op, *, arg):
return self._array_reduction(dtype=op.dtype, arg=arg, output=lambda sum, _: sum)

def visit_ArrayMean(self, op, *, arg):
return self._array_reduction(dtype=op.dtype, arg=arg, output=operator.truediv)


compiler = PySparkCompiler()
30 changes: 25 additions & 5 deletions ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,31 @@ class RisingWaveCompiler(PostgresCompiler):
),
)

SIMPLE_OPS = {
ops.First: "first_value",
ops.Last: "last_value",
}

def visit_DateNow(self, op):
return self.cast(sge.CurrentTimestamp(), dt.date)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the risingwave backend"
)
if not order_by:
raise com.UnsupportedOperationError(
"RisingWave requires an `order_by` be specified in `first`"
)
return self.agg.first_value(arg, where=where, order_by=order_by)

def visit_Last(self, op, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the risingwave backend"
)
if not order_by:
raise com.UnsupportedOperationError(
"RisingWave requires an `order_by` be specified in `last`"
)
return self.agg.last_value(arg, where=where, order_by=order_by)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
raise com.UnsupportedOperationError(
Expand Down Expand Up @@ -86,3 +103,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype):
elif dtype.is_json():
return sge.convert(str(value))
return None


compiler = RisingWaveCompiler()
188 changes: 175 additions & 13 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import inspect
import itertools
import platform
import sys
import textwrap
from functools import partial

import sqlglot as sg
Expand All @@ -10,7 +14,14 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, C, FuncGen, SQLGlotCompiler
from ibis.backends.sql.compilers.base import (
NULL,
STAR,
AggGen,
C,
FuncGen,
SQLGlotCompiler,
)
from ibis.backends.sql.datatypes import SnowflakeType
from ibis.backends.sql.dialects import Snowflake
from ibis.backends.sql.rewrites import (
Expand All @@ -29,9 +40,15 @@ class SnowflakeFuncGen(FuncGen):
class SnowflakeCompiler(SQLGlotCompiler):
__slots__ = ()

latest_udf_python_version = (3, 11)

dialect = Snowflake
type_mapper = SnowflakeType
no_limit_value = NULL
supports_qualify = True

agg = AggGen(supports_order_by=True)

rewrites = (
exclude_unsupported_window_frame_from_row_number,
exclude_unsupported_window_frame_from_ops,
Expand Down Expand Up @@ -85,6 +102,94 @@ def __init__(self):
super().__init__()
self.f = SnowflakeFuncGen()

_UDF_TEMPLATES = {
ops.udf.InputType.PYTHON: """\
{preamble}
HANDLER = '{func_name}'
AS $$
from __future__ import annotations
from typing import *
{source}
$$""",
ops.udf.InputType.PANDAS: """\
{preamble}
HANDLER = 'wrapper'
AS $$
from __future__ import annotations
from typing import *
import _snowflake
import pandas as pd
{source}
@_snowflake.vectorized(input=pd.DataFrame)
def wrapper(df):
return {func_name}(*(col for _, col in df.items()))
$$""",
}

_UDF_PREAMBLE_LINES = (
"CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})",
"RETURNS {return_type}",
"LANGUAGE PYTHON",
"IMMUTABLE",
"RUNTIME_VERSION = '{version}'",
"COMMENT = '{comment}'",
)

def _compile_udf(self, udf_node: ops.ScalarUDF):
import ibis

name = type(udf_node).__name__
signature = ", ".join(
f"{name} {self.type_mapper.to_string(arg.dtype)}"
for name, arg in zip(udf_node.argnames, udf_node.args)
)
return_type = SnowflakeType.to_string(udf_node.dtype)
lines, _ = inspect.getsourcelines(udf_node.__func__)
source = textwrap.dedent(
"".join(
itertools.dropwhile(
lambda line: not line.lstrip().startswith("def "), lines
)
)
).strip()

config = udf_node.__config__

preamble_lines = [*self._UDF_PREAMBLE_LINES]

if imports := config.get("imports"):
preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})")

packages = "({})".format(
", ".join(map(repr, ("pandas", *config.get("packages", ()))))
)
preamble_lines.append(f"PACKAGES = {packages}")

template = self._UDF_TEMPLATES[udf_node.__input_type__]
return template.format(
source=source,
name=name,
func_name=udf_node.__func_name__,
preamble="\n".join(preamble_lines).format(
name=name,
signature=signature,
return_type=return_type,
comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}",
version=".".join(
map(str, min(sys.version_info[:2], self.latest_udf_python_version))
),
),
)

_compile_pandas_udf = _compile_udf
_compile_python_udf = _compile_udf

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down Expand Up @@ -351,23 +456,53 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9}
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short])

def visit_First(self, op, *, arg, where):
return self.f.get(self.agg.array_agg(arg, where=where), 0)
def _array_collect(self, *, arg, where, order_by, include_null):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the snowflake backend"
)

def visit_Last(self, op, *, arg, where):
expr = self.agg.array_agg(arg, where=where)
return self.f.get(expr, self.f.array_size(expr) - 1)
if where is not None:
arg = self.if_(where, arg, NULL)

def visit_GroupConcat(self, op, *, arg, where, sep):
if where is None:
return self.f.listagg(arg, sep)
out = self.f.array_agg(arg)

return self.if_(
self.f.count_if(where) > 0,
self.f.listagg(self.if_(where, arg, NULL), sep),
NULL,
if order_by:
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))

return out

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
return self._array_collect(
arg=arg, where=where, order_by=order_by, include_null=include_null
)

def visit_First(self, op, *, arg, where, order_by, include_null):
out = self._array_collect(
arg=arg, where=where, order_by=order_by, include_null=include_null
)
return self.f.get(out, 0)

def visit_Last(self, op, *, arg, where, order_by, include_null):
out = self._array_collect(
arg=arg, where=where, order_by=order_by, include_null=include_null
)
return self.f.get(out, self.f.array_size(out) - 1)

def visit_GroupConcat(self, op, *, arg, where, sep, order_by):
if where is not None:
arg = self.if_(where, arg, NULL)

out = self.f.listagg(arg, sep)

if order_by:
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))

if where is None:
return out

return self.if_(self.f.count_if(where) > 0, out, NULL)

def visit_TimestampBucket(self, op, *, arg, interval, offset):
if offset is not None:
raise com.UnsupportedOperationError(
Expand Down Expand Up @@ -436,6 +571,12 @@ def visit_ExtractFragment(self, op, *, arg):
self.f.as_varchar(self.f.get(self.f.parse_url(arg, 1), "fragment")), ""
)

def visit_ExtractUserInfo(self, op, *, arg):
host = self.f.get(self.f.parse_url(arg), "host")
return self.if_(
host.like(sge.convert("%@%")), self.f.split_part(host, "@", 1), NULL
)

def visit_Unnest(self, op, *, arg):
sep = sge.convert(util.guid())
split = self.f.split(
Expand Down Expand Up @@ -727,3 +868,24 @@ def visit_TableUnnest(
.from_(parent)
.join(unnest, join_type="CROSS" if not keep_empty else "LEFT")
)

def visit_ArrayMin(self, op, *, arg):
return self.cast(self.f.array_min(self.f.array_compact(arg)), op.dtype)

def visit_ArrayMax(self, op, *, arg):
return self.cast(self.f.array_max(self.f.array_compact(arg)), op.dtype)

def visit_ArrayAny(self, op, *, arg):
return self.f.udf.array_any(arg)

def visit_ArrayAll(self, op, *, arg):
return self.f.udf.array_all(arg)

def visit_ArraySum(self, op, *, arg):
return self.cast(self.f.udf.array_sum(arg), op.dtype)

def visit_ArrayMean(self, op, *, arg):
return self.cast(self.f.udf.array_avg(arg), op.dtype)


compiler = SnowflakeCompiler()
Loading