Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing error types #117

Merged
merged 8 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 26 additions & 19 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"time"

"github.com/databricks/databricks-sql-go/driverctx"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerr "github.com/databricks/databricks-sql-go/internal/err"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/rows"
"github.com/databricks/databricks-sql-go/internal/sentinel"
"github.com/databricks/databricks-sql-go/logger"
Expand Down Expand Up @@ -46,19 +47,19 @@ func (c *conn) Close() error {

if err != nil {
log.Err(err).Msg("databricks: failed to close connection")
return dbsqlerr.WrapErr(err, "failed to close connection")
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
}
return nil
}

// Not supported in Databricks.
func (c *conn) Begin() (driver.Tx, error) {
return nil, errors.New(dbsqlerr.ErrTransactionsNotSupported)
return nil, dbsqlerrint.NewDriverError(context.TODO(), dbsqlerr.ErrNotImplemented, nil)
}

// Not supported in Databricks.
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return nil, errors.New(dbsqlerr.ErrTransactionsNotSupported)
return nil, dbsqlerrint.NewDriverError(context.TODO(), dbsqlerr.ErrNotImplemented, nil)
}

// Ping attempts to verify that the server is accessible.
Expand Down Expand Up @@ -100,7 +101,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name

ctx = driverctx.NewContextWithConnId(ctx, c.id)
if len(args) > 0 {
return nil, errors.New(dbsqlerr.ErrParametersNotSupported)
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
}
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)

Expand All @@ -122,7 +123,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
}
if err != nil {
log.Err(err).Msgf("databricks: failed to execute query: query %s", query)
return nil, dbsqlerr.WrapErrf(err, "failed to execute query")
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
}

res := result{AffectedRows: opStatusResp.GetNumModifiedRows()}
Expand All @@ -142,20 +143,21 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam

ctx = driverctx.NewContextWithConnId(ctx, c.id)
if len(args) > 0 {
return nil, errors.New(dbsqlerr.ErrParametersNotSupported)
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
}
// first we try to get the results synchronously.
// at any point in time that the context is done we must cancel and return
exStmtResp, _, err := c.runQuery(ctx, query, args)
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)

if exStmtResp != nil && exStmtResp.OperationHandle != nil {
ctx = driverctx.NewContextWithQueryId(ctx, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
}
defer log.Duration(msg, start)

if err != nil {
log.Err(err).Msg("databricks: failed to run query") // To log query we need to redact credentials
return nil, dbsqlerr.WrapErrf(err, "failed to run query")
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
}
// hold on to the operation handle
opHandle := exStmtResp.OperationHandle
Expand All @@ -177,9 +179,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
}
opHandle := exStmtResp.OperationHandle
if opHandle != nil && opHandle.OperationId != nil {
ctx = driverctx.NewContextWithQueryId(ctx, client.SprintGuid(opHandle.OperationId.GUID))
log = logger.WithContext(
c.id,
driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(opHandle.OperationId.GUID),
driverctx.CorrelationIdFromContext(ctx), driverctx.QueryIdFromContext(ctx),
)
}

Expand Down Expand Up @@ -217,16 +220,16 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
cli_service.TOperationState_ERROR_STATE,
cli_service.TOperationState_TIMEDOUT_STATE:
logBadQueryState(log, statusResp)
return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage())
return exStmtResp, statusResp, dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrInvalidOperationState, nil)
// live states
default:
logBadQueryState(log, statusResp)
return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened")
return exStmtResp, statusResp, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrInvalidOperationState, nil)
}
// weird states
default:
logBadQueryState(log, opStatus)
return exStmtResp, opStatus, errors.New("invalid operation state. This should not have happened")
return exStmtResp, opStatus, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrInvalidOperationState, nil)
}

} else {
Expand All @@ -245,11 +248,11 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
cli_service.TOperationState_ERROR_STATE,
cli_service.TOperationState_TIMEDOUT_STATE:
logBadQueryState(log, statusResp)
return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage())
return exStmtResp, statusResp, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrInvalidOperationState, nil)
// live states
default:
logBadQueryState(log, statusResp)
return exStmtResp, statusResp, errors.New("invalid operation state. This should not have happened")
return exStmtResp, statusResp, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrInvalidOperationState, nil)
}
}
}
Expand Down Expand Up @@ -311,7 +314,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
} else {
log.Debug().Msgf("databricks: cancel success")
}

} else {
log.Debug().Msg("databricks: query did not need cancellation")
}
Expand All @@ -337,6 +339,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
statusResp, err = c.client.GetOperationStatus(newCtx, &cli_service.TGetOperationStatusReq{
OperationHandle: opHandle,
})

