Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean aggregation function #2916

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions db/functions/base.py
Expand Up @@ -383,6 +383,18 @@ def to_sa_expression(column_expr):
return sa_call_sql_function('count', column_expr, return_type=PostgresType.INTEGER)


class Mean(DBFunction):
id = 'mean'
name = 'mean'
hints = tuple([
hints.aggregation,
])

@staticmethod
def to_sa_expression(column_expr):
return sa_call_sql_function('avg', column_expr, return_type=PostgresType.NUMERIC)


class ArrayAgg(DBFunction):
id = 'aggregate_to_array'
name = 'aggregate to array'
Expand Down
3 changes: 2 additions & 1 deletion mathesar/models/query.py
Expand Up @@ -6,7 +6,7 @@
from db.transforms.operations.deserialize import deserialize_transformation
from db.transforms.operations.serialize import serialize_transformation
from db.transforms.base import Summarize
from db.functions.base import Count, ArrayAgg
from db.functions.base import Count, ArrayAgg, Mean
from db.functions.packed import DistinctArrayAgg

from mathesar.api.exceptions.query_exceptions.exceptions import DeletedColumnAccess
Expand Down Expand Up @@ -419,6 +419,7 @@ def _get_default_display_name_for_agg_output_alias(
DistinctArrayAgg.id: " distinct list",
ArrayAgg.id: " list",
Count.id: " count",
Mean.id: " mean",
}
suffix_to_add = map_of_agg_function_to_suffix.get(agg_function)
if suffix_to_add:
Expand Down
63 changes: 63 additions & 0 deletions mathesar/tests/api/query/test_aggregation_functions.py
@@ -0,0 +1,63 @@
display_option_origin = "display_option_origin"


def test_mean_aggregation(library_ma_tables, get_uid, client):
_ = library_ma_tables
checkouts = {
t["name"]: t for t in client.get("/api/db/v0/tables/").json()["results"]
}["Checkouts"]
columns = {
c["name"]: c for c in checkouts["columns"]
}
request_data = {
"name": get_uid(),
"base_table": checkouts["id"],
"initial_columns": [
{"id": columns["Checkout Time"]["id"], "alias": "Checkout Time"},
{"id": columns["Patron"]["id"], "alias": "Patron"},
],
"display_names": {
"Checkout Month": "Month",
"Mean": "Mean of patron",
},
"display_options": {
"Checkout Time": {
display_option_origin: "Checkout Time",
},
"Patron": {
display_option_origin: "Patron",
},
},
"transformations": [
{
"spec": {
"grouping_expressions": [
{
"input_alias": "Checkout Time",
"output_alias": "Checkout Month",
"preproc": "truncate_to_month",
}
],
"aggregation_expressions": [
{
"input_alias": "Patron",
"output_alias": "Mean",
"function": "mean",
}
]
},
"type": "summarize",
}
]
}
response = client.post('/api/db/v0/queries/', data=request_data)
assert response.status_code == 201
query_id = response.json()['id']
expect_records = [
{'Checkout Month': '2022-05', 'Mean': 16.641025641025642},
{'Checkout Month': '2022-06', 'Mean': 11.461538461538462},
{'Checkout Month': '2022-07', 'Mean': 18.06896551724138},
{'Checkout Month': '2022-08', 'Mean': 12.6},
]
actual_records = client.get(f'/api/db/v0/queries/{query_id}/records/').json()['results']
assert sorted(actual_records, key=lambda x: x['Checkout Month']) == expect_records