Skip to content

Commit

Permalink
Implemented IsRetryable() and RetryAfter() for DatabricksError (#119)
Browse files Browse the repository at this point in the history
Added IsRetryable and RetryAfter functions to DBError interface.
Added an internal error type for retryable errors.
Updated client to insert a retryable error instance into the error
chain.
  • Loading branch information
rcypher-databricks committed Apr 20, 2023
2 parents 36b12cd + 014b68e commit 350ea35
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 55 deletions.
2 changes: 1 addition & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
setStmt := fmt.Sprintf("SET `%s` = `%s`;", k, v)
_, err := conn.ExecContext(ctx, setStmt, []driver.NamedValue{})
if err != nil {
return nil, err
return nil, dbsqlerrint.NewExecutionError(ctx, fmt.Sprintf("error setting session param: %s", setStmt), err, nil)
}
log.Info().Msgf("set session parameter: param=%s value=%s", k, v)
}
Expand Down
7 changes: 5 additions & 2 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,15 @@ Example usage:
if errors.Is(err, dbsqlerr.ExecutionError) {
var execErr dbsqlerr.DBExecutionError
if ok := errors.As(err, &execError); ok {
fmt.Printf("%s, corrId: %s, connId: %s, queryId: %s, sqlState: %s",
fmt.Printf("%s, corrId: %s, connId: %s, queryId: %s, sqlState: %s, isRetryable: %t, retryAfter: %f seconds",
execErr.Error(),
execErr.CorrelationId(),
execErr.ConnectionId(),
execErr.QueryId(),
execErr.SqlState())
execErr.SqlState(),
execErr.IsRetryable(),
execErr.RetryAfter().Seconds(),
)
}
}
...
Expand Down
68 changes: 68 additions & 0 deletions driver_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,74 @@ func TestRetries(t *testing.T) {
require.ErrorContains(t, err, "after 1 attempt(s)")
})

t.Run("a 429 or 503 should result in a retryable error", func(t *testing.T) {

_ = logger.SetLogLevel("debug")
state := &callState{}
// load basic responses
loadTestData(t, "OpenSessionSuccess.json", &state.openSessionResp)
loadTestData(t, "CloseSessionSuccess.json", &state.closeSessionResp)
loadTestData(t, "CloseOperationSuccess.json", &state.closeOperationResp)

ts := getServer(state)

defer ts.Close()
r, err := url.Parse(ts.URL)
require.NoError(t, err)
port, err := strconv.Atoi(r.Port())
require.NoError(t, err)

connector, err := NewConnector(
WithServerHostname("localhost"),
WithHTTPPath("/429-5-retries"),
WithPort(port),
WithRetries(2, 10*time.Millisecond, 1*time.Second),
)
require.NoError(t, err)
db := sql.OpenDB(connector)
defer db.Close()

state.executeStatementResp = cli_service.TExecuteStatementResp{}
loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp)

err = db.Ping()
require.ErrorContains(t, err, "after 3 attempt(s)")

// The error chain should contain a databricks request error
b := errors.Is(err, dbsqlerr.RequestError)
require.True(t, b)
var re dbsqlerr.DBRequestError
b = errors.As(err, &re)
require.True(t, b)
require.NotNil(t, re)
require.True(t, re.IsRetryable())
require.Equal(t, 12*time.Second, re.RetryAfter())

connector2, err := NewConnector(
WithServerHostname("localhost"),
WithHTTPPath("/503-5-retries"),
WithPort(port),
WithRetries(2, 10*time.Millisecond, 1*time.Second),
)
require.NoError(t, err)
db2 := sql.OpenDB(connector2)
defer db.Close()

state.executeStatementResp = cli_service.TExecuteStatementResp{}
loadTestData(t, "ExecuteStatement1.json", &state.executeStatementResp)

err = db2.Ping()
require.ErrorContains(t, err, "after 3 attempt(s)")

// The error chain should contain a databricks request error
b = errors.Is(err, dbsqlerr.RequestError)
require.True(t, b)
b = errors.As(err, &re)
require.True(t, b)
require.NotNil(t, re)
require.True(t, re.IsRetryable())
})

}

// TODO: add tests for x-databricks headers
Expand Down
12 changes: 9 additions & 3 deletions errors/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package errors

import "github.com/pkg/errors"
import (
"time"

"github.com/pkg/errors"
)

