/
m_trans.go
55 lines (46 loc) · 1.02 KB
/
m_trans.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
package model
import (
"context"
"github.com/levin9/go-admin/internal/app/errors"
"github.com/jinzhu/gorm"
)
// NewTrans 创建事务管理实例
func NewTrans(db *gorm.DB) *Trans {
return &Trans{db}
}
// Trans 事务管理
type Trans struct {
db *gorm.DB
}
// Begin 开启事务
func (a *Trans) Begin(ctx context.Context) (interface{}, error) {
result := a.db.Begin()
if err := result.Error; err != nil {
return nil, errors.WithStack(err)
}
return result, nil
}
// Commit 提交事务
func (a *Trans) Commit(ctx context.Context, trans interface{}) error {
db, ok := trans.(*gorm.DB)
if !ok {
return errors.New("unknow trans")
}
result := db.Commit()
if err := result.Error; err != nil {
return errors.WithStack(err)
}
return nil
}
// Rollback 回滚事务
func (a *Trans) Rollback(ctx context.Context, trans interface{}) error {
db, ok := trans.(*gorm.DB)
if !ok {
return errors.New("unknow trans")
}
result := db.Rollback()
if err := result.Error; err != nil {
return errors.WithStack(err)
}
return nil
}