Skip to content

Commit

Permalink
Implement WithConnection for sqlserver database driver
Browse files Browse the repository at this point in the history
  • Loading branch information
selaux committed Aug 16, 2022
1 parent 03613f1 commit a5148ef
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 13 deletions.
49 changes: 36 additions & 13 deletions database/sqlserver/sqlserver.go
Expand Up @@ -58,22 +58,28 @@ type SQLServer struct {
config *Config
}

// WithInstance returns a database instance from an already created database connection.
// WithConnection returns a database driver instance from an already created database connection.
// The connection will be closed when the database driver is closed.
//
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*SQLServer, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := conn.PingContext(ctx); err != nil {
return nil, err
}

ss := SQLServer{
conn: conn,
config: config,
}

if config.DatabaseName == "" {
query := `SELECT DB_NAME()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -87,7 +93,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT SCHEMA_NAME()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -102,22 +108,36 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())
if err := ss.ensureVersionTable(); err != nil {
return nil, err
}

if err != nil {
return &ss, nil
}

// WithInstance returns a database driver instance from an already created database handle.
// The database handle will be closed when the database driver is closed.
//
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()

if err := instance.Ping(); err != nil {
return nil, err
}

ss := &SQLServer{
conn: conn,
db: instance,
config: config,
conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}

if err := ss.ensureVersionTable(); err != nil {
ss, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}

ss.db = instance

return ss, nil
}

Expand Down Expand Up @@ -183,7 +203,10 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) {
// Close the database connection
func (ss *SQLServer) Close() error {
connErr := ss.conn.Close()
dbErr := ss.db.Close()
var dbErr error
if ss.db != nil {
dbErr = ss.db.Close()
}
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down
47 changes: 47 additions & 0 deletions database/sqlserver/sqlserver_test.go
Expand Up @@ -120,6 +120,53 @@ func Test(t *testing.T) {
})
}

func TestWithConnection(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ctx := context.Background()

ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("sqlserver", msConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()

conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}

p, err := WithConnection(ctx, conn, &Config{})
if err != nil {
t.Fatal(err)
}

defer func() {
if err := p.Close(); err != nil {
t.Error(err)
}
// Ensure connection is closed after database provider close
_, err := conn.QueryContext(ctx, "SELECT 1")
if err != sql.ErrConnDone {
t.Error("connection not closed")
}
_, err = db.QueryContext(ctx, "SELECT 1")
if err != nil {
t.Error("database handle should not be closed")
}
}()
dt.Test(t, p, []byte("SELECT 1"))
})
}

func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
SkipIfUnsupportedArch(t, c)
Expand Down

0 comments on commit a5148ef

Please sign in to comment.