if statusResp != nil && statusResp.OperationState != nil {
log.Debug().Msgf("databricks: status %s", statusResp.GetOperationState().String())
}
Expand All @@ -363,13 +366,17 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
return ret, err
},
}
_, resp, err := pollSentinel.Watch(ctx, c.cfg.PollInterval, 0)
status, resp, err := pollSentinel.Watch(ctx, c.cfg.PollInterval, 0)
if err != nil {
return nil, dbsqlerr.WrapErr(err, "failed to poll query state")
if status == sentinel.WatchTimeout {
err = dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrSentinelTimeout, err)
}
return nil, err
}

statusResp, ok := resp.(*cli_service.TGetOperationStatusResp)
if !ok {
return nil, errors.New("could not read query status")
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrReadQueryStatus, nil)
}
return statusResp, nil
}
Expand Down
3 changes: 2 additions & 1 deletion connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func TestConn_executeStatement(t *testing.T) {
if opTest.err == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, opTest.err)
assert.EqualError(t, err, "databricks: execution error: failed to execute query: "+opTest.err)
}
assert.Equal(t, 1, executeStatementCount)
assert.Equal(t, opTest.closeOperationCount, closeOperationCount)
Expand Down Expand Up @@ -539,6 +539,7 @@ func TestConn_pollOperation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
res, err := testConn.pollOperation(ctx, &cli_service.TOperationHandle{

OperationId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 4, 4, 223, 34, 54},
Secret: []byte("b"),
Expand Down
7 changes: 4 additions & 3 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import (

"github.com/databricks/databricks-sql-go/auth/pat"
"github.com/databricks/databricks-sql-go/driverctx"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerr "github.com/databricks/databricks-sql-go/internal/err"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/logger"
)

Expand All @@ -35,7 +36,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {

tclient, err := client.InitThriftClient(c.cfg, c.client)
if err != nil {
return nil, dbsqlerr.WrapErr(err, "error initializing thrift client")
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrThriftClient, err)
}
protocolVersion := int64(c.cfg.ThriftProtocolVersion)
session, err := tclient.OpenSession(ctx, &cli_service.TOpenSessionReq{
Expand All @@ -49,7 +50,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
})

if err != nil {
return nil, dbsqlerr.WrapErrf(err, "error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath)
return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err)
}

conn := &conn{
Expand Down
46 changes: 46 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,52 @@ The result log may look like this:

{"level":"debug","connId":"01ed6545-5669-1ec7-8c7e-6d8a1ea0ab16","corrId":"workflow-example","queryId":"01ed6545-57cc-188a-bfc5-d9c0eaf8e189","time":1668558402,"message":"Run Main elapsed time: 1.298712292s"}

# Errors

There are three error types exposed via dbsql/errors

DBDriverError - An error in the go driver. Example: unimplemented functionality, invalid driver state, errors processing a server response, etc.

DBRequestError - An error that is caused by an invalid request. Example: permission denied, invalid http path or other connection parameter, resource not available, etc.

DBExecutionError - Any error that occurs after the SQL statement has been accepted such as a SQL syntax error, missing table, etc.

Each type has a corresponding sentinel value which can be used with errors.Is() to determine if one of the types is present in an error chain.

DriverError
RequestError
ExecutionError

Example usage:

import (
fmt
errors
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
)

func main() {
...
_, err := db.ExecContext(ogCtx, `Select id from range(100)`)
if err != nil {
if errors.Is(err, dbsqlerr.ExecutionError) {
var execErr dbsqlerr.DBExecutionError
if ok := errors.As(err, &execError); ok {
fmt.Printf("%s, corrId: %s, connId: %s, queryId: %s, sqlState: %s",
execErr.Error(),
execErr.CorrelationId(),
execErr.ConnectionId(),
execErr.QueryId(),
execErr.SqlState())
}
}
...
}
...
}

See the documentation for dbsql/errors for more information.

# Supported Data Types

==================================
Expand Down
17 changes: 13 additions & 4 deletions driver_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
"time"

"github.com/databricks/databricks-sql-go/driverctx"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
dbsqlerr "github.com/databricks/databricks-sql-go/internal/err"
"github.com/databricks/databricks-sql-go/logger"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -281,11 +282,19 @@ func TestContextTimeoutExample(t *testing.T) {
ctx1, cancel := context.WithTimeout(ogCtx, 5*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx1, `SELECT id FROM RANGE(100000000) ORDER BY RANDOM() + 2 asc`)
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
if err, ok := err.(interface{ StackTrace() errors.StackTrace }); ok {
fmt.Printf("Stack trace: %v", err.StackTrace())
}
require.True(t, errors.Is(err, context.DeadlineExceeded))
require.True(t, errors.Is(err, dbsqlerr.ExecutionError))
var ee dbsqlerr.DBExecutionError
require.True(t, errors.As(err, &ee))
require.Equal(t, "context-timeout-example", ee.CorrelationId())
require.Nil(t, rows)
_, ok := err.(dbsqlerr.Causer)

_, ok := err.(interface{ Cause() error })
assert.True(t, ok)
_, ok = err.(dbsqlerr.StackTracer)
_, ok = err.(interface{ StackTrace() errors.StackTrace })
assert.True(t, ok)
assert.Equal(t, 1, state.executeStatementCalls)
assert.GreaterOrEqual(t, state.getOperationStatusCalls, 1)
Expand Down
46 changes: 46 additions & 0 deletions driverctx/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,24 @@ type contextKey int
const (
CorrelationIdContextKey contextKey = iota
ConnIdContextKey
QueryIdContextKey
QueryIdCallbackKey
ConnIdCallbackKey
)

type IdCallbackFunc func(string)

// NewContextWithCorrelationId creates a new context with correlationId value. Used by Logger to populate field corrId.
func NewContextWithCorrelationId(ctx context.Context, correlationId string) context.Context {
return context.WithValue(ctx, CorrelationIdContextKey, correlationId)
}

// CorrelationIdFromContext retrieves the correlationId stored in context.
func CorrelationIdFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}

corrId, ok := ctx.Value(CorrelationIdContextKey).(string)
if !ok {
return ""
Expand All @@ -29,14 +38,51 @@ func CorrelationIdFromContext(ctx context.Context) string {

// NewContextWithConnId creates a new context with connectionId value.
func NewContextWithConnId(ctx context.Context, connId string) context.Context {
if callback, ok := ctx.Value(ConnIdCallbackKey).(IdCallbackFunc); ok {
callback(connId)
}
return context.WithValue(ctx, ConnIdContextKey, connId)
}

// ConnIdFromContext retrieves the connectionId stored in context.
func ConnIdFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}

connId, ok := ctx.Value(ConnIdContextKey).(string)
if !ok {
return ""
}
return connId
}

// NewContextWithQueryId creates a new context with queryId value.
func NewContextWithQueryId(ctx context.Context, queryId string) context.Context {
if callback, ok := ctx.Value(QueryIdCallbackKey).(IdCallbackFunc); ok {
callback(queryId)
}

return context.WithValue(ctx, QueryIdContextKey, queryId)
}

// QueryIdFromContext retrieves the queryId stored in context.
func QueryIdFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}

queryId, ok := ctx.Value(QueryIdContextKey).(string)
if !ok {
return ""
}
return queryId
}

func NewContextWithQueryIdCallback(ctx context.Context, callback IdCallbackFunc) context.Context {
return context.WithValue(ctx, QueryIdCallbackKey, callback)
}

func NewContextWithConnIdCallback(ctx context.Context, callback IdCallbackFunc) context.Context {
return context.WithValue(ctx, ConnIdCallbackKey, callback)
}