Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of bad connection errors and specifying server protocol. #152

Merged
merged 1 commit into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 27 additions & 3 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,37 @@ func withUserConfig(ucfg config.UserConfig) connOption {
// WithServerHostname sets up the server hostname. Mandatory.
func WithServerHostname(host string) connOption {
return func(c *config.Config) {
if host == "localhost" {
c.Protocol = "http"
protocol, hostname := parseHostName(host)
if protocol != "" {
c.Protocol = protocol
}
c.Host = host

c.Host = hostname
}
}

func parseHostName(host string) (protocol, hostname string) {
hostname = host
if strings.HasPrefix(host, "https") {
hostname = strings.TrimPrefix(host, "https")
protocol = "https"
} else if strings.HasPrefix(host, "http") {
hostname = strings.TrimPrefix(host, "http")
protocol = "http"
}

if protocol != "" {
hostname = strings.TrimPrefix(hostname, ":")
hostname = strings.TrimPrefix(hostname, "//")
}

if hostname == "localhost" && protocol == "" {
protocol = "http"
}

return
}

// WithPort sets up the server port. Mandatory.
func WithPort(port int) connOption {
return func(c *config.Config) {
Expand Down
31 changes: 31 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,37 @@ func TestNewConnector(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, expectedCfg, coni.cfg)
})

t.Run("Connector test WithServerHostname", func(t *testing.T) {
cases := []struct {
hostname, host, protocol string
}{
{"databricks-host", "databricks-host", "https"},
{"http://databricks-host", "databricks-host", "http"},
{"https://databricks-host", "databricks-host", "https"},
{"http:databricks-host", "databricks-host", "http"},
{"https:databricks-host", "databricks-host", "https"},
{"htt://databricks-host", "htt://databricks-host", "https"},
{"localhost", "localhost", "http"},
{"http:localhost", "localhost", "http"},
{"https:localhost", "localhost", "https"},
}

for i := range cases {
c := cases[i]
con, err := NewConnector(
WithServerHostname(c.hostname),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
userConfig := coni.cfg.UserConfig
require.Equal(t, c.protocol, userConfig.Protocol)
require.Equal(t, c.host, userConfig.Host)
}

})
}

type mockRoundTripper struct{}
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Use sql.OpenDB() to create a database handle via a new connector object created

Supported functional options include:

- WithServerHostname(<hostname> string): Sets up the server hostname. Mandatory
- WithServerHostname(<hostname> string): Sets up the server hostname. The hostname can be prefixed with "http:" or "https:" to specify a protocol to use. Mandatory
- WithPort(<port> int): Sets up the server port. Mandatory
- WithAccessToken(<my_token> string): Sets up the Personal Access Token. Mandatory
- WithHTTPPath(<http_path> string): Sets up the endpoint to the warehouse. Mandatory
Expand Down
183 changes: 111 additions & 72 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"os"
"regexp"
"strconv"
"strings"
"time"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
Expand Down Expand Up @@ -45,116 +46,121 @@ const (

type clientMethod int

//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod -trimprefix=clientMethod

const (
unknown clientMethod = iota
openSession
closeSession
fetchResults
getResultSetMetadata
executeStatement
getOperationStatus
closeOperation
cancelOperation
clientMethodUnknown clientMethod = iota
clientMethodOpenSession
clientMethodCloseSession
clientMethodFetchResults
clientMethodGetResultSetMetadata
clientMethodExecuteStatement
clientMethodGetOperationStatus
clientMethodCloseOperation
clientMethodCancelOperation
)

var nonRetryableClientMethods map[clientMethod]any = map[clientMethod]any{
executeStatement: struct{}{},
unknown: struct{}{}}
clientMethodExecuteStatement: struct{}{},
clientMethodUnknown: struct{}{},
}

var clientMethodRequestErrorMsgs map[clientMethod]string = map[clientMethod]string{
clientMethodOpenSession: "open session request error",
clientMethodCloseSession: "close session request error",
clientMethodFetchResults: "fetch results request error",
clientMethodGetResultSetMetadata: "get result set metadata request error",
clientMethodExecuteStatement: "execute statement request error",
clientMethodGetOperationStatus: "get operation status request error",
clientMethodCloseOperation: "close operation request error",
clientMethodCancelOperation: "cancel operation request error",
}

// OpenSession is a wrapper around the thrift operation OpenSession
// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
ctx = context.WithValue(ctx, ClientMethod, openSession)
ctx = context.WithValue(ctx, ClientMethod, clientMethodOpenSession)
msg, start := logger.Track("OpenSession")
resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req)
if err != nil {
return nil, dbsqlerrint.NewRequestError(ctx, "open session request error", err)
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

log := logger.WithContext(SprintGuid(resp.SessionHandle.SessionId.GUID), driverctx.CorrelationIdFromContext(ctx), "")
defer log.Duration(msg, start)
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("OpenSession%d.json", resultIndex), j, 0600)
resultIndex++
}

return resp, CheckStatus(resp)
}

// CloseSession is a wrapper around the thrift operation CloseSession
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) {
ctx = context.WithValue(ctx, ClientMethod, closeSession)
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseSession)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "")
defer log.Duration(logger.Track("CloseSession"))
resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "close session request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("CloseSession%d.json", resultIndex), j, 0600)
resultIndex++
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
}

