From 3693f5605d14a15966b889e352d245744eb3b5f9 Mon Sep 17 00:00:00 2001 From: Pieter Callewaert Date: Tue, 31 Aug 2021 09:13:38 +0200 Subject: [PATCH] Refactor GetConnection to be able to get the pq error code of the error --- .gitignore | 3 +++ pkg/postgres/database.go | 21 +++++++++++++++------ pkg/postgres/postgres.go | 15 ++++++++------- pkg/postgres/role.go | 11 +++++++++-- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 278594949..d734ce916 100644 --- a/.gitignore +++ b/.gitignore @@ -76,5 +76,8 @@ tags .history # End of https://www.gitignore.io/api/go,vim,emacs,visualstudiocode +### MacOS +.DS_Store + .idea deploy/secret.yaml \ No newline at end of file diff --git a/pkg/postgres/database.go b/pkg/postgres/database.go index 5600fd351..8f3c4e828 100644 --- a/pkg/postgres/database.go +++ b/pkg/postgres/database.go @@ -35,10 +35,13 @@ func (c *pg) CreateDB(dbname, role string) error { } func (c *pg) CreateSchema(db, role, schema string, logger logr.Logger) error { - tmpDb := GetConnection(c.user, c.pass, c.host, db, c.args, logger) + tmpDb, err := GetConnection(c.user, c.pass, c.host, db, c.args, logger) + if err != nil { + return err + } defer tmpDb.Close() - _, err := tmpDb.Exec(fmt.Sprintf(CREATE_SCHEMA, schema, role)) + _, err = tmpDb.Exec(fmt.Sprintf(CREATE_SCHEMA, schema, role)) if err != nil { return err } @@ -58,10 +61,13 @@ func (c *pg) DropDatabase(database string, logger logr.Logger) error { } func (c *pg) CreateExtension(db, extension string, logger logr.Logger) error { - tmpDb := GetConnection(c.user, c.pass, c.host, db, c.args, logger) + tmpDb, err := GetConnection(c.user, c.pass, c.host, db, c.args, logger) + if err != nil { + return err + } defer tmpDb.Close() - _, err := tmpDb.Exec(fmt.Sprintf(CREATE_EXTENSION, extension)) + _, err = tmpDb.Exec(fmt.Sprintf(CREATE_EXTENSION, extension)) if err != nil { return err } @@ -69,11 +75,14 @@ func (c *pg) CreateExtension(db, extension string, logger logr.Logger) error { } func (c *pg) SetSchemaPrivileges(db, creator, role, schema, privs string, logger logr.Logger) error { - tmpDb := GetConnection(c.user, c.pass, c.host, db, c.args, logger) + tmpDb, err := GetConnection(c.user, c.pass, c.host, db, c.args, logger) + if err != nil { + return err + } defer tmpDb.Close() // Grant role usage on schema - _, err := tmpDb.Exec(fmt.Sprintf(GRANT_USAGE_SCHEMA, schema, role)) + _, err = tmpDb.Exec(fmt.Sprintf(GRANT_USAGE_SCHEMA, schema, role)) if err != nil { return err } diff --git a/pkg/postgres/postgres.go b/pkg/postgres/postgres.go index 687b3c169..bbe9778bd 100644 --- a/pkg/postgres/postgres.go +++ b/pkg/postgres/postgres.go @@ -36,8 +36,13 @@ type pg struct { } func NewPG(host, user, password, uri_args, default_database, cloud_type string, logger logr.Logger) (PG, error) { + db, err := GetConnection(user, password, host, default_database, uri_args, logger) + if err != nil { + log.Fatalf("failed to connect to PostgreSQL server: %s", err.Error()) + } + logger.Info("connected to postgres server") postgres := &pg{ - db: GetConnection(user, password, host, default_database, uri_args, logger), + db: db, log: logger, host: host, user: user, @@ -64,15 +69,11 @@ func (c *pg) GetDefaultDatabase() string { return c.default_database } -func GetConnection(user, password, host, database, uri_args string, logger logr.Logger) *sql.DB { +func GetConnection(user, password, host, database, uri_args string, logger logr.Logger) (*sql.DB, error) { db, err := sql.Open("postgres", fmt.Sprintf("postgresql://%s:%s@%s/%s?%s", user, password, host, database, uri_args)) if err != nil { log.Fatal(err) } err = db.Ping() - if err != nil { - log.Fatalf("failed to connect to PostgreSQL server: %s", err.Error()) - } - logger.Info("connected to postgres server") - return db + return db, err } diff --git a/pkg/postgres/role.go b/pkg/postgres/role.go index 24a603736..8bf4f4b71 100644 --- a/pkg/postgres/role.go +++ b/pkg/postgres/role.go @@ -62,8 +62,15 @@ func (c *pg) RevokeRole(role, revoked string) error { func (c *pg) DropRole(role, newOwner, database string, logger logr.Logger) error { // REASSIGN OWNED BY only works if the correct database is selected - tmpDb := GetConnection(c.user, c.pass, c.host, database, c.args, logger) - _, err := tmpDb.Exec(fmt.Sprintf(REASIGN_OBJECTS, role, newOwner)) + tmpDb, err := GetConnection(c.user, c.pass, c.host, database, c.args, logger) + if err != nil { + if err.(*pq.Error).Code == "3D000" { + return nil // Database is does not exist (anymore) + } else { + return err + } + } + _, err = tmpDb.Exec(fmt.Sprintf(REASIGN_OBJECTS, role, newOwner)) defer tmpDb.Close() // Check if error exists and if different from "ROLE NOT FOUND" => 42704 if err != nil && err.(*pq.Error).Code != "42704" {