/
rdbms.go
85 lines (70 loc) · 1.44 KB
/
rdbms.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package waitforit
import (
"database/sql"
"fmt"
"strconv"
"strings"
log "github.com/Sirupsen/logrus"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
var _ = log.Print
type DbWaiter struct {
target *Target
db *sql.DB
driver string
}
func NewPostgresWaiter(target *Target) *DbWaiter {
return &DbWaiter{
target: target,
driver: "postgres",
}
}
func NewMySQLWaiter(target *Target) *DbWaiter {
return &DbWaiter{
target: target,
driver: "mysql",
}
}
func (w *DbWaiter) Connect() (err error) {
u := w.target.url
values := u.Query()
switch w.driver {
case "postgres":
values.Set("connect_timeout", strconv.Itoa(int(w.target.Timeout.Seconds())))
if w.target.Insecure {
values.Set("sslmode", "disable")
}
case "mysql":
u.Scheme = ""
u.Host = fmt.Sprintf("tcp(%s:%d)", w.target.host, w.target.port)
values.Set("timeout", fmt.Sprintf("%ds", int(w.target.Timeout.Seconds())))
}
u.RawQuery = values.Encode()
dsn := u.String()
if strings.HasPrefix(dsn, "//") {
dsn = dsn[2:]
}
w.db, err = sql.Open(w.driver, dsn)
if err != nil {
return
}
return w.db.Ping()
}
func (w *DbWaiter) RunTest() (err error) {
if w.target.Exists != "" {
var ok string
q := `select exists (select 1 from %s limit 1)`
err = w.db.QueryRow(fmt.Sprintf(q, w.target.Exists)).Scan(&ok)
if err != nil {
return
}
}
return
}
func (w *DbWaiter) Cancel() (err error) {
if w.db != nil {
return w.db.Close()
}
return
}