/
db.go
103 lines (84 loc) · 1.72 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package common
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
type DBmysql struct {
DSN string
Conn *sql.DB
}
type DBConf struct {
DBType string `json:"db_type"`
DSN string `json:"dsn"`
}
func InitMysql(conf DBConf) (db *DBmysql, err error) {
conn, err := sql.Open(conf.DBType, conf.DSN)
db = &DBmysql{DSN: conf.DSN, Conn: conn}
return db, err
}
func (this *DBmysql) Query(sqlStr string, args ...interface{}) (rst []map[string]string, err error) {
var (
stmt *sql.Stmt
rows *sql.Rows
)
stmt, err = this.Conn.Prepare(sqlStr)
if err != nil {
return
}
defer stmt.Close()
rows, err = stmt.Query(args...)
if err != nil {
return
}
defer rows.Close()
var cols []string
cols, err = rows.Columns()
if err != nil {
return
}
cvals := make([]sql.RawBytes, len(cols))
scanArgs := make([]interface{}, len(cols))
for i := range cvals {
scanArgs[i] = &cvals[i]
}
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
return
}
tmap := make(map[string]string, len(cols))
for i, col := range cvals {
if col == nil {
tmap[cols[i]] = ""
} else {
tmap[cols[i]] = string(col)
}
}
rst = append(rst, tmap)
}
return
}
func (this *DBmysql) Insert(sqlStr string, args ...interface{}) (int64, error) {
stmt, err := this.Conn.Prepare(sqlStr)
if err != nil {
return -1, err
}
defer stmt.Close()
rst, err := stmt.Exec(args...)
if err != nil {
return -1, err
}
return rst.LastInsertId()
}
func (this *DBmysql) Update(sqlStr string, args ...interface{}) (int64, error) {
stmt, err := this.Conn.Prepare(sqlStr)
if err != nil {
return -1, err
}
defer stmt.Close()
rst, err := stmt.Exec(args...)
if err != nil {
return -1, err
}
return rst.RowsAffected()
}