/
proc_iters.go
343 lines (293 loc) · 9.14 KB
/
proc_iters.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rowexec
import (
"errors"
"fmt"
"io"
"strings"
"github.com/dolthub/vitess/go/mysql"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
)
// ifElseIter is the row iterator for *IfElseBlock.
type ifElseIter struct {
branchIter sql.RowIter
sch sql.Schema
branchNode sql.Node
}
var _ plan.BlockRowIter = (*ifElseIter)(nil)
// Next implements the sql.RowIter interface.
func (i *ifElseIter) Next(ctx *sql.Context) (sql.Row, error) {
if err := startTransaction(ctx); err != nil {
return nil, err
}
return i.branchIter.Next(ctx)
}
// Close implements the sql.RowIter interface.
func (i *ifElseIter) Close(ctx *sql.Context) error {
return i.branchIter.Close(ctx)
}
// RepresentingNode implements the sql.BlockRowIter interface.
func (i *ifElseIter) RepresentingNode() sql.Node {
return i.branchNode
}
// Schema implements the sql.BlockRowIter interface.
func (i *ifElseIter) Schema() sql.Schema {
return i.sch
}
// beginEndIter is the sql.RowIter of *BeginEndBlock.
type beginEndIter struct {
*plan.BeginEndBlock
rowIter sql.RowIter
}
var _ sql.RowIter = (*beginEndIter)(nil)
// Next implements the interface sql.RowIter.
func (b *beginEndIter) Next(ctx *sql.Context) (sql.Row, error) {
if err := startTransaction(ctx); err != nil {
return nil, err
}
row, err := b.rowIter.Next(ctx)
if err != nil {
if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(b.Label) {
if controlFlow.IsExit {
err = nil
} else {
err = fmt.Errorf("encountered ITERATE on BEGIN...END, which should should have been caught by the analyzer")
}
}
if nErr := b.Pref.PopScope(ctx); nErr != nil && err == io.EOF {
err = nErr
}
if errors.Is(err, expression.FetchEOF) {
err = io.EOF
}
return nil, err
}
return row, nil
}
// Close implements the interface sql.RowIter.
func (b *beginEndIter) Close(ctx *sql.Context) error {
return b.rowIter.Close(ctx)
}
// callIter is the row iterator for *Call.
type callIter struct {
call *plan.Call
innerIter sql.RowIter
}
// Next implements the sql.RowIter interface.
func (iter *callIter) Next(ctx *sql.Context) (sql.Row, error) {
return iter.innerIter.Next(ctx)
}
// Close implements the sql.RowIter interface.
func (iter *callIter) Close(ctx *sql.Context) error {
err := iter.innerIter.Close(ctx)
if err != nil {
return err
}
err = iter.call.Pref.CloseAllCursors(ctx)
if err != nil {
return err
}
// Set all user and system variables from INOUT and OUT params
for i, param := range iter.call.Procedure.Params {
if param.Direction == plan.ProcedureParamDirection_Inout ||
(param.Direction == plan.ProcedureParamDirection_Out && iter.call.Pref.VariableHasBeenSet(param.Name)) {
val, err := iter.call.Pref.GetVariableValue(param.Name)
if err != nil {
return err
}
typ := iter.call.Pref.GetVariableType(param.Name)
switch callParam := iter.call.Params[i].(type) {
case *expression.UserVar:
err = ctx.SetUserVariable(ctx, callParam.Name, val, typ)
if err != nil {
return err
}
case *expression.SystemVar:
// This should have been caught by the analyzer, so a major bug exists somewhere
return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name)
case *expression.ProcedureParam:
err = callParam.Set(val, param.Type)
if err != nil {
return err
}
}
} else if param.Direction == plan.ProcedureParamDirection_Out { // VariableHasBeenSet was false
// For OUT only, if a var was not set within the procedure body, then we set the vars to nil.
// If the var had a value before the call then it is basically removed.
switch callParam := iter.call.Params[i].(type) {
case *expression.UserVar:
err = ctx.SetUserVariable(ctx, callParam.Name, nil, iter.call.Pref.GetVariableType(param.Name))
if err != nil {
return err
}
case *expression.SystemVar:
// This should have been caught by the analyzer, so a major bug exists somewhere
return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name)
case *expression.ProcedureParam:
err := callParam.Set(nil, param.Type)
if err != nil {
return err
}
}
}
}
return nil
}
type elseCaseErrorIter struct{}
var _ sql.RowIter = elseCaseErrorIter{}
// Next implements the interface sql.RowIter.
func (e elseCaseErrorIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, mysql.NewSQLError(1339, "20000", "Case not found for CASE statement")
}
// Close implements the interface sql.RowIter.
func (e elseCaseErrorIter) Close(context *sql.Context) error {
return nil
}
// openIter is the sql.RowIter of *Open.
type openIter struct {
pRef *expression.ProcedureReference
name string
row sql.Row
b *BaseBuilder
}
var _ sql.RowIter = (*openIter)(nil)
// Next implements the interface sql.RowIter.
func (o *openIter) Next(ctx *sql.Context) (sql.Row, error) {
if err := o.openCursor(ctx, o.pRef, o.name, o.row); err != nil {
return nil, err
}
return nil, io.EOF
}
func (o *openIter) openCursor(ctx *sql.Context, ref *expression.ProcedureReference, name string, row sql.Row) error {
lowerName := strings.ToLower(name)
scope := ref.InnermostScope
for scope != nil {
if cursorRefVal, ok := scope.Cursors[lowerName]; ok {
if cursorRefVal.RowIter != nil {
return sql.ErrCursorAlreadyOpen.New(name)
}
var err error
cursorRefVal.RowIter, err = o.b.buildNodeExec(ctx, cursorRefVal.SelectStmt, row)
return err
}
scope = scope.Parent
}
return fmt.Errorf("cannot find cursor `%s`", name)
}
// Close implements the interface sql.RowIter.
func (o *openIter) Close(ctx *sql.Context) error {
return nil
}
// closeIter is the sql.RowIter of *Close.
type closeIter struct {
pRef *expression.ProcedureReference
name string
}
var _ sql.RowIter = (*closeIter)(nil)
// Next implements the interface sql.RowIter.
func (c *closeIter) Next(ctx *sql.Context) (sql.Row, error) {
if err := c.pRef.CloseCursor(ctx, c.name); err != nil {
return nil, err
}
return nil, io.EOF
}
// Close implements the interface sql.RowIter.
func (c *closeIter) Close(ctx *sql.Context) error {
return nil
}
// loopError is an error used to control a loop's flow.
type loopError struct {
Label string
IsExit bool
}
var _ error = loopError{}
// Error implements the interface error. As long as the analysis step is implemented correctly, this should never be seen.
func (l loopError) Error() string {
option := "exited"
if !l.IsExit {
option = "continued"
}
return fmt.Sprintf("should have %s the loop `%s` but it was somehow not found in the call stack", option, l.Label)
}
// loopAcquireRowIter is a helper function for LOOP that conditionally acquires a new sql.RowIter. If a loop exit is
// encountered, `exitIter` determines whether to return an empty iterator or an io.EOF error.
func (b *BaseBuilder) loopAcquireRowIter(ctx *sql.Context, row sql.Row, label string, block *plan.Block, exitIter bool) (sql.RowIter, error) {
blockIter, err := b.buildBlock(ctx, block, row)
if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(label) {
if controlFlow.IsExit {
if exitIter {
return sql.RowsToRowIter(), nil
} else {
return nil, io.EOF
}
} else {
err = io.EOF
}
}
if err == io.EOF {
blockIter = sql.RowsToRowIter()
err = nil
}
return blockIter, err
}
// leaveIter is the sql.RowIter of *Leave.
type leaveIter struct {
Label string
}
var _ sql.RowIter = (*leaveIter)(nil)
// Next implements the interface sql.RowIter.
func (l *leaveIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, loopError{
Label: l.Label,
IsExit: true,
}
}
// Close implements the interface sql.RowIter.
func (l *leaveIter) Close(ctx *sql.Context) error {
return nil
}
// iterateIter is the sql.RowIter of *Iterate.
type iterateIter struct {
Label string
}
var _ sql.RowIter = (*iterateIter)(nil)
// Next implements the interface sql.RowIter.
func (i *iterateIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, loopError{
Label: i.Label,
IsExit: false,
}
}
// Close implements the interface sql.RowIter.
func (i *iterateIter) Close(ctx *sql.Context) error {
return nil
}
// startTransaction begins a new transaction if necessary, e.g. if a statement in a stored procedure committed the
// current one
func startTransaction(ctx *sql.Context) error {
if ctx.GetTransaction() == nil {
ts, ok := ctx.Session.(sql.TransactionSession)
if ok {
tx, err := ts.StartTransaction(ctx, sql.ReadWrite)
if err != nil {
return err
}
ctx.SetTransaction(tx)
}
}
return nil
}