diff --git a/internal/verifier/compare.go b/internal/verifier/compare.go index 7cfcfa7a..8c97d43b 100644 --- a/internal/verifier/compare.go +++ b/internal/verifier/compare.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "iter" "time" "github.com/10gen/migration-verifier/chanutil" @@ -12,6 +13,7 @@ import ( "github.com/10gen/migration-verifier/internal/retry" "github.com/10gen/migration-verifier/internal/types" "github.com/10gen/migration-verifier/internal/util" + "github.com/10gen/migration-verifier/mmongo/cursor" "github.com/10gen/migration-verifier/option" "github.com/pkg/errors" "go.mongodb.org/mongo-driver/v2/bson" @@ -30,6 +32,11 @@ const ( docKeyInHashedCompare = "k" ) +type seqWithTs struct { + seq iter.Seq2[bson.Raw, error] + ts bson.Timestamp +} + type docWithTs struct { doc bson.Raw ts bson.Timestamp @@ -45,7 +52,7 @@ func (verifier *Verifier) FetchAndCompareDocuments( types.ByteCount, error, ) { - var srcChannel, dstChannel <-chan docWithTs + var srcChannel, dstChannel <-chan seqWithTs var readSrcCallback, readDstCallback func(context.Context, *retry.FuncInfo) error results := []VerificationResult{} @@ -100,7 +107,7 @@ func (verifier *Verifier) compareDocsFromChannels( workerNum int, fi *retry.FuncInfo, task *VerificationTask, - srcChannel, dstChannel <-chan docWithTs, + srcChannel, dstChannel <-chan seqWithTs, ) ( []VerificationResult, types.DocumentCount, @@ -204,7 +211,7 @@ func (verifier *Verifier) compareDocsFromChannels( for !srcClosed || !dstClosed { simpleTimerReset(readTimer, readTimeout) - var srcDocWithTs, dstDocWithTs docWithTs + var srcDocsWithTs, dstDocsWithTs seqWithTs eg, egCtx := contextplus.ErrGroup(ctx) @@ -219,21 +226,13 @@ func (verifier *Verifier) compareDocsFromChannels( "failed to read from source after %s", readTimeout, ) - case srcDocWithTs, alive = <-srcChannel: + case srcDocsWithTs, alive = <-srcChannel: if !alive { srcClosed = true break } fi.NoteSuccess("received document from source") - - srcDocCount++ - srcByteCount += types.ByteCount(len(srcDocWithTs.doc)) - verifier.workerTracker.SetSrcCounts( - workerNum, - srcDocCount, - srcByteCount, - ) } return nil @@ -251,7 +250,7 @@ func (verifier *Verifier) compareDocsFromChannels( "failed to read from destination after %s", readTimeout, ) - case dstDocWithTs, alive = <-dstChannel: + case dstDocsWithTs, alive = <-dstChannel: if !alive { dstClosed = true break @@ -271,32 +270,72 @@ func (verifier *Verifier) compareDocsFromChannels( ) } - if srcDocWithTs.doc != nil { - err := handleNewDoc(srcDocWithTs, true) + if srcDocsWithTs.seq != nil { + for doc, err := range srcDocsWithTs.seq { + if err != nil { + return nil, 0, 0, errors.Wrapf( + err, + "reading batch of docs from source (task: %s)", + task.PrimaryKey, + ) + } - if err != nil { + srcDocCount++ + srcByteCount += types.ByteCount(len(doc)) + verifier.workerTracker.SetSrcCounts( + workerNum, + srcDocCount, + srcByteCount, + ) - return nil, 0, 0, errors.Wrapf( - err, - "comparer thread failed to handle %#q's source doc (task: %s) with ID %v", - namespace, - task.PrimaryKey, - srcDocWithTs.doc.Lookup("_id"), + err := handleNewDoc( + docWithTs{ + doc: doc, + ts: srcDocsWithTs.ts, + }, + true, ) + + if err != nil { + return nil, 0, 0, errors.Wrapf( + err, + "comparer thread failed to handle %#q's source doc (task: %s) with ID %v", + namespace, + task.PrimaryKey, + doc.Lookup("_id"), + ) + } } + } - if dstDocWithTs.doc != nil { - err := handleNewDoc(dstDocWithTs, false) + if dstDocsWithTs.seq != nil { + for doc, err := range dstDocsWithTs.seq { + if err != nil { + return nil, 0, 0, errors.Wrapf( + err, + "reading batch of docs from destination (task: %s)", + task.PrimaryKey, + ) + } - if err != nil { - return nil, 0, 0, errors.Wrapf( - err, - "comparer thread failed to handle %#q's destination doc (task: %s) with ID %v", - namespace, - task.PrimaryKey, - dstDocWithTs.doc.Lookup("_id"), + err := handleNewDoc( + docWithTs{ + doc: doc, + ts: dstDocsWithTs.ts, + }, + false, ) + + if err != nil { + return nil, 0, 0, errors.Wrapf( + err, + "comparer thread failed to handle %#q's destination doc (task: %s) with ID %v", + namespace, + task.PrimaryKey, + doc.Lookup("_id"), + ) + } } } } @@ -427,13 +466,13 @@ func simpleTimerReset(t *time.Timer, dur time.Duration) { func (verifier *Verifier) getFetcherChannelsAndCallbacks( task *VerificationTask, ) ( - <-chan docWithTs, - <-chan docWithTs, + <-chan seqWithTs, + <-chan seqWithTs, func(context.Context, *retry.FuncInfo) error, func(context.Context, *retry.FuncInfo) error, ) { - srcChannel := make(chan docWithTs) - dstChannel := make(chan docWithTs) + srcChannel := make(chan seqWithTs) + dstChannel := make(chan seqWithTs) readSrcCallback := func(ctx context.Context, state *retry.FuncInfo) error { // We open a session here so that we can read the session’s cluster @@ -510,38 +549,44 @@ func (verifier *Verifier) getFetcherChannelsAndCallbacks( } func iterateCursorToChannel( - sctx context.Context, + ctx context.Context, state *retry.FuncInfo, - cursor *mongo.Cursor, - writer chan<- docWithTs, + myCursor *cursor.BatchCursor, + writer chan<- seqWithTs, ) error { defer close(writer) - sess := mongo.SessionFromContext(sctx) + for { + seq := myCursor.GetCurrentBatchIterator() - for cursor.Next(sctx) { state.NoteSuccess("received a document") - clusterTime, err := util.GetClusterTimeFromSession(sess) + ct, err := myCursor.GetClusterTime() if err != nil { - return errors.Wrap(err, "reading cluster time from session") + return errors.Wrap(err, "reading cluster time from batch") } err = chanutil.WriteWithDoneCheck( - sctx, + ctx, writer, - docWithTs{ - doc: slices.Clone(cursor.Current), - ts: clusterTime, + seqWithTs{ + seq: seq, + ts: ct, }, ) if err != nil { - return errors.Wrapf(err, "sending document to compare thread") + return errors.Wrapf(err, "sending iterator to compare thread") + } + + if myCursor.IsFinished() { + return nil } - } - return errors.Wrap(cursor.Err(), "failed to iterate cursor") + if err := myCursor.GetNext(ctx); err != nil { + return errors.Wrap(err, "failed to iterate cursor") + } + } } func getMapKey(docKeyValues []bson.RawValue) string { @@ -555,8 +600,13 @@ func getMapKey(docKeyValues []bson.RawValue) string { return keyBuffer.String() } -func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mongo.Collection, clusterInfo *util.ClusterInfo, - startAtTs *bson.Timestamp, task *VerificationTask) (*mongo.Cursor, error) { +func (verifier *Verifier) getDocumentsCursor( + ctx context.Context, + collection *mongo.Collection, + clusterInfo *util.ClusterInfo, + startAtTs *bson.Timestamp, + task *VerificationTask, +) (*cursor.BatchCursor, error) { var findOptions bson.D runCommandOptions := options.RunCmd() var andPredicates bson.A @@ -673,7 +723,16 @@ func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mo } } - return collection.Database().RunCommandCursor(ctx, cmd, runCommandOptions) + c, err := cursor.New( + collection.Database(), + collection.Database().RunCommand(ctx, cmd, runCommandOptions), + ) + + if err == nil { + c.SetSession(mongo.SessionFromContext(ctx)) + } + + return c, err } func transformPipelineForToHashedIndexKey( diff --git a/internal/verifier/migration_verifier_test.go b/internal/verifier/migration_verifier_test.go index 92ab6bf3..b4e35638 100644 --- a/internal/verifier/migration_verifier_test.go +++ b/internal/verifier/migration_verifier_test.go @@ -26,6 +26,7 @@ import ( "github.com/10gen/migration-verifier/internal/types" "github.com/10gen/migration-verifier/internal/util" "github.com/10gen/migration-verifier/mbson" + "github.com/10gen/migration-verifier/mseq" "github.com/10gen/migration-verifier/mslices" "github.com/cespare/permute/v2" "github.com/rs/zerolog" @@ -1150,13 +1151,15 @@ func TestVerifierCompareDocs(t *testing.T) { namespace := "testdb.testns" - makeDocChannel := func(docs []bson.D) <-chan docWithTs { - theChan := make(chan docWithTs, len(docs)) + makeDocChannel := func(docs []bson.D) <-chan seqWithTs { + theChan := make(chan seqWithTs, len(docs)) for d, doc := range docs { - theChan <- docWithTs{ - doc: testutil.MustMarshal(doc), - ts: bson.Timestamp{1, uint32(d)}, + theChan <- seqWithTs{ + seq: mseq.FromSliceWithNilErr( + mslices.Of(testutil.MustMarshal(doc)), + ), + ts: bson.Timestamp{1, uint32(d)}, } } diff --git a/mmongo/cursor/batch.go b/mmongo/cursor/batch.go new file mode 100644 index 00000000..3b900507 --- /dev/null +++ b/mmongo/cursor/batch.go @@ -0,0 +1,252 @@ +// Package cursor exposes a cursor implementation that facilitates easy +// batch reads as well as reading of custom cursor properties like +// resume tokens. +package cursor + +import ( + "context" + "fmt" + "io" + "iter" + "strings" + "time" + + "github.com/10gen/migration-verifier/mbson" + "github.com/10gen/migration-verifier/mslices" + "github.com/10gen/migration-verifier/option" + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + +var ( + clusterTimePath = mslices.Of("$clusterTime", "clusterTime") +) + +// BatchCursor is like mongo.Cursor, but it exposes documents per batch rather than +// a per-document reader. This facilitates more efficient memory usage +// because there is no need to clone each document individually, as +// mongo.Cursor requires. +type BatchCursor struct { + sess *mongo.Session + maxAwaitTime option.Option[time.Duration] + id int64 + ns string + db *mongo.Database + rawResp bson.Raw + curBatch bson.RawArray +} + +// GetCurrentBatchIterator returns an iterator over the BatchCursor’s current batch. +// Note that the iteratees are NOT copied; the expectation is that each batch +// will be iterated exactly once. (Nothing *requires* that, of course.) +func (c *BatchCursor) GetCurrentBatchIterator() iter.Seq2[bson.Raw, error] { + // NB: Use of iter.Seq2 to return an error is a bit controversial. + // The pattern is used here in order to minimize the odds that a caller + // would overlook the need to check the error, which seems more probable + // with various other patterns. + // + // See “https://sinclairtarget.com/blog/2025/07/error-handling-with-iterators-in-go/”. + + batch := c.curBatch + + // NB: This MUST NOT close around c (the receiver), or else there can be + // a race condition between this callback and GetNext(). + return func(yield func(bson.Raw, error) bool) { + iterator := &bsoncore.Iterator{ + List: bsoncore.Array(batch), + } + + for { + val, err := iterator.Next() + if errors.Is(err, io.EOF) { + return + } + + doc, ok := val.DocumentOK() + if !ok { + err = fmt.Errorf("expected BSON %s but found %s", bson.TypeEmbeddedDocument, val.Type) + } + + if !yield(bson.Raw(doc), err) { + return + } + + if err != nil { + panic(fmt.Sprintf("Iteration must stop after error (%v)", err)) + } + } + } +} + +// GetClusterTime returns the server response’s cluster time. +func (c *BatchCursor) GetClusterTime() (bson.Timestamp, error) { + ctRV, err := c.rawResp.LookupErr(clusterTimePath...) + + if err != nil { + return bson.Timestamp{}, errors.Wrapf( + err, + "extracting %#q from server response", + clusterTimePath, + ) + } + + ts, err := mbson.CastRawValue[bson.Timestamp](ctRV) + if err != nil { + return bson.Timestamp{}, errors.Wrapf( + err, + "parsing server response’s %#q", + clusterTimePath, + ) + } + + return ts, nil +} + +// IsFinished indicates whether the present batch is the final one. +func (c *BatchCursor) IsFinished() bool { + return c.id == 0 +} + +// GetNext fetches the next batch of responses from the server and caches it +// for access via GetCurrentBatch(). +// +// extraPieces are things you want to add to the underlying `getMore` +// server call, such as `batchSize`. +func (c *BatchCursor) GetNext(ctx context.Context, extraPieces ...bson.E) error { + if c.IsFinished() { + panic("internal error: cursor already finished!") + } + + nsDB, nsColl, found := strings.Cut(c.ns, ".") + if !found { + panic("Malformed namespace from cursor (expect a dot): " + c.ns) + } + if nsDB != c.db.Name() { + panic(fmt.Sprintf("db from cursor (%s) mismatches db struct (%s)", nsDB, c.db.Name())) + } + + cmd := bson.D{ + {"getMore", c.id}, + {"collection", nsColl}, + } + + if awaitTime, has := c.maxAwaitTime.Get(); has { + cmd = append(cmd, bson.E{"maxTimeMS", awaitTime.Milliseconds()}) + } + + cmd = append(cmd, extraPieces...) + + if c.sess != nil { + ctx = mongo.NewSessionContext(ctx, c.sess) + } + resp := c.db.RunCommand(ctx, cmd) + + raw, err := resp.Raw() + if err != nil { + return fmt.Errorf("iterating %#q’s cursor: %w", c.ns, err) + } + + nextBatch, err := raw.LookupErr("cursor", "nextBatch") + if err != nil { + return errors.Wrap(err, "extracting nextBatch") + } + + var ok bool + c.curBatch, ok = nextBatch.ArrayOK() + if !ok { + return fmt.Errorf("nextBatch should be BSON %s but found %s", bson.TypeArray, nextBatch.Type) + } + + c.rawResp = raw + + cursorID, err := raw.LookupErr("cursor", "id") + if err != nil { + return errors.Wrap(err, "extracting cursor ID") + } + + c.id, ok = cursorID.AsInt64OK() + if !ok { + return fmt.Errorf("cursor.id should be numeric but found BSON %s", cursorID.Type) + } + + return nil +} + +type cursorResponse struct { + ID int64 + Ns string + + // These are both BSON arrays. We use bson.Raw here to delay parsing + // and avoid allocating a large slice. + FirstBatch bson.Raw + NextBatch bson.Raw +} + +type baseResponse struct { + Cursor cursorResponse +} + +// New creates a Cursor from the response of a cursor-returning command +// like `find` or `bulkWrite`. +// +// Use this control (rather than the Go driver’s cursor implementation) +// to extract parts of the cursor responses that the driver’s API doesn’t +// expose. This is useful, e.g., to do a resumable $natural scan by +// extracting resume tokens from `find` responses. +// +// See NewWithSession() as well. +func New( + db *mongo.Database, + resp *mongo.SingleResult, +) (*BatchCursor, error) { + raw, err := resp.Raw() + if err != nil { + return nil, errors.Wrapf(err, "cursor open failed") + } + + baseResp := baseResponse{} + + err = bson.Unmarshal(raw, &baseResp) + if err != nil { + return nil, errors.Wrapf(err, "failed to decode cursor-open response to %T", baseResp) + } + + return &BatchCursor{ + db: db, + id: baseResp.Cursor.ID, + ns: baseResp.Cursor.Ns, + rawResp: raw, + curBatch: bson.RawArray(baseResp.Cursor.FirstBatch), + }, nil +} + +func (c *BatchCursor) SetSession(sess *mongo.Session) { + c.sess = sess +} + +func (c *BatchCursor) SetMaxAwaitTime(d time.Duration) { + c.maxAwaitTime = option.Some(d) +} + +// GetResumeToken is a convenience function that extracts the +// post-batch resume token from the cursor. +func GetResumeToken(c *BatchCursor) (bson.Raw, error) { + var resumeToken bson.Raw + + tokenRV, err := c.rawResp.LookupErr("cursor", "postBatchResumeToken") + if err != nil { + return nil, errors.Wrapf(err, "extracting change stream’s resume token") + } + + resumeToken, err = mbson.CastRawValue[bson.Raw](tokenRV) + if err != nil { + return nil, errors.Wrap( + err, + "parsing change stream’s resume token", + ) + } + + return resumeToken, nil +} diff --git a/mseq/mseq.go b/mseq/mseq.go new file mode 100644 index 00000000..c4b41ec5 --- /dev/null +++ b/mseq/mseq.go @@ -0,0 +1,29 @@ +package mseq + +import "iter" + +// FromSlice returns an iterator over a slice. +// +// NB: See slices.Collect for the opposite operation. +func FromSlice[T any](s []T) iter.Seq[T] { + return func(yield func(T) bool) { + for _, v := range s { + if !yield(v) { + return + } + } + } +} + +// FromSliceWithNilErr is like FromSlice but returns a Seq2 +// whose second return is always a nil error. This is useful +// in testing. +func FromSliceWithNilErr[T any](s []T) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + for _, v := range s { + if !yield(v, nil) { + return + } + } + } +}