Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated retry behaviour #125

Merged
merged 3 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
182 changes: 125 additions & 57 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package client

import (
"context"
"crypto/x509"
"encoding/json"
"fmt"
"log"
"math"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"os"
"regexp"
"strconv"
"time"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
Expand All @@ -34,9 +37,31 @@ type ThriftServiceClient struct {
*cli_service.TCLIServiceClient
}

type contextKey int

const (
ClientMethod contextKey = iota
)

type clientMethod int

const (
openSession clientMethod = iota
closeSession
fetchResults
getResultSetMetadata
executeStatement
getOperationStatus
closeOperation
cancelOperation
)

var nonRetryableClientMethods map[clientMethod]any = map[clientMethod]any{executeStatement: struct{}{}}

// OpenSession is a wrapper around the thrift operation OpenSession
// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
ctx = context.WithValue(ctx, ClientMethod, openSession)
msg, start := logger.Track("OpenSession")
resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req)
if err != nil {
Expand All @@ -55,6 +80,7 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
// CloseSession is a wrapper around the thrift operation CloseSession
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) {
ctx = context.WithValue(ctx, ClientMethod, closeSession)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "")
defer log.Duration(logger.Track("CloseSession"))
resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req)
Expand All @@ -72,6 +98,7 @@ func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_servi
// FetchResults is a wrapper around the thrift operation FetchResults
// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
ctx = context.WithValue(ctx, ClientMethod, fetchResults)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("FetchResults"))
resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req)
Expand All @@ -89,6 +116,7 @@ func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_servi
// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
ctx = context.WithValue(ctx, ClientMethod, getResultSetMetadata)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("GetResultSetMetadata"))
resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req)
Expand All @@ -106,7 +134,10 @@ func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *c
// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
ctx = context.WithValue(ctx, ClientMethod, executeStatement)
msg, start := logger.Track("ExecuteStatement")

// We use context.Background to fix a problem where on context done the query would not be cancelled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "execute statement request error", err)
Expand All @@ -129,6 +160,7 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) {
ctx = context.WithValue(ctx, ClientMethod, getOperationStatus)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("GetOperationStatus"))
resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req)
Expand All @@ -146,6 +178,7 @@ func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli
// CloseOperation is a wrapper around the thrift operation CloseOperation
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
ctx = context.WithValue(ctx, ClientMethod, closeOperation)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("CloseOperation"))
resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req)
Expand All @@ -163,6 +196,7 @@ func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_ser
// CancelOperation is a wrapper around the thrift operation CancelOperation
// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) {
ctx = context.WithValue(ctx, ClientMethod, cancelOperation)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("CancelOperation"))
resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req)
Expand Down Expand Up @@ -277,15 +311,11 @@ func SprintGuid(bts []byte) string {
return fmt.Sprintf("%x", bts)
}

var retryableStatusCode = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}
var retryableStatusCodes = map[int]any{http.StatusTooManyRequests: struct{}{}, http.StatusServiceUnavailable: struct{}{}}

func isRetryable(statusCode int) bool {
for _, c := range retryableStatusCode {
if c == statusCode {
return true
}
}
return false
func isRetryableServerResponse(resp *http.Response) bool {
_, ok := retryableStatusCodes[resp.StatusCode]
return ok
}

type Transport struct {
Expand Down Expand Up @@ -324,31 +354,8 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
// req.Body is assumed to be closed by the base RoundTripper.
reqBodyClosed = true
resp, err := t.Base.RoundTrip(req2)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
if isRetryable(resp.StatusCode) {
if terrmsg != "" {
logger.Warn().Msg(terrmsg)
}
return resp, nil
}

if reason != "" {
logger.Err(fmt.Errorf(reason)).Msg("non retryable error")
return nil, errors.New(reason)
}
if terrmsg != "" {
logger.Err(fmt.Errorf(terrmsg)).Msg("non retryable error")
return nil, errors.New(terrmsg)
}
return nil, errors.New(resp.Status)
}

return resp, nil
return resp, err
}

