From 02b51198a506e3c3967698e631d28b588819c3a4 Mon Sep 17 00:00:00 2001 From: Jorge Rojas Date: Wed, 10 Jul 2024 00:46:11 -0400 Subject: [PATCH] fix: connect to previous database on error --- drivers/postgres.go | 87 ++++++++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/drivers/postgres.go b/drivers/postgres.go index bb809b7..e95cbf8 100644 --- a/drivers/postgres.go +++ b/drivers/postgres.go @@ -13,12 +13,17 @@ import ( ) type Postgres struct { - Connection *sql.DB - Provider string - CurrentDatabase string - Urlstr string + Connection *sql.DB + Provider string + CurrentDatabase string + PreviousDatabase string + Urlstr string } +const ( + DEFAULT_PORT = "5432" +) + func (db *Postgres) Connect(urlstr string) (err error) { db.SetProvider("postgres") @@ -38,7 +43,12 @@ func (db *Postgres) Connect(urlstr string) (err error) { rows := db.Connection.QueryRow("SELECT current_database();") - err = rows.Scan(&db.CurrentDatabase) + database := "" + + err = rows.Scan(&database) + + db.CurrentDatabase = database + db.PreviousDatabase = database if err != nil { return err } @@ -71,38 +81,24 @@ func (db *Postgres) GetDatabases() (databases []string, err error) { func (db *Postgres) GetTables(database string) (tables map[string][]string, err error) { tables = make(map[string][]string) - if database != db.CurrentDatabase { - parsedConn, err := dburl.Parse(db.Urlstr) - if err != nil { - return tables, err - } - - user := parsedConn.User.Username() - password, _ := parsedConn.User.Password() - host := parsedConn.Host - port := parsedConn.Port() - dbname := parsedConn.Path + switchDatabase := false - if port == "" { - port = "5432" - } - - if dbname == "" { - dbname = database - } - - db.Connection.Close() - - db.Connection, err = sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname)) + if database != db.CurrentDatabase { + err = db.SwitchDatabase(database) if err != nil { return tables, err } - - db.CurrentDatabase = database + switchDatabase = true } rows, err := db.Connection.Query(fmt.Sprintf("SELECT table_name, table_schema FROM information_schema.tables WHERE table_catalog = '%s'", database)) if err != nil { + if switchDatabase { + err = db.SwitchDatabase(db.PreviousDatabase) + if err != nil { + return tables, err + } + } return tables, err } @@ -542,3 +538,36 @@ func (db *Postgres) SetProvider(provider string) { func (db *Postgres) GetProvider() string { return db.Provider } + +func (db *Postgres) SwitchDatabase(database string) error { + parsedConn, err := dburl.Parse(db.Urlstr) + if err != nil { + return err + } + + user := parsedConn.User.Username() + password, _ := parsedConn.User.Password() + host := parsedConn.Host + port := parsedConn.Port() + dbname := parsedConn.Path + + if port == "" { + port = DEFAULT_PORT + } + + if dbname == "" { + dbname = database + } + + connection, err := sql.Open("postgres", fmt.Sprintf("host=%s port=%s user=%s password=%s dbname='%s' sslmode=disable", host, port, user, password, dbname)) + if err != nil { + return err + } + + db.Connection.Close() + db.Connection = connection + db.PreviousDatabase = db.CurrentDatabase + db.CurrentDatabase = database + + return nil +}