Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ func TestTriggerViewWarning(t *testing.T) {
assert.NoError(t, err)

ctx := harness.NewContext()
ctx.SetCurrentDatabase("mydb")
enginetest.CreateNewConnectionForServerEngine(ctx, e)

enginetest.TestQueryWithContext(t, ctx, e, harness, "insert into mytable values (4, 'fourth row')", []sql.Row{{types.NewOkResult(1)}}, nil, nil, nil)
Expand Down Expand Up @@ -1000,6 +1001,7 @@ func TestAlterTableWithBadSchema(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := harness.NewContext()
ctx.SetCurrentDatabase("mydb")
_, iter, _, err := engine.Query(ctx, tt.q)
// errors should be analyze time, not execution time
if tt.err {
Expand Down
43 changes: 39 additions & 4 deletions enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q
if sh.SkipQueryTest(script.Name) {
t.Skip()
}

if !supportedDialect(harness, script.Dialect) {
t.Skip()
}
}

for _, statement := range script.SetUpScript {
Expand All @@ -95,6 +99,7 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q
t.Skip()
}
}

ctx = ctx.WithQuery(statement)
RunQueryWithContext(t, e, harness, ctx, statement)
}
Expand All @@ -120,10 +125,7 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q
ctx = th.NewSession()
}

if sh, ok := harness.(SkippingHarness); ok && sh.SkipQueryTest(assertion.Query) {
t.Skip()
}
if assertion.Skip {
if skipAssertion(t, harness, assertion) {
t.Skip()
}

Expand Down Expand Up @@ -158,6 +160,22 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q
})
}

func skipAssertion(t *testing.T, harness Harness, assertion queries.ScriptTestAssertion) bool {
if sh, ok := harness.(SkippingHarness); ok && sh.SkipQueryTest(assertion.Query) {
return true
}

if !supportedDialect(harness, assertion.Dialect) {
return true
}

if assertion.Skip {
return true
}

return false
}

// TestScriptPrepared substitutes literals for bindvars, runs the test script given,
// and makes any assertions given
func TestScriptPrepared(t *testing.T, harness Harness, script queries.ScriptTest) bool {
Expand Down Expand Up @@ -1113,6 +1131,11 @@ func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, t
t.Skip()
}
}

if !supportedDialect(harness, tt.Dialect) {
t.Skip()
}

ctx := NewContext(harness)
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil, nil)
expectedSelect := tt.ExpectedSelect
Expand All @@ -1122,6 +1145,18 @@ func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, t
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil, nil)
}

func supportedDialect(harness Harness, dialect string) bool {
if dialect == "" {
return true
}

harnessDialect := "mysql"
if hd, ok := harness.(DialectHarness); ok {
harnessDialect = hd.Dialect()
}
return harnessDialect == dialect
}

func runWriteQueryTestPrepared(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
t.Run(tt.WriteQuery, func(t *testing.T) {
if tt.Skip {
Expand Down
8 changes: 8 additions & 0 deletions enginetest/harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,11 @@ type ResultEvaluationHarness interface {
// EvaluateExpectedErrorKind compares expected error kinds to actual errors and emits failed test assertions in the
EvaluateExpectedErrorKind(t *testing.T, expected *errors.Kind, err error)
}

type DialectHarness interface {
Harness

// Dialect returns the dialect that the engine being tested supports. If this harness interface isn't implemented,
// the dialect "mysql" is used by engine tests.
Dialect() string
}
1 change: 1 addition & 0 deletions enginetest/join_stats_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestJoinStats(t *testing.T, harness Harness) {
e.EngineAnalyzer().Catalog.DbProvider = newPro.(sql.DatabaseProvider)

ctx := harness.NewContext()
ctx.SetCurrentDatabase("mydb")
for _, q := range tt.setup {
_, iter, _, err := e.Query(ctx, q)
require.NoError(t, err)
Expand Down
5 changes: 3 additions & 2 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ func TestJoinOps(t *testing.T) {
}

func TestJoinStats(t *testing.T) {
harness := enginetest.NewDefaultMemoryHarness()
// We keep join stats in the session, so we need to retain the session after setup
harness := enginetest.NewDefaultMemoryHarness().RetainSessionAfterSetup()
if harness.IsUsingServer() {
t.Skip("join stats don't work with bindvars")
}
Expand Down Expand Up @@ -450,7 +451,7 @@ func TestTpchQueryPlans(t *testing.T) {

for _, indexInit := range indexBehaviors {
t.Run(indexInit.name, func(t *testing.T) {
harness := enginetest.NewMemoryHarness(indexInit.name, 1, 1, indexInit.nativeIndexes, indexInit.driverInitializer)
harness := enginetest.NewMemoryHarness(indexInit.name, 1, 1, indexInit.nativeIndexes, indexInit.driverInitializer).RetainSessionAfterSetup()
enginetest.TestTpchPlans(t, harness)
})
}
Expand Down
11 changes: 11 additions & 0 deletions enginetest/memory_harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type MemoryHarness struct {
skippedQueries map[string]struct{}
session sql.Session
retainSession bool
retainSessionAfterSetup bool
setupData []setup.SetupScript
externalProcedureRegistry sql.ExternalStoredProcedureRegistry
server bool
Expand Down Expand Up @@ -98,6 +99,11 @@ func NewReadOnlyMemoryHarness() *MemoryHarness {
return h
}

func (m MemoryHarness) RetainSessionAfterSetup() *MemoryHarness {
m.retainSessionAfterSetup = true
return &m
}

func (m *MemoryHarness) SessionBuilder() server.SessionBuilder {
return func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
host := ""
Expand Down Expand Up @@ -191,6 +197,11 @@ func (m *MemoryHarness) NewEngine(t *testing.T) (QueryEngine, error) {
return NewServerQueryEngine(t, engine, m.SessionBuilder())
}

// reset the session to clear any session state that may have been set during engine creation
if !m.retainSessionAfterSetup {
m.session = nil
}

return engine, nil
}

Expand Down
Loading
Loading