/
postgres.go
94 lines (84 loc) · 2.46 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package postgres
import (
"database/sql"
"fmt"
"log"
"github.com/go-logr/logr"
)
type PG interface {
CreateDB(dbname, username string) error
CreateSchema(db, role, schema string, logger logr.Logger) error
CreateExtension(db, extension string, logger logr.Logger) error
CreateGroupRole(role string) error
CreateUserRole(role, password string) (string, error)
UpdatePassword(role, password string) error
GrantRole(role, grantee string) error
SetSchemaPrivileges(schemaPrivileges PostgresSchemaPrivileges, logger logr.Logger) error
RevokeRole(role, revoked string) error
AlterDefaultLoginRole(role, setRole string) error
DropDatabase(db string, logger logr.Logger) error
DropRole(role, newOwner, database string, logger logr.Logger) error
GetUser() string
GetDefaultDatabase() string
}
type pg struct {
db *sql.DB
log logr.Logger
host string
user string
pass string
args string
default_database string
}
type PostgresSchemaPrivileges struct {
DB string
Creator string
Role string
Schema string
Privs string
CreateSchema bool
}
func NewPG(host, user, password, uri_args, default_database, cloud_type string, logger logr.Logger) (PG, error) {
db, err := GetConnection(user, password, host, default_database, uri_args, logger)
if err != nil {
log.Fatalf("failed to connect to PostgreSQL server: %s", err.Error())
}
logger.Info("connected to postgres server")
postgres := &pg{
db: db,
log: logger,
host: host,
user: user,
pass: password,
args: uri_args,
default_database: default_database,
}
switch cloud_type {
case "AWS":
logger.Info("Using AWS wrapper")
return newAWSPG(postgres), nil
case "Azure":
logger.Info("Using Azure wrapper")
return newAzurePG(postgres), nil
case "GCP":
logger.Info("Using GCP wrapper")
return newGCPPG(postgres), nil
default:
logger.Info("Using default postgres implementation")
return postgres, nil
}
}
func (c *pg) GetUser() string {
return c.user
}
func (c *pg) GetDefaultDatabase() string {
return c.default_database
}
func GetConnection(user, password, host, database, uri_args string, logger logr.Logger) (*sql.DB, error) {
db, err := sql.Open("postgres", fmt.Sprintf("postgresql://%s:%s@%s/%s?%s", user, password, host, database, uri_args))
if err != nil {
log.Fatal(err)
}
err = db.Ping()
return db, err
}