Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Aug 22, 2019
1 parent f173425 commit 8f1c35e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
7 changes: 6 additions & 1 deletion ibis/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def compile_selection(t, expr, scope, **kwargs):
op = expr.op()

# TODO: Support predicates and sort_keys
if len(op.predicates) > 0 or len(op.sort_keys) > 0:
if op.predicates or op.sort_keys:
raise NotImplementedError(
"predicates and sort_keys are not supported with Selection")

Expand Down Expand Up @@ -192,6 +192,11 @@ def compile_aggregator(t, expr, scope, fn, context=None, **kwargs):
if context:
return col
else:
# We are trying to compile a expr such as some_col.max()
# to a Spark expression.
# Here we get the root table df of that column and compile
# the expr to:
# df.select(max(some_col))
return t.translate(expr.op().arg.op().table, scope).select(col)


Expand Down
16 changes: 12 additions & 4 deletions ibis/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest

import ibis
import ibis.common.exceptions as comm

pytest.importorskip('pyspark')
pytestmark = pytest.mark.pyspark


@pytest.fixture(scope='session')
def client():
pytest.importorskip('pyspark')
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

Expand Down Expand Up @@ -91,7 +92,10 @@ def test_groupby(client):
tm.assert_frame_equal(result.toPandas(), expected.toPandas())


@pytest.mark.xfail
@pytest.mark.xfail(
reason='This is not implemented yet',
raises=comm.OperationNotDefinedError
)
def test_window(client):
import pyspark.sql.functions as F
from pyspark.sql.window import Window
Expand Down Expand Up @@ -156,13 +160,17 @@ def test_selection(client):
[['plus1', *df.columns]].toPandas())


@pytest.mark.xfail(
reason='Join is not fully implemented',
raises=AssertionError
)
def test_join(client):
table = client.table('table1')
result = table.join(table, 'id').compile()
result = table.join(table, ['id', 'str_col']).compile()
spark_table = table.compile()
expected = (
spark_table
.join(spark_table, spark_table['id'] == spark_table['id'])
.join(spark_table, ['id', 'str_col'])
)

tm.assert_frame_equal(result.toPandas(), expected.toPandas())

0 comments on commit 8f1c35e

Please sign in to comment.