diff --git a/datastore/query.go b/datastore/query.go index 0c3fe01db853..195ae102bbf5 100644 --- a/datastore/query.go +++ b/datastore/query.go @@ -627,6 +627,14 @@ func (c *Client) Run(ctx context.Context, q *Query) *Iterator { // RunAggregationQuery gets aggregation query (e.g. COUNT) results from the service. func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) (AggregationResult, error) { + if aq == nil { + return nil, errors.New("datastore: aggregation query cannot be nil") + } + + if aq.query == nil { + return nil, errors.New("datastore: aggregation query must include nested query") + } + if len(aq.aggregationQueries) == 0 { return nil, errors.New("datastore: aggregation query must contain one or more operators (e.g. count)") } diff --git a/datastore/query_test.go b/datastore/query_test.go index c46b0bb5cb67..824c439e97a0 100644 --- a/datastore/query_test.go +++ b/datastore/query_test.go @@ -721,3 +721,33 @@ func TestAggregationQuery(t *testing.T) { t.Errorf("want: %v\ngot: %v\n", want, cv) } } + +func TestAggregationQueryIsNil(t *testing.T) { + client := &Client{ + client: &fakeClient{ + aggQueryFn: func(req *pb.RunAggregationQueryRequest) (*pb.RunAggregationQueryResponse, error) { + return fakeRunAggregationQuery(req) + }, + }, + } + + var q Query + aq := q.NewAggregationQuery() + _, err := client.RunAggregationQuery(context.Background(), aq) + if err == nil { + t.Fatal(err) + } + + q2 := NewQuery("Gopher") + aq2 := q2.NewAggregationQuery() + _, err = client.RunAggregationQuery(context.Background(), aq2) + if err == nil { + t.Fatal(err) + } + + aq3 := q2.NewAggregationQuery().WithCount("") + _, err = client.RunAggregationQuery(context.Background(), aq3) + if err == nil { + t.Fatal(err) + } +}