Skip to content

Commit

Permalink
add pipe operator on QueryBuilder (#759)
Browse files Browse the repository at this point in the history
* some draft work

* add docstring

* add test

* add to docstring

* add explicit string

* add pipe section
  • Loading branch information
wd60622 committed Dec 10, 2023
1 parent 627b60a commit 8841520
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
52 changes: 52 additions & 0 deletions README.rst
Expand Up @@ -1368,6 +1368,58 @@ This produces:
DROP INDEX IF EXISTS my_index
Chaining Functions
^^^^^^^^^^^^^^^^^^

The ``QueryBuilder.pipe`` method gives a more readable alternative while chaining functions.

.. code-block:: python
# This
(
query
.pipe(func1, *args)
.pipe(func2, **kwargs)
.pipe(func3)
)
# Is equivalent to this
func3(func2(func1(query, *args), **kwargs))
Or for a more concrete example:

.. code-block:: python
from pypika import Field, Query, functions as fn
from pypika.queries import QueryBuilder
def filter_days(query: QueryBuilder, col, num_days: int) -> QueryBuilder:
if isinstance(col, str):
col = Field(col)
return query.where(col > fn.Now() - num_days)
def count_groups(query: QueryBuilder, *groups) -> QueryBuilder:
return query.groupby(*groups).select(*groups, fn.Count("*").as_("n_rows"))
base_query = Query.from_("table")
query = (
base_query
.pipe(filter_days, "date", num_days=7)
.pipe(count_groups, "col1", "col2")
)
This produces:

.. code-block:: sql
SELECT "col1","col2",COUNT(*) n_rows
FROM "table"
WHERE "date">NOW()-7
GROUP BY "col1","col2"
.. _tutorial_end:

.. _contributing_start:
Expand Down
48 changes: 48 additions & 0 deletions pypika/queries.py
Expand Up @@ -1560,6 +1560,54 @@ def _set_sql(self, **kwargs: Any) -> str:
)
)

def pipe(self, func, *args, **kwargs):
"""Call a function on the current object and return the result.
Example usage:
.. code-block:: python
from pypika import Query, functions as fn
from pypika.queries import QueryBuilder
def rows_by_group(query: QueryBuilder, *groups) -> QueryBuilder:
return (
query
.select(*groups, fn.Count("*").as_("n_rows"))
.groupby(*groups)
)
base_query = Query.from_("table")
col1_agg = base_query.pipe(rows_by_group, "col1")
col2_agg = base_query.pipe(rows_by_group, "col2")
col1_col2_agg = base_query.pipe(rows_by_group, "col1", "col2")
Makes chaining functions together easier, especially when the functions are
defined elsewhere. For example, you could define a function that filters
rows by a date range and then group by a set of columns:
.. code-block:: python
from datetime import datetime, timedelta
from pypika import Field
def days_since(query: QueryBuilder, n_days: int) -> QueryBuilder:
return (
query
.where("date" > fn.Date(datetime.now().date() - timedelta(days=n_days)))
)
(
base_query
.pipe(days_since, n_days=7)
.pipe(rows_by_group, "col1", "col2")
)
"""
return func(self, *args, **kwargs)


class Joiner:
def __init__(
Expand Down
38 changes: 37 additions & 1 deletion pypika/tests/test_query.py
@@ -1,6 +1,6 @@
import unittest

from pypika import Case, Query, Tables, Tuple, functions
from pypika import Case, Query, Tables, Tuple, functions, Field
from pypika.dialects import (
ClickHouseQuery,
ClickHouseQueryBuilder,
Expand Down Expand Up @@ -204,3 +204,39 @@ def test_query_builders_have_reference_to_correct_query_class(self):

with self.subTest('OracleQueryBuilder'):
self.assertEqual(OracleQuery, OracleQueryBuilder.QUERY_CLS)

def test_pipe(self) -> None:
base_query = Query.from_("test")

def select(query: QueryBuilder) -> QueryBuilder:
return query.select("test1", "test2")

def count_group(query: QueryBuilder, *groups) -> QueryBuilder:
return query.groupby(*groups).select(*groups, functions.Count("*"))

for func, args, kwargs, expected_str in [
(select, [], {}, 'SELECT "test1","test2" FROM "test"'),
(
count_group,
["test1", "test2"],
{},
'SELECT "test1","test2",COUNT(*) FROM "test" GROUP BY "test1","test2"',
),
(count_group, ["test1"], {}, 'SELECT "test1",COUNT(*) FROM "test" GROUP BY "test1"'),
]:
result_str = str(base_query.pipe(func, *args, **kwargs))
self.assertEqual(result_str, str(func(base_query, *args, **kwargs)))
self.assertEqual(result_str, expected_str)

def where_clause(query: QueryBuilder, num_days: int) -> QueryBuilder:
return query.where(Field("date") > functions.Now() - num_days)

result_str = str(base_query.pipe(select).pipe(where_clause, num_days=1))
self.assertEqual(
result_str,
str(select(where_clause(base_query, num_days=1))),
)
self.assertEqual(
result_str,
'SELECT "test1","test2" FROM "test" WHERE "date">NOW()-1',
)

0 comments on commit 8841520

Please sign in to comment.