diff --git a/pypika/terms.py b/pypika/terms.py index e3505b5c..719421b6 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -1050,7 +1050,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: return arithmetic_sql -class Case(Term): +class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self._cases = [] diff --git a/pypika/tests/test_selects.py b/pypika/tests/test_selects.py index 5d9916e6..d07d5ee8 100644 --- a/pypika/tests/test_selects.py +++ b/pypika/tests/test_selects.py @@ -473,6 +473,34 @@ def test_select_with_force_index_and_where(self): self.assertEqual('SELECT "foo" FROM "abc" FORCE INDEX ("egg") WHERE "foo"="bar"', str(q)) + def test_where_with_multiple_wheres_using_and_case(self): + case_stmt = Case().when(self.t.foo == 'bar', 1).else_(0) + query = ( + Query.from_(self.t) + .select(case_stmt) + .where(case_stmt & self.t.blah.isin(['test'])) + ) + + self.assertEqual( + 'SELECT CASE WHEN "foo"=\'bar\' THEN 1 ELSE 0 END FROM "abc" WHERE CASE WHEN "foo"=\'bar\' THEN 1 ELSE 0 ' + 'END AND "blah" IN (\'test\')', + str(query) + ) + + def test_where_with_multiple_wheres_using_or_case(self): + case_stmt = Case().when(self.t.foo == 'bar', 1).else_(0) + query = ( + Query.from_(self.t) + .select(case_stmt) + .where(case_stmt | self.t.blah.isin(['test'])) + ) + + self.assertEqual( + 'SELECT CASE WHEN "foo"=\'bar\' THEN 1 ELSE 0 END FROM "abc" WHERE CASE WHEN "foo"=\'bar\' THEN 1 ELSE 0 ' + 'END OR "blah" IN (\'test\')', + str(query) + ) + class PreWhereTests(WhereTests): t = Table("abc")