Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max/sqlerror codes #297

Merged
merged 9 commits into from
Feb 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func TestSingleQuery(t *testing.T) {

var test enginetest.QueryTest
test = enginetest.QueryTest{
Query: "SELECT i from mytable where 4 = :foo * 2 order by 1",
Query: "SELECT i from mytable where 4 = :foo * 2 order by 1",
Expected: []sql.Row{
{1},
{2},
Expand Down
8 changes: 4 additions & 4 deletions enginetest/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ var QueryTests = []QueryTest{
},
},
{
Query: "SELECT :foo * 2",
Query: "SELECT :foo * 2",
Expected: []sql.Row{
{2},
},
Expand All @@ -296,7 +296,7 @@ var QueryTests = []QueryTest{
},
},
{
Query: "SELECT i from mytable where i in (:foo, :bar) order by 1",
Query: "SELECT i from mytable where i in (:foo, :bar) order by 1",
Expected: []sql.Row{
{1},
{2},
Expand All @@ -307,7 +307,7 @@ var QueryTests = []QueryTest{
},
},
{
Query: "SELECT i from mytable where i = :foo * 2",
Query: "SELECT i from mytable where i = :foo * 2",
Expected: []sql.Row{
{2},
},
Expand All @@ -316,7 +316,7 @@ var QueryTests = []QueryTest{
},
},
{
Query: "SELECT i from mytable where 4 = :foo * 2 order by 1",
Query: "SELECT i from mytable where 4 = :foo * 2 order by 1",
Expected: []sql.Row{
{1},
{2},
Expand Down
20 changes: 18 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (h *Handler) ComPrepare(c *mysql.Conn, query string) ([]*query.Field, error
}

func (h *Handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
return h.doQuery(c, prepare.PrepareStmt, prepare.BindVars, callback)
return h.errorWrappedDoQuery(c, prepare.PrepareStmt, prepare.BindVars, callback)
}

func (h *Handler) ComResetConnection(c *mysql.Conn) {
Expand Down Expand Up @@ -156,7 +156,7 @@ func (h *Handler) ComQuery(
query string,
callback func(*sqltypes.Result) error,
) error {
return h.doQuery(c, query, nil, callback)
return h.errorWrappedDoQuery(c, query, nil, callback)
}

func bindingsToExprs(bindings map[string]*query.BindVariable) (map[string]sql.Expression, error) {
Expand Down Expand Up @@ -455,6 +455,22 @@ rowLoop:
return callback(r)
}

// Call doQuery and cast known errors to SQLError
func (h *Handler) errorWrappedDoQuery(
c *mysql.Conn,
query string,
bindings map[string]*query.BindVariable,
callback func(*sqltypes.Result) error,
) error {
err := h.doQuery(c, query, bindings, callback)
err, ok := sql.CastSQLError(err)
if ok {
return nil
} else {
return err
}
}

// Periodically polls the connection socket to determine if it is has been closed by the client, sending an error on
// the supplied error channel if it has. Meant to be run in a separate goroutine from the query handler routine.
// Returns immediately on platforms that can't support TCP socket checks.
Expand Down
2 changes: 1 addition & 1 deletion server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func TestHandlerTimeout(t *testing.T) {
err := timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result) error {
return nil
})
require.EqualError(err, "row read wait bigger than connection timeout")
require.EqualError(err, "row read wait bigger than connection timeout (errno 1105) (sqlstate HY000)")

err = timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(0.5)", func(res *sqltypes.Result) error {
return nil
Expand Down
23 changes: 22 additions & 1 deletion sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package sql

import "gopkg.in/src-d/go-errors.v1"
import (
"github.com/dolthub/vitess/go/mysql"
"gopkg.in/src-d/go-errors.v1"
)

var (
// ErrInvalidType is thrown when there is an unexpected type at some part of
Expand Down Expand Up @@ -127,3 +130,21 @@ var (
// ErrTruncateReferencedFromForeignKey is returned when a table is referenced in a foreign key and TRUNCATE is called on it.
ErrTruncateReferencedFromForeignKey = errors.NewKind("cannot truncate table %s as it is referenced in foreign key %s on table %s")
)

func CastSQLError(err error) (*mysql.SQLError, bool) {
if err == nil {
return nil, true
}

var code int
var sqlState string = ""

switch {
case ErrTableNotFound.Is(err):
code = mysql.ERNoSuchTable
default:
code = mysql.ERUnknownError
}

return mysql.NewSQLError(code, sqlState, err.Error()), false
}
36 changes: 36 additions & 0 deletions sql/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package sql

import (
"fmt"
"github.com/dolthub/vitess/go/mysql"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSQLErrorCast(t *testing.T) {

tests := []struct {
err error
code int
}{
{ErrTableNotFound.New("table not found err"), mysql.ERNoSuchTable},
{ErrInvalidType.New("unhandled mysql error"), mysql.ERUnknownError},
{fmt.Errorf("generic error"), mysql.ERUnknownError},
{nil, mysql.ERUnknownError},
}

for _, test := range tests {
var nilErr *mysql.SQLError = nil
t.Run(fmt.Sprintf("%v %v", test.err, test.code), func(t *testing.T) {
err, ok := CastSQLError(test.err)
if !ok {
require.Error(t, err)
assert.Equal(t, err.Number(), test.code)
} else {
assert.Equal(t, err, nilErr)
}
})
}
}