Skip to content

Commit

Permalink
Merge pull request #448 from andyNewman42/locking
Browse files Browse the repository at this point in the history
Add advisory locking to mongodb
  • Loading branch information
dhui committed Sep 26, 2020
2 parents 50439fe + 2c2f691 commit c602605
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 13 deletions.
4 changes: 4 additions & 0 deletions database/mongodb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
|------------|---------------------|-------------|
| `x-migrations-collection` | `MigrationsCollection` | Name of the migrations collection |
| `x-transaction-mode` | `TransactionMode` | If set to `true` wrap commands in [transaction](https://docs.mongodb.com/manual/core/transactions). Available only for replica set. Driver is using [strconv.ParseBool](https://golang.org/pkg/strconv/#ParseBool) for parsing|
| `x-advisory-locking` | `true` | Feature flag for advisory locking, if set to false, disable advisory locking |
| `x-advisory-lock-collection` | `migrate_advisory_lock` | The name of the collection to use for advisory locking.|
| `x-advisory-lock-timout` | `15` | The max time in seconds that the advisory lock will wait if the db is already locked. |
| `x-advisory-lock-timout-interval` | `10` | The max timeout in seconds interval that the advisory lock will wait if the db is already locked. |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `user` | | The user to sign in as. Can be omitted |
| `password` | | The user's password. Can be omitted |
Expand Down
212 changes: 200 additions & 12 deletions database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ package mongodb
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/url"
"strconv"

"github.com/cenkalti/backoff/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
"io"
"io/ioutil"
"net/url"
os "os"
"strconv"
"time"
)

