/
parallelism.go
61 lines (54 loc) · 1.49 KB
/
parallelism.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
package enginetest
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
)
func TestParallelismQueries(t *testing.T, harness Harness) {
harness.Setup(setup.XySetup...)
e := mustNewEngine(t, harness)
defer e.Close()
for _, tt := range queries.ParallelismTests {
t.Run(tt.Query, func(t *testing.T) {
evalParallelismTest(t, harness, e, tt.Query, tt.Parallel)
})
}
}
func evalParallelismTest(t *testing.T, harness Harness, e QueryEngine, query string, parallel bool) {
ctx := NewContext(harness)
ctx = ctx.WithQuery(query)
a, err := analyzeQuery(ctx, e, query)
require.NoError(t, err)
require.Equal(t, parallel, findExchange(a), fmt.Sprintf("expected exchange: %t\nplan:\n%s", parallel, sql.DebugString(a)))
}
func findExchange(n sql.Node) bool {
return transform.InspectUp(n, func(n sql.Node) bool {
if n == nil {
return false
}
_, ok := n.(*plan.Exchange)
if ok {
return true
}
if ex, ok := n.(sql.Expressioner); ok {
for _, e := range ex.Expressions() {
found := transform.InspectExpr(e, func(e sql.Expression) bool {
sq, ok := e.(*plan.Subquery)
if !ok {
return false
}
return findExchange(sq.Query)
})
if found {
return true
}
}
}
return false
})
}