forked from hashicorp/vault
/
sql.go
136 lines (111 loc) · 3.42 KB
/
sql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package connutil
import (
"database/sql"
"fmt"
"strings"
"sync"
"time"
// Import sql drivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/vault/helper/parseutil"
_ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
)
// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type SQLConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
Type string
maxConnectionLifetime time.Duration
Initialized bool
db *sql.DB
sync.Mutex
}
func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
}
if c.MaxOpenConnections == 0 {
c.MaxOpenConnections = 2
}
if c.MaxIdleConnections == 0 {
c.MaxIdleConnections = c.MaxOpenConnections
}
if c.MaxIdleConnections > c.MaxOpenConnections {
c.MaxIdleConnections = c.MaxOpenConnections
}
if c.MaxConnectionLifetimeRaw == nil {
c.MaxConnectionLifetimeRaw = "0s"
}
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
if err != nil {
return fmt.Errorf("invalid max_connection_lifetime: %s", err)
}
if verifyConnection {
if _, err := c.Connection(); err != nil {
return fmt.Errorf("error initalizing connection: %s", err)
}
if err := c.db.Ping(); err != nil {
return fmt.Errorf("error initalizing connection: %s", err)
}
}
c.Initialized = true
return nil
}
func (c *SQLConnectionProducer) Connection() (interface{}, error) {
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.Ping(); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
}
// For mssql backend, switch to sqlserver instead
dbType := c.Type
if c.Type == "mssql" {
dbType = "sqlserver"
}
// Otherwise, attempt to make connection
conn := c.ConnectionURL
// Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
}
var err error
c.db, err = sql.Open(dbType, conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
c.db.SetMaxOpenConns(c.MaxOpenConnections)
c.db.SetMaxIdleConns(c.MaxIdleConnections)
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
return c.db, nil
}
// Close attempts to close the connection
func (c *SQLConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
if c.db != nil {
c.db.Close()
}
c.db = nil
return nil
}