diff --git a/database/mongodb/README.md b/database/mongodb/README.md index 4204581a9..32b736fae 100644 --- a/database/mongodb/README.md +++ b/database/mongodb/README.md @@ -13,6 +13,10 @@ |------------|---------------------|-------------| | `x-migrations-collection` | `MigrationsCollection` | Name of the migrations collection | | `x-transaction-mode` | `TransactionMode` | If set to `true` wrap commands in [transaction](https://docs.mongodb.com/manual/core/transactions). Available only for replica set. Driver is using [strconv.ParseBool](https://golang.org/pkg/strconv/#ParseBool) for parsing| +| `x-advisory-locking` | `true` | Feature flag for advisory locking, if set to false, disable advisory locking | +| `x-advisory-lock-collection` | `migrate_advisory_lock` | The name of the collection to use for advisory locking.| +| `x-advisory-lock-timout` | `15` | The max time in seconds that the advisory lock will wait if the db is already locked. | +| `x-advisory-lock-timout-interval` | `10` | The max timeout in seconds interval that the advisory lock will wait if the db is already locked. | | `dbname` | `DatabaseName` | The name of the database to connect to | | `user` | | The user to sign in as. Can be omitted | | `password` | | The user's password. Can be omitted | diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 95992ecb2..17ca804f2 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -3,16 +3,19 @@ package mongodb import ( "context" "fmt" - "io" - "io/ioutil" - "net/url" - "strconv" - + "github.com/cenkalti/backoff/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "io" + "io/ioutil" + "net/url" + os "os" + "strconv" + "time" ) func init() { @@ -23,6 +26,14 @@ func init() { var DefaultMigrationsCollection = "schema_migrations" +const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default. +const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked). +const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released. +const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout. +const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true. +const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field. +const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for. + var ( ErrNoDatabaseName = fmt.Errorf("no database name") ErrNilConfig = fmt.Errorf("no config") @@ -31,21 +42,36 @@ var ( type Mongo struct { client *mongo.Client db *mongo.Database - config *Config } +type Locking struct { + CollectionName string + Timeout int + Enabled bool + Interval int +} type Config struct { DatabaseName string MigrationsCollection string TransactionMode bool + Locking Locking } - type versionInfo struct { Version int `bson:"version"` Dirty bool `bson:"dirty"` } +type lockObj struct { + Key int `bson:"locking_key"` + Pid int `bson:"pid"` + Hostname string `bson:"hostname"` + CreatedAt time.Time `bson:"created_at"` +} +type findFilter struct { + Key int `bson:"locking_key"` +} + func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -56,17 +82,36 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro if len(config.MigrationsCollection) == 0 { config.MigrationsCollection = DefaultMigrationsCollection } + if len(config.Locking.CollectionName) == 0 { + config.Locking.CollectionName = DefaultLockingCollection + } + if config.Locking.Timeout <= 0 { + config.Locking.Timeout = DefaultLockTimeout + } + if config.Locking.Interval <= 0 { + config.Locking.Interval = DefaultLockTimeoutInterval + } + mc := &Mongo{ client: instance, db: instance.Database(config.DatabaseName), config: config, } + if mc.config.Locking.Enabled { + if err := mc.ensureLockTable(); err != nil { + return nil, err + } + } + if err := mc.ensureVersionTable(); err != nil { + return nil, err + } + return mc, nil } func (m *Mongo) Open(dsn string) (database.Driver, error) { - //connsting is experimental package, but it used for parse connection string in mongo.Connect function + //connstring is experimental package, but it used for parse connection string in mongo.Connect function uri, err := connstring.Parse(dsn) if err != nil { return nil, err @@ -74,16 +119,31 @@ func (m *Mongo) Open(dsn string) (database.Driver, error) { if len(uri.Database) == 0 { return nil, ErrNoDatabaseName } - unknown := url.Values(uri.UnknownOptions) migrationsCollection := unknown.Get("x-migrations-collection") - transactionMode, _ := strconv.ParseBool(unknown.Get("x-transaction-mode")) - + lockCollection := unknown.Get("x-advisory-lock-collection") + transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false) + if err != nil { + return nil, err + } + advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag) + if err != nil { + return nil, err + } + lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout) + if err != nil { + return nil, err + } + maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval) + if err != nil { + return nil, err + } client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn)) if err != nil { return nil, err } + if err = client.Ping(context.TODO(), nil); err != nil { return nil, err } @@ -91,6 +151,12 @@ func (m *Mongo) Open(dsn string) (database.Driver, error) { DatabaseName: uri.Database, MigrationsCollection: migrationsCollection, TransactionMode: transactionMode, + Locking: Locking{ + CollectionName: lockCollection, + Timeout: lockingTimout, + Enabled: advisoryLockingFlag, + Interval: maxLockingIntervals, + }, }) if err != nil { return nil, err @@ -98,6 +164,39 @@ func (m *Mongo) Open(dsn string) (database.Driver, error) { return mc, nil } +//Parse the url param, convert it to boolean +// returns error if param invalid. returns defaultValue if param not present +func parseBoolean(urlParam string, defaultValue bool) (bool, error) { + + // if parameter passed, parse it (otherwise return default value) + if urlParam != "" { + result, err := strconv.ParseBool(urlParam) + if err != nil { + return false, err + } + return result, nil + } + + // if no url Param passed, return default value + return defaultValue, nil +} + +//Parse the url param, convert it to int +// returns error if param invalid. returns defaultValue if param not present +func parseInt(urlParam string, defaultValue int) (int, error) { + + // if parameter passed, parse it (otherwise return default value) + if urlParam != "" { + result, err := strconv.Atoi(urlParam) + if err != nil { + return -1, err + } + return result, nil + } + + // if no url Param passed, return default value + return defaultValue, nil +} func (m *Mongo) SetVersion(version int, dirty bool) error { migrationsCollection := m.db.Collection(m.config.MigrationsCollection) if err := migrationsCollection.Drop(context.TODO()); err != nil { @@ -184,10 +283,99 @@ func (m *Mongo) Drop() error { return m.db.Drop(context.TODO()) } -func (m *Mongo) Lock() error { +func (m *Mongo) ensureLockTable() error { + indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes() + + indexOptions := options.Index().SetUnique(true).SetName(LockIndexName) + _, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{ + Options: indexOptions, + Keys: findFilter{Key: -1}, + }) + if err != nil { + return err + } + return nil +} + +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the MongoDb type. +func (m *Mongo) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + + if err != nil { + return err + } + if _, _, err = m.Version(); err != nil { + return err + } return nil } +// Utilizes advisory locking on the config.LockingCollection collection +// This uses a unique index on the `locking_key` field. +func (m *Mongo) Lock() error { + if !m.config.Locking.Enabled { + return nil + } + pid := os.Getpid() + hostname, err := os.Hostname() + if err != nil { + hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) + } + + newLockObj := lockObj{ + Key: lockKeyUniqueValue, + Pid: pid, + Hostname: hostname, + CreatedAt: time.Now(), + } + operation := func() error { + timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) + _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) + defer cancelFunc() + return err + } + exponentialBackOff := backoff.NewExponentialBackOff() + duration := time.Duration(m.config.Locking.Timeout) * time.Second + exponentialBackOff.MaxElapsedTime = duration + exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second + + err = backoff.Retry(operation, exponentialBackOff) + if err != nil { + return database.ErrLocked + } + + return nil + +} func (m *Mongo) Unlock() error { + if !m.config.Locking.Enabled { + return nil + } + + filter := findFilter{ + Key: lockKeyUniqueValue, + } + + ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) + _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) + defer cancel() + + if err != nil { + return err + } return nil } diff --git a/database/mongodb/mongodb_test.go b/database/mongodb/mongodb_test.go index c0d09f387..c73da46c4 100644 --- a/database/mongodb/mongodb_test.go +++ b/database/mongodb/mongodb_test.go @@ -92,7 +92,7 @@ func Test(t *testing.T) { } }() dt.TestNilVersion(t, d) - //TestLockAndUnlock(t, d) driver doesn't support lock on database level + dt.TestLockAndUnlock(t, d) dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`))) dt.TestSetVersion(t, d) dt.TestDrop(t, d) @@ -180,6 +180,73 @@ func TestWithAuth(t *testing.T) { }) } +func TestLockWorks(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := mongoConnectionString(ip, port) + p := &Mongo{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + + dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`))) + + mc := d.(*Mongo) + + err = mc.Lock() + if err != nil { + t.Fatal(err) + } + err = mc.Unlock() + if err != nil { + t.Fatal(err) + } + + err = mc.Lock() + if err != nil { + t.Fatal(err) + } + err = mc.Unlock() + if err != nil { + t.Fatal(err) + } + + // disable locking, validate wer can lock twice + mc.config.Locking.Enabled = false + err = mc.Lock() + if err != nil { + t.Fatal(err) + } + err = mc.Lock() + if err != nil { + t.Fatal(err) + } + + // re-enable locking, + //try to hit a lock conflict + mc.config.Locking.Enabled = true + mc.config.Locking.Timeout = 1 + err = mc.Lock() + if err != nil { + t.Fatal(err) + } + err = mc.Lock() + if err == nil { + t.Fatal("should have failed, mongo should be locked already") + } + }) +} + func TestTransaction(t *testing.T) { transactionSpecs := []dktesting.ContainerSpec{ {ImageName: "mongo:4", Options: dktest.Options{PortRequired: true, ReadyFunc: isReady, diff --git a/go.mod b/go.mod index b8746d8fb..fcef92a8d 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go v1.17.7 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect + github.com/cenkalti/backoff/v4 v4.0.2 github.com/cockroachdb/cockroach-go v0.0.0-20190925194419-606b3d062051 github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369 // indirect github.com/denisenkom/go-mssqldb v0.0.0-20200620013148-b91950f658ec diff --git a/go.sum b/go.sum index 2ef1a3e45..aea4b9257 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,8 @@ github.com/bkaradzic/go-lz4 v1.0.0 h1:RXc4wYsyz985CkXXeX04y4VnZFGG8Rd43pRaHsOXAK github.com/bkaradzic/go-lz4 v1.0.0/go.mod h1:0YdlkowM3VswSROI7qDxhRvJ3sLhlFrRRwjwegp5jy4= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/cenkalti/backoff/v4 v4.0.2 h1:JIufpQLbh4DkbQoii76ItQIUFzevQSqOLZca4eamEDs= +github.com/cenkalti/backoff/v4 v4.0.2/go.mod h1:eEew/i+1Q6OrCDZh3WiXYv3+nJwBASZ8Bog/87DQnVg= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=