Skip to content

Commit

Permalink
Calling WithConnection from WithInstance to de-duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
Sébastien CAPARROS committed Jun 29, 2021
1 parent f98ad3a commit 394279a
Showing 1 changed file with 20 additions and 39 deletions.
59 changes: 20 additions & 39 deletions database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type Mysql struct {
}

// connection instance must have `multiStatements` set to true
func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) {
func WithConnection(conn *sql.Conn, config *Config) (*Mysql, error) {
if config == nil {
return nil, ErrNilConfig
}
Expand All @@ -66,8 +66,22 @@ func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) {
config: config,
}

if err := mx.setupDefaultConfig(); err != nil {
return nil, err
if config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

if len(databaseName.String) == 0 {
return nil, ErrNoDatabaseName
}

config.DatabaseName = databaseName.String
}

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}

if err := mx.ensureVersionTable(); err != nil {
Expand All @@ -79,10 +93,6 @@ func WithConnection(conn *sql.Conn, config *Config) (database.Driver, error) {

// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
return nil, err
}
Expand All @@ -92,45 +102,16 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return nil, err
}

mx := &Mysql{
conn: conn,
db: instance,
config: config,
}

if err := mx.setupDefaultConfig(); err != nil {
mx, err := WithConnection(conn, config)
if err != nil {
return nil, err
}

if err := mx.ensureVersionTable(); err != nil {
return nil, err
}
mx.db = instance

return mx, nil
}

func (m *Mysql) setupDefaultConfig() error {
if m.config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if len(databaseName.String) == 0 {
return ErrNoDatabaseName
}

m.config.DatabaseName = databaseName.String
}

if len(m.config.MigrationsTable) == 0 {
m.config.MigrationsTable = DefaultMigrationsTable
}

return nil
}

// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
Expand Down

0 comments on commit 394279a

Please sign in to comment.