Skip to content

Commit

Permalink
Merge branch 'pr/75'
Browse files Browse the repository at this point in the history
Change-Id: Iac8815ecb955ac4915b8c29af60aba40d6dac09d
  • Loading branch information
skriptble committed Aug 8, 2018
2 parents 44fa48d + ad3c80f commit 9b26cbb
Show file tree
Hide file tree
Showing 12 changed files with 592 additions and 36 deletions.
140 changes: 140 additions & 0 deletions core/command/count_documents.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package command

import (
"context"
"errors"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/option"
"github.com/mongodb/mongo-go-driver/core/readconcern"
"github.com/mongodb/mongo-go-driver/core/readpref"
"github.com/mongodb/mongo-go-driver/core/session"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
)

// CountDocuments represents the CountDocuments command.
//
// The countDocuments command counts how many documents in a collection match the given query.
type CountDocuments struct {
NS Namespace
Pipeline *bson.Array
Opts []option.CountOptioner
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client

result int64
err error
}

// Encode will encode this command into a wire message for the given server description.
func (c *CountDocuments) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
if err := c.NS.Validate(); err != nil {
return nil, err
}
command := bson.NewDocument()
command.Append(bson.EC.String("aggregate", c.NS.Collection), bson.EC.Array("pipeline", c.Pipeline))

cursor := bson.NewDocument()
command.Append(bson.EC.SubDocument("cursor", cursor))
for _, opt := range c.Opts {
if opt == nil {
continue
}
//because we already have these options in the pipeline
switch opt.(type) {
case option.OptSkip:
continue
case option.OptLimit:
continue
}
err := opt.Option(command)
if err != nil {
return nil, err
}
}

return (&Read{DB: c.NS.DB, ReadPref: c.ReadPref, Command: command}).Encode(desc)
}

// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, wm wiremessage.WireMessage) *CountDocuments {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
c.err = err
return c
}
cur, err := cb.BuildCursor(rdr, c.Session, c.Clock)
if err != nil {
c.err = err
return c
}

var doc = bson.NewDocument()
if cur.Next(ctx) {
err = cur.Decode(doc)
if err != nil {
c.err = err
return c
}
val, err := doc.LookupErr("n")
switch {
case err == bson.ErrElementNotFound:
c.err = errors.New("Invalid response from server, no 'n' field")
return c
case err != nil:
c.err = err
return c
}
switch val.Type() {
case bson.TypeInt32:
c.result = int64(val.Int32())
case bson.TypeInt64:
c.result = val.Int64()
default:
c.err = errors.New("Invalid response from server, value field is not a number")
}

return c
}

c.result = 0
return c
}

// Result returns the result of a decoded wire message and server description.
func (c *CountDocuments) Result() (int64, error) {
if c.err != nil {
return 0, c.err
}
return c.result, nil
}

// Err returns the error set on this command.
func (c *CountDocuments) Err() error { return c.err }

// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.SelectedServer, cb CursorBuilder, rw wiremessage.ReadWriter) (int64, error) {
wm, err := c.Encode(desc)
if err != nil {
return 0, err
}

err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return 0, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return 0, err
}
return c.Decode(ctx, desc, cb, wm).Result()
}
39 changes: 39 additions & 0 deletions core/dispatch/count_documents.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package dispatch

import (
"context"

"github.com/mongodb/mongo-go-driver/core/command"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/topology"
)

// CountDocuments handles the full cycle dispatch and execution of a countDocuments command against the provided
// topology.
func CountDocuments(
ctx context.Context,
cmd command.CountDocuments,
topo *topology.Topology,
selector description.ServerSelector,
) (int64, error) {

ss, err := topo.SelectServer(ctx, selector)
if err != nil {
return 0, err
}

desc := ss.Description()
conn, err := ss.Connection(ctx)
if err != nil {
return 0, err
}
defer conn.Close()

return cmd.RoundTrip(ctx, desc, ss, conn)
}
22 changes: 11 additions & 11 deletions core/option/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,17 +620,17 @@ func (opt OptMaxTime) Option(d *bson.Document) error {
return nil
}

func (OptMaxTime) aggregateOption() {}
func (OptMaxTime) countOption() {}
func (OptMaxTime) distinctOption() {}
func (OptMaxTime) findOption() {}
func (OptMaxTime) findOneOption() {}
func (OptMaxTime) findOneAndDeleteOption() {}
func (OptMaxTime) findOneAndReplaceOption() {}
func (OptMaxTime) findOneAndUpdateOption() {}
func (OptMaxTime) listIndexesOption() {}
func (OptMaxTime) dropIndexesOption() {}
func (OptMaxTime) createIndexesOption() {}
func (OptMaxTime) aggregateOption() {}
func (OptMaxTime) countOption() {}
func (OptMaxTime) distinctOption() {}
func (OptMaxTime) findOption() {}
func (OptMaxTime) findOneOption() {}
func (OptMaxTime) findOneAndDeleteOption() {}
func (OptMaxTime) findOneAndReplaceOption() {}
func (OptMaxTime) findOneAndUpdateOption() {}
func (OptMaxTime) listIndexesOption() {}
func (OptMaxTime) dropIndexesOption() {}
func (OptMaxTime) createIndexesOption() {}

