-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
row_sampling.go
344 lines (309 loc) · 11.1 KB
/
row_sampling.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
// Copyright 2017 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
package stats
import (
"container/heap"
"context"
"github.com/cockroachdb/cockroach/pkg/sql/memsize"
"github.com/cockroachdb/cockroach/pkg/sql/rowenc"
"github.com/cockroachdb/cockroach/pkg/sql/sem/eval"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sqlerrors"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/intsets"
"github.com/cockroachdb/cockroach/pkg/util/mon"
"github.com/cockroachdb/errors"
)
// SampledRow is a row that was sampled.
type SampledRow struct {
Row rowenc.EncDatumRow
Rank uint64
}
// SampleReservoir implements reservoir sampling using random sort. Each
// row is assigned a rank (which should be a uniformly generated random value),
// and rows with the smallest K ranks are retained.
//
// This is implemented as a max-heap of the smallest K ranks; each row can
// replace the row with the maximum rank. Note that heap operations only happen
// when we actually encounter a row that is among the top K so far; the
// probability of this is K/N if there were N rows so far; for large streams, we
// would have O(K log K) heap operations. The overall running time for a stream
// of size N is O(N + K log^2 K).
//
// The same structure can be used to combine sample sets (as long as the
// original ranks are preserved) for distributed reservoir sampling. The
// requirement is that the capacity of each distributed reservoir must have been
// at least as large as this reservoir.
type SampleReservoir struct {
samples []SampledRow
colTypes []*types.T
da tree.DatumAlloc
ra rowenc.EncDatumRowAlloc
memAcc *mon.BoundAccount
// minNumSamples is the minimum capcity (K) needed for sampling to be
// meaningful. If the reservoir capacity would fall below this, SampleRow will
// err instead of decreasing it further.
minNumSamples int
// sampleCols contains the ordinals of columns that should be sampled from
// each row. Note that the sampled rows still contain all columns, but
// any columns not part of this set are given a null value.
sampleCols intsets.Fast
}
var _ heap.Interface = &SampleReservoir{}
// Init initializes a SampleReservoir.
func (sr *SampleReservoir) Init(
numSamples, minNumSamples int,
colTypes []*types.T,
memAcc *mon.BoundAccount,
sampleCols intsets.Fast,
) {
if minNumSamples < 1 || minNumSamples > numSamples {
minNumSamples = numSamples
}
sr.samples = make([]SampledRow, 0, numSamples)
sr.minNumSamples = minNumSamples
sr.colTypes = colTypes
sr.memAcc = memAcc
sr.sampleCols = sampleCols
}
// Disable releases the memory of this SampleReservoir and sets its capacity
// to zero.
func (sr *SampleReservoir) Disable() {
sr.samples = nil
}
// Len is part of heap.Interface.
func (sr *SampleReservoir) Len() int {
return len(sr.samples)
}
// Cap returns K, the maximum number of samples the reservoir can hold.
func (sr *SampleReservoir) Cap() int {
return cap(sr.samples)
}
// Less is part of heap.Interface.
func (sr *SampleReservoir) Less(i, j int) bool {
// We want a max heap, so higher ranks sort first.
return sr.samples[i].Rank > sr.samples[j].Rank
}
// Swap is part of heap.Interface.
func (sr *SampleReservoir) Swap(i, j int) {
sr.samples[i], sr.samples[j] = sr.samples[j], sr.samples[i]
}
// Push is part of heap.Interface, but we're not using it.
func (sr *SampleReservoir) Push(x interface{}) { panic("unimplemented") }
// Pop is part of heap.Interface.
func (sr *SampleReservoir) Pop() interface{} {
n := len(sr.samples)
samp := sr.samples[n-1]
sr.samples[n-1] = SampledRow{} // Avoid leaking the popped sample.
sr.samples = sr.samples[:n-1]
return samp
}
// MaybeResize safely shrinks the capacity of the reservoir (K) without
// introducing bias if the requested capacity is less than the current
// capacity, and returns whether the capacity changed. (Note that the capacity
// can only decrease without introducing bias. Increasing capacity would cause
// later rows to be over-represented relative to earlier rows.)
func (sr *SampleReservoir) MaybeResize(ctx context.Context, k int) bool {
if k >= cap(sr.samples) {
return false
}
// Make sure we have initialized the heap before popping.
heap.Init(sr)
for len(sr.samples) > k {
samp := heap.Pop(sr).(SampledRow)
if sr.memAcc != nil {
sr.memAcc.Shrink(ctx, int64(samp.Row.Size()))
}
}
// Copy to a new array to allow garbage collection.
samples := make([]SampledRow, len(sr.samples), k)
copy(samples, sr.samples)
sr.samples = samples
return true
}
// retryMaybeResize tries to execute a memory-allocating operation, shrinking
// the capacity of the reservoir (K) as necessary until the operation succeeds
// or the capacity reaches minNumSamples, at which point an error is returned.
func (sr *SampleReservoir) retryMaybeResize(ctx context.Context, op func() error) error {
for {
if err := op(); err == nil || !sqlerrors.IsOutOfMemoryError(err) ||
len(sr.samples) == 0 || len(sr.samples)/2 < sr.minNumSamples {
return err
}
// We've used too much memory. Remove half the samples and try again.
sr.MaybeResize(ctx, len(sr.samples)/2)
}
}
// SampleRow looks at a row and either drops it or adds it to the reservoir. The
// capacity of the reservoir (K) will shrink if it hits a memory limit. If
// capacity goes below minNumSamples, SampleRow will return an error. If
// SampleRow returns an error (any type of error), no additional calls to
// SampleRow should be made as the failed samples will have introduced bias.
func (sr *SampleReservoir) SampleRow(
ctx context.Context, evalCtx *eval.Context, row rowenc.EncDatumRow, rank uint64,
) error {
return sr.retryMaybeResize(ctx, func() error {
if len(sr.samples) < cap(sr.samples) {
// We haven't accumulated enough rows yet, just append.
rowCopy := sr.ra.AllocRow(len(row))
// Perform memory accounting for the allocated EncDatumRow. We will
// account for the additional memory used after copying inside copyRow.
if sr.memAcc != nil {
if err := sr.memAcc.Grow(ctx, int64(rowCopy.Size())); err != nil {
return err
}
}
if err := sr.copyRow(ctx, evalCtx, rowCopy, row); err != nil {
return err
}
sr.samples = append(sr.samples, SampledRow{Row: rowCopy, Rank: rank})
if len(sr.samples) == cap(sr.samples) {
// We just reached the limit; initialize the heap.
heap.Init(sr)
}
return nil
}
// Replace the max rank if ours is smaller.
if len(sr.samples) > 0 && rank < sr.samples[0].Rank {
if err := sr.copyRow(ctx, evalCtx, sr.samples[0].Row, row); err != nil {
// WARNING: At this point sr.samples[0].Row might have a mix of old and
// new values. The caller must call heap.Pop() to keep using the
// reservoir.
return err
}
sr.samples[0].Rank = rank
heap.Fix(sr, 0)
}
return nil
})
}
// Get returns the sampled rows.
func (sr *SampleReservoir) Get() []SampledRow {
return sr.samples
}
// GetNonNullDatums returns the non-null values of the specified column. The
// capacity of the reservoir (K) will shrink if we hit a memory limit while
// building this return slice. If the capacity goes below minNumSamples,
// GetNonNullDatums will return an error.
func (sr *SampleReservoir) GetNonNullDatums(
ctx context.Context, memAcc *mon.BoundAccount, colIdx int,
) (values tree.Datums, err error) {
err = sr.retryMaybeResize(ctx, func() error {
// Account for the memory we'll use copying the samples into values.
if memAcc != nil {
if err := memAcc.Grow(ctx, memsize.DatumOverhead*int64(len(sr.samples))); err != nil {
return err
}
}
values = make(tree.Datums, 0, len(sr.samples))
for _, sample := range sr.samples {
d := sample.Row[colIdx].Datum
if d == nil {
values = nil
return errors.AssertionFailedf("value in column %d not decoded", colIdx)
}
if d != tree.DNull {
values = append(values, d)
}
}
return nil
})
return
}
func (sr *SampleReservoir) copyRow(
ctx context.Context, evalCtx *eval.Context, dst, src rowenc.EncDatumRow,
) error {
for i := range src {
if !sr.sampleCols.Contains(i) {
dst[i].Datum = tree.DNull
continue
}
// Copy only the decoded datum to ensure that we remove any reference to
// the encoded bytes. The encoded bytes would have been scanned in a batch
// of ~10000 rows, so we must delete the reference to allow the garbage
// collector to release the memory from the batch.
if err := src[i].EnsureDecoded(sr.colTypes[i], &sr.da); err != nil {
return err
}
beforeSize := dst[i].Size()
dst[i] = rowenc.DatumToEncDatum(sr.colTypes[i], src[i].Datum)
afterSize := dst[i].Size()
// If the datum is too large, truncate it.
if afterSize > uintptr(maxBytesPerSample) {
dst[i].Datum = truncateDatum(evalCtx, dst[i].Datum, maxBytesPerSample)
afterSize = dst[i].Size()
}
// Perform memory accounting.
if sr.memAcc != nil {
if err := sr.memAcc.Resize(ctx, int64(beforeSize), int64(afterSize)); err != nil {
return err
}
}
}
return nil
}
const maxBytesPerSample = 400
// truncateDatum truncates large datums to avoid using excessive memory or disk
// space. It performs a best-effort attempt to return a datum that is similar
// to d using at most maxBytes bytes.
//
// For example, if maxBytes=10, "Cockroach Labs" would be truncated to
// "Cockroach ".
func truncateDatum(evalCtx *eval.Context, d tree.Datum, maxBytes int) tree.Datum {
switch t := d.(type) {
case *tree.DBitArray:
b := tree.DBitArray{BitArray: t.ToWidth(uint(maxBytes * 8))}
return &b
case *tree.DBytes:
// Make a copy so the memory from the original byte string can be garbage
// collected.
b := make([]byte, maxBytes)
copy(b, *t)
return tree.NewDBytes(tree.DBytes(b))
case *tree.DString:
return tree.NewDString(truncateString(string(*t), maxBytes))
case *tree.DCollatedString:
contents := truncateString(t.Contents, maxBytes)
// Note: this will end up being larger than maxBytes due to the key and
// locale, so this is just a best-effort attempt to limit the size.
res, err := tree.NewDCollatedString(contents, t.Locale, &evalCtx.CollationEnv)
if err != nil {
return d
}
return res
case *tree.DOidWrapper:
return &tree.DOidWrapper{
Wrapped: truncateDatum(evalCtx, t.Wrapped, maxBytes),
Oid: t.Oid,
}
default:
// It's not easy to truncate other types (e.g. Decimal).
return d
}
}
// truncateString truncates long strings to the longest valid substring that is
// less than maxBytes bytes. It is rune-aware so it does not cut unicode
// characters in half.
func truncateString(s string, maxBytes int) string {
last := 0
// For strings, range skips from rune to rune and i is the byte index of
// the current rune.
for i := range s {
if i > maxBytes {
break
}
last = i
}
// Copy the truncated string so that the memory from the longer string can
// be garbage collected.
b := make([]byte, last)
copy(b, s)
return string(b)
}