func RetryableClient(cfg *config.Config) *http.Client {
Expand All @@ -361,7 +368,7 @@ func RetryableClient(cfg *config.Config) *http.Client {
RetryMax: cfg.RetryMax,
ErrorHandler: errorHandler,
CheckRetry: RetryPolicy,
Backoff: retryablehttp.DefaultBackoff,
Backoff: backoff,
}
return retryableClient.StandardClient()
}
Expand Down Expand Up @@ -431,59 +438,120 @@ func (l *leveledLogger) Warn(msg string, keysAndValues ...interface{}) {

func errorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
var werr error
msg := fmt.Sprintf("request error after %d attempt(s)", numTries)
if err == nil {
err = errors.New(fmt.Sprintf("request error after %d attempt(s)", numTries))
werr = errors.New(msg)
} else {
werr = errors.Wrap(err, msg)
}

if resp != nil {
var orgid, reason, terrmsg, errmsg, retryAfter string
// TODO @mattdeekay: convert these to specific error types
if resp.Header != nil {
orgid = resp.Header.Get("X-Databricks-Org-Id")
reason = resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
terrmsg = resp.Header.Get("X-Thriftserver-Error-Message")
errmsg = resp.Header.Get("x-databricks-error-or-redirect-message")
retryAfter = resp.Header.Get("Retry-After")
// TODO note: need to see if there's other headers
}
msg := fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg)
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")

if isRetryable(resp.StatusCode) {
err = dbsqlerrint.NewRetryableError(err, retryAfter)
if reason != "" {
werr = dbsqlerrint.WrapErr(werr, reason)
} else if terrmsg != "" {
werr = dbsqlerrint.WrapErr(werr, terrmsg)
}
}

werr = dbsqlerrint.WrapErr(err, msg)
} else {
werr = err
logger.Err(werr).Msg(resp.Status)
}

return resp, werr
}

func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
var lostConn = regexp.MustCompile(`EOF`)
var (
// A regular expression to match the error returned by net/http when the
// configured number of redirects is exhausted. This error isn't typed
// specifically so we resort to matching on the error string.
redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`)

// A regular expression to match the error returned by net/http when the
// scheme specified in the URL is invalid. This error isn't typed
// specifically so we resort to matching on the error string.
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)

// A regular expression to match the error returned by net/http when the
// TLS certificate is not trusted. This error isn't typed
// specifically so we resort to matching on the error string.
notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`)

errorRes = []*regexp.Regexp{redirectsErrorRe, schemeErrorRe, notTrustedErrorRe}
)

func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
// do not retry on context.Canceled or context.DeadlineExceeded
if ctx.Err() != nil {
return false, ctx.Err()
}

if err != nil {
if v, ok := err.(*url.Error); ok {
if lostConn.MatchString(v.Error()) {
return true, v
s := v.Error()
for _, re := range errorRes {
if re.MatchString(s) {
return false, v
}
}

if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
return false, v
}
}
return false, nil

// The error is likely recoverable so retry.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably not retry execute statement in case of network connection severed

return true, nil
}

var checkErr error
if resp.StatusCode != http.StatusOK {
checkErr = fmt.Errorf("unexpected HTTP status %s", resp.Status)
}

// 429 Too Many Requests or 503 service unavailable is recoverable. Sometimes the server puts
// a Retry-After response header to indicate when the server is
// available to start processing request from client.
if isRetryable(resp.StatusCode) {
return true, nil
if isRetryableServerResponse(resp) {
var retryAfter string
if resp.Header != nil {
retryAfter = resp.Header.Get("Retry-After")
}

return true, dbsqlerrint.NewRetryableError(checkErr, retryAfter)
}

return false, nil
if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) {
callerAny := ctx.Value(ClientMethod)
if caller, ok := callerAny.(clientMethod); ok {
if _, noRetry := nonRetryableClientMethods[caller]; !noRetry {
return true, checkErr
}
}
}

// checkErr will be non-nil if the response code was not StatusOK.
// Returning it here ensures that the error handler will be called.
return false, checkErr
}

func backoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
// honour the Retry-After header
if resp != nil && resp.Header != nil {
if s, ok := resp.Header["Retry-After"]; ok {
if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil {
return time.Second * time.Duration(sleep)
}
}
}

// exponential backoff
mult := math.Pow(2, float64(attemptNum)) * float64(min)
sleep := time.Duration(mult)
if float64(sleep) != mult || sleep > max {
sleep = max
}
return sleep
}