diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 40598e6df1c2..aa66957bf4a1 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -35,6 +35,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/api/option" + firestore "google.golang.org/genproto/googleapis/firestore/v1beta1" "google.golang.org/genproto/googleapis/type/latlng" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -1773,3 +1774,52 @@ func TestIntegration_BulkWriter(t *testing.T) { } } } + +func TestIntegration_CountAggregationQuery(t *testing.T) { + docs := []*DocumentRef{ + iColl.NewDoc(), + iColl.NewDoc(), + } + + c := integrationClient(t) + ctx := context.Background() + bw := c.BulkWriter(ctx) + jobs := make([]*BulkWriterJob, 0) + + // Populate the collection + f := integrationTestMap + for _, d := range docs { + j, err := bw.Create(d, f) + jobs = append(jobs, j) + if err != nil { + t.Fatal(err) + } + } + bw.End() + + for _, j := range jobs { + _, err := j.Results() + if err != nil { + t.Fatal(err) + } + } + + // [START firestore_count_query] + alias := "twos" + q := iColl.Where("str", "==", "two") + aq := q.NewAggregationQuery() + ar, err := aq.WithCount(alias).Get(ctx) + // [END firestore_count_query] + if err != nil { + t.Fatal(err) + } + + count, ok := ar[alias] + if !ok { + t.Errorf("key %s not in response %v", alias, ar) + } + cv := count.(*firestore.Value) + if cv.GetIntegerValue() != 2 { + t.Errorf("COUNT aggregation query mismatch;\ngot: %d, want: %d", cv.GetIntegerValue(), 2) + } +} diff --git a/firestore/mock_test.go b/firestore/mock_test.go index ef40a92953b3..34fa3582b04c 100644 --- a/firestore/mock_test.go +++ b/firestore/mock_test.go @@ -187,6 +187,27 @@ func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryS return nil } +func (s *mockServer) RunAggregationQuery(req *pb.RunAggregationQueryRequest, qs pb.Firestore_RunAggregationQueryServer) error { + res, err := s.popRPC(req) + if err != nil { + return err + } + responses := res.([]interface{}) + for _, res := range responses { + switch res := res.(type) { + case *pb.RunAggregationQueryResponse: + if err := qs.Send(res); err != nil { + return err + } + case error: + return res + default: + return fmt.Errorf("bad response type in RunAggregationQuery: %+v", res) + } + } + return nil +} + func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) { res, err := s.popRPC(req) if err != nil { diff --git a/firestore/query.go b/firestore/query.go index d9bb471e5bc4..218738b7a359 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -304,6 +304,14 @@ func (q Query) Deserialize(bytes []byte) (Query, error) { return q.fromProto(&runQueryRequest) } +// NewAggregationQuery returns an AggregationQuery with this query as its +// base query. +func (q *Query) NewAggregationQuery() *AggregationQuery { + return &AggregationQuery{ + query: q, + } +} + // fromProto creates a new Query object from a RunQueryRequest. This can be used // in combination with ToProto to serialize Query objects. This could be useful, // for instance, if executing a query formed in one process in another. @@ -1036,3 +1044,74 @@ func (it *btreeDocumentIterator) next() (*DocumentSnapshot, error) { } func (*btreeDocumentIterator) stop() {} + +// AggregationQuery allows for generating aggregation results of an underlying +// basic query. A single AggregationQuery can contain multiple aggregations. +type AggregationQuery struct { + // aggregateQueries contains all of the queries for this request. + aggregateQueries []*pb.StructuredAggregationQuery_Aggregation + // query contains a reference pointer to the underlying structured query. + query *Query +} + +// WithCount specifies that the aggregation query provide a count of results +// returned by the underlying Query. +func (a *AggregationQuery) WithCount(alias string) *AggregationQuery { + aq := &pb.StructuredAggregationQuery_Aggregation{ + Alias: alias, + Operator: &pb.StructuredAggregationQuery_Aggregation_Count_{}, + } + + a.aggregateQueries = append(a.aggregateQueries, aq) + + return a +} + +// Get retrieves the aggregation query results from the service. +func (a *AggregationQuery) Get(ctx context.Context) (AggregationResult, error) { + + client := a.query.c.c + q, err := a.query.toProto() + if err != nil { + return nil, err + } + + req := &pb.RunAggregationQueryRequest{ + Parent: a.query.parentPath, + QueryType: &pb.RunAggregationQueryRequest_StructuredAggregationQuery{ + StructuredAggregationQuery: &pb.StructuredAggregationQuery{ + QueryType: &pb.StructuredAggregationQuery_StructuredQuery{ + StructuredQuery: q, + }, + Aggregations: a.aggregateQueries, + }, + }, + } + ctx = withResourceHeader(ctx, a.query.c.path()) + stream, err := client.RunAggregationQuery(ctx, req) + if err != nil { + return nil, err + } + + resp := make(AggregationResult) + + for { + res, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + f := res.Result.AggregateFields + + for k, v := range f { + resp[k] = v + } + } + return resp, nil +} + +// AggregationResult contains the results of an aggregation query. +type AggregationResult map[string]interface{} diff --git a/firestore/query_test.go b/firestore/query_test.go index fafbdf96b980..ea67419875e8 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -923,3 +923,35 @@ func (b byQuery) Less(i, j int) bool { } return c < 0 } + +func TestAggregationQuery(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + srv.addRPC(nil, []interface{}{ + &pb.RunAggregationQueryResponse{ + Result: &pb.AggregationResult{ + AggregateFields: map[string]*pb.Value{ + "testAlias": intval(1), + }, + }, + }, + }) + + q := c.Collection("coll1").Where("f", "==", 2) + ar, err := q.NewAggregationQuery().WithCount("testAlias").Get(ctx) + if err != nil { + t.Fatal(err) + } + + count, ok := ar["testAlias"] + if !ok { + t.Errorf("aggregation query key not found") + } + + cv := count.(*pb.Value) + if cv.GetIntegerValue() != 1 { + t.Errorf("got: %v\nwant: %v\n; result: %v\n", cv.GetIntegerValue(), 1, count) + } +}