Skip to content

Commit

Permalink
feat(backends): add more array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Apr 11, 2023
1 parent 1b93011 commit 5208801
Show file tree
Hide file tree
Showing 10 changed files with 652 additions and 58 deletions.
59 changes: 30 additions & 29 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ibis.backends.base import BaseBackend
from ibis.backends.clickhouse.compiler import translate
from ibis.backends.clickhouse.datatypes import parse, serialize
from ibis.config import options

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -260,10 +259,22 @@ def _normalize_external_tables(self, external_tables=None):
raise TypeError(f'Schema is empty for external table {name}')

df = obj.data.to_frame()
structure = list(zip(schema.names, map(serialize, schema.types)))
external_tables_list.append(
dict(name=name, data=df.to_dict("records"), structure=structure)
structure = list(
zip(
schema.names,
map(
serialize,
(
# unwrap nested structures because clickhouse does
# not accept nullable arrays, maps or structs
typ.copy(nullable=not typ.is_nested())
for typ in schema.types
),
),
)
)
data = dict(name=name, data=df.to_dict("records"), structure=structure)
external_tables_list.append(data)
return external_tables_list

def _client_execute(self, query, external_tables=None):
Expand Down Expand Up @@ -446,16 +457,12 @@ def close(self):
self._client.disconnect()
self.con.close()

def _fully_qualified_name(self, name, database):
def _fully_qualified_name(self, name: str, database: str | None) -> str:
return sg.table(name, db=database or self.current_database or None).sql(
dialect="clickhouse"
)

def get_schema(
self,
table_name: str,
database: str | None = None,
) -> sch.Schema:
def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema:
"""Return a Schema object for the indicated table and database.
Parameters
Expand All @@ -476,22 +483,16 @@ def get_schema(
f"DESCRIBE {qualified_name}"
)

return sch.Schema.from_tuples(zip(column_names, map(parse, types)))

def _ensure_temp_db_exists(self):
name = (options.clickhouse.temp_db,)
if name not in self.list_databases():
self.create_database(name, force=True)
return sch.Schema(dict(zip(column_names, map(parse, types))))

def _get_schema_using_query(self, query: str) -> sch.Schema:
[(raw_plans,)] = self._client.execute(
f"EXPLAIN json = 1, description = 0, header = 1 {query}"
)
[plan] = json.loads(raw_plans)
fields = [
(field["Name"], parse(field["Type"])) for field in plan["Plan"]["Header"]
]
return sch.Schema.from_tuples(fields)
return sch.Schema(
{field["Name"]: parse(field["Type"]) for field in plan["Plan"]["Header"]}
)

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
Expand All @@ -502,12 +503,12 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
def create_database(
self, name: str, *, force: bool = False, engine: str = "Atomic"
) -> None:
self.raw_sql(
f"CREATE DATABASE {'IF NOT EXISTS ' * force}{name} ENGINE = {engine}"
)
if_not_exists = "IF NOT EXISTS " * force
self.raw_sql(f"CREATE DATABASE {if_not_exists}{name} ENGINE = {engine}")

def drop_database(self, name: str, *, force: bool = False) -> None:
self.raw_sql(f"DROP DATABASE {'IF EXISTS ' * force}{name}")
if_exists = "IF EXISTS " * force
self.raw_sql(f"DROP DATABASE {if_exists}{name}")

def truncate_table(self, name: str, database: str | None = None) -> None:
ident = self._fully_qualified_name(name, database)
Expand Down Expand Up @@ -621,16 +622,16 @@ def create_view(
database: str | None = None,
overwrite: bool = False,
) -> ir.Table:
name = ".".join(filter(None, (database, name)))
qualname = self._fully_qualified_name(name, database)
replace = "OR REPLACE " * overwrite
query = self.compile(obj)
code = f"CREATE {replace}VIEW {name} AS {query}"
code = f"CREATE {replace}VIEW {qualname} AS {query}"
self.raw_sql(code)
return self.table(name, database=database)

def drop_view(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
name = ".".join(filter(None, (database, name)))
if_not_exists = "IF EXISTS " * force
self.raw_sql(f"DROP VIEW {if_not_exists}{name}")
name = self._fully_qualified_name(name, database)
if_exists = "IF EXISTS " * force
self.raw_sql(f"DROP VIEW {if_exists}{name}")
39 changes: 24 additions & 15 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Literal, Mapping

import sqlglot as sg
from toolz import flip

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -1076,6 +1077,12 @@ def formatter(op, **kw):
ops.BitwiseLeftShift: "bitShiftLeft",
ops.BitwiseRightShift: "bitShiftRight",
ops.BitwiseNot: "bitNot",
ops.ArrayDistinct: "arrayDistinct",
ops.ArraySort: "arraySort",
ops.ArrayContains: "has",
ops.FirstValue: "first_value",
ops.LastValue: "last_value",
ops.NTile: "ntile",
}


Expand Down Expand Up @@ -1268,11 +1275,6 @@ def formatter(op, **kw):
shift_like(ops.Lead, "leadInFrame")


@translate_val.register(ops.NTile)
def _ntile(op, **kw):
return f'ntile({translate_val(op.buckets, **kw)})'


@translate_val.register(ops.RowNumber)
def _row_number(_, **kw):
return "row_number()"
Expand All @@ -1288,16 +1290,6 @@ def _rank(_, **kw):
return "rank()"


@translate_val.register(ops.FirstValue)
def _first_value(op, **kw):
return f"first_value({translate_val(op.arg, **kw)})"


@translate_val.register(ops.LastValue)
def _last_value(op, **kw):
return f"last_value({translate_val(op.arg, **kw)})"


@translate_val.register(ops.ExtractProtocol)
def _extract_protocol(op, **kw):
arg = translate_val(op.arg, **kw)
Expand Down Expand Up @@ -1368,3 +1360,20 @@ def _array_filter(op, **kw):
arg = translate_val(op.arg, **kw)
result = translate_val(op.result, **kw)
return f"arrayFilter(({op.parameter}) -> {result}, {arg})"


@translate_val.register(ops.ArrayPosition)
def _array_position(op, **kw):
arg = translate_val(op.arg, **kw)
el = translate_val(op.other, **kw)
return f"indexOf({arg}, {el}) - 1"


@translate_val.register(ops.ArrayRemove)
def _array_remove(op, **kw):
return translate_val(ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)), **kw)


@translate_val.register(ops.ArrayUnion)
def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw)
17 changes: 15 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction
from toolz.curried import flip

import ibis.expr.operations as ops
from ibis.backends.base.sql import alchemy
Expand Down Expand Up @@ -311,6 +312,20 @@ def _map_merge(t, op):
ops.ArrayIndex: _array_index(
index_converter=_neg_idx_to_pos, func=sa.func.list_extract
),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.ArrayContains: fixed_arity(sa.func.list_has, 2),
ops.ArrayPosition: fixed_arity(
lambda lst, el: sa.func.list_indexof(lst, el) - 1, 2
),
ops.ArrayDistinct: fixed_arity(sa.func.list_distinct, 1),
ops.ArraySort: fixed_arity(sa.func.list_sort, 1),
ops.ArrayRemove: lambda t, op: _array_filter(
t, ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other))
),
ops.ArrayUnion: fixed_arity(
lambda left, right: sa.func.list_distinct(sa.func.list_cat(left, right)), 2
),
ops.DayOfWeekName: unary(sa.func.dayname),
ops.Literal: _literal,
ops.Log2: unary(sa.func.log2),
Expand Down Expand Up @@ -370,8 +385,6 @@ def _map_merge(t, op):
ops.SimpleCase: _simple_case,
ops.StartsWith: fixed_arity(sa.func.prefix, 2),
ops.EndsWith: fixed_arity(sa.func.suffix, 2),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.Argument: lambda _, op: sa.literal_column(op.name),
ops.Unnest: unary(sa.func.unnest),
ops.MapGet: fixed_arity(
Expand Down
42 changes: 42 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,3 +1999,45 @@ def compile_array_string_join(t, op, **kwargs):
arg = t.translate(op.arg, **kwargs)
sep = t.translate(op.sep, raw=True, **kwargs)
return F.concat_ws(sep, arg)


@compiles(ops.ArrayContains)
def compile_array_contains(t, op, **kwargs):
arg = t.translate(op.arg, **kwargs)
other = t.translate(op.other, **kwargs)
return F.when(
~F.isnull(arg), F.coalesce(F.array_contains(arg, other), F.lit(False))
).otherwise(F.lit(None))


@compiles(ops.ArrayPosition)
def compile_array_pos(t, op, **kwargs):
arg = t.translate(op.arg, **kwargs)
other = t.translate(op.other, raw=True, **kwargs)
return F.array_position(arg, other) - 1


@compiles(ops.ArrayDistinct)
def compile_array_distinct(t, op, **kwargs):
arg = t.translate(op.arg, **kwargs)
return F.array_distinct(arg)


@compiles(ops.ArraySort)
def compile_array_sort(t, op, **kwargs):
arg = t.translate(op.arg, **kwargs)
return F.array_sort(arg)


@compiles(ops.ArrayRemove)
def compile_array_remove(t, op, **kwargs):
arg = t.translate(op.arg, **kwargs)
other = t.translate(op.other, raw=True, **kwargs)
return F.array_remove(arg, other)


@compiles(ops.ArrayUnion)
def compile_array_union(t, op, **kwargs):
left = t.translate(op.left, **kwargs)
right = t.translate(op.right, **kwargs)
return F.array_union(left, right)
16 changes: 15 additions & 1 deletion ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def _group_concat(t, op):
"",
)
),
# snowflake typeof only accepts VARIANT
ops.ArrayIndex: fixed_arity(sa.func.get, 2),
ops.ArrayLength: fixed_arity(sa.func.array_size, 1),
ops.ArrayConcat: fixed_arity(sa.func.array_cat, 2),
Expand All @@ -333,7 +332,21 @@ def _group_concat(t, op):
sa.func.ifnull(arg, sa.func.parse_json("null")), type_=ARRAY
)
),
ops.ArrayContains: fixed_arity(sa.func.array_contains, 2),
ops.ArrayPosition: fixed_arity(
lambda lst, el: sa.func.array_position(lst, el) - 1, 2
),
ops.ArrayDistinct: fixed_arity(sa.func.array_distinct, 1),
ops.ArrayRemove: fixed_arity(
lambda lst, el: sa.func.array_except(lst, sa.func.array_construct(el)),
2,
),
ops.ArrayUnion: fixed_arity(
lambda left, right: sa.func.array_distinct(sa.func.array_cat(left, right)),
2,
),
ops.StringSplit: fixed_arity(sa.func.split, 2),
# snowflake typeof only accepts VARIANT, so we cast
ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))),
ops.All: reduction(sa.func.booland_agg),
ops.NotAll: reduction(lambda arg: ~sa.func.booland_agg(arg)),
Expand Down Expand Up @@ -384,6 +397,7 @@ def _group_concat(t, op):
ops.NTile,
# ibis.expr.operations.array
ops.ArrayRepeat,
ops.ArraySort,
# ibis.expr.operations.reductions
ops.MultiQuantile,
# ibis.expr.operations.strings
Expand Down
Loading

0 comments on commit 5208801

Please sign in to comment.