Skip to content

Commit

Permalink
add sqlm
Browse files Browse the repository at this point in the history
  • Loading branch information
MunMunMiao committed Jan 4, 2022
1 parent 81d8af2 commit 629cdf2
Show file tree
Hide file tree
Showing 10 changed files with 460 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
.com.apple.timemachine.donotpresent
config.toml
dist/
test
test/
temp/
4 changes: 2 additions & 2 deletions database/orm/orm.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package orm

import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/go-sqlm-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/schema"
Expand Down Expand Up @@ -33,7 +33,7 @@ func New(c *Config) *gorm.DB {

db.Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci")

// 获取通用数据库对象 sql.DB ,然后使用其提供的功能
// 获取通用数据库对象 sqlm.DB ,然后使用其提供的功能
sqlDB, err := db.DB()
if err != nil {
panic(err)
Expand Down
29 changes: 0 additions & 29 deletions database/sql/sql.go

This file was deleted.

108 changes: 108 additions & 0 deletions database/sqlm/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package sqlm

import (
"context"
"database/sql"
"time"
)

type DB struct {
db *sql.DB
}

type DBStats struct {
stats sql.DBStats
}

func (d *DB) Begin(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
t, err := d.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &Tx{t}, nil
}

func (d *DB) Transaction(ctx context.Context, f func(tx *Tx) (interface{}, error)) (interface{}, error) {
tx, err := d.Begin(ctx, nil)
if err != nil {
return nil, err
}

r, err := f(tx)
if err != nil {
if err = tx.Rollback(); err != nil {
return nil, err
}
return nil, err
}

err = tx.Commit()
if err != nil {
if err = tx.Rollback(); err != nil {
return nil, err
}
return nil, err
}

return r, nil
}

func (d *DB) Insert(ctx context.Context, table string, field []string, values [][]interface{}) (Result, error) {
query, args := genInsertParam(table, field, values)
return d.Exec(ctx, query, args...)
}

func (d *DB) Close() error {
return d.db.Close()
}

func (d *DB) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) {
return d.db.ExecContext(ctx, query, args...)
}

func (d *DB) Ping(ctx context.Context) error {
return d.db.PingContext(ctx)
}

func (d *DB) Prepare(ctx context.Context, query string) (*Stmt, error) {
s, err := d.db.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &Stmt{stmt: s}, nil
}

func (d *DB) Query(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{rows: rows}, nil
}

func (d *DB) QueryRow(ctx context.Context, query string, args ...interface{}) *Row {
row := d.db.QueryRowContext(ctx, query, args...)
return &Row{row: row}
}

func (d *DB) SetConnMaxIdleTime(t time.Duration) {
d.db.SetConnMaxIdleTime(t)
}

func (d *DB) SetConnMaxLifetime(t time.Duration) {
d.db.SetConnMaxLifetime(t)
}

func (d *DB) SetMaxIdleConns(n int) {
d.db.SetMaxIdleConns(n)
}

func (d *DB) SetMaxOpenConns(n int) {
d.SetMaxOpenConns(n)
}

func (d *DB) Stats() DBStats {
s := d.db.Stats()
return DBStats{stats: s}
}
21 changes: 21 additions & 0 deletions database/sqlm/row.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package sqlm

import (
"database/sql"
"errors"
)

type Row struct {
row *sql.Row
}

func (r *Row) Err() error {
err := r.row.Err()
if errors.Is(err, sql.ErrNoRows){
return ErrNoRows
}
return err
}
func (r *Row) Scan(dest ...interface{}) error {
return r.row.Scan(dest...)
}
41 changes: 41 additions & 0 deletions database/sqlm/rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqlm

import "database/sql"

type Rows struct {
rows *sql.Rows
}

type ColumnType struct {
columnType *sql.ColumnType
}

func (r *Rows) Close() error {
return r.rows.Close()
}
func (r *Rows) ColumnTypes() ([]*ColumnType, error) {
c, err := r.rows.ColumnTypes()
if err != nil {
return nil, err
}
ct := make([]*ColumnType, 0, len(c))
for _, i := range c {
ct = append(ct, &ColumnType{columnType: i})
}
return ct, nil
}
func (r *Rows) Columns() ([]string, error) {
return r.rows.Columns()
}
func (r *Rows) Err() error {
return r.rows.Err()
}
func (r *Rows) Next() bool {
return r.rows.Next()
}
func (r *Rows) NextResultSet() bool {
return r.rows.NextResultSet()
}
func (r *Rows) Scan(dest ...interface{}) error {
return r.rows.Scan(dest...)
}
119 changes: 119 additions & 0 deletions database/sqlm/sqlm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package sqlm

import (
"database/sql"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
"time"
)

var (
ErrNoRows = sql.ErrNoRows
ErrTxDone = sql.ErrTxDone
)

// A Result summarizes an executed SQL command.
type Result interface {
// LastInsertId returns the integer generated by the database
// in response to a command. Typically this will be from an
// "auto increment" column when inserting a new row. Not all
// databases support this feature, and the syntax of such
// statements varies.
LastInsertId() (int64, error)

// RowsAffected returns the number of rows affected by an
// update, insert, or delete. Not every database or database
// driver may support this.
RowsAffected() (int64, error)
}

type Config struct {
DSN string
MaxOpenConn int
MaxIdleConn int
ConnMaxLifetime time.Duration
}

func mergeConfig(c *Config) {
if c == nil {
panic(errors.New("Config is nil\n"))
}

if c.MaxOpenConn == 0 {
c.MaxOpenConn = 200
}
if c.MaxIdleConn == 0 {
c.MaxIdleConn = 10
}
if c.ConnMaxLifetime == 0 {
c.ConnMaxLifetime = time.Minute * 5
}
}

func New(c *Config) *DB {
mergeConfig(c)

db, err := sql.Open("mysql", c.DSN)
if err != nil {
panic(err)
}

db.SetMaxOpenConns(c.MaxOpenConn)
db.SetMaxIdleConns(c.MaxIdleConn)
db.SetConnMaxLifetime(c.ConnMaxLifetime)

if err = db.Ping(); err != nil {
panic(err)
}

return &DB{db}
}

func genInsertParam(table string, fields []string, values [][]interface{}) (string, []interface{}) {
s := fmt.Sprintf("INSERT INTO `%s` ", table)

fieldStr := ""
for i, field := range fields {
if i == 0 {
fieldStr += fmt.Sprintf("`%s`", field)
} else {
fieldStr += fmt.Sprintf(",`%s`", field)
}
}
s += fmt.Sprintf("(%s) ", fieldStr)

value := ""
if len(values) > 0 {
if len(values) > 1 {
value += "VALUES "
} else {
value += "VALUE "
}
}
s += value

placeholder := ""
vs := make([]interface{}, 0)
for i, v := range values {
p := ""
for a, m := range v {
if a == 0 {
p += "?"
} else {
p += ",?"
}
vs = append(vs, m)
}

if i == 0 {
p = fmt.Sprintf("(%s)", p)
} else {
p = fmt.Sprintf(",(%s)", p)
}
placeholder += p
}
s += placeholder

return s, vs
}
40 changes: 40 additions & 0 deletions database/sqlm/sqlm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sqlm

import (
"context"
"fmt"
"testing"
)

func TestGenInsertParam(t *testing.T) {
value := [][]interface{}{
{1, 100, 200},
{2, 200, 400},
{3, 300, 600},
}
_, _ = genInsertParam("file", []string{"file_id", "width", "height"}, value)
}

func TestInsert(t *testing.T) {
db := New(&Config{DSN: "root@(127.0.0.1:3306)/t?charset=utf8mb4&collation=utf8_unicode_ci"})

r, err := db.Insert(context.Background(), "values", []string{"name", "age"}, [][]interface{}{
{"小红", 16},
//{"小丽", 13},
})
if err != nil {
t.Fatal(err)
}

affected, err := r.RowsAffected()
if err != nil {
t.Fatal(err)
}
fmt.Printf("affected: %v\n", affected)

id, err := r.LastInsertId()
if err != nil {
t.Fatal(err)
}
fmt.Printf("lastInsertId: %+v\n", id)
}

0 comments on commit 629cdf2

Please sign in to comment.