Skip to content

Commit

Permalink
refactor(polars): update the polars backend to use the new relational…
Browse files Browse the repository at this point in the history
… abstractions (#7868)

Bare minimum changes required for the polars backend to work with the
new relational operations. Since polars' join API follows the same
semantics as pandas, I'm using the pandas specific rewrites here.

In the future we may want to rewrite the compiler similarly to the
pandas one: using `node.map()` and `Dispatched`.
  • Loading branch information
kszucs committed Feb 12, 2024
1 parent db45e41 commit 29b5b53
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 115 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ibis-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ jobs:
title: Datafusion
extras:
- datafusion
# - name: polars
# title: Polars
# extras:
# - polars
# - deltalake
- name: polars
title: Polars
extras:
- polars
- deltalake
# - name: mysql
# title: MySQL
# services:
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/pandas/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def split_join_predicates(left, right, predicates, only_equality=True):

@replace(ops.JoinChain)
def rewrite_join(_, **kwargs):
# TODO(kszucs): JoinTable.index can be used as a prefix
prefixes = {}
prefixes[_.first] = prefix = str(len(prefixes))
left = PandasRename.from_prefix(_.first, prefix)
Expand Down
26 changes: 14 additions & 12 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, Database
from ibis.backends.pandas.rewrites import (
bind_unbound_table,
replace_parameter,
rewrite_join,
)
from ibis.backends.polars.compiler import translate
from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars
from ibis.common.patterns import Replace
from ibis.util import gen_name, normalize_filename

if TYPE_CHECKING:
Expand Down Expand Up @@ -379,20 +383,18 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
def compile(
self, expr: ir.Expr, params: Mapping[ir.Expr, object] | None = None, **_: Any
):
node = expr.op()
ctx = self._context

if params:
if params is None:
params = dict()
else:
params = {param.op(): value for param, value in params.items()}
rule = Replace(
ops.ScalarParameter,
lambda _: ops.Literal(value=params[_], dtype=_.dtype),
)
node = node.replace(rule)
expr = node.to_expr()

node = expr.as_table().op()
return translate(node, ctx=ctx)
node = node.replace(
rewrite_join | replace_parameter | bind_unbound_table,
context={"params": params, "backend": self},
)

return translate(node, ctx=self._context)

def _get_schema_using_query(self, query: str) -> sch.Schema:
return schema_from_polars(self._context.execute(query).schema)
Expand Down
195 changes: 126 additions & 69 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.pandas.rewrites import PandasAsofJoin, PandasJoin, PandasRename
from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name


def _expr_method(expr, op, methods):
Expand Down Expand Up @@ -59,7 +61,7 @@ def table(op, **_):

@translate.register(ops.DummyTable)
def dummy_table(op, **kw):
selections = [translate(arg, **kw) for arg in op.values]
selections = [translate(arg, **kw) for name, arg in op.values.items()]
return pl.DataFrame().lazy().select(selections)


Expand Down Expand Up @@ -181,7 +183,7 @@ def _cast(op, strict=True, **kw):
return arg.cast(typ, strict=strict)


@translate.register(ops.TableColumn)
@translate.register(ops.Field)
def column(op, **_):
return pl.col(op.name)

Expand All @@ -196,45 +198,41 @@ def sort_key(op, **kw):
return arg.sort(reverse=descending) # pragma: no cover


@translate.register(ops.Selection)
def selection(op, **kw):
lf = translate(op.table, **kw)

if op.predicates:
predicates = map(partial(translate, **kw), op.predicates)
predicate = reduce(operator.and_, predicates)
lf = lf.filter(predicate)
@translate.register(ops.Project)
def project(op, **kw):
lf = translate(op.parent, **kw)

selections = []
unnests = []
for arg in op.selections:
if isinstance(arg, ops.TableNode):
for name in arg.schema.names:
column = ops.TableColumn(table=arg, name=name)
selections.append(translate(column, **kw))
elif (
isinstance(arg, ops.Alias) and isinstance(unnest := arg.arg, ops.Unnest)
) or isinstance(unnest := arg, ops.Unnest):
name = arg.name
for name, arg in op.values.items():
if isinstance(arg, ops.Unnest):
unnests.append(name)
selections.append(translate(unnest.arg, **kw).alias(name))
translated = translate(arg.arg, **kw)
elif isinstance(arg, ops.Value):
selections.append(translate(arg, **kw))
translated = translate(arg, **kw)
else:
raise com.TranslationError(
"Polars backend is unable to compile selection with "
f"operation type of {type(arg)}"
)
selections.append(translated.alias(name))

if selections:
lf = lf.select(selections)

if unnests:
lf = lf.explode(*unnests)

if op.sort_keys:
by = [key.name for key in op.sort_keys]
descending = [key.descending for key in op.sort_keys]
return lf


@translate.register(ops.Sort)
def sort(op, **kw):
lf = translate(op.parent, **kw)

if op.keys:
by = [key.name for key in op.keys]
descending = [key.descending for key in op.keys]
try:
lf = lf.sort(by, descending=descending)
except TypeError: # pragma: no cover
Expand All @@ -243,6 +241,18 @@ def selection(op, **kw):
return lf


@translate.register(ops.Filter)
def filter_(op, **kw):
lf = translate(op.parent, **kw)

if op.predicates:
predicates = map(partial(translate, **kw), op.predicates)
predicate = reduce(operator.and_, predicates)
lf = lf.filter(predicate)

return lf


@translate.register(ops.Limit)
def limit(op, **kw):
if (n := op.n) is not None and not isinstance(n, int):
Expand All @@ -251,75 +261,99 @@ def limit(op, **kw):
if not isinstance(offset := op.offset, int):
raise NotImplementedError("Dynamic offset not supported")

lf = translate(op.table, **kw)
lf = translate(op.parent, **kw)
return lf.slice(offset, n)


@translate.register(ops.Aggregation)
@translate.register(ops.Aggregate)
def aggregation(op, **kw):
lf = translate(op.table, **kw)
lf = translate(op.parent, **kw)

if op.predicates:
lf = lf.filter(
reduce(
operator.and_,
map(partial(translate, **kw), op.predicates),
if op.groups:
# project first to handle computed group by columns
lf = (
lf.with_columns(
[translate(arg, **kw).alias(name) for name, arg in op.groups.items()]
)
.group_by(list(op.groups.keys()))
.agg
)

# project first to handle computed group by columns
lf = lf.with_columns([translate(arg, **kw) for arg in op.by])

if op.by:
lf = lf.group_by([pl.col(by.name) for by in op.by]).agg
else:
lf = lf.select

if op.metrics:
metrics = [translate(arg, **kw).alias(arg.name) for arg in op.metrics]
metrics = [translate(arg, **kw).alias(name) for name, arg in op.metrics.items()]
lf = lf(metrics)

return lf


_join_types = {
ops.InnerJoin: "inner",
ops.LeftJoin: "left",
ops.RightJoin: "right",
ops.OuterJoin: "outer",
ops.LeftAntiJoin: "anti",
ops.LeftSemiJoin: "semi",
}
@translate.register(PandasRename)
def rename(op, **kw):
parent = translate(op.parent, **kw)
return parent.rename(op.mapping)


@translate.register(ops.Join)
@translate.register(PandasJoin)
def join(op, **kw):
how = op.how
left = translate(op.left, **kw)
right = translate(op.right, **kw)

if isinstance(op, ops.RightJoin):
# workaround required for https://github.com/pola-rs/polars/issues/13130
prefix = gen_name("on")
left_on = {f"{prefix}_{i}": translate(v, **kw) for i, v in enumerate(op.left_on)}
right_on = {f"{prefix}_{i}": translate(v, **kw) for i, v in enumerate(op.right_on)}
left = left.with_columns(**left_on)
right = right.with_columns(**right_on)
on = list(left_on.keys())

if how == "right":
how = "left"
left, right = right, left
else:
how = _join_types[type(op)]

left_on, right_on = [], []
for pred in op.predicates:
if isinstance(pred, ops.Equals):
left_on.append(translate(pred.left, **kw))
right_on.append(translate(pred.right, **kw))
else:
raise com.TranslationError(
"Polars backend is unable to compile join predicate "
f"with operation type of {type(pred)}"
)
joined = left.join(right, on=on, how=how)
joined = joined.drop(columns=on)

return joined

return left.join(right, left_on=left_on, right_on=right_on, how=how)

@translate.register(PandasAsofJoin)
def asof_join(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)

# workaround required for https://github.com/pola-rs/polars/issues/13130
on, by = gen_name("on"), gen_name("by")
left_on = {f"{on}_{i}": translate(v, **kw) for i, v in enumerate(op.left_on)}
right_on = {f"{on}_{i}": translate(v, **kw) for i, v in enumerate(op.right_on)}
left_by = {f"{by}_{i}": translate(v, **kw) for i, v in enumerate(op.left_by)}
right_by = {f"{by}_{i}": translate(v, **kw) for i, v in enumerate(op.right_by)}

left = left.with_columns(**left_on, **left_by)
right = right.with_columns(**right_on, **right_by)

on = list(left_on.keys())
by = list(left_by.keys())

if op.operator in {ops.Less, ops.LessEqual}:
direction = "forward"
elif op.operator in {ops.Greater, ops.GreaterEqual}:
direction = "backward"
elif op.operator == ops.Equals:
direction = "nearest"
else:
raise NotImplementedError(f"Operator {operator} not supported for asof join")

assert len(on) == 1
joined = left.join_asof(right, on=on[0], by=by, strategy=direction)
joined = joined.drop(columns=on + by)
return joined


@translate.register(ops.DropNa)
def dropna(op, **kw):
lf = translate(op.table, **kw)
lf = translate(op.parent, **kw)

if op.subset is None:
subset = None
Expand All @@ -337,10 +371,28 @@ def dropna(op, **kw):

@translate.register(ops.FillNa)
def fillna(op, **kw):
table = translate(op.table, **kw)
table = translate(op.parent, **kw)

columns = []
for name, dtype in op.table.schema.items():

repls = op.replacements

if isinstance(repls, Mapping):

def get_replacement(name):
repl = repls.get(name)
if repl is not None:
return repl.value
else:
return None

else:
value = repls.value

def get_replacement(_):
return value

for name, dtype in op.parent.schema.items():
column = pl.col(name)
if isinstance(op.replacements, Mapping):
value = op.replacements.get(name)
Expand Down Expand Up @@ -422,11 +474,11 @@ def greatest(op, **kw):
return pl.max_horizontal(arg)


@translate.register(ops.InColumn)
@translate.register(ops.InSubquery)
def in_column(op, **kw):
value = translate(op.value, **kw)
options = translate(op.options, **kw)
return value.is_in(options)
needle = translate(op.needle, **kw)
return needle.is_in(value)


@translate.register(ops.InValues)
Expand Down Expand Up @@ -734,7 +786,7 @@ def correlation(op, **kw):

@translate.register(ops.Distinct)
def distinct(op, **kw):
table = translate(op.table, **kw)
table = translate(op.parent, **kw)
return table.unique()


Expand Down Expand Up @@ -1166,6 +1218,11 @@ def execute_self_reference(op, **kw):
return translate(op.table, **kw)


@translate.register(ops.JoinTable)
def execute_join_table(op, **kw):
return translate(op.parent, **kw)


@translate.register(ops.CountDistinctStar)
def execute_count_distinct_star(op, **kw):
arg = pl.struct(*op.arg.schema.names)
Expand Down
Loading

0 comments on commit 29b5b53

Please sign in to comment.