Skip to content

Commit

Permalink
add pyspark compile rule for greatest, fix bug with selection (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjoo authored and icexelloss committed Aug 22, 2019
1 parent fa4ad23 commit c4a2b79
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
36 changes: 31 additions & 5 deletions ibis/pyspark/compiler.py
Expand Up @@ -2,6 +2,7 @@
import ibis.sql.compiler as comp
import ibis.expr.window as window
import ibis.expr.operations as ops
import ibis.expr.types as types


from ibis.pyspark.operations import PysparkTable
Expand Down Expand Up @@ -54,11 +55,16 @@ def compile_datasource(t, expr):
@compiles(ops.Selection)
def compile_selection(t, expr):
op = expr.op()
src_table = t.translate(op.selections[0])
for selection in op.selections[1:]:
column_name = selection.get_name()
column = t.translate(selection)
src_table = src_table.withColumn(column_name, column)

if isinstance(op.selections[0], types.ColumnExpr):
column_names = [expr.op().name for expr in op.selections]
src_table = t.translate(op.table)[column_names]
elif isinstance(op.selections[0], types.TableExpr):
src_table = t.translate(op.table)
for selection in op.selections[1:]:
column_name = selection.get_name()
column = t.translate(selection)
src_table = src_table.withColumn(column_name, column)

return src_table

Expand Down Expand Up @@ -107,18 +113,38 @@ def max(v):
src_column = t.translate(op.arg)
return max(src_column)


@compiles(ops.Mean)
def compile_mean(t, expr):
op = expr.op()
src_column = t.translate(op.arg)

return F.mean(src_column)


@compiles(ops.WindowOp)
def compile_window_op(t, expr):
op = expr.op()
return t.translate(op.expr).over(compile_window(op.window))


@compiles(ops.Greatest)
def compile_greatest(t, expr):
op = expr.op()

src_columns = t.translate(op.arg)
if len(src_columns) == 1:
return src_columns[0]
else:
return F.greatest(*src_columns)


@compiles(ops.ValueList)
def compile_value_list(t, expr):
op = expr.op()
return [t.translate(col) for col in op.values]


# Cannot register with @compiles because window doesn't have an
# op() object
def compile_window(expr):
Expand Down
21 changes: 21 additions & 0 deletions ibis/pyspark/tests/test_basic.py
Expand Up @@ -78,3 +78,24 @@ def test_window(client):

tm.assert_frame_equal(result.toPandas(), expected.toPandas())
tm.assert_frame_equal(result2.toPandas(), expected.toPandas())


def test_greatest(client):
table = client.table('table1')
result = table.mutate(greatest = ibis.greatest(table.id)).compile()
df = table.compile()
expected = table.compile().withColumn('greatest', df.id)

tm.assert_frame_equal(result.toPandas(), expected.toPandas())


def test_selection(client):
table = client.table('table1')
table = table.mutate(id2=table['id'])

result1 = table[['id']].compile()
result2 = table[['id', 'id2']].compile()

df = table.compile()
tm.assert_frame_equal(result1.toPandas(), df[['id']].toPandas())
tm.assert_frame_equal(result2.toPandas(), df[['id', 'id2']].toPandas())

0 comments on commit c4a2b79

Please sign in to comment.