Skip to content

Commit

Permalink
fix: connect to previous database on error
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgerojas26 committed Jul 10, 2024
1 parent f14d255 commit 02b5119
Showing 1 changed file with 58 additions and 29 deletions.
87 changes: 58 additions & 29 deletions drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ import (
)

type Postgres struct {
Connection *sql.DB
Provider string
CurrentDatabase string
Urlstr string
Connection *sql.DB
Provider string
CurrentDatabase string
PreviousDatabase string
Urlstr string
}

const (
DEFAULT_PORT = "5432"
)

func (db *Postgres) Connect(urlstr string) (err error) {
db.SetProvider("postgres")

Expand All @@ -38,7 +43,12 @@ func (db *Postgres) Connect(urlstr string) (err error) {

rows := db.Connection.QueryRow("SELECT current_database();")

err = rows.Scan(&db.CurrentDatabase)
database := ""

err = rows.Scan(&database)

db.CurrentDatabase = database
db.PreviousDatabase = database
if err != nil {
return err
}
Expand Down Expand Up @@ -71,38 +81,24 @@ func (db *Postgres) GetDatabases() (databases []string, err error) {
func (db *Postgres) GetTables(database string) (tables map[string][]string, err error) {
tables = make(map[string][]string)

if database != db.CurrentDatabase {
parsedConn, err := dburl.Parse(db.Urlstr)
if err != nil {
return tables, err
}

user := parsedConn.User.Username()
password, _ := parsedConn.User.Password()
host := parsedConn.Host
port := parsedConn.Port()
dbname := parsedConn.Path
switchDatabase := false

if port == "" {
port = "5432"
}

if dbname == "" {
dbname = database
}

db.Connection.Close()

db.Connection, err = sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname))
if database != db.CurrentDatabase {
err = db.SwitchDatabase(database)
if err != nil {
return tables, err
}

db.CurrentDatabase = database
switchDatabase = true
}

rows, err := db.Connection.Query(fmt.Sprintf("SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = '%s'", database))
if err != nil {
if switchDatabase {
err = db.SwitchDatabase(db.PreviousDatabase)
if err != nil {
return tables, err
}
}
return tables, err
}

Expand Down Expand Up @@ -542,3 +538,36 @@ func (db *Postgres) SetProvider(provider string) {
func (db *Postgres) GetProvider() string {
return db.Provider
}

func (db *Postgres) SwitchDatabase(database string) error {
parsedConn, err := dburl.Parse(db.Urlstr)
if err != nil {
return err
}

user := parsedConn.User.Username()
password, _ := parsedConn.User.Password()
host := parsedConn.Host
port := parsedConn.Port()
dbname := parsedConn.Path

if port == "" {
port = DEFAULT_PORT
}

if dbname == "" {
dbname = database
}

connection, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname))
if err != nil {
return err
}

db.Connection.Close()
db.Connection = connection
db.PreviousDatabase = db.CurrentDatabase
db.CurrentDatabase = database

return nil
}

0 comments on commit 02b5119

Please sign in to comment.