/
sql.go
89 lines (76 loc) · 2 KB
/
sql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package database
import (
"context"
"database/sql"
"fmt"
"net/url"
"strings"
"github.com/golangid/candi/codebase/interfaces"
"github.com/golangid/candi/config/env"
"github.com/golangid/candi/logger"
)
type sqlInstance struct {
read, write *sql.DB
}
func (s *sqlInstance) ReadDB() *sql.DB {
return s.read
}
func (s *sqlInstance) WriteDB() *sql.DB {
return s.write
}
func (s *sqlInstance) Health() map[string]error {
mErr := make(map[string]error)
mErr["sql_read"] = s.read.Ping()
mErr["sql_write"] = s.write.Ping()
return mErr
}
func (s *sqlInstance) Disconnect(ctx context.Context) (err error) {
defer logger.LogWithDefer("sql: disconnect...")()
if err := s.read.Close(); err != nil {
return err
}
return s.write.Close()
}
// InitSQLDatabase return sql db read & write instance from environment:
// SQL_DB_READ_DSN, SQL_DB_WRITE_DSN
func InitSQLDatabase() interfaces.SQLDatabase {
defer logger.LogWithDefer("Load SQL connection...")()
return &sqlInstance{
read: ConnectSQLDatabase(env.BaseEnv().DbSQLReadDSN),
write: ConnectSQLDatabase(env.BaseEnv().DbSQLWriteDSN),
}
}
// ParseSQLDSN parse sql dsn
func ParseSQLDSN(source string) (driverName string, dsn string) {
sqlDriver, conn, ok := strings.Cut(source, "://")
if !ok {
panic("SQL DSN: invalid url format")
}
driverName = sqlDriver
switch sqlDriver {
case "mysql", "sqlite3":
dsn = conn
case "sqlserver":
driverName = "mssql"
fallthrough
case "postgres":
if i := strings.LastIndex(conn, "@"); i > 0 {
if username, password, ok := strings.Cut(conn[:i], ":"); ok {
conn = fmt.Sprintf("%s:%s@%s", url.QueryEscape(username), url.QueryEscape(password), conn[i+1:])
}
}
dsn = fmt.Sprintf("%s://%s", sqlDriver, conn)
}
return
}
// ConnectSQLDatabase connect to sql database with dsn
func ConnectSQLDatabase(dsn string) *sql.DB {
db, err := sql.Open(ParseSQLDSN(dsn))
if err != nil {
panic(fmt.Sprintf("SQL Connection: %v", err))
}
if err = db.Ping(); err != nil {
panic(fmt.Sprintf("SQL Ping: %v", err))
}
return db
}