diff --git a/google/cloud/datastore/aggregation.py b/google/cloud/datastore/aggregation.py index 421ffc93..0518514e 100644 --- a/google/cloud/datastore/aggregation.py +++ b/google/cloud/datastore/aggregation.py @@ -39,6 +39,9 @@ class BaseAggregation(ABC): Base class representing an Aggregation operation in Datastore """ + def __init__(self, alias=None): + self.alias = alias + @abc.abstractmethod def _to_pb(self): """ @@ -59,7 +62,7 @@ class CountAggregation(BaseAggregation): """ def __init__(self, alias=None): - self.alias = alias + super(CountAggregation, self).__init__(alias=alias) def _to_pb(self): """ @@ -71,6 +74,60 @@ def _to_pb(self): return aggregation_pb +class SumAggregation(BaseAggregation): + """ + Representation of a "Sum" aggregation query. + + :type property_ref: str + :param property_ref: The property_ref for the aggregation. + + :type value: int + :param value: The resulting value from the aggregation. + + """ + + def __init__(self, property_ref, alias=None): + self.property_ref = property_ref + super(SumAggregation, self).__init__(alias=alias) + + def _to_pb(self): + """ + Convert this instance to the protobuf representation + """ + aggregation_pb = query_pb2.AggregationQuery.Aggregation() + aggregation_pb.sum = query_pb2.AggregationQuery.Aggregation.Sum() + aggregation_pb.sum.property.name = self.property_ref + aggregation_pb.alias = self.alias + return aggregation_pb + + +class AvgAggregation(BaseAggregation): + """ + Representation of a "Avg" aggregation query. + + :type property_ref: str + :param property_ref: The property_ref for the aggregation. + + :type value: int + :param value: The resulting value from the aggregation. + + """ + + def __init__(self, property_ref, alias=None): + self.property_ref = property_ref + super(AvgAggregation, self).__init__(alias=alias) + + def _to_pb(self): + """ + Convert this instance to the protobuf representation + """ + aggregation_pb = query_pb2.AggregationQuery.Aggregation() + aggregation_pb.avg = query_pb2.AggregationQuery.Aggregation.Avg() + aggregation_pb.avg.property.name = self.property_ref + aggregation_pb.alias = self.alias + return aggregation_pb + + class AggregationResult(object): """ A class representing result from Aggregation Query @@ -154,6 +211,28 @@ def count(self, alias=None): self._aggregations.append(count_aggregation) return self + def sum(self, property_ref, alias=None): + """ + Adds a sum over the nested query + + :type property_ref: str + :param property_ref: The property_ref for the sum + """ + sum_aggregation = SumAggregation(property_ref=property_ref, alias=alias) + self._aggregations.append(sum_aggregation) + return self + + def avg(self, property_ref, alias=None): + """ + Adds a avg over the nested query + + :type property_ref: str + :param property_ref: The property_ref for the sum + """ + avg_aggregation = AvgAggregation(property_ref=property_ref, alias=alias) + self._aggregations.append(avg_aggregation) + return self + def add_aggregation(self, aggregation): """ Adds an aggregation operation to the nested query @@ -327,8 +406,7 @@ def _build_protobuf(self): """ pb = self._aggregation_query._to_pb() if self._limit is not None and self._limit > 0: - for aggregation in pb.aggregations: - aggregation.count.up_to = self._limit + pb.nested_query.limit = self._limit return pb def _process_query_results(self, response_pb): @@ -438,5 +516,8 @@ def _item_to_aggregation_result(iterator, pb): :rtype: :class:`google.cloud.datastore.aggregation.AggregationResult` :returns: The list of AggregationResults """ - results = [AggregationResult(alias=k, value=pb[k].integer_value) for k in pb.keys()] + results = [ + AggregationResult(alias=k, value=pb[k].integer_value or pb[k].double_value) + for k in pb.keys() + ] return results diff --git a/tests/system/index.yaml b/tests/system/index.yaml index f9cc2a5b..1f27c246 100644 --- a/tests/system/index.yaml +++ b/tests/system/index.yaml @@ -39,9 +39,9 @@ indexes: - name: family - name: appearances - - kind: Character ancestor: yes properties: - name: family - name: appearances + diff --git a/tests/system/test_aggregation_query.py b/tests/system/test_aggregation_query.py index 51045003..ae9a8297 100644 --- a/tests/system/test_aggregation_query.py +++ b/tests/system/test_aggregation_query.py @@ -70,54 +70,281 @@ def nested_query(aggregation_query_client, ancestor_key): @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) -def test_aggregation_query_default(aggregation_query_client, nested_query, database_id): +def test_count_query_default(aggregation_query_client, nested_query, database_id): query = nested_query aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count() result = _do_fetch(aggregation_query) assert len(result) == 1 - for r in result[0]: - assert r.alias == "property_1" - assert r.value == 8 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "property_1" + expected_count = len(populate_datastore.CHARACTERS) + assert r.value == expected_count @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) -def test_aggregation_query_with_alias( - aggregation_query_client, nested_query, database_id +@pytest.mark.parametrize( + "aggregation_type,aggregation_args,expected", + [ + ("count", (), len(populate_datastore.CHARACTERS)), + ( + "sum", + ("appearances",), + sum(c["appearances"] for c in populate_datastore.CHARACTERS), + ), + ( + "avg", + ("appearances",), + sum(c["appearances"] for c in populate_datastore.CHARACTERS) + / len(populate_datastore.CHARACTERS), + ), + ], +) +def test_aggregation_query_in_transaction( + aggregation_query_client, + nested_query, + database_id, + aggregation_type, + aggregation_args, + expected, ): + """ + When an aggregation query is run in a transaction, the transaction id should be sent with the request. + The result is the same as when it is run outside of a transaction. + """ + with aggregation_query_client.transaction(): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + getattr(aggregation_query, aggregation_type)(*aggregation_args) + # run full query + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "property_1" + assert r.value == expected + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_count_query_with_alias(aggregation_query_client, nested_query, database_id): query = nested_query aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total") result = _do_fetch(aggregation_query) assert len(result) == 1 - for r in result[0]: - assert r.alias == "total" - assert r.value > 0 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "total" + expected_count = len(populate_datastore.CHARACTERS) + assert r.value == expected_count @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) -def test_aggregation_query_with_limit( - aggregation_query_client, nested_query, database_id -): +def test_sum_query_default(aggregation_query_client, nested_query, database_id): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.sum("appearances") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "property_1" + expected_sum = sum(c["appearances"] for c in populate_datastore.CHARACTERS) + assert r.value == expected_sum + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_sum_query_with_alias(aggregation_query_client, nested_query, database_id): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.sum("appearances", alias="sum_appearances") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "sum_appearances" + expected_sum = sum(c["appearances"] for c in populate_datastore.CHARACTERS) + assert r.value == expected_sum + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_avg_query_default(aggregation_query_client, nested_query, database_id): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.avg("appearances") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "property_1" + expected_avg = sum(c["appearances"] for c in populate_datastore.CHARACTERS) / len( + populate_datastore.CHARACTERS + ) + assert r.value == expected_avg + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_avg_query_with_alias(aggregation_query_client, nested_query, database_id): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.avg("appearances", alias="avg_appearances") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "avg_appearances" + expected_avg = sum(c["appearances"] for c in populate_datastore.CHARACTERS) / len( + populate_datastore.CHARACTERS + ) + assert r.value == expected_avg + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_count_query_with_limit(aggregation_query_client, nested_query, database_id): query = nested_query aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total") result = _do_fetch(aggregation_query) # count without limit assert len(result) == 1 - for r in result[0]: - assert r.alias == "total" - assert r.value == 8 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "total" + expected_count = len(populate_datastore.CHARACTERS) + assert r.value == expected_count aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total_up_to") - result = _do_fetch(aggregation_query, limit=2) # count with limit = 2 + limit = 2 + result = _do_fetch(aggregation_query, limit=limit) # count with limit = 2 + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "total_up_to" + assert r.value == limit + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total_high_limit") + limit = 2 + result = _do_fetch( + aggregation_query, limit=expected_count * 2 + ) # count with limit > total + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "total_high_limit" + assert r.value == expected_count + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_sum_query_with_limit(aggregation_query_client, nested_query, database_id): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.sum("appearances", alias="sum_limited") + limit = 2 + result = _do_fetch(aggregation_query, limit=limit) # count with limit = 2 + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "sum_limited" + expected = sum(c["appearances"] for c in populate_datastore.CHARACTERS[:limit]) + assert r.value == expected + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.sum("appearances", alias="sum_high_limit") + num_characters = len(populate_datastore.CHARACTERS) + result = _do_fetch( + aggregation_query, limit=num_characters * 2 + ) # count with limit > total + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "sum_high_limit" + assert r.value == sum(c["appearances"] for c in populate_datastore.CHARACTERS) + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_avg_query_with_limit(aggregation_query_client, nested_query, database_id): + query = nested_query + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.avg("appearances", alias="avg_limited") + limit = 2 + result = _do_fetch(aggregation_query, limit=limit) # count with limit = 2 + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "avg_limited" + expected = ( + sum(c["appearances"] for c in populate_datastore.CHARACTERS[:limit]) / limit + ) + assert r.value == expected + + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.avg("appearances", alias="avg_high_limit") + num_characters = len(populate_datastore.CHARACTERS) + result = _do_fetch( + aggregation_query, limit=num_characters * 2 + ) # count with limit > total + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "avg_high_limit" + assert ( + r.value + == sum(c["appearances"] for c in populate_datastore.CHARACTERS) / num_characters + ) + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_count_query_empty(aggregation_query_client, nested_query, database_id): + query = nested_query + query.add_filter("name", "=", "nonexistent") + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.count(alias="total") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "total" + assert r.value == 0 + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_sum_query_empty(aggregation_query_client, nested_query, database_id): + query = nested_query + query.add_filter("family", "=", "nonexistent") + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.sum("appearances", alias="sum") + result = _do_fetch(aggregation_query) assert len(result) == 1 - for r in result[0]: - assert r.alias == "total_up_to" - assert r.value == 2 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "sum" + assert r.value == 0 + + +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_avg_query_empty(aggregation_query_client, nested_query, database_id): + query = nested_query + query.add_filter("family", "=", "nonexistent") + aggregation_query = aggregation_query_client.aggregation_query(query) + aggregation_query.avg("appearances", alias="avg") + result = _do_fetch(aggregation_query) + assert len(result) == 1 + assert len(result[0]) == 1 + r = result[0][0] + assert r.alias == "avg" + assert r.value == 0 @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) @@ -129,11 +356,20 @@ def test_aggregation_query_multiple_aggregations( aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total") aggregation_query.count(alias="all") + aggregation_query.sum("appearances", alias="sum_appearances") + aggregation_query.avg("appearances", alias="avg_appearances") result = _do_fetch(aggregation_query) assert len(result) == 1 - for r in result[0]: - assert r.alias in ["all", "total"] - assert r.value > 0 + assert len(result[0]) == 4 + result_dict = {r.alias: r for r in result[0]} + assert result_dict["total"].value == len(populate_datastore.CHARACTERS) + assert result_dict["all"].value == len(populate_datastore.CHARACTERS) + assert result_dict["sum_appearances"].value == sum( + c["appearances"] for c in populate_datastore.CHARACTERS + ) + assert result_dict["avg_appearances"].value == sum( + c["appearances"] for c in populate_datastore.CHARACTERS + ) / len(populate_datastore.CHARACTERS) @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) @@ -141,36 +377,66 @@ def test_aggregation_query_add_aggregation( aggregation_query_client, nested_query, database_id ): from google.cloud.datastore.aggregation import CountAggregation + from google.cloud.datastore.aggregation import SumAggregation + from google.cloud.datastore.aggregation import AvgAggregation query = nested_query aggregation_query = aggregation_query_client.aggregation_query(query) count_aggregation = CountAggregation(alias="total") aggregation_query.add_aggregation(count_aggregation) + + sum_aggregation = SumAggregation("appearances", alias="sum_appearances") + aggregation_query.add_aggregation(sum_aggregation) + + avg_aggregation = AvgAggregation("appearances", alias="avg_appearances") + aggregation_query.add_aggregation(avg_aggregation) + result = _do_fetch(aggregation_query) assert len(result) == 1 - for r in result[0]: - assert r.alias == "total" - assert r.value > 0 + assert len(result[0]) == 3 + result_dict = {r.alias: r for r in result[0]} + assert result_dict["total"].value == len(populate_datastore.CHARACTERS) + assert result_dict["sum_appearances"].value == sum( + c["appearances"] for c in populate_datastore.CHARACTERS + ) + assert result_dict["avg_appearances"].value == sum( + c["appearances"] for c in populate_datastore.CHARACTERS + ) / len(populate_datastore.CHARACTERS) @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) def test_aggregation_query_add_aggregations( aggregation_query_client, nested_query, database_id ): - from google.cloud.datastore.aggregation import CountAggregation + from google.cloud.datastore.aggregation import ( + CountAggregation, + SumAggregation, + AvgAggregation, + ) query = nested_query aggregation_query = aggregation_query_client.aggregation_query(query) count_aggregation_1 = CountAggregation(alias="total") count_aggregation_2 = CountAggregation(alias="all") - aggregation_query.add_aggregations([count_aggregation_1, count_aggregation_2]) + sum_aggregation = SumAggregation("appearances", alias="sum_appearances") + avg_aggregation = AvgAggregation("appearances", alias="avg_appearances") + aggregation_query.add_aggregations( + [count_aggregation_1, count_aggregation_2, sum_aggregation, avg_aggregation] + ) result = _do_fetch(aggregation_query) assert len(result) == 1 - for r in result[0]: - assert r.alias in ["total", "all"] - assert r.value > 0 + assert len(result[0]) == 4 + result_dict = {r.alias: r for r in result[0]} + assert result_dict["total"].value == len(populate_datastore.CHARACTERS) + assert result_dict["all"].value == len(populate_datastore.CHARACTERS) + assert result_dict["sum_appearances"].value == sum( + c["appearances"] for c in populate_datastore.CHARACTERS + ) + assert result_dict["avg_appearances"].value == sum( + c["appearances"] for c in populate_datastore.CHARACTERS + ) / len(populate_datastore.CHARACTERS) @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) @@ -218,12 +484,20 @@ def test_aggregation_query_with_nested_query_filtered( aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total") + aggregation_query.sum("appearances", alias="sum_appearances") + aggregation_query.avg("appearances", alias="avg_appearances") result = _do_fetch(aggregation_query) assert len(result) == 1 - - for r in result[0]: - assert r.alias == "total" - assert r.value == 6 + assert len(result[0]) == 3 + result_dict = {r.alias: r for r in result[0]} + assert result_dict["total"].value == expected_matches + expected_sum = sum( + c["appearances"] + for c in populate_datastore.CHARACTERS + if c["appearances"] >= 20 + ) + assert result_dict["sum_appearances"].value == expected_sum + assert result_dict["avg_appearances"].value == expected_sum / expected_matches @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) @@ -243,9 +517,17 @@ def test_aggregation_query_with_nested_query_multiple_filters( aggregation_query = aggregation_query_client.aggregation_query(query) aggregation_query.count(alias="total") + aggregation_query.sum("appearances", alias="sum_appearances") + aggregation_query.avg("appearances", alias="avg_appearances") result = _do_fetch(aggregation_query) assert len(result) == 1 - - for r in result[0]: - assert r.alias == "total" - assert r.value == 4 + assert len(result[0]) == 3 + result_dict = {r.alias: r for r in result[0]} + assert result_dict["total"].value == expected_matches + expected_sum = sum( + c["appearances"] + for c in populate_datastore.CHARACTERS + if c["appearances"] >= 26 and "Stark" in c["family"] + ) + assert result_dict["sum_appearances"].value == expected_sum + assert result_dict["avg_appearances"].value == expected_sum / expected_matches diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 864bab57..9d7bec06 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -82,6 +82,21 @@ def test_query_w_ancestor(ancestor_query, database_id): assert len(entities) == expected_matches +@pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) +def test_query_in_transaction(ancestor_query, database_id): + """ + when a query is run in a transaction, the transaction id should be sent with the request. + the result is the same as when it is run outside of a transaction. + """ + query = ancestor_query + client = query._client + expected_matches = 8 + with client.transaction(): + # run full query + entities = _do_fetch(query) + assert len(entities) == expected_matches + + @pytest.mark.parametrize("database_id", [None, _helpers.TEST_DATABASE], indirect=True) def test_query_w_limit_paging(ancestor_query, database_id): query = ancestor_query diff --git a/tests/unit/test_aggregation.py b/tests/unit/test_aggregation.py index fd72ad0a..15d11aca 100644 --- a/tests/unit/test_aggregation.py +++ b/tests/unit/test_aggregation.py @@ -15,7 +15,12 @@ import mock import pytest -from google.cloud.datastore.aggregation import CountAggregation, AggregationQuery +from google.cloud.datastore.aggregation import ( + CountAggregation, + SumAggregation, + AvgAggregation, + AggregationQuery, +) from google.cloud.datastore.helpers import set_database_id_to_request from tests.unit.test_query import _make_query, _make_client @@ -34,6 +39,30 @@ def test_count_aggregation_to_pb(): assert count_aggregation._to_pb() == expected_aggregation_query_pb +def test_sum_aggregation_to_pb(): + from google.cloud.datastore_v1.types import query as query_pb2 + + sum_aggregation = SumAggregation("appearances", alias="total") + + expected_aggregation_query_pb = query_pb2.AggregationQuery.Aggregation() + expected_aggregation_query_pb.sum = query_pb2.AggregationQuery.Aggregation.Sum() + expected_aggregation_query_pb.sum.property.name = sum_aggregation.property_ref + expected_aggregation_query_pb.alias = sum_aggregation.alias + assert sum_aggregation._to_pb() == expected_aggregation_query_pb + + +def test_avg_aggregation_to_pb(): + from google.cloud.datastore_v1.types import query as query_pb2 + + avg_aggregation = AvgAggregation("appearances", alias="total") + + expected_aggregation_query_pb = query_pb2.AggregationQuery.Aggregation() + expected_aggregation_query_pb.avg = query_pb2.AggregationQuery.Aggregation.Avg() + expected_aggregation_query_pb.avg.property.name = avg_aggregation.property_ref + expected_aggregation_query_pb.alias = avg_aggregation.alias + assert avg_aggregation._to_pb() == expected_aggregation_query_pb + + @pytest.fixture def database_id(request): return request.param @@ -117,6 +146,8 @@ def test_pb_over_query_with_add_aggregations(client, database_id): aggregations = [ CountAggregation(alias="total"), CountAggregation(alias="all"), + SumAggregation("appearances", alias="sum_appearances"), + AvgAggregation("appearances", alias="avg_appearances"), ] query = _make_query(client) @@ -125,9 +156,73 @@ def test_pb_over_query_with_add_aggregations(client, database_id): aggregation_query.add_aggregations(aggregations) pb = aggregation_query._to_pb() assert pb.nested_query == _pb_from_query(query) - assert len(pb.aggregations) == 2 + assert len(pb.aggregations) == 4 assert pb.aggregations[0] == CountAggregation(alias="total")._to_pb() assert pb.aggregations[1] == CountAggregation(alias="all")._to_pb() + assert ( + pb.aggregations[2] + == SumAggregation("appearances", alias="sum_appearances")._to_pb() + ) + assert ( + pb.aggregations[3] + == AvgAggregation("appearances", alias="avg_appearances")._to_pb() + ) + + +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +def test_pb_over_query_with_sum(client, database_id): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.sum("appearances", alias="total") + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == SumAggregation("appearances", alias="total")._to_pb() + + +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +def test_pb_over_query_sum_with_add_aggregation(client, database_id): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.add_aggregation(SumAggregation("appearances", alias="total")) + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == SumAggregation("appearances", alias="total")._to_pb() + + +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +def test_pb_over_query_with_avg(client, database_id): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.avg("appearances", alias="avg") + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == AvgAggregation("appearances", alias="avg")._to_pb() + + +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +def test_pb_over_query_avg_with_add_aggregation(client, database_id): + from google.cloud.datastore.query import _pb_from_query + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + aggregation_query.add_aggregation(AvgAggregation("appearances", alias="avg")) + pb = aggregation_query._to_pb() + assert pb.nested_query == _pb_from_query(query) + assert len(pb.aggregations) == 1 + assert pb.aggregations[0] == AvgAggregation("appearances", alias="avg")._to_pb() @pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) @@ -243,8 +338,11 @@ def test_iterator__build_protobuf_all_values(): query = _make_query(client) alias = "total" limit = 2 + property_ref = "appearances" aggregation_query = AggregationQuery(client=client, query=query) aggregation_query.count(alias) + aggregation_query.sum(property_ref) + aggregation_query.avg(property_ref) iterator = _make_aggregation_iterator(aggregation_query, client, limit=limit) iterator.num_results = 4 @@ -252,9 +350,22 @@ def test_iterator__build_protobuf_all_values(): pb = iterator._build_protobuf() expected_pb = query_pb2.AggregationQuery() expected_pb.nested_query = query_pb2.Query() + expected_pb.nested_query.limit = limit + expected_count_pb = query_pb2.AggregationQuery.Aggregation(alias=alias) - expected_count_pb.count.up_to = limit + expected_count_pb.count = query_pb2.AggregationQuery.Aggregation.Count() expected_pb.aggregations.append(expected_count_pb) + + expected_sum_pb = query_pb2.AggregationQuery.Aggregation() + expected_sum_pb.sum = query_pb2.AggregationQuery.Aggregation.Sum() + expected_sum_pb.sum.property.name = property_ref + expected_pb.aggregations.append(expected_sum_pb) + + expected_avg_pb = query_pb2.AggregationQuery.Aggregation() + expected_avg_pb.avg = query_pb2.AggregationQuery.Aggregation.Avg() + expected_avg_pb.avg.property.name = property_ref + expected_pb.aggregations.append(expected_avg_pb) + assert pb == expected_pb @@ -426,6 +537,81 @@ def test__item_to_aggregation_result(): assert result[0].value == map_composite_mock.__getitem__().integer_value +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +@pytest.mark.parametrize( + "aggregation_type,aggregation_args", + [ + ("count", ()), + ( + "sum", + ("appearances",), + ), + ("avg", ("appearances",)), + ], +) +def test_eventual_transaction_fails(database_id, aggregation_type, aggregation_args): + """ + Queries with eventual consistency cannot be used in a transaction. + """ + import mock + + transaction = mock.Mock() + transaction.id = b"expected_id" + client = _Client(None, database=database_id, transaction=transaction) + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + # initiate requested aggregation (ex count, sum, avg) + getattr(aggregation_query, aggregation_type)(*aggregation_args) + with pytest.raises(ValueError): + list(aggregation_query.fetch(eventual=True)) + + +@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True) +@pytest.mark.parametrize( + "aggregation_type,aggregation_args", + [ + ("count", ()), + ( + "sum", + ("appearances",), + ), + ("avg", ("appearances",)), + ], +) +def test_transaction_id_populated(database_id, aggregation_type, aggregation_args): + """ + When an aggregation is run in the context of a transaction, the transaction + ID should be populated in the request. + """ + import mock + + transaction = mock.Mock() + transaction.id = b"expected_id" + mock_datastore_api = mock.Mock() + mock_gapic = mock_datastore_api.run_aggregation_query + mock_gapic.return_value = _make_aggregation_query_response([]) + client = _Client( + None, + datastore_api=mock_datastore_api, + database=database_id, + transaction=transaction, + ) + + query = _make_query(client) + aggregation_query = _make_aggregation_query(client=client, query=query) + + # initiate requested aggregation (ex count, sum, avg) + getattr(aggregation_query, aggregation_type)(*aggregation_args) + # run mock query + list(aggregation_query.fetch()) + assert mock_gapic.call_count == 1 + request = mock_gapic.call_args[1]["request"] + read_options = request["read_options"] + # ensure transaction ID is populated + assert read_options.transaction == client.current_transaction.id + + class _Client(object): def __init__( self, @@ -459,7 +645,9 @@ def _make_aggregation_iterator(*args, **kw): return AggregationResultIterator(*args, **kw) -def _make_aggregation_query_response(aggregation_pbs, more_results_enum): +def _make_aggregation_query_response( + aggregation_pbs, more_results_enum=3 +): # 3 = NO_MORE_RESULTS from google.cloud.datastore_v1.types import datastore as datastore_pb2 from google.cloud.datastore_v1.types import aggregation_result diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 25b3febb..7758d7fb 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -652,6 +652,56 @@ def test_query_fetch_w_explicit_client_w_retry_w_timeout(database_id): assert iterator._timeout == timeout +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_eventual_transaction_fails(database_id): + """ + Queries with eventual consistency cannot be used in a transaction. + """ + import mock + + transaction = mock.Mock() + transaction.id = b"expected_id" + client = _Client(None, database=database_id, transaction=transaction) + + query = _make_query(client) + with pytest.raises(ValueError): + list(query.fetch(eventual=True)) + + +@pytest.mark.parametrize("database_id", [None, "somedb"]) +def test_transaction_id_populated(database_id): + """ + When an aggregation is run in the context of a transaction, the transaction + ID should be populated in the request. + """ + import mock + + transaction = mock.Mock() + transaction.id = b"expected_id" + mock_datastore_api = mock.Mock() + mock_gapic = mock_datastore_api.run_query + + more_results_enum = 3 # NO_MORE_RESULTS + response_pb = _make_query_response([], b"", more_results_enum, 0) + mock_gapic.return_value = response_pb + + client = _Client( + None, + datastore_api=mock_datastore_api, + database=database_id, + transaction=transaction, + ) + + query = _make_query(client) + # run mock query + list(query.fetch()) + assert mock_gapic.call_count == 1 + request = mock_gapic.call_args[1]["request"] + read_options = request["read_options"] + # ensure transaction ID is populated + assert read_options.transaction == client.current_transaction.id + + def test_iterator_constructor_defaults(): query = object() client = object()