func init() {
Expand All @@ -23,6 +26,14 @@ func init() {

var DefaultMigrationsCollection = "schema_migrations"

const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released.
const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout.
const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true.
const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field.
const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for.

var (
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNilConfig = fmt.Errorf("no config")
Expand All @@ -31,21 +42,36 @@ var (
type Mongo struct {
client *mongo.Client
db *mongo.Database

config *Config
}

type Locking struct {
CollectionName string
Timeout int
Enabled bool
Interval int
}
type Config struct {
DatabaseName string
MigrationsCollection string
TransactionMode bool
Locking Locking
}

type versionInfo struct {
Version int `bson:"version"`
Dirty bool `bson:"dirty"`
}

type lockObj struct {
Key int `bson:"locking_key"`
Pid int `bson:"pid"`
Hostname string `bson:"hostname"`
CreatedAt time.Time `bson:"created_at"`
}
type findFilter struct {
Key int `bson:"locking_key"`
}

func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
Expand All @@ -56,48 +82,121 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro
if len(config.MigrationsCollection) == 0 {
config.MigrationsCollection = DefaultMigrationsCollection
}
if len(config.Locking.CollectionName) == 0 {
config.Locking.CollectionName = DefaultLockingCollection
}
if config.Locking.Timeout <= 0 {
config.Locking.Timeout = DefaultLockTimeout
}
if config.Locking.Interval <= 0 {
config.Locking.Interval = DefaultLockTimeoutInterval
}

mc := &Mongo{
client: instance,
db: instance.Database(config.DatabaseName),
config: config,
}

if mc.config.Locking.Enabled {
if err := mc.ensureLockTable(); err != nil {
return nil, err
}
}
if err := mc.ensureVersionTable(); err != nil {
return nil, err
}

return mc, nil
}

func (m *Mongo) Open(dsn string) (database.Driver, error) {
//connsting is experimental package, but it used for parse connection string in mongo.Connect function
//connstring is experimental package, but it used for parse connection string in mongo.Connect function
uri, err := connstring.Parse(dsn)
if err != nil {
return nil, err
}
if len(uri.Database) == 0 {
return nil, ErrNoDatabaseName
}

unknown := url.Values(uri.UnknownOptions)

migrationsCollection := unknown.Get("x-migrations-collection")
transactionMode, _ := strconv.ParseBool(unknown.Get("x-transaction-mode"))

lockCollection := unknown.Get("x-advisory-lock-collection")
transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
if err != nil {
return nil, err
}
advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
if err != nil {
return nil, err
}
lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
if err != nil {
return nil, err
}
maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval)
if err != nil {
return nil, err
}
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
if err != nil {
return nil, err
}

if err = client.Ping(context.TODO(), nil); err != nil {
return nil, err
}
mc, err := WithInstance(client, &Config{
DatabaseName: uri.Database,
MigrationsCollection: migrationsCollection,
TransactionMode: transactionMode,
Locking: Locking{
CollectionName: lockCollection,
Timeout: lockingTimout,
Enabled: advisoryLockingFlag,
Interval: maxLockingIntervals,
},
})
if err != nil {
return nil, err
}
return mc, nil
}

//Parse the url param, convert it to boolean
// returns error if param invalid. returns defaultValue if param not present
func parseBoolean(urlParam string, defaultValue bool) (bool, error) {

// if parameter passed, parse it (otherwise return default value)
if urlParam != "" {
result, err := strconv.ParseBool(urlParam)
if err != nil {
return false, err
}
return result, nil
}

// if no url Param passed, return default value
return defaultValue, nil
}

//Parse the url param, convert it to int
// returns error if param invalid. returns defaultValue if param not present
func parseInt(urlParam string, defaultValue int) (int, error) {

// if parameter passed, parse it (otherwise return default value)
if urlParam != "" {
result, err := strconv.Atoi(urlParam)
if err != nil {
return -1, err
}
return result, nil
}

// if no url Param passed, return default value
return defaultValue, nil
}
func (m *Mongo) SetVersion(version int, dirty bool) error {
migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
if err := migrationsCollection.Drop(context.TODO()); err != nil {
Expand Down Expand Up @@ -184,10 +283,99 @@ func (m *Mongo) Drop() error {
return m.db.Drop(context.TODO())
}

func (m *Mongo) Lock() error {
func (m *Mongo) ensureLockTable() error {
indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()

indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
Options: indexOptions,
Keys: findFilter{Key: -1},
})
if err != nil {
return err
}
return nil
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the MongoDb type.
func (m *Mongo) ensureVersionTable() (err error) {
if err = m.Lock(); err != nil {
return err
}

defer func() {
if e := m.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()

if err != nil {
return err
}
if _, _, err = m.Version(); err != nil {
return err
}
return nil
}

// Utilizes advisory locking on the config.LockingCollection collection
// This uses a unique index on the `locking_key` field.
func (m *Mongo) Lock() error {
if !m.config.Locking.Enabled {
return nil
}
pid := os.Getpid()
hostname, err := os.Hostname()
if err != nil {
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
}

newLockObj := lockObj{
Key: lockKeyUniqueValue,
Pid: pid,
Hostname: hostname,
CreatedAt: time.Now(),
}
operation := func() error {
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
defer cancelFunc()
return err
}
exponentialBackOff := backoff.NewExponentialBackOff()
duration := time.Duration(m.config.Locking.Timeout) * time.Second
exponentialBackOff.MaxElapsedTime = duration
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second

err = backoff.Retry(operation, exponentialBackOff)
if err != nil {
return database.ErrLocked
}

return nil

}
func (m *Mongo) Unlock() error {
if !m.config.Locking.Enabled {
return nil
}

filter := findFilter{
Key: lockKeyUniqueValue,
}

ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
defer cancel()

if err != nil {
return err
}
return nil
}
69 changes: 68 additions & 1 deletion database/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func Test(t *testing.T) {
}
}()
dt.TestNilVersion(t, d)
//TestLockAndUnlock(t, d) driver doesn't support lock on database level
dt.TestLockAndUnlock(t, d)
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
dt.TestSetVersion(t, d)
dt.TestDrop(t, d)
Expand Down Expand Up @@ -180,6 +180,73 @@ func TestWithAuth(t *testing.T) {
})
}

func TestLockWorks(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := mongoConnectionString(ip, port)
p := &Mongo{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()

dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))

mc := d.(*Mongo)

err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Unlock()
if err != nil {
t.Fatal(err)
}

err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Unlock()
if err != nil {
t.Fatal(err)
}

// disable locking, validate wer can lock twice
mc.config.Locking.Enabled = false
err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Lock()
if err != nil {
t.Fatal(err)
}

// re-enable locking,
//try to hit a lock conflict
mc.config.Locking.Enabled = true
mc.config.Locking.Timeout = 1
err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Lock()
if err == nil {
t.Fatal("should have failed, mongo should be locked already")
}
})
}

func TestTransaction(t *testing.T) {
transactionSpecs := []dktesting.ContainerSpec{
{ImageName: "mongo:4", Options: dktest.Options{PortRequired: true, ReadyFunc: isReady,
Expand Down
Loading

0 comments on commit c602605

Please sign in to comment.