Skip to content

Commit

Permalink
feat: variadic union, intersect, & difference functions
Browse files Browse the repository at this point in the history
Also makes the corresponding methods variadic.
  • Loading branch information
jcrist authored and cpcloud committed Aug 24, 2022
1 parent b1f30ba commit 05aca5a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 33 deletions.
3 changes: 3 additions & 0 deletions docs/api/expressions/top_level.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ These methods and objects are available directly in the `ibis` module.
::: ibis.cumulative_window
::: ibis.date
::: ibis.desc
::: ibis.difference
::: ibis.greatest
::: ibis.ifelse
::: ibis.intersect
::: ibis.interval
::: ibis.least
::: ibis.literal
Expand All @@ -34,5 +36,6 @@ These methods and objects are available directly in the `ibis` module.
::: ibis.timestamp
::: ibis.trailing_range_window
::: ibis.trailing_window
::: ibis.union
::: ibis.where
::: ibis.window
17 changes: 4 additions & 13 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from pytest import param

import ibis
from ibis import _


Expand All @@ -26,9 +27,7 @@ def union_subsets(alltypes, df):
def test_union(backend, union_subsets, distinct):
(a, b, c), (da, db, dc) = union_subsets

expr = (
a.union(b, distinct=distinct).union(c, distinct=distinct).sort_by("id")
)
expr = ibis.union(a, b, c, distinct=distinct).sort_by("id")
result = expr.execute()

expected = (
Expand Down Expand Up @@ -82,11 +81,7 @@ def test_intersect(backend, alltypes, df, distinct):
db = df[(5205 <= df.id) & (df.id <= 5215)]
dc = df[(5195 <= df.id) & (df.id <= 5208)]

expr = (
a.intersect(b, distinct=distinct)
.intersect(c, distinct=distinct)
.sort_by("id")
)
expr = ibis.intersect(a, b, c, distinct=distinct).sort_by("id")
result = expr.execute()

index = da.index.intersection(db.index).intersection(dc.index)
Expand Down Expand Up @@ -124,11 +119,7 @@ def test_difference(backend, alltypes, df, distinct):
db = df[(5205 <= df.id) & (df.id <= 5215)]
dc = df[(5195 <= df.id) & (df.id <= 5202)]

expr = (
a.difference(b, distinct=distinct)
.difference(c, distinct=distinct)
.sort_by("id")
)
expr = ibis.difference(a, b, c, distinct=distinct).sort_by("id")
result = expr.execute()

index = da.index.difference(db.index).difference(dc.index)
Expand Down
7 changes: 7 additions & 0 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
'cumulative_window',
'date',
'desc',
'difference',
'asc',
'e',
'Expr',
Expand Down Expand Up @@ -181,6 +182,7 @@
'ifelse',
'infer_dtype',
'infer_schema',
'intersect',
'interval',
'join',
'least',
Expand All @@ -207,6 +209,7 @@
'timestamp',
'trailing_range_window',
'trailing_window',
'union',
'where',
'window',
'_',
Expand Down Expand Up @@ -803,4 +806,8 @@ def category_label(
join = ir.Table.join
asof_join = ir.Table.asof_join

union = ir.Table.union
intersect = ir.Table.intersect
difference = ir.Table.difference

_ = Deferred()
49 changes: 29 additions & 20 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,26 +305,29 @@ def view(self) -> Table:

return ops.SelfReference(self).to_expr()

def difference(self, right: Table, distinct: bool = True) -> Table:
"""Compute the set difference of two table expressions.
def difference(self, *tables: Table, distinct: bool = True) -> Table:
"""Compute the set difference of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
right
Table expression
*tables
One or more table expressions
distinct
Only diff distinct rows not occurring in the calling table
Returns
-------
Table
The rows present in `left` that are not present in `right`.
The rows present in `self` that are not present in `tables`.
"""
from ibis.expr import operations as ops

return ops.Difference(self, right, distinct=distinct).to_expr()
left = self
for right in tables:
left = ops.Difference(left, right, distinct=distinct).to_expr()
return left

def aggregate(
self,
Expand Down Expand Up @@ -465,49 +468,55 @@ def sort_by(

def union(
self,
right: Table,
*tables: Table,
distinct: bool = False,
) -> Table:
"""Compute the set union of two table expressions.
"""Compute the set union of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
right
Table expression
*tables
One or more table expressions
distinct
Only union distinct rows not occurring in the calling table
Only return distinct rows
Returns
-------
Table
Union of table and `right`
A new table containing the union of all input tables.
"""
from ibis.expr import operations as ops

return ops.Union(self, right, distinct=distinct).to_expr()
left = self
for right in tables:
left = ops.Union(left, right, distinct=distinct).to_expr()
return left

def intersect(self, right: Table, distinct: bool = True) -> Table:
"""Compute the set intersection of two table expressions.
def intersect(self, *tables: Table, distinct: bool = True) -> Table:
"""Compute the set intersection of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
right
Table expression
*tables
One or more table expressions
distinct
Only intersect distinct rows not occurring in the calling table
Only return distinct rows
Returns
-------
Table
The rows common amongst `left` and `right`.
A new table containing the intersection of all input tables.
"""
from ibis.expr import operations as ops

return ops.Intersection(self, right, distinct=distinct).to_expr()
left = self
for right in tables:
left = ops.Intersection(left, right, distinct=distinct).to_expr()
return left

def to_array(self) -> ir.Column:
"""View a single column table as an array.
Expand Down

0 comments on commit 05aca5a

Please sign in to comment.