// Error messages
const (
Expand Down Expand Up @@ -59,6 +63,10 @@ type DBError interface {

// Underlying causative error. May be nil.
Cause() error

IsRetryable() bool

RetryAfter() time.Duration
}

// An error that is caused by an invalid request.
Expand All @@ -70,8 +78,6 @@ type DBRequestError interface {
// A fault that is caused by Databricks services
type DBDriverError interface {
DBError

IsRetryable() bool
}

// Any error that occurs after the SQL statement has been accepted (e.g. SQL syntax error).
Expand Down
41 changes: 38 additions & 3 deletions examples/error/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,34 @@ func main() {
}
ctx = driverctx.NewContextWithQueryIdCallback(ctx, queryIdCallback)

rows, err1 := db.QueryContext(ctx, `select * from default.intervals`)
fmt.Printf("conn Id: %s, query Id: %s\n", connId, queryId)
var rows *sql.Rows
maxRetries := 3
shouldTry := true

// We want to retry running the query if an error is returned where IsRetryable() is true up
// to the maximum number of retries.
for i := 0; i < maxRetries && shouldTry; i++ {
var err1 error
var wait time.Duration

rows, err1 = db.QueryContext(ctx, `select * from default.Intervals`)

// Check if the error is retryable and if there is a wait before
// trying again.
if shouldTry, wait = isRetryable(err1); shouldTry {
fmt.Printf("query failed, retrying after %f seconds", wait.Seconds())
time.Sleep(wait)
} else {
// handle the error, which may be nil
handleErr(err1)
}
}

handleErr(err1)
// At this point the query completed successfully
defer rows.Close()

fmt.Printf("conn Id: %s, query Id: %s\n", connId, queryId)

colNames, _ := rows.Columns()
for i := range colNames {
fmt.Printf("%d: %s\n", i, colNames[i])
Expand All @@ -91,6 +113,8 @@ func main() {

}

// If the error is not nil extract/ databricks specific error information and then
// terminate the program.
func handleErr(err error) {
if err == nil {
return
Expand Down Expand Up @@ -155,3 +179,14 @@ func getQueryIdAndSQLState(err error) (queryId, sqlState string) {

return
}

// Use errors.As to extract a DBError from the error chain and return the associated
// values for isRetryable and retryAfter
func isRetryable(err error) (isRetryable bool, retryAfter time.Duration) {
var dbErr dbsqlerr.DBError
if errors.As(err, &dbErr) {
isRetryable = dbErr.IsRetryable()
retryAfter = dbErr.RetryAfter()
}
return
}
49 changes: 31 additions & 18 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ func SprintGuid(bts []byte) string {

var retryableStatusCode = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}

func isRetryable(statusCode int) bool {
for _, c := range retryableStatusCode {
if c == statusCode {
return true
}
}
return false
}

type Transport struct {
Base *http.Transport
Authr auth.Authenticator
Expand Down Expand Up @@ -321,14 +330,13 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if resp.StatusCode != http.StatusOK {
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
for _, c := range retryableStatusCode {
if c == resp.StatusCode {
if terrmsg != "" {
logger.Warn().Msg(terrmsg)
}
return resp, nil
if isRetryable(resp.StatusCode) {
if terrmsg != "" {
logger.Warn().Msg(terrmsg)
}
return resp, nil
}

if reason != "" {
logger.Err(fmt.Errorf(reason)).Msg("non retryable error")
return nil, errors.New(reason)
Expand Down Expand Up @@ -426,17 +434,25 @@ func errorHandler(resp *http.Response, err error, numTries int) (*http.Response,
if err == nil {
err = errors.New(fmt.Sprintf("request error after %d attempt(s)", numTries))
}
if resp != nil && resp.Header != nil {

if resp != nil {
var orgid, reason, terrmsg, errmsg, retryAfter string
// TODO @mattdeekay: convert these to specific error types
if resp.Header != nil {
orgid = resp.Header.Get("X-Databricks-Org-Id")
reason = resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
terrmsg = resp.Header.Get("X-Thriftserver-Error-Message")
errmsg = resp.Header.Get("x-databricks-error-or-redirect-message")
retryAfter = resp.Header.Get("Retry-After")
// TODO note: need to see if there's other headers
}
msg := fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg)

orgid := resp.Header.Get("X-Databricks-Org-Id")
reason := resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
errmsg := resp.Header.Get("x-databricks-error-or-redirect-message")
// TODO note: need to see if there's other headers
if isRetryable(resp.StatusCode) {
err = dbsqlerrint.NewRetryableError(err, retryAfter)
}

werr = errors.Wrapf(err, fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg))
werr = dbsqlerrint.WrapErr(err, msg)
} else {
werr = err
}
Expand Down Expand Up @@ -464,11 +480,8 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
// 429 Too Many Requests or 503 service unavailable is recoverable. Sometimes the server puts
// a Retry-After response header to indicate when the server is
// available to start processing request from client.

for _, c := range retryableStatusCode {
if c == resp.StatusCode {
return true, nil
}
if isRetryable(resp.StatusCode) {
return true, nil
}

return false, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"strings"
"time"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/pkg/errors"

"github.com/databricks/databricks-sql-go/auth"
"github.com/databricks/databricks-sql-go/auth/noop"
"github.com/databricks/databricks-sql-go/auth/pat"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/logger"
Expand Down

0 comments on commit 350ea35

Please sign in to comment.