Skip to content

Commit

Permalink
BUG: fix isin and selection on selectables
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Jevnik committed Jun 13, 2016
1 parent 4363dc6 commit ed351f7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
19 changes: 19 additions & 0 deletions blaze/compute/sql.py
Expand Up @@ -459,6 +459,16 @@ def compute_up(expr, tbl, predicate, scope=None, **kwargs):
return select([tbl]).where(predicate)


@dispatch(Selection, Selectable, Selectable)
def compute_up(expr, tbl, predicate, **kwargs):
col, = inner_columns(predicate)
return reconstruct_select(
inner_columns(tbl),
tbl,
whereclause=unify_wheres((tbl, predicate)),
).where(col)


def select(s):
""" Permissive SQL select
Expand Down Expand Up @@ -1408,6 +1418,15 @@ def compute_up(expr, data, **kwargs):
return data.in_(expr._keys)


@dispatch(IsIn, Selectable)
def compute_up(expr, data, **kwargs):
assert len(data.columns) == 1, (
'only 1 column is allowed in a Select in IsIn'
)
col, = inner_columns(data)
return reconstruct_select((col.in_(expr._keys),), data)


@dispatch(Slice, (Select, Selectable, ColumnElement))
def compute_up(expr, data, **kwargs):
index = expr.index[0] # [0] replace_slices returns tuple ((start, stop), )
Expand Down
20 changes: 20 additions & 0 deletions blaze/compute/tests/test_postgresql_compute.py
Expand Up @@ -810,3 +810,23 @@ def test_all(sql):
assert compute(~(s.B == 1).all(), {s: sql}, return_type='core')
assert compute(~(s.B == 2).all(), {s: sql}, return_type='core')
assert compute(~(s.B == 3).all(), {s: sql}, return_type='core')


def test_isin_selectable(sql):
s = symbol('s', discover(sql))

# wrap the resource in a select
assert compute(s.B.isin({1, 3}),
sa.select(sql._resources()[sql].columns),
return_type=list) == [(True,), (False,)]


def test_selection_selectable(sql):
s = symbol('s', discover(sql))

# wrap the resource in a select
assert (compute(s[s.B.isin({1, 3})],
sa.select(sql._resources()[sql].columns),
return_type=pd.DataFrame) ==
pd.DataFrame([['a', 1]],
columns=s.dshape.measure.names)).all().all()
2 changes: 2 additions & 0 deletions docs/source/whatsnew/0.10.2.txt
Expand Up @@ -49,6 +49,8 @@ Bug Fixes
(:issue:`1517`).
* Fixes issue with string and datetime coercions on Pandas objects
(:issue:`1519` :issue:`1524`).
* Fixed a bug with ``isin`` and ``Selection``\s on sql selectables
(:issue:`1528`).

Miscellaneous
~~~~~~~~~~~~~
Expand Down

0 comments on commit ed351f7

Please sign in to comment.