-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better handling of bad connection errors and specifying server protocol. #152
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ import ( | |
"os" | ||
"regexp" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
dbsqlerr "github.com/databricks/databricks-sql-go/errors" | ||
|
@@ -45,116 +46,121 @@ const ( | |
|
||
type clientMethod int | ||
|
||
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod | ||
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod -trimprefix=clientMethod | ||
|
||
const ( | ||
unknown clientMethod = iota | ||
openSession | ||
closeSession | ||
fetchResults | ||
getResultSetMetadata | ||
executeStatement | ||
getOperationStatus | ||
closeOperation | ||
cancelOperation | ||
clientMethodUnknown clientMethod = iota | ||
clientMethodOpenSession | ||
clientMethodCloseSession | ||
clientMethodFetchResults | ||
clientMethodGetResultSetMetadata | ||
clientMethodExecuteStatement | ||
clientMethodGetOperationStatus | ||
clientMethodCloseOperation | ||
clientMethodCancelOperation | ||
) | ||
|
||
var nonRetryableClientMethods map[clientMethod]any = map[clientMethod]any{ | ||
executeStatement: struct{}{}, | ||
unknown: struct{}{}} | ||
clientMethodExecuteStatement: struct{}{}, | ||
clientMethodUnknown: struct{}{}, | ||
} | ||
|
||
var clientMethodRequestErrorMsgs map[clientMethod]string = map[clientMethod]string{ | ||
clientMethodOpenSession: "open session request error", | ||
clientMethodCloseSession: "close session request error", | ||
clientMethodFetchResults: "fetch results request error", | ||
clientMethodGetResultSetMetadata: "get result set metadata request error", | ||
clientMethodExecuteStatement: "execute statement request error", | ||
clientMethodGetOperationStatus: "get operation status request error", | ||
clientMethodCloseOperation: "close operation request error", | ||
clientMethodCancelOperation: "cancel operation request error", | ||
} | ||
|
||
// OpenSession is a wrapper around the thrift operation OpenSession | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json | ||
func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, openSession) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodOpenSession) | ||
msg, start := logger.Track("OpenSession") | ||
resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req) | ||
if err != nil { | ||
return nil, dbsqlerrint.NewRequestError(ctx, "open session request error", err) | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
log := logger.WithContext(SprintGuid(resp.SessionHandle.SessionId.GUID), driverctx.CorrelationIdFromContext(ctx), "") | ||
defer log.Duration(msg, start) | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("OpenSession%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
} | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
// CloseSession is a wrapper around the thrift operation CloseSession | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json | ||
func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, closeSession) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseSession) | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "") | ||
defer log.Duration(logger.Track("CloseSession")) | ||
resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(ctx, "close session request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("CloseSession%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
// FetchResults is a wrapper around the thrift operation FetchResults | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json | ||
func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, fetchResults) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodFetchResults) | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) | ||
defer log.Duration(logger.Track("FetchResults")) | ||
resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(ctx, "fetch results request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("FetchResults%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json | ||
func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, getResultSetMetadata) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetResultSetMetadata) | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) | ||
defer log.Duration(logger.Track("GetResultSetMetadata")) | ||
resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(ctx, "get result set metadata request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("GetResultSetMetadata%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json | ||
func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, executeStatement) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodExecuteStatement) | ||
msg, start := logger.Track("ExecuteStatement") | ||
|
||
// We use context.Background to fix a problem where on context done the query would not be cancelled. | ||
resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(ctx, "execute statement request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex), j, 0600) | ||
// f, _ := os.ReadFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex)) | ||
// var resp2 cli_service.TExecuteStatementResp | ||
// json.Unmarshal(f, &resp2) | ||
resultIndex++ | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
if resp != nil && resp.OperationHandle != nil { | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(resp.OperationHandle.OperationId.GUID)) | ||
defer log.Duration(msg, start) | ||
|
@@ -165,54 +171,51 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s | |
// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json | ||
func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, getOperationStatus) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetOperationStatus) | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) | ||
defer log.Duration(logger.Track("GetOperationStatus")) | ||
resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), "databricks: get operation status request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("GetOperationStatus%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
err = handleClientMethodError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
// CloseOperation is a wrapper around the thrift operation CloseOperation | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json | ||
func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, closeOperation) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseOperation) | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) | ||
defer log.Duration(logger.Track("CloseOperation")) | ||
resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(ctx, "close operation request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("CloseOperation%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
// CancelOperation is a wrapper around the thrift operation CancelOperation | ||
// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json | ||
func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) { | ||
ctx = context.WithValue(ctx, ClientMethod, cancelOperation) | ||
ctx = context.WithValue(ctx, ClientMethod, clientMethodCancelOperation) | ||
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) | ||
defer log.Duration(logger.Track("CancelOperation")) | ||
resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req) | ||
if err != nil { | ||
return resp, dbsqlerrint.NewRequestError(ctx, "cancel operation request error", err) | ||
} | ||
if RecordResults { | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("CancelOperation%d.json", resultIndex), j, 0600) | ||
resultIndex++ | ||
err = handleClientMethodError(ctx, err) | ||
return resp, err | ||
} | ||
|
||
recordResult(ctx, resp) | ||
|
||
return resp, CheckStatus(resp) | ||
} | ||
|
||
|
@@ -283,6 +286,42 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi | |
return tsClient, nil | ||
} | ||
|
||
// handler function for errors returned by the thrift client methods | ||
func handleClientMethodError(ctx context.Context, err error) dbsqlerr.DBRequestError { | ||
if err == nil { | ||
return nil | ||
} | ||
|
||
// If the passed error indicates an invalid session we inject a bad connection error | ||
// into the error stack. This allows the for retrying with a new connection. | ||
s := err.Error() | ||
if strings.Contains(s, "Invalid SessionHandle") { | ||
err = dbsqlerrint.NewBadConnectionError(err) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, looks like the NewBadConnectionError will be wrapped again in a NewRequestError. Would it be better to just return the error as is? Or is the client expecting a NewRequestError? |
||
} | ||
|
||
// the passed error will be wrapped in a DBRequestError | ||
method := getClientMethod(ctx) | ||
msg := clientMethodRequestErrorMsgs[method] | ||
|
||
return dbsqlerrint.NewRequestError(ctx, msg, err) | ||
} | ||
|
||
// Extract a clientMethod value from the given Context. | ||
func getClientMethod(ctx context.Context) clientMethod { | ||
v, _ := ctx.Value(ClientMethod).(clientMethod) | ||
return v | ||
} | ||
|
||
// Write the result | ||
func recordResult(ctx context.Context, resp any) { | ||
if RecordResults && resp != nil { | ||
method := getClientMethod(ctx) | ||
j, _ := json.MarshalIndent(resp, "", " ") | ||
_ = os.WriteFile(fmt.Sprintf("%s%d.json", method, resultIndex), j, 0600) | ||
resultIndex++ | ||
} | ||
} | ||
|
||
// ThriftResponse represents the thrift rpc response | ||
type ThriftResponse interface { | ||
GetStatus() *cli_service.TStatus | ||
|
@@ -507,7 +546,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err | |
return false, ctx.Err() | ||
} | ||
|
||
caller, _ := ctx.Value(ClientMethod).(clientMethod) | ||
caller := getClientMethod(ctx) | ||
_, nonRetryableClientMethod := nonRetryableClientMethods[caller] | ||
|
||
if err != nil { | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to be very careful with session state. If the user expects some state - temp views, etc - and we just open a new session in the background, it may be very hard to understand and debug what is going on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the context of
database/sql
, if the user is using methods on thesql.DB
object directly (e.g.db.QueryContext(...)
) then they cannot rely on maintaining per-session state anyway because these methods might create new connections or reuse an arbitrary connection from the pool.If a user does need to rely on per-session state, then doing so correctly requires explicitly grabbing a connection and then calling methods on the connection. In that case, it's the user's responsibility to manage the connection lifecycle, which includes handling errors and accounting for the possibility that a connection may end up in a bad state.