/
crud.go
124 lines (113 loc) · 2.69 KB
/
crud.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package crud
import (
"errors"
"fmt"
"regexp"
"github.com/go-mixins/gorm/v4"
"github.com/oleiade/reflections"
g "gorm.io/gorm"
)
type Basic[A any] gorm.Backend
func (b *Basic[A]) Begin() *Basic[A] {
backend := (*gorm.Backend)(b).Begin()
return (*Basic[A])(backend)
}
func (b *Basic[A]) End(rErr error) error {
backend := (*gorm.Backend)(b)
return backend.End(rErr)
}
func (b *Basic[A]) Create(src *A, opts ...func(*g.DB) *g.DB) error {
q := b.DB
for _, opt := range opts {
q = opt(q)
}
if err := q.Create(src).Error; errors.Is(err, g.ErrDuplicatedKey) {
return ErrFound
} else if err != nil {
return fmt.Errorf("creating %T: %+v", src, err)
}
return nil
}
func (b *Basic[A]) Update(upd A, opts ...func(*g.DB) *g.DB) error {
q := b.DB.Model(upd)
for _, opt := range opts {
q = opt(q)
}
if err := q.Updates(upd).Error; errors.Is(err, g.ErrDuplicatedKey) {
return ErrFound
} else if errors.Is(err, g.ErrRecordNotFound) {
return ErrNotFound
} else if err != nil {
return fmt.Errorf("updating %T: %+v", upd, err)
} else if q.RowsAffected == 0 {
return ErrUpdateNotApplied
}
return nil
}
func (b *Basic[A]) Get(conds ...interface{}) (*A, error) {
var dest A
q := b.DB.Model(dest)
if err := q.First(&dest, conds...).Error; errors.Is(err, g.ErrRecordNotFound) {
return nil, ErrNotFound
} else if err != nil {
return nil, fmt.Errorf("reading %T: %+v", dest, err)
}
return &dest, nil
}
func (b *Basic[A]) Delete(conds ...interface{}) error {
var dest A
if err := b.DB.Delete(&dest, conds...).Error; errors.Is(err, g.ErrRecordNotFound) {
return ErrNotFound
} else if err != nil {
return fmt.Errorf("deleting %T: %+v", dest, err)
}
return nil
}
var splitRe = regexp.MustCompile(`\s*[;,]\s*`)
func (b *Basic[A]) Find(pgn gorm.Pagination, opts ...func(*g.DB) *g.DB) ([]*A, *gorm.Pagination, error) {
var (
res []*A
elt A
)
p := &gorm.Paginator[*A]{}
fields, err := reflections.FieldsDeep(&elt)
if err != nil {
return nil, nil, err
}
for _, f := range fields {
t, err := reflections.GetFieldTag(&elt, f, `paginate`)
if err != nil {
return nil, nil, err
}
if t == "" {
continue
}
options := splitRe.Split(t, -1)
switch options[0] {
case "key":
p.FieldName = f
for _, o := range options {
switch o {
case "reverse":
p.Reverse = true
case "isTime":
p.IsTime = true
}
}
case "tieBreak":
p.TieBreakField = f
}
}
if p.FieldName == "" {
return nil, nil, fmt.Errorf("key field for %T must be tagged", elt)
}
q := b.DB.Scopes(p.Scope(&pgn))
for _, o := range opts {
q = o(q)
}
if err := q.Find(&res).Error; err != nil {
return nil, nil, err
}
results, resPgn := p.Paginate(res, &pgn)
return results, resPgn, nil
}