diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 90e3926df..f9db2abc0 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -58,22 +58,28 @@ type SQLServer struct { config *Config } -// WithInstance returns a database instance from an already created database connection. +// WithConnection returns a database driver instance from an already created database connection. +// The connection will be closed when the database driver is closed. // // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. -func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { +func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*SQLServer, error) { if config == nil { return nil, ErrNilConfig } - if err := instance.Ping(); err != nil { + if err := conn.PingContext(ctx); err != nil { return nil, err } + ss := SQLServer{ + conn: conn, + config: config, + } + if config.DatabaseName == "" { query := `SELECT DB_NAME()` var databaseName string - if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -87,7 +93,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config.SchemaName == "" { query := `SELECT SCHEMA_NAME()` var schemaName string - if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + if err := conn.QueryRowContext(ctx, query).Scan(&schemaName); err != nil { return nil, &database.Error{OrigErr: err, Query: []byte(query)} } @@ -102,22 +108,36 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.MigrationsTable = DefaultMigrationsTable } - conn, err := instance.Conn(context.Background()) + if err := ss.ensureVersionTable(); err != nil { + return nil, err + } - if err != nil { + return &ss, nil +} + +// WithInstance returns a database driver instance from an already created database handle. +// The database handle will be closed when the database driver is closed. +// +// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. +func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { + ctx := context.Background() + + if err := instance.Ping(); err != nil { return nil, err } - ss := &SQLServer{ - conn: conn, - db: instance, - config: config, + conn, err := instance.Conn(ctx) + if err != nil { + return nil, err } - if err := ss.ensureVersionTable(); err != nil { + ss, err := WithConnection(ctx, conn, config) + if err != nil { return nil, err } + ss.db = instance + return ss, nil } @@ -183,7 +203,10 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { // Close the database connection func (ss *SQLServer) Close() error { connErr := ss.conn.Close() - dbErr := ss.db.Close() + var dbErr error + if ss.db != nil { + dbErr = ss.db.Close() + } if connErr != nil || dbErr != nil { return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) } diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index afe7fd253..9055197af 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -120,6 +120,53 @@ func Test(t *testing.T) { }) } +func TestWithConnection(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ctx := context.Background() + + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + db, err := sql.Open("sqlserver", msConnectionString(ip, port)) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + p, err := WithConnection(ctx, conn, &Config{}) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := p.Close(); err != nil { + t.Error(err) + } + // Ensure connection is closed after database provider close + _, err := conn.QueryContext(ctx, "SELECT 1") + if err != sql.ErrConnDone { + t.Error("connection not closed") + } + _, err = db.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Error("database handle should not be closed") + } + }() + dt.Test(t, p, []byte("SELECT 1")) + }) +} + func TestMigrate(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { SkipIfUnsupportedArch(t, c)