/
iteration.go
137 lines (123 loc) · 2.99 KB
/
iteration.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package gormutil
import (
"errors"
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var ErrIterationEoF = errors.New("iteration eof")
type limitIterator[T any] struct {
offset, limit int
query *gorm.DB
}
func (i *limitIterator[T]) Count() (count int64, err error) {
err = i.query.Count(&count).Error
return
}
func (i *limitIterator[T]) Next() ([]*T, error) {
var results []*T
err := i.query.Offset(i.offset).Limit(i.limit).Find(&results).Error
if err == nil && len(results) == 0 {
return results, ErrIterationEoF
}
i.offset += i.limit
return results, err
}
type columnIterator[T any] struct {
lastID any
lastIDColumn string
column string
db *gorm.DB
where any
}
func (i *columnIterator[T]) Count() (count int64, err error) {
query := i.db
if i.where != nil {
switch cond := i.where.(type) {
case clause.Expression:
query = query.Clauses(cond)
default:
query = query.Where(cond)
}
}
err = query.Count(&count).Error
return
}
func (i *columnIterator[T]) Next() ([]*T, error) {
query := i.db
if i.where != nil {
switch cond := i.where.(type) {
case clause.Expression:
if i.lastID != nil {
query = query.Clauses(clause.And(cond, clause.Gt{Column: "`" + i.column + "`", Value: i.lastID}))
} else {
query = query.Clauses(cond)
}
default:
query = query.Where(cond)
if i.lastID != nil {
query = query.Where("`"+i.column+"` > ?", i.lastID)
}
}
} else if i.lastID != nil {
query = query.Where("`"+i.column+"` > ?", i.lastID)
}
query = query.Order("`" + i.column + "`")
var result []*T
err := query.Find(&result).Error
if err != nil {
return result, err
}
if len(result) == 0 {
return nil, ErrIterationEoF
}
return result, i.extractLastID(result[len(result)-1])
}
func (i *columnIterator[T]) extractLastID(item any) error {
val := reflect.ValueOf(item)
if i.lastIDColumn == "" {
typ := reflect.TypeOf(item)
for n := 0; n < val.NumField(); n++ {
field := typ.Field(n)
//decode gorm tag and find column name
if tag := field.Tag.Get("gorm"); tag != "" {
var column string
for _, v := range strings.Split(tag, ";") {
if len(v) > 7 && strings.ToLower(v[:7]) == "column:" {
column = v[7:]
break
}
}
if column != "" && column == i.column {
i.lastIDColumn = field.Name
break
}
}
// decode table column name to struct field name
var column []byte
var toUpper = true //make first character to upper
for _, v := range column {
if v == '_' { // skip underline and make next character to upper
toUpper = true
continue
}
if toUpper && (v >= 'a' && v <= 'z') {
v = v - 32
}
column = append(column, v)
}
if string(column) == field.Name {
i.lastIDColumn = field.Name
break
}
}
}
if i.lastIDColumn == "" {
typ := reflect.TypeOf(item)
return fmt.Errorf("struct %s does not contain field %s", typ.String(), i.column)
}
i.lastID = val.FieldByName(i.lastIDColumn).Interface()
return nil
}