// String implements the Stringer interface.
func (opt OptMaxTime) String() string {
Expand Down
2 changes: 1 addition & 1 deletion examples/documentation_examples/examples.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (

"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/mongo"
"github.com/stretchr/testify/require"
"github.com/mongodb/mongo-go-driver/mongo/findopt"
"github.com/stretchr/testify/require"
)

func requireCursorLength(t *testing.T, cursor mongo.Cursor, length int) {
Expand Down
4 changes: 2 additions & 2 deletions internal/channel_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package internal

import (
"context"
"testing"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
"testing"
)

// Implements the connection.Connection interface by reading and writing wire messages
Expand Down
40 changes: 20 additions & 20 deletions internal/testutil/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
"sync"
"testing"

"github.com/mongodb/mongo-go-driver/core/connstring"
"github.com/mongodb/mongo-go-driver/core/topology"
"github.com/mongodb/mongo-go-driver/core/event"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/core/command"
"github.com/mongodb/mongo-go-driver/core/connection"
"github.com/mongodb/mongo-go-driver/core/connstring"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/event"
"github.com/mongodb/mongo-go-driver/core/topology"
"github.com/stretchr/testify/require"
"github.com/mongodb/mongo-go-driver/core/command"
"github.com/mongodb/mongo-go-driver/bson"
)

var connectionString connstring.ConnString
Expand Down Expand Up @@ -80,20 +80,20 @@ func AddCompressorToUri(uri string) string {
func MonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topology.Topology {
cs := ConnString(t)
opts := []topology.Option{
topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
return append(
opts,
topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option {
return append(
opts,
connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
return monitor
}),
)
}),
)
}),
topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }),
topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
return append(
opts,
topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option {
return append(
opts,
connection.WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
return monitor
}),
)
}),
)
}),
}

monitoredTopologyOnce.Do(func() {
Expand All @@ -110,7 +110,7 @@ func MonitoredTopology(t *testing.T, monitor *event.CommandMonitor) *topology.To
require.NoError(t, err)

_, err = (&command.Write{
DB: DBName(t),
DB: DBName(t),
Command: bson.NewDocument(bson.EC.Int32("dropDatabase", 1)),
}).RoundTrip(context.Background(), s.SelectedDescription(), c)

Expand Down
4 changes: 2 additions & 2 deletions internal/testutil/ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (
"github.com/mongodb/mongo-go-driver/core/command"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/dispatch"
"github.com/mongodb/mongo-go-driver/core/session"
"github.com/mongodb/mongo-go-driver/core/topology"
"github.com/mongodb/mongo-go-driver/core/uuid"
"github.com/mongodb/mongo-go-driver/core/writeconcern"
"github.com/mongodb/mongo-go-driver/internal/testutil/helpers"
"github.com/stretchr/testify/require"
"github.com/mongodb/mongo-go-driver/core/uuid"
"github.com/mongodb/mongo-go-driver/core/session"
)

// AutoCreateIndexes creates an index in the test cluster.
Expand Down
73 changes: 73 additions & 0 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,79 @@ func (coll *Collection) Count(ctx context.Context, filter interface{},
)
}

// CountDocuments gets the number of documents matching the filter. A user can supply a
// custom context to this method, or nil to default to context.Background().
//
// This method uses countDocumentsAggregatePipeline to turn the filter parameter and options
// into aggregate pipeline.
func (coll *Collection) CountDocuments(ctx context.Context, filter interface{},
opts ...countopt.Count) (int64, error) {

if ctx == nil {
ctx = context.Background()
}

pipelineArr, err := countDocumentsAggregatePipeline(filter, opts...)
if err != nil {
return 0, err
}

countOpts, sess, err := countopt.BundleCount(opts...).Unbundle(true)
if err != nil {
return 0, err
}

err = coll.client.ValidSession(sess)
if err != nil {
return 0, err
}

oldns := coll.namespace()
cmd := command.CountDocuments{
NS: command.Namespace{DB: oldns.DB, Collection: oldns.Collection},
Pipeline: pipelineArr,
Opts: countOpts,
ReadPref: coll.readPreference,
ReadConcern: coll.readConcern,
Session: sess,
Clock: coll.client.clock,
}
return dispatch.CountDocuments(ctx, cmd, coll.client.topology, coll.readSelector)
}

// EstimatedDocumentCount gets an estimate of the count of documents in a collection using collection metadata.
func (coll *Collection) EstimatedDocumentCount(ctx context.Context,
opts ...countopt.EstimatedDocumentCount) (int64, error) {

if ctx == nil {
ctx = context.Background()
}

countOpts, sess, err := countopt.BundleEstimatedDocumentCount(opts...).Unbundle(true)
if err != nil {
return 0, err
}

err = coll.client.ValidSession(sess)
if err != nil {
return 0, err
}

oldns := coll.namespace()

cmd := command.Count{
NS: command.Namespace{DB: oldns.DB, Collection: oldns.Collection},
Query: bson.NewDocument(),
Opts: countOpts,
ReadPref: coll.readPreference,
ReadConcern: coll.readConcern,
Session: sess,
Clock: coll.client.clock,
}
return dispatch.Count(ctx, cmd, coll.client.topology, coll.readSelector, coll.client.id,
coll.client.topology.SessionPool)
}

// Distinct finds the distinct values for a specified field across a single
// collection. A user can supply a custom context to this method, or nil to
// default to context.Background().
Expand Down
Loading

0 comments on commit 9b26cbb

Please sign in to comment.