/
repository.go
101 lines (84 loc) · 2.42 KB
/
repository.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package gormutil
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/go-chocolate/contrib/database/repository"
)
type Repository[T any] struct {
db *gorm.DB
}
func NewRepository[T any](db *gorm.DB) *Repository[T] {
return &Repository[T]{db: db}
}
func (r *Repository[T]) GetDB(ctx context.Context) *gorm.DB {
if r.db != nil {
return r.db
}
return FromContext(ctx)
}
func (r *Repository[T]) SetDB(db *gorm.DB) {
r.db = db
}
func (r *Repository[T]) where(cmd *gorm.DB, where any) *gorm.DB {
var model T
cmd = cmd.Model(model)
if where != nil {
switch condition := where.(type) {
case clause.Expression:
cmd = cmd.Clauses(condition)
case int64:
cmd = cmd.Where("id = ?", condition)
default:
cmd = cmd.Where(condition)
}
}
return cmd
}
func (r *Repository[T]) list(cmd *gorm.DB, offset, limit int, order ...any) ([]*T, int64, error) {
var count int64
if err := cmd.Count(&count).Error; err != nil {
return nil, count, err
}
for _, v := range order {
cmd = cmd.Order(v)
}
var dst []*T
err := cmd.Offset(offset).Limit(limit).Find(dst).Error
return dst, count, err
}
func (r *Repository[T]) FindOne(ctx context.Context, where any) (dst *T, err error) {
dst = new(T)
err = r.where(r.db, where).Take(dst).Error
return
}
func (r *Repository[T]) List(ctx context.Context, where any, offset, limit int, order ...any) ([]*T, int64, error) {
var cmd = r.where(r.GetDB(ctx), where)
return r.list(cmd, offset, limit, order)
}
func (r *Repository[T]) Count(ctx context.Context, where any) (int64, error) {
var cmd = r.where(r.GetDB(ctx), where)
var count int64
err := cmd.Count(&count).Error
return count, err
}
func (r *Repository[T]) Update(ctx context.Context, where any, update any) (int64, error) {
var cmd = r.where(r.GetDB(ctx), where).Updates(update)
return cmd.RowsAffected, cmd.Error
}
func (r *Repository[T]) Insert(ctx context.Context, data any) (int64, error) {
var cmd = r.where(r.GetDB(ctx), nil).Create(data)
return cmd.RowsAffected, cmd.Error
}
func (r *Repository[T]) Delete(ctx context.Context, where any) (int64, error) {
var cmd = r.where(r.GetDB(ctx), where).Delete(nil)
return cmd.RowsAffected, cmd.Error
}
func (r *Repository[T]) Iterate(ctx context.Context, column string, where any) (repository.Iterator[T], error) {
return &columnIterator[T]{
column: column,
db: r.GetDB(ctx),
where: where,
}, nil
}
var _ repository.Repository[any] = (*Repository[any])(nil)