-
Notifications
You must be signed in to change notification settings - Fork 67
/
join.go
262 lines (250 loc) · 6.68 KB
/
join.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
package join
import (
"context"
"fmt"
"sync"
"github.com/brimdata/zed"
"github.com/brimdata/zed/order"
"github.com/brimdata/zed/runtime"
"github.com/brimdata/zed/runtime/sam/expr"
"github.com/brimdata/zed/runtime/sam/op/sort"
"github.com/brimdata/zed/zbuf"
"github.com/brimdata/zed/zio"
)
type Op struct {
rctx *runtime.Context
anti bool
inner bool
ctx context.Context
cancel context.CancelFunc
once sync.Once
left *puller
right *zio.Peeker
getLeftKey expr.Evaluator
getRightKey expr.Evaluator
compare expr.CompareFn
cutter *expr.Cutter
joinKey *zed.Value
joinSet []zed.Value
types map[int]map[int]*zed.TypeRecord
}
func New(rctx *runtime.Context, anti, inner bool, left, right zbuf.Puller, leftKey, rightKey expr.Evaluator,
leftDir, rightDir order.Direction, lhs []*expr.Lval,
rhs []expr.Evaluator) (*Op, error) {
var o order.Which
switch {
case leftDir != order.Unknown:
o = leftDir == order.Down
case rightDir != order.Unknown:
o = rightDir == order.Down
}
var err error
// Add sorts if needed.
if !leftDir.HasOrder(o) {
left, err = sort.New(rctx, left, []expr.Evaluator{leftKey}, o, false)
if err != nil {
return nil, err
}
}
if !rightDir.HasOrder(o) {
right, err = sort.New(rctx, right, []expr.Evaluator{rightKey}, o, false)
if err != nil {
return nil, err
}
}
ctx, cancel := context.WithCancel(rctx.Context)
return &Op{
rctx: rctx,
anti: anti,
inner: inner,
ctx: ctx,
cancel: cancel,
getLeftKey: leftKey,
getRightKey: rightKey,
left: newPuller(left, ctx),
right: zio.NewPeeker(newPuller(right, ctx)),
compare: expr.NewValueCompareFn(o, true),
cutter: expr.NewCutter(rctx.Zctx, lhs, rhs),
types: make(map[int]map[int]*zed.TypeRecord),
}, nil
}
// Pull implements the merge logic for returning data from the upstreams.
func (o *Op) Pull(done bool) (zbuf.Batch, error) {
// XXX see issue #3437 regarding done protocol.
o.once.Do(func() {
go o.left.run()
go o.right.Reader.(*puller).run()
})
var out []zed.Value
// See #3366
ectx := expr.NewContext()
for {
leftRec, err := o.left.Read()
if err != nil {
return nil, err
}
if leftRec == nil {
if len(out) == 0 {
return nil, nil
}
//XXX See issue #3427.
return zbuf.NewArray(out), nil
}
key := o.getLeftKey.Eval(ectx, *leftRec)
if key.IsMissing() {
// If the left key isn't present (which is not a thing
// in a sql join), then drop the record and return only
// left records that can eval the key expression.
continue
}
rightRecs, err := o.getJoinSet(key)
if err != nil {
return nil, err
}
if rightRecs == nil {
// Nothing to add to the left join.
// Accumulate this record for an outer join.
if !o.inner {
out = append(out, leftRec.Copy())
}
continue
}
if o.anti {
continue
}
// For every record on the right with a key matching
// this left record, generate a joined record.
// XXX This loop could be more efficient if we had CutAppend
// and built the record in a re-usable buffer, then allocated
// a right-sized output buffer for the record body and copied
// the two inputs into the output buffer. Even better, these
// output buffers could come from a large buffer that implements
// Batch and lives in a pool so the downstream user can
// release the batch with and bypass GC.
for _, rightRec := range rightRecs {
cutRec := o.cutter.Eval(ectx, rightRec)
rec, err := o.splice(*leftRec, cutRec)
if err != nil {
return nil, err
}
out = append(out, rec)
}
}
}
func (o *Op) getJoinSet(leftKey zed.Value) ([]zed.Value, error) {
if o.joinKey != nil && o.compare(leftKey, *o.joinKey) == 0 {
return o.joinSet, nil
}
// See #3366
ectx := expr.NewContext()
for {
rec, err := o.right.Peek()
if err != nil || rec == nil {
return nil, err
}
rightKey := o.getRightKey.Eval(ectx, *rec)
if rightKey.IsMissing() {
o.right.Read()
continue
}
cmp := o.compare(leftKey, rightKey)
if cmp == 0 {
// Copy leftKey.Bytes since it might get reused.
if o.joinKey == nil {
o.joinKey = leftKey.Copy().Ptr()
} else {
o.joinKey.CopyFrom(leftKey)
}
o.joinSet, err = o.readJoinSet(o.joinKey)
return o.joinSet, err
}
if cmp < 0 {
// If the left key is smaller than the next eligible
// join key, then there is nothing to join for this
// record.
return nil, nil
}
// Discard the peeked-at record and keep looking for
// a righthand key that either matches or exceeds the
// lefthand key.
o.right.Read()
}
}
// fillJoinSet is called when a join key has been found that matches
// the current lefthand key. It returns the all the subsequent records
// from the righthand stream that match this key.
func (o *Op) readJoinSet(joinKey *zed.Value) ([]zed.Value, error) {
var recs []zed.Value
// See #3366
ectx := expr.NewContext()
for {
rec, err := o.right.Peek()
if err != nil {
return nil, err
}
if rec == nil {
return recs, nil
}
key := o.getRightKey.Eval(ectx, *rec)
if key.IsMissing() {
o.right.Read()
continue
}
if o.compare(key, *joinKey) != 0 {
return recs, nil
}
recs = append(recs, rec.Copy())
o.right.Read()
}
}
func (o *Op) lookupType(left, right *zed.TypeRecord) *zed.TypeRecord {
if table, ok := o.types[left.ID()]; ok {
return table[right.ID()]
}
return nil
}
func (o *Op) enterType(combined, left, right *zed.TypeRecord) {
id := left.ID()
table := o.types[id]
if table == nil {
table = make(map[int]*zed.TypeRecord)
o.types[id] = table
}
table[right.ID()] = combined
}
func (o *Op) buildType(left, right *zed.TypeRecord) (*zed.TypeRecord, error) {
fields := make([]zed.Field, 0, len(left.Fields)+len(right.Fields))
fields = append(fields, left.Fields...)
for _, f := range right.Fields {
name := f.Name
for k := 2; left.HasField(name); k++ {
name = fmt.Sprintf("%s_%d", f.Name, k)
}
fields = append(fields, zed.NewField(name, f.Type))
}
return o.rctx.Zctx.LookupTypeRecord(fields)
}
func (o *Op) combinedType(left, right *zed.TypeRecord) (*zed.TypeRecord, error) {
if typ := o.lookupType(left, right); typ != nil {
return typ, nil
}
typ, err := o.buildType(left, right)
if err != nil {
return nil, err
}
o.enterType(typ, left, right)
return typ, nil
}
func (o *Op) splice(left, right zed.Value) (zed.Value, error) {
left = left.Under()
right = right.Under()
typ, err := o.combinedType(zed.TypeRecordOf(left.Type()), zed.TypeRecordOf(right.Type()))
if err != nil {
return zed.Null, err
}
n := len(left.Bytes())
bytes := make([]byte, n+len(right.Bytes()))
copy(bytes, left.Bytes())
copy(bytes[n:], right.Bytes())
return zed.NewValue(typ, bytes), nil
}