Skip to content
This repository has been archived by the owner on Jan 21, 2022. It is now read-only.

Commit

Permalink
Fix race condition with DB connections
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanschneider committed May 21, 2015
1 parent 2141b06 commit d428c7f
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 20 deletions.
16 changes: 12 additions & 4 deletions provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@ import (

// fmt template paramters: 1.databaseId
var createDatabaseTemplate = []string{
"use master",
"create database [%[1]v] containment = partial",
}

// fmt template parameters: 1.databaseId
var deleteDatabaseTemplate = []string{
"use master",
"alter database [%[1]v] set single_user with rollback immediate",
"drop database [%[1]v]",
}
Expand Down Expand Up @@ -72,6 +70,11 @@ func (provisioner *MssqlProvisioner) Init() error {
var err error = nil
connString := buildConnectionString(provisioner.connectionParams)
provisioner.dbClient, err = sql.Open(provisioner.goSqlDriver, connString)

// Set idle connections to 0 to prevent keeping open databases
// Enabling idle connections will create problems with ODBC driver when deleting DBs
provisioner.dbClient.SetMaxIdleConns(0)

if err != nil {
return err
}
Expand All @@ -84,6 +87,11 @@ func (provisioner *MssqlProvisioner) Init() error {
return nil
}

func (provisioner *MssqlProvisioner) Close() error {
err := provisioner.dbClient.Close()
return err
}

func (provisioner *MssqlProvisioner) CreateDatabase(databaseId string) error {
return provisioner.executeTemplateWithoutTx(createDatabaseTemplate, databaseId)
}
Expand All @@ -93,11 +101,11 @@ func (provisioner *MssqlProvisioner) DeleteDatabase(databaseId string) error {
}

func (provisioner *MssqlProvisioner) CreateUser(databaseId, userId, password string) error {
return provisioner.executeTemplateWithoutTx(createUserTemplate, databaseId, userId, password)
return provisioner.executeTemplateWithTx(createUserTemplate, databaseId, userId, password)
}

func (provisioner *MssqlProvisioner) DeleteUser(databaseId, userId string) error {
return provisioner.executeTemplateWithoutTx(deleteUserTemplate, databaseId, userId)
return provisioner.executeTemplateWithTx(deleteUserTemplate, databaseId, userId)
}

func (provisioner *MssqlProvisioner) IsDatabaseCreated(databaseId string) (bool, error) {
Expand Down
183 changes: 167 additions & 16 deletions provisioner/provisioner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,37 @@ import (
)

var logger *lagertest.TestLogger = lagertest.NewTestLogger("mssql-provisioner")
var mssqlPars = map[string]string{
"driver": "sql server",
"server": "127.0.0.1",

var odbcPars = map[string]string{
"driver": "{SQL Server Native Client 11.0}", // or with an older driver version "{SQL Server}"
"server": "127.0.0.1", // or (local)\\sqlexpress
"database": "master",
"trusted_connection": "yes",
}

var mssqlPars = map[string]string{
"server": "127.0.0.1",
"port": "38017",
"database": "master",
"user id": "sa",
"password": "password",
}

func TestCreateDatabaseOdbcDriver(t *testing.T) {
dbName := "cf-broker-testing.create-db"

sqlClient, err := sql.Open("odbc", buildConnectionString(mssqlPars))
sqlClient, err := sql.Open("odbc", buildConnectionString(odbcPars))
defer sqlClient.Close()

sqlClient.Exec("drop database [" + dbName + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", mssqlPars)
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
err = mssqlProv.Init()
if err != nil {
t.Errorf("Provisioner init error, %v", err)
}
defer mssqlProv.Close()

// Act
err = mssqlProv.CreateDatabase(dbName)
Expand All @@ -50,17 +60,19 @@ func TestCreateDatabaseOdbcDriver(t *testing.T) {
func TestDeleteDatabaseOdbcDriver(t *testing.T) {
dbName := "cf-broker-testing.delete-db"

sqlClient, err := sql.Open("odbc", buildConnectionString(mssqlPars))
sqlClient, err := sql.Open("odbc", buildConnectionString(odbcPars))
defer sqlClient.Close()

sqlClient.Exec("drop database [" + dbName + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", mssqlPars)
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
err = mssqlProv.Init()
if err != nil {
t.Errorf("Database init error, %v", err)
}
defer mssqlProv.Close()

err = mssqlProv.CreateDatabase(dbName)

// Act
Expand All @@ -80,18 +92,19 @@ func TestDeleteDatabaseOdbcDriver(t *testing.T) {
}
}

func TestDeleteUserOdbcDriver(t *testing.T) {
func TestCreateUserOdbcDriver(t *testing.T) {
dbName := "cf-broker-testing.create-db"
userNanme := "cf-broker-testing.create-user"

sqlClient, err := sql.Open("odbc", buildConnectionString(mssqlPars))
sqlClient, err := sql.Open("odbc", buildConnectionString(odbcPars))
defer sqlClient.Close()

sqlClient.Exec("drop database [" + dbName + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", mssqlPars)
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
err = mssqlProv.Init()

if err != nil {
t.Errorf("Provisioner init error, %v", err)
}
Expand All @@ -108,6 +121,7 @@ func TestDeleteUserOdbcDriver(t *testing.T) {
if err != nil {
t.Errorf("User create error, %v", err)
}

defer sqlClient.Exec("drop database [" + dbName + "]")

row := sqlClient.QueryRow(fmt.Sprintf("select count(*) from [%s].sys.database_principals where name = ?", dbName), userNanme)
Expand All @@ -118,21 +132,22 @@ func TestDeleteUserOdbcDriver(t *testing.T) {
}
}

func TestCreateUserOdbcDriver(t *testing.T) {
func TestDeleteUserOdbcDriver(t *testing.T) {
dbName := "cf-broker-testing.create-db"
userNanme := "cf-broker-testing.create-user"

sqlClient, err := sql.Open("odbc", buildConnectionString(mssqlPars))
sqlClient, err := sql.Open("odbc", buildConnectionString(odbcPars))
defer sqlClient.Close()

sqlClient.Exec("drop database [" + dbName + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", mssqlPars)
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
err = mssqlProv.Init()
if err != nil {
t.Errorf("Provisioner init error, %v", err)
}
defer mssqlProv.Close()

err = mssqlProv.CreateDatabase(dbName)
if err != nil {
Expand Down Expand Up @@ -187,7 +202,7 @@ func TestIsDatabaseCreatedOdbcDriver(t *testing.T) {
dbName := "cf-broker-testing.nonexisting-db"

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", mssqlPars)
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
err := mssqlProv.Init()
if err != nil {
t.Errorf("Provisioner init error, %v", err)
Expand All @@ -208,13 +223,13 @@ func TestIsDatabaseCreatedOdbcDriver(t *testing.T) {
func TestIsDatabaseCreatedOdbcDriver2(t *testing.T) {
dbName := "cf-broker-testing.create-db"

sqlClient, err := sql.Open("odbc", buildConnectionString(mssqlPars))
sqlClient, err := sql.Open("odbc", buildConnectionString(odbcPars))
defer sqlClient.Close()

sqlClient.Exec("drop database [" + dbName + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", mssqlPars)
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
mssqlProv.Init()
if err != nil {
t.Errorf("Provisioner init error, %v", err)
Expand All @@ -237,3 +252,139 @@ func TestIsDatabaseCreatedOdbcDriver2(t *testing.T) {

defer sqlClient.Exec("drop database [" + dbName + "]")
}

func TestStressOdbcDriver(t *testing.T) {
dbName := "cf-broker-testing.create-db"
dbNameA := "cf-broker-testing.create-db-A"
dbName2 := "cf-broker-testing.create-db-2"
userNanme := "cf-broker-testing.create-user"

sqlClient, err := sql.Open("odbc", buildConnectionString(odbcPars))
defer sqlClient.Close()

sqlClient.Exec("drop database [" + dbName + "]")
sqlClient.Exec("drop database [" + dbNameA + "]")
sqlClient.Exec("drop database [" + dbName2 + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "odbc", odbcPars)
err = mssqlProv.Init()
if err != nil {
t.Errorf("Provisioner init error, %v", err)
}

err = mssqlProv.CreateDatabase(dbName)
if err != nil {
t.Errorf("Database create error, %v", err)
}

err = mssqlProv.CreateDatabase(dbNameA)
if err != nil {
t.Errorf("Database create error, %v", err)
}

wait := make(chan bool)

go func() {
for i := 1; i < 8; i++ {

err := mssqlProv.CreateDatabase(dbName2)
if err != nil {
t.Errorf("Database create error, %v", err)
break
}

err = mssqlProv.DeleteDatabase(dbName2)
if err != nil {
t.Errorf("Database delete error, %v", err)
break
}
}

wait <- true
}()

go func() {
for i := 1; i < 32; i++ {
err = mssqlProv.CreateUser(dbName, userNanme, "passwordAa_0")
if err != nil {
t.Errorf("User create error, %v", err)
break
}

err = mssqlProv.DeleteUser(dbName, userNanme)
if err != nil {
t.Errorf("User delete error, %v", err)
break
}

}

wait <- true
}()

go func() {
for i := 1; i < 32; i++ {
err = mssqlProv.CreateUser(dbNameA, userNanme, "passwordAa_0")
if err != nil {
t.Errorf("User create error, %v", err)
break
}

err = mssqlProv.DeleteUser(dbNameA, userNanme)
if err != nil {
t.Errorf("User delete error, %v", err)
break
}

}

wait <- true
}()

<-wait
<-wait
<-wait

sqlClient.Exec("drop database [" + dbName + "]")
sqlClient.Exec("drop database [" + dbName2 + "]")
sqlClient.Exec("drop database [" + dbNameA + "]")
}

func TestCreateDatabaseMssqlDriver(t *testing.T) {
dbName := "cf-broker-testing.create-db"

sqlClient, err := sql.Open("mssql", buildConnectionString(mssqlPars))
defer sqlClient.Close()

err = sqlClient.Ping()
if err != nil {
t.Skipf("Could not connect with pure mssql driver to %v", mssqlPars)
return
}

sqlClient.Exec("drop database [" + dbName + "]")

logger = lagertest.NewTestLogger("process-controller")
mssqlProv := NewMssqlProvisioner(logger, "mssql", mssqlPars)
err = mssqlProv.Init()
if err != nil {
t.Errorf("Provisioner init error, %v", err)
}

// Act
err = mssqlProv.CreateDatabase(dbName)

// Assert
if err != nil {
t.Errorf("Database create error, %v", err)
}
defer sqlClient.Exec("drop database [" + dbName + "]")

row := sqlClient.QueryRow("SELECT count(*) FROM sys.databases where name = ?", dbName)
dbCount := 0
row.Scan(&dbCount)
if dbCount == 0 {
t.Errorf("Database was not created")
}
}

0 comments on commit d428c7f

Please sign in to comment.