/
db.go
156 lines (130 loc) · 4.17 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Package test provides shared testing utilities.
package test
import (
"database/sql"
"log"
"os"
"strings"
"sync"
"time"
sqlite3Driver "github.com/mattn/go-sqlite3"
"gopkg.in/reform.v1"
"gopkg.in/reform.v1/dialects"
"gopkg.in/reform.v1/dialects/mssql" //nolint:staticcheck
"gopkg.in/reform.v1/dialects/mysql"
"gopkg.in/reform.v1/dialects/postgresql"
"gopkg.in/reform.v1/dialects/sqlite3"
"gopkg.in/reform.v1/dialects/sqlserver"
)
//nolint:gochecknoglobals
var (
sqlite3RegisterOnce sync.Once
inspectOnce sync.Once
)
// ConnectToTestDB returns open and prepared connection to test DB.
func ConnectToTestDB() *reform.DB {
driver := strings.TrimSpace(os.Getenv("REFORM_TEST_DRIVER"))
source := strings.TrimSpace(os.Getenv("REFORM_TEST_SOURCE"))
if driver == "" || source == "" {
log.Fatal("no driver or source, set REFORM_TEST_DRIVER and REFORM_TEST_SOURCE")
}
// register custom function "sleep" for context tests
if driver == "sqlite3" {
driver = "sqlite3_with_sleep"
sqlite3RegisterOnce.Do(func() {
sleep := func(nsec int64) (int64, error) {
time.Sleep(time.Duration(nsec))
return nsec, nil
}
sql.Register(driver, &sqlite3Driver.SQLiteDriver{
ConnectHook: func(conn *sqlite3Driver.SQLiteConn) error {
return conn.RegisterFunc("sleep", sleep, false)
},
})
})
}
db, err := sql.Open(driver, source)
if err != nil {
log.Fatal(err)
}
// Use single connection so various session-related variables work.
// For example: "PRAGMA foreign_keys" for SQLite3, "SET IDENTITY_INSERT" for MS SQL, etc.
db.SetMaxIdleConns(1)
db.SetMaxOpenConns(1)
db.SetConnMaxLifetime(0)
if err = db.Ping(); err != nil {
log.Fatal(err)
}
now := time.Now()
// select dialect for driver
dialect := dialects.ForDriver(driver)
switch dialect {
case postgresql.Dialect:
inspectOnce.Do(func() {
log.Printf("driver = %q, source = %q", driver, source)
log.Printf("time.Now() = %s", now)
log.Printf("time.Now().UTC() = %s", now.UTC())
var version, tz string
if err = db.QueryRow("SHOW server_version").Scan(&version); err != nil {
log.Fatal(err)
}
if err = db.QueryRow("SHOW TimeZone").Scan(&tz); err != nil {
log.Fatal(err)
}
log.Printf("PostgreSQL version = %q", version)
log.Printf("PostgreSQL TimeZone = %q", tz)
})
case mysql.Dialect:
inspectOnce.Do(func() {
log.Printf("driver = %q, source = %q", driver, source)
log.Printf("time.Now() = %s", now)
log.Printf("time.Now().UTC() = %s", now.UTC())
q := "SELECT @@version, @@sql_mode, @@autocommit, @@time_zone"
var version, mode, autocommit, tz string
if err = db.QueryRow(q).Scan(&version, &mode, &autocommit, &tz); err != nil {
log.Fatal(err)
}
log.Printf("MySQL version = %q", version)
log.Printf("MySQL sql_mode = %q", mode)
log.Printf("MySQL autocommit = %q", autocommit)
log.Printf("MySQL time_zone = %q", tz)
})
case sqlite3.Dialect:
if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
log.Fatal(err)
}
inspectOnce.Do(func() {
log.Printf("driver = %q, source = %q", driver, source)
log.Printf("time.Now() = %s", now)
log.Printf("time.Now().UTC() = %s", now.UTC())
var version, sourceID string
if err = db.QueryRow("SELECT sqlite_version(), sqlite_source_id()").Scan(&version, &sourceID); err != nil {
log.Fatal(err)
}
log.Printf("SQLite3 version = %q", version)
log.Printf("SQLite3 source = %q", sourceID)
})
case mssql.Dialect: //nolint:staticcheck
fallthrough
case sqlserver.Dialect:
inspectOnce.Do(func() {
log.Printf("driver = %q, source = %q", driver, source)
log.Printf("time.Now() = %s", now)
log.Printf("time.Now().UTC() = %s", now.UTC())
var version string
var options uint16
if err = db.QueryRow("SELECT @@VERSION, @@OPTIONS").Scan(&version, &options); err != nil {
log.Fatal(err)
}
xact := "ON"
if options&0x4000 == 0 {
xact = "OFF"
}
log.Printf("MS SQL VERSION = %s", version)
log.Printf("MS SQL OPTIONS = %#4x (XACT_ABORT %s)", options, xact)
})
default:
log.Fatalf("reform: no dialect for driver %s", driver)
}
return reform.NewDB(db, dialect, nil)
}