Skip to content

Commit

Permalink
fix(sql): case_when now passes pandas tests
Browse files Browse the repository at this point in the history
  • Loading branch information
machow committed Oct 5, 2022
1 parent b12fdc3 commit 05b08a2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
17 changes: 14 additions & 3 deletions siuba/sql/verbs/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from siuba.dply.verbs import case_when, if_else
from siuba.siu import Call

from ..backend import LazyTbl


@case_when.register(sql.base.ImmutableColumnCollection)
def _case_when(__data, cases):
# TODO: will need listener to enter case statements, to handle when they use windows
Expand All @@ -22,14 +25,22 @@ def _case_when(__data, cases):
val = val(__data)

# handle when expressions
if ii+1 == n_items and expr is True:
else_val = val
elif callable(expr):
#if ii+1 == n_items and expr is True:
# else_val = val
if callable(expr):
whens.append((expr(__data), val))
else:
whens.append((expr, val))

return sql.case(whens, else_ = else_val)


@case_when.register(LazyTbl)
def _case_when(__data, cases):
raise NotImplementedError(
"`case_when()` must be used inside a verb like `mutate()`, when using a "
"SQL backend."
)


# if_else ---------------------------------------------------------------------
Expand Down
21 changes: 16 additions & 5 deletions siuba/tests/test_verb_case_when.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import numpy as np
import pytest

from siuba.dply.verbs import case_when
from pandas.testing import assert_series_equal
from numpy.testing import assert_equal
from siuba.tests.helpers import assert_equal_query

from siuba.siu import _
from siuba.dply.verbs import case_when, mutate


DATA = pd.DataFrame({
'x': [0,1,2],
Expand All @@ -18,7 +21,7 @@ def data():
return DATA.copy()


@pytest.mark.parametrize("k,v, res", [
@pytest.mark.parametrize("k,v, dst", [
(True, 1, [1]*3),
(True, False, [False]*3),
(True, _.y, [10, 11, 12]),
Expand All @@ -29,10 +32,18 @@ def data():
(lambda _: _.x < 2, 0, [0, 0, None]),
#(np.array([True, True, False]), 0, [0, 0, None])
])
def test_case_when_single_cond(k, v, res, data):
out = case_when(data, {k: v})
def test_case_when_single_cond(backend, data, k, v, dst):
src = backend.load_df(data)
query = mutate(_, res = case_when(_, {k: v}))

assert_equal_query(src, query, data.assign(res = dst))


def test_case_when_multiple_clauses(backend, data):
src = backend.load_df(data)
query = mutate(_, res = case_when({_.x == 0: "zero", _.x > 1: "big", True: "small"}))

assert_series_equal(out, pd.Series(res))
assert_equal_query(src, query, data.assign(res = ["zero", "small", "big"]))


def test_case_when_cond_order(data):
Expand Down

0 comments on commit 05b08a2

Please sign in to comment.