Skip to content

Commit

Permalink
fix(api): allow boolean scalars in predicate APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 28, 2022
1 parent a90ce35 commit 2a2636b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
5 changes: 4 additions & 1 deletion ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ def __init__(self, table, selections, predicates, sort_keys, **kwargs):
)

for predicate in predicates:
if not shares_some_roots(predicate, table):
if isinstance(predicate, ops.Literal):
if not (dtype := predicate.output_dtype).is_boolean():
raise com.IbisTypeError(f"Invalid predicate dtype: {dtype}")
elif not shares_some_roots(predicate, table):
raise com.RelationError("Predicate doesn't share any roots with table")

super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __rich_console__(self, console, options):

def __getitem__(self, what):
from ibis.expr.types.generic import Column
from ibis.expr.types.logical import BooleanColumn
from ibis.expr.types.logical import BooleanValue

if isinstance(what, (str, int)):
return self.get_column(what)
Expand All @@ -133,7 +133,7 @@ def __getitem__(self, what):
if isinstance(what, (list, tuple, Table)):
# Projection case
return self.select(what)
elif isinstance(what, BooleanColumn):
elif isinstance(what, BooleanValue):
# Boolean predicate
return self.filter([what])
elif isinstance(what, Column):
Expand Down
19 changes: 19 additions & 0 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,3 +1636,22 @@ def test_array_string_compare():
t = ibis.table(schema=dict(by="string", words="array<string>"), name="t")
expr = t[t.by == "foo"].mutate(words=_.words.unnest()).filter(_.words == "the")
assert expr is not None


@pytest.mark.parametrize("value", [True, False])
@pytest.mark.parametrize(
"api",
[
param(lambda t, value: t[value], id="getitem"),
param(lambda t, value: t.filter(value), id="filter"),
],
)
def test_filter_with_literal(value, api):
t = ibis.table(dict(a="string"))
filt = api(t, ibis.literal(value))
assert filt is not None

# ints are invalid predicates
int_val = ibis.literal(int(value))
with pytest.raises((NotImplementedError, com.IbisTypeError)):
api(t, int_val)

0 comments on commit 2a2636b

Please sign in to comment.