// FetchResults is a wrapper around the thrift operation FetchResults
// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
ctx = context.WithValue(ctx, ClientMethod, fetchResults)
ctx = context.WithValue(ctx, ClientMethod, clientMethodFetchResults)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("FetchResults"))
resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "fetch results request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("FetchResults%d.json", resultIndex), j, 0600)
resultIndex++
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
}

// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
ctx = context.WithValue(ctx, ClientMethod, getResultSetMetadata)
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetResultSetMetadata)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("GetResultSetMetadata"))
resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "get result set metadata request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("GetResultSetMetadata%d.json", resultIndex), j, 0600)
resultIndex++
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
}

// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
ctx = context.WithValue(ctx, ClientMethod, executeStatement)
ctx = context.WithValue(ctx, ClientMethod, clientMethodExecuteStatement)
msg, start := logger.Track("ExecuteStatement")

// We use context.Background to fix a problem where on context done the query would not be cancelled.
resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "execute statement request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex), j, 0600)
// f, _ := os.ReadFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex))
// var resp2 cli_service.TExecuteStatementResp
// json.Unmarshal(f, &resp2)
resultIndex++
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

if resp != nil && resp.OperationHandle != nil {
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(resp.OperationHandle.OperationId.GUID))
defer log.Duration(msg, start)
Expand All @@ -165,54 +171,51 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) {
ctx = context.WithValue(ctx, ClientMethod, getOperationStatus)
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetOperationStatus)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("GetOperationStatus"))
resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), "databricks: get operation status request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("GetOperationStatus%d.json", resultIndex), j, 0600)
resultIndex++
err = handleClientMethodError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), err)
return resp, err
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
}

// CloseOperation is a wrapper around the thrift operation CloseOperation
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
ctx = context.WithValue(ctx, ClientMethod, closeOperation)
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseOperation)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("CloseOperation"))
resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "close operation request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("CloseOperation%d.json", resultIndex), j, 0600)
resultIndex++
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
}

// CancelOperation is a wrapper around the thrift operation CancelOperation
// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) {
ctx = context.WithValue(ctx, ClientMethod, cancelOperation)
ctx = context.WithValue(ctx, ClientMethod, clientMethodCancelOperation)
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
defer log.Duration(logger.Track("CancelOperation"))
resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req)
if err != nil {
return resp, dbsqlerrint.NewRequestError(ctx, "cancel operation request error", err)
}
if RecordResults {
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("CancelOperation%d.json", resultIndex), j, 0600)
resultIndex++
err = handleClientMethodError(ctx, err)
return resp, err
}

recordResult(ctx, resp)

return resp, CheckStatus(resp)
}

Expand Down Expand Up @@ -283,6 +286,42 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi
return tsClient, nil
}

// handler function for errors returned by the thrift client methods
func handleClientMethodError(ctx context.Context, err error) dbsqlerr.DBRequestError {
if err == nil {
return nil
}

// If the passed error indicates an invalid session we inject a bad connection error
// into the error stack. This allows the for retrying with a new connection.
s := err.Error()
if strings.Contains(s, "Invalid SessionHandle") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to be very careful with session state. If the user expects some state - temp views, etc - and we just open a new session in the background, it may be very hard to understand and debug what is going on.

Copy link
Contributor

@aldld aldld Aug 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the context of database/sql, if the user is using methods on the sql.DB object directly (e.g. db.QueryContext(...)) then they cannot rely on maintaining per-session state anyway because these methods might create new connections or reuse an arbitrary connection from the pool.

If a user does need to rely on per-session state, then doing so correctly requires explicitly grabbing a connection and then calling methods on the connection. In that case, it's the user's responsibility to manage the connection lifecycle, which includes handling errors and accounting for the possibility that a connection may end up in a bad state.

err = dbsqlerrint.NewBadConnectionError(err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, looks like the NewBadConnectionError will be wrapped again in a NewRequestError. Would it be better to just return the error as is? Or is the client expecting a NewRequestError?

}

// the passed error will be wrapped in a DBRequestError
method := getClientMethod(ctx)
msg := clientMethodRequestErrorMsgs[method]

return dbsqlerrint.NewRequestError(ctx, msg, err)
}

// Extract a clientMethod value from the given Context.
func getClientMethod(ctx context.Context) clientMethod {
v, _ := ctx.Value(ClientMethod).(clientMethod)
return v
}

// Write the result
func recordResult(ctx context.Context, resp any) {
if RecordResults && resp != nil {
method := getClientMethod(ctx)
j, _ := json.MarshalIndent(resp, "", " ")
_ = os.WriteFile(fmt.Sprintf("%s%d.json", method, resultIndex), j, 0600)
resultIndex++
}
}

// ThriftResponse represents the thrift rpc response
type ThriftResponse interface {
GetStatus() *cli_service.TStatus
Expand Down Expand Up @@ -507,7 +546,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
return false, ctx.Err()
}

caller, _ := ctx.Value(ClientMethod).(clientMethod)
caller := getClientMethod(ctx)
_, nonRetryableClientMethod := nonRetryableClientMethods[caller]

if err != nil {
Expand Down