Skip to content

Commit

Permalink
feat(duckdb): implement Table.sample as a TABLESAMPLE query
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Oct 17, 2023
1 parent e1870ea commit 3a80f3a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
44 changes: 21 additions & 23 deletions ibis/backends/base/sql/alchemy/query_builder.py
Expand Up @@ -6,6 +6,7 @@
import toolz
from sqlalchemy import sql

import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.translator import (
Expand Down Expand Up @@ -84,7 +85,7 @@ def _format_table(self, op):
ctx = self.context

orig_op = op
if isinstance(op, ops.SelfReference):
if isinstance(op, (ops.SelfReference, ops.Sample)):
op = op.table

alias = ctx.get_ref(orig_op)
Expand Down Expand Up @@ -128,28 +129,27 @@ def _format_table(self, op):
for name, value in zip(op.schema.names, op.values)
)
)
elif ctx.is_extracted(op):
if isinstance(orig_op, ops.SelfReference):
result = ctx.get_ref(op)
else:
result = alias
else:
# A subquery
if ctx.is_extracted(op):
# Was put elsewhere, e.g. WITH block, we just need to grab
# its alias
alias = ctx.get_ref(orig_op)

# hack
if isinstance(orig_op, ops.SelfReference):
table = ctx.get_ref(op)
self_ref = alias if hasattr(alias, "name") else table.alias(alias)
ctx.set_ref(orig_op, self_ref)
return self_ref
return alias

alias = ctx.get_ref(orig_op)
result = ctx.get_compiled_expr(orig_op)
result = ctx.get_compiled_expr(op)

result = alias if hasattr(alias, "name") else result.alias(alias)

if isinstance(orig_op, ops.Sample):
result = self._format_sample(orig_op, result)

ctx.set_ref(orig_op, result)
return result

def _format_sample(self, op, table):
# Should never be hit in practice, as Sample operations should be rewritten
# before this point for all backends without TABLESAMPLE support
raise com.UnsupportedOperationError("`Table.sample` is not supported")

def _format_in_memory_table(self, op, translator):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
if self.context.compiler.cheap_in_memory_tables:
Expand All @@ -168,7 +168,7 @@ def _format_in_memory_table(self, op, translator):
).limit(0)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=op.name).data(rows)
result = sa.values(*columns, name=op.name).data(rows).select().subquery()
else:
raw_rows = (
sa.select(
Expand Down Expand Up @@ -219,13 +219,11 @@ def _compile_subqueries(self):
self.context.set_ref(expr, result)

def _compile_table_set(self):
if self.table_set is not None:
helper = self.table_set_formatter_class(self, self.table_set)
result = helper.get_result()
return result
else:
if self.table_set is None:
return None

return self.table_set_formatter_class(self, self.table_set).get_result()

def _add_select(self, table_set):
if not self.select_set:
return table_set.element
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/base/sql/compiler/select_builder.py
Expand Up @@ -147,6 +147,11 @@ def _collect_Limit(self, op, toplevel=False):
assert self.limit is None
self.limit = _LimitSpec(op.n, op.offset)

def _collect_Sample(self, op, toplevel=False):
if toplevel:
self.table_set = op
self.select_set = [op]

def _collect_Union(self, op, toplevel=False):
if toplevel:
self.table_set = op
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/duckdb/compiler.py
Expand Up @@ -6,6 +6,7 @@
import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter
from ibis.backends.duckdb.datatypes import DuckDBType
from ibis.backends.duckdb.registry import operation_registry

Expand Down Expand Up @@ -60,6 +61,19 @@ def _no_op(expr):
return expr


class DuckDBTableSetFormatter(_AlchemyTableSetFormatter):
def _format_sample(self, op, table):
if op.method == "row":
method = sa.func.bernoulli
else:
method = sa.func.system
return table.tablesample(
sampling=method(sa.literal_column(f"{op.fraction * 100} PERCENT")),
seed=(None if op.seed is None else sa.literal_column(str(op.seed))),
)


class DuckDBSQLCompiler(AlchemyCompiler):
cheap_in_memory_tables = True
translator_class = DuckDBSQLExprTranslator
table_set_formatter_class = DuckDBTableSetFormatter
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_generic.py
Expand Up @@ -1535,7 +1535,6 @@ def test_dynamic_table_slice_with_computed_offset(backend):
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"pandas",
Expand Down Expand Up @@ -1566,7 +1565,6 @@ def test_sample(backend):
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"pandas",
Expand All @@ -1590,7 +1588,6 @@ def test_sample_memtable(con, backend):
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"mssql",
Expand Down

0 comments on commit 3a80f3a

Please sign in to comment.