diff --git a/document.go b/document.go index 1592e12..6fac69a 100644 --- a/document.go +++ b/document.go @@ -145,7 +145,7 @@ func (doc *Document[T, P]) Save(ctx context.Context) error { return nil, err } - _, err := withTransaction(ctx, doc.Collection(), callback) + _, err := withAtomicity(ctx, doc.Collection(), callback) return err } @@ -168,7 +168,7 @@ func (doc *Document[T, P]) Delete(ctx context.Context) error { return nil, err } - _, err = withTransaction(ctx, doc.Collection(), callback) + _, err = withAtomicity(ctx, doc.Collection(), callback) if err != nil { return err } diff --git a/model.go b/model.go index 849509e..7248b87 100644 --- a/model.go +++ b/model.go @@ -13,7 +13,11 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -type sessFn func(sessCtx mongo.SessionContext) (interface{}, error) +type SessionFunc func(sessCtx mongo.SessionContext) (interface{}, error) + +type SessionLike interface { + *mongo.Database | *mongo.Collection | *mongo.SessionContext | *mongo.Client | *mongo.Session +} type Model[T Schema, P IDefaultSchema] struct { collection *mongo.Collection @@ -91,7 +95,7 @@ func (model *Model[T, P]) CreateOne(ctx context.Context, doc T, opts ...*mopt.In return nil, err } - _, err := withTransaction(ctx, model.collection, callback) + _, err := withAtomicity(ctx, model.collection, callback) if err != nil { return nil, err } @@ -123,7 +127,7 @@ func (model *Model[T, P]) CreateMany(ctx context.Context, docs []T, opts ...*mop return newDocs, err } - newDocs, err := withTransaction(ctx, model.collection, callback) + newDocs, err := withAtomicity(ctx, model.collection, callback) if err != nil { return nil, err } @@ -153,7 +157,7 @@ func (model *Model[T, P]) DeleteOne(ctx context.Context, query bson.M, opts ...* return res, err } - res, err := withTransaction(ctx, model.collection, callback) + res, err := withAtomicity(ctx, model.collection, callback) if err != nil { return nil, err } @@ -183,7 +187,7 @@ func (model *Model[T, P]) DeleteMany(ctx context.Context, query bson.M, opts ... return res, err } - res, err := withTransaction(ctx, model.collection, callback) + res, err := withAtomicity(ctx, model.collection, callback) if err != nil { return nil, err } @@ -375,7 +379,7 @@ func (model *Model[T, P]) UpdateOne(ctx context.Context, query bson.M, update bs err = runAfterUpdateHooks(sessCtx, ds, newHookArg[T](res, UpdateOne)) return res, err } - res, err := withTransaction(ctx, model.collection, callback) + res, err := withAtomicity(ctx, model.collection, callback) if err != nil { return nil, err } @@ -411,13 +415,44 @@ func (model *Model[T, P]) UpdateMany(ctx context.Context, query bson.M, update b return res, err } - res, err := withTransaction(ctx, model.collection, callback) + res, err := withAtomicity(ctx, model.collection, callback) if err != nil { return nil, err } return res.(*mongo.UpdateResult), nil } +// WithTransaction executes the callback function in a transaction. +// When a transaction is started with [mongo.SessionContext] options are ignored because the session is already created. +func WithTransaction[T SessionLike](ctx context.Context, sess T, fn SessionFunc, opts ...*options.TransactionOptions) (any, error) { + var session mongo.Session + var err error + switch sess := any(sess).(type) { + case *mongo.SessionContext: + return fn(*sess) + case *mongo.Session: + return (*sess).WithTransaction(ctx, fn, opts...) + case *mongo.Client: + session, err = sess.StartSession() + case *mongo.Database: + session, err = sess.Client().StartSession() + case *mongo.Collection: + session, err = sess.Database().Client().StartSession() + } + if err != nil { + return nil, err + } + defer session.EndSession(ctx) + return session.WithTransaction(ctx, fn, opts...) +} + +func withAtomicity(ctx context.Context, coll *mongo.Collection, callback SessionFunc) (interface{}, error) { + if ctx, ok := ctx.(mongo.SessionContext); ok { + return callback(ctx) + } + return WithTransaction(ctx, coll, callback) +} + // func (model *Model[T, P]) CountDocuments(ctx context.Context, // query bson.M, opts ...*options.CountOptions, // ) (int64, error) { @@ -472,17 +507,6 @@ func findWithPopulate[U int.UnionFindOpts, T Schema, P IDefaultSchema](ctx conte return docs, nil } -func withTransaction(ctx context.Context, coll *mongo.Collection, fn sessFn, opts ...*options.TransactionOptions) (interface{}, error) { - session, err := coll.Database().Client().StartSession() - if err != nil { - return nil, err - } - defer session.EndSession(ctx) - - res, err := session.WithTransaction(ctx, fn, opts...) - return res, err -} - func getObjectId(id any) (*primitive.ObjectID, error) { var oid primitive.ObjectID switch id := id.(type) { diff --git a/model_test.go b/model_test.go index d4b611b..690ad03 100644 --- a/model_test.go +++ b/model_test.go @@ -2,6 +2,7 @@ package mgs_test import ( "context" + "errors" "fmt" "math/rand" "testing" @@ -287,6 +288,95 @@ func TestModel_Populate(t *testing.T) { }) } +func TestWithTransaction(t *testing.T) { + ctx := context.Background() + db, cleanup := getDb(ctx) + defer cleanup(ctx) + + bookModel := mgs.NewModel[Book, *mgs.DefaultSchema](db.Collection("books")) + genBooks := generateBooks(ctx, db) + + t.Run("Should run transaction with mongo.Client", func(t *testing.T) { + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + err := genBooks[0].Delete(sessCtx) + return nil, err + } + _, err := mgs.WithTransaction(ctx, db.Client(), callback) + assert.NoError(t, err, "WithTransaction should not return error") + }) + + t.Run("Should run transaction with mongo.Database", func(t *testing.T) { + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + err := genBooks[1].Delete(sessCtx) + return nil, err + } + + _, err := mgs.WithTransaction(ctx, db, callback) + assert.NoError(t, err, "WithTransaction should not return error") + }) + + t.Run("Should run transaction with mongo.Collection", func(t *testing.T) { + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + err := genBooks[2].Delete(sessCtx) + return nil, err + } + _, err := mgs.WithTransaction(ctx, db.Collection("books"), callback) + assert.NoError(t, err, "WithTransaction should not return error") + book, err := bookModel.FindById(ctx, genBooks[2].GetID()) + assert.Error(t, err, "FindById return error") + assert.Nil(t, book, "book should be nil") + }) + + t.Run("Should run transaction with mongo.Session", func(t *testing.T) { + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + err := genBooks[3].Delete(sessCtx) + return nil, err + } + + sess, err := db.Client().StartSession() + if err != nil { + t.Fatal(err) + } + defer sess.EndSession(ctx) + _, err = mgs.WithTransaction(ctx, &sess, callback) + assert.NoError(t, err, "WithTransaction should not return error") + book, err := bookModel.FindById(ctx, genBooks[3].GetID()) + assert.Error(t, err, "FindById return error") + assert.Nil(t, book, "book should be nil") + }) + + t.Run("Should run transaction with mongo.SessionContext", func(t *testing.T) { + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + callback2 := func(sCtx mongo.SessionContext) (interface{}, error) { + err := genBooks[4].Delete(sCtx) + return nil, err + } + + if _, err := mgs.WithTransaction(context.TODO(), &sessCtx, callback2); err != nil { + return nil, err + } + + genBooks[5].Doc.Title = "This is a test title" + if err = genBooks[5].Save(sessCtx); err != nil { + return nil, err + } + + return nil, errors.New("this is a test error") + } + _, err := mgs.WithTransaction(ctx, db, callback) + assert.Error(t, err, "WithTransaction should return error") + + book, err := bookModel.FindById(ctx, genBooks[5].GetID()) + if err != nil { + t.Fatal(err) + } + assert.True(t, book.Doc.Title != "This is a test title", "WithTransaction should rollback changes") + + _, err = bookModel.FindById(ctx, genBooks[4].GetID()) + assert.NoError(t, err, "WithTransaction should rollback changes") + }) +} + func TestModelMongodbErrors(t *testing.T) { mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock)) defer mt.Close()