-
Notifications
You must be signed in to change notification settings - Fork 3.5k
/
iterator.go
308 lines (269 loc) · 9.69 KB
/
iterator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
package orm
import (
"fmt"
"reflect"
"github.com/cosmos/cosmos-sdk/codec"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/query"
"github.com/cosmos/cosmos-sdk/x/group/errors"
)
// defaultPageLimit is the default limit value for pagination requests.
const defaultPageLimit = 100
// IteratorFunc is a function type that satisfies the Iterator interface
// The passed function is called on LoadNext operations.
type IteratorFunc func(dest codec.ProtoMarshaler) (RowID, error)
// LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there
// are no more items the errors.ErrORMIteratorDone error is returned
// The key is the rowID and not any MultiKeyIndex key.
func (i IteratorFunc) LoadNext(dest codec.ProtoMarshaler) (RowID, error) {
return i(dest)
}
// Close always returns nil
func (i IteratorFunc) Close() error {
return nil
}
func NewSingleValueIterator(rowID RowID, val []byte) Iterator {
var closed bool
return IteratorFunc(func(dest codec.ProtoMarshaler) (RowID, error) {
if dest == nil {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "destination object must not be nil")
}
if closed || val == nil {
return nil, errors.ErrORMIteratorDone
}
closed = true
return rowID, dest.Unmarshal(val)
})
}
// Iterator that return ErrORMInvalidIterator only.
func NewInvalidIterator() Iterator {
return IteratorFunc(func(dest codec.ProtoMarshaler) (RowID, error) {
return nil, errors.ErrORMInvalidIterator
})
}
// LimitedIterator returns up to defined maximum number of elements.
type LimitedIterator struct {
remainingCount int
parentIterator Iterator
}
// LimitIterator returns a new iterator that returns max number of elements.
// The parent iterator must not be nil
// max can be 0 or any positive number
func LimitIterator(parent Iterator, max int) (*LimitedIterator, error) {
if max < 0 {
return nil, errors.ErrORMInvalidArgument.Wrap("quantity must not be negative")
}
if parent == nil {
return nil, errors.ErrORMInvalidArgument.Wrap("parent iterator must not be nil")
}
return &LimitedIterator{remainingCount: max, parentIterator: parent}, nil
}
// LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there
// are no more items or the defined max number of elements was returned the `errors.ErrORMIteratorDone` error is returned
// The key is the rowID and not any MultiKeyIndex key.
func (i *LimitedIterator) LoadNext(dest codec.ProtoMarshaler) (RowID, error) {
if i.remainingCount == 0 {
return nil, errors.ErrORMIteratorDone
}
i.remainingCount--
return i.parentIterator.LoadNext(dest)
}
// Close releases the iterator and should be called at the end of iteration
func (i LimitedIterator) Close() error {
return i.parentIterator.Close()
}
// First loads the first element into the given destination type and closes the iterator.
// When the iterator is closed or has no elements the according error is passed as return value.
func First(it Iterator, dest codec.ProtoMarshaler) (RowID, error) {
if it == nil {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil")
}
defer it.Close()
binKey, err := it.LoadNext(dest)
if err != nil {
return nil, err
}
return binKey, nil
}
// Paginate does pagination with a given Iterator based on the provided
// PageRequest and unmarshals the results into the dest interface that must be
// an non-nil pointer to a slice.
//
// If pageRequest is nil, then we will use these default values:
// - Offset: 0
// - Key: nil
// - Limit: 100
// - CountTotal: true
//
// If pageRequest.Key was provided, it got used beforehand to instantiate the Iterator,
// using for instance UInt64Index.GetPaginated method. Only one of pageRequest.Offset or
// pageRequest.Key should be set. Using pageRequest.Key is more efficient for querying
// the next page.
//
// If pageRequest.CountTotal is set, we'll visit all iterators elements.
// pageRequest.CountTotal is only respected when offset is used.
//
// This function will call it.Close().
func Paginate(
it Iterator,
pageRequest *query.PageRequest,
dest ModelSlicePtr,
) (*query.PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &query.PageRequest{}
}
offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
if offset > 0 && key != nil {
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}
if limit == 0 {
limit = defaultPageLimit
// count total results when the limit is zero/not supplied
countTotal = true
}
if it == nil {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil")
}
defer it.Close()
var destRef, tmpSlice reflect.Value
elemType, err := assertDest(dest, &destRef, &tmpSlice)
if err != nil {
return nil, err
}
var end = offset + limit
var count uint64
var nextKey []byte
for {
obj := reflect.New(elemType)
val := obj.Elem()
model := obj
if elemType.Kind() == reflect.Ptr {
val.Set(reflect.New(elemType.Elem()))
// if elemType is already a pointer (e.g. dest being some pointer to a slice of pointers,
// like []*GroupMember), then obj is a pointer to a pointer which might cause issues
// if we try to do obj.Interface().(codec.ProtoMarshaler).
// For that reason, we copy obj into model if we have a simple pointer
// but in case elemType.Kind() == reflect.Ptr, we overwrite it with model = val
// so we can safely call model.Interface().(codec.ProtoMarshaler) afterwards.
model = val
}
modelProto, ok := model.Interface().(codec.ProtoMarshaler)
if !ok {
return nil, sdkerrors.Wrapf(errors.ErrORMInvalidArgument, "%s should implement codec.ProtoMarshaler", elemType)
}
binKey, err := it.LoadNext(modelProto)
if err != nil {
if errors.ErrORMIteratorDone.Is(err) {
break
}
return nil, err
}
count++
// During the first loop, count value at this point will be 1,
// so if offset is >= 1, it will continue to load the next value until count > offset
// else (offset = 0, key might be set or not),
// it will start to append values to tmpSlice.
if count <= offset {
continue
}
if count <= end {
tmpSlice = reflect.Append(tmpSlice, val)
} else if count == end+1 {
nextKey = binKey
// countTotal is set to true to indicate that the result set should include
// a count of the total number of items available for pagination in UIs.
// countTotal is only respected when offset is used. It is ignored when key
// is set.
if !countTotal || len(key) != 0 {
break
}
}
}
destRef.Set(tmpSlice)
res := &query.PageResponse{NextKey: nextKey}
if countTotal && len(key) == 0 {
res.Total = count
}
return res, nil
}
// ModelSlicePtr represents a pointer to a slice of models. Think of it as
// *[]Model Because of Go's type system, using []Model type would not work for us.
// Instead we use a placeholder type and the validation is done during the
// runtime.
type ModelSlicePtr interface{}
// ReadAll consumes all values for the iterator and stores them in a new slice at the passed ModelSlicePtr.
// The slice can be empty when the iterator does not return any values but not nil. The iterator
// is closed afterwards.
// Example:
// var loaded []testdata.GroupInfo
// rowIDs, err := ReadAll(it, &loaded)
// require.NoError(t, err)
//
func ReadAll(it Iterator, dest ModelSlicePtr) ([]RowID, error) {
if it == nil {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil")
}
defer it.Close()
var destRef, tmpSlice reflect.Value
elemType, err := assertDest(dest, &destRef, &tmpSlice)
if err != nil {
return nil, err
}
var rowIDs []RowID
for {
obj := reflect.New(elemType)
val := obj.Elem()
model := obj
if elemType.Kind() == reflect.Ptr {
val.Set(reflect.New(elemType.Elem()))
model = val
}
binKey, err := it.LoadNext(model.Interface().(codec.ProtoMarshaler))
switch {
case err == nil:
tmpSlice = reflect.Append(tmpSlice, val)
case errors.ErrORMIteratorDone.Is(err):
destRef.Set(tmpSlice)
return rowIDs, nil
default:
return nil, err
}
rowIDs = append(rowIDs, binKey)
}
}
// assertDest checks that the provided dest is not nil and a pointer to a slice.
// It also verifies that the slice elements implement *codec.ProtoMarshaler.
// It overwrites destRef and tmpSlice using reflection.
func assertDest(dest ModelSlicePtr, destRef *reflect.Value, tmpSlice *reflect.Value) (reflect.Type, error) {
if dest == nil {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "destination must not be nil")
}
tp := reflect.ValueOf(dest)
if tp.Kind() != reflect.Ptr {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "destination must be a pointer to a slice")
}
if tp.Elem().Kind() != reflect.Slice {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "destination must point to a slice")
}
// Since dest is just an interface{}, we overwrite destRef using reflection
// to have an assignable copy of it.
*destRef = tp.Elem()
// We need to verify that we can call Set() on destRef.
if !destRef.CanSet() {
return nil, sdkerrors.Wrap(errors.ErrORMInvalidArgument, "destination not assignable")
}
elemType := reflect.TypeOf(dest).Elem().Elem()
protoMarshaler := reflect.TypeOf((*codec.ProtoMarshaler)(nil)).Elem()
if !elemType.Implements(protoMarshaler) &&
!reflect.PtrTo(elemType).Implements(protoMarshaler) {
return nil, sdkerrors.Wrapf(errors.ErrORMInvalidArgument, "unsupported type :%s", elemType)
}
// tmpSlice is a slice value for the specified type
// that we'll use for appending new elements.
*tmpSlice = reflect.MakeSlice(reflect.SliceOf(elemType), 0, 0)
return elemType, nil
}