Skip to content

Commit

Permalink
fix(snowflake): ensure that memtables are translated correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztof-kwitt authored and cpcloud committed Mar 14, 2023
1 parent 3f918de commit b361e07
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 29 deletions.
39 changes: 21 additions & 18 deletions ibis/backends/base/sql/alchemy/query_builder.py
Expand Up @@ -115,24 +115,7 @@ def _format_table(self, op):
backend = child_expr._find_backend()
backend._create_temp_view(view=result, definition=definition)
elif isinstance(ref_op, ops.InMemoryTable):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)

if self.context.compiler.cheap_in_memory_tables:
result = sa.table(ref_op.name, *columns)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
else:
raw_rows = (
sa.select(
*(
translator.translate(ops.Literal(val, dtype=type_))
for val, type_ in zip(row, op.schema.types)
)
)
for row in op.data.to_frame().itertuples(index=False)
)
result = sa.union_all(*raw_rows).alias(ref_op.name)
result = self._format_in_memory_table(op, ref_op, translator)
else:
# A subquery
if ctx.is_extracted(ref_op):
Expand All @@ -155,6 +138,26 @@ def _format_table(self, op):
ctx.set_ref(op, result)
return result

def _format_in_memory_table(self, op, ref_op, translator):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
if self.context.compiler.cheap_in_memory_tables:
result = sa.table(ref_op.name, *columns)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
else:
raw_rows = (
sa.select(
*(
translator.translate(ops.Literal(val, dtype=type_))
for val, type_ in zip(row, op.schema.types)
)
)
for row in op.data.to_frame().itertuples(index=False)
)
result = sa.union_all(*raw_rows).alias(ref_op.name)
return result


class AlchemySelect(Select):
def __init__(self, *args, **kwargs):
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/snowflake/__init__.py
Expand Up @@ -15,6 +15,7 @@
AlchemyExprTranslator,
BaseAlchemyBackend,
)
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter


@contextlib.contextmanager
Expand Down Expand Up @@ -74,9 +75,21 @@ class SnowflakeExprTranslator(AlchemyExprTranslator):
supports_unnest_in_select = False


class SnowflakeTableSetFormatter(_AlchemyTableSetFormatter):
def _format_in_memory_table(self, op, ref_op, translator):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
rows = list(ref_op.data.to_frame().itertuples(index=False))
pos_columns = [
sa.column(f"${idx}")
for idx, name in enumerate(ref_op.schema.names, start=1)
]
return sa.select(*pos_columns).select_from(sa.values(*columns).data(rows))


class SnowflakeCompiler(AlchemyCompiler):
cheap_in_memory_tables = _NATIVE_ARROW
translator_class = SnowflakeExprTranslator
table_set_formatter_class = SnowflakeTableSetFormatter


class _SnowFlakeConverter(_BaseSnowflakeConverter):
Expand Down
11 changes: 0 additions & 11 deletions ibis/backends/tests/test_join.py
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pandas as pd
import pytest
import sqlalchemy as sa
from packaging.version import parse as vparse
from pytest import param

Expand Down Expand Up @@ -206,11 +205,6 @@ def test_semi_join_topk(batting, awards_players):
raises=exc.IbisTypeError,
reason="DuckDB as of 0.7.1 occasionally segfaults when there are `null`-typed columns present",
)
@pytest.mark.broken(
["snowflake"],
reason="Case isn't preserved for memtable yet",
raises=sa.exc.ProgrammingError,
)
def test_join_with_pandas(batting, awards_players):
batting_filt = batting[lambda t: t.yearID < 1900]
awards_players_filt = awards_players[lambda t: t.yearID < 1900].execute()
Expand All @@ -221,11 +215,6 @@ def test_join_with_pandas(batting, awards_players):


@pytest.mark.notimpl(["dask", "datafusion", "pandas"])
@pytest.mark.broken(
["snowflake"],
reason="Case isn't preserved for memtable yet",
raises=sa.exc.ProgrammingError,
)
def test_join_with_pandas_non_null_typed_columns(batting, awards_players):
batting_filt = batting[lambda t: t.yearID < 1900][["yearID"]]
awards_players_filt = awards_players[lambda t: t.yearID < 1900][
Expand Down

0 comments on commit b361e07

Please sign in to comment.