Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1016-2] Add handling for special types #158

Merged
merged 7 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
104 changes: 14 additions & 90 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package dbsql
import (
"context"
"database/sql/driver"
"fmt"
"strconv"
"time"

"github.com/databricks/databricks-sql-go/driverctx"
Expand Down Expand Up @@ -102,9 +100,6 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
defer log.Duration(msg, start)

ctx = driverctx.NewContextWithConnId(ctx, c.id)
if len(args) > 0 {
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
}

exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)

Expand Down Expand Up @@ -145,9 +140,6 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
msg, start := log.Track("QueryContext")

ctx = driverctx.NewContextWithConnId(ctx, c.id)
if len(args) > 0 {
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
}

// first we try to get the results synchronously.
// at any point in time that the context is done we must cancel and return
Expand Down Expand Up @@ -288,7 +280,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
Parameters: namedValuesToTSparkParams(args),
Parameters: convertNamedValuesToSparkParams(args),
}

if c.cfg.UseArrowBatches {
Expand Down Expand Up @@ -342,87 +334,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
return resp, err
}

func namedValuesToTSparkParams(args []driver.NamedValue) []*cli_service.TSparkParameter {
var ts []string = []string{"STRING", "DOUBLE", "BOOLEAN", "TIMESTAMP", "FLOAT", "INTEGER", "TINYINT", "SMALLINT", "BIGINT"}
var params []*cli_service.TSparkParameter
for i := range args {
arg := args[i]
param := cli_service.TSparkParameter{Value: &cli_service.TSparkParameterValue{}}
if arg.Name != "" {
param.Name = &arg.Name
} else {
i := int32(arg.Ordinal)
param.Ordinal = &i
}

switch t := arg.Value.(type) {
case bool:
b := arg.Value.(bool)
param.Value.BooleanValue = &b
param.Type = &ts[2]
case string:
s := arg.Value.(string)
param.Value.StringValue = &s
param.Type = &ts[0]
case int:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case uint:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case int8:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[6]
case uint8:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[6]
case int16:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[7]
case uint16:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[7]
case int32:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case uint32:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case int64:
s := strconv.FormatInt(t, 10)
param.Value.StringValue = &s
param.Type = &ts[8]
case uint64:
s := strconv.FormatUint(t, 10)
param.Value.StringValue = &s
param.Type = &ts[8]
case float32:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[4]
case time.Time:
s := t.String()
param.Value.StringValue = &s
param.Type = &ts[3]
default:
s := fmt.Sprintf("%s", arg.Value)
param.Value.StringValue = &s
param.Type = &ts[0]
}

params = append(params, &param)
}
return params
}

func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
corrId := driverctx.CorrelationIdFromContext(ctx)
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
Expand Down Expand Up @@ -481,6 +392,18 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
return statusResp, nil
}

func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
var err error
if dbsqlParam, ok := nv.Value.(DBSqlParam); ok {
nv.Name = dbsqlParam.Name
dbsqlParam.Value, err = driver.DefaultParameterConverter.ConvertValue(dbsqlParam.Value)
return err
}

nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
return err
}

var _ driver.Conn = (*conn)(nil)
var _ driver.Pinger = (*conn)(nil)
var _ driver.SessionResetter = (*conn)(nil)
Expand All @@ -489,3 +412,4 @@ var _ driver.ExecerContext = (*conn)(nil)
var _ driver.QueryerContext = (*conn)(nil)
var _ driver.ConnPrepareContext = (*conn)(nil)
var _ driver.ConnBeginTx = (*conn)(nil)
var _ driver.NamedValueChecker = (*conn)(nil)
1 change: 0 additions & 1 deletion errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ const (
// Driver errors
ErrNotImplemented = "not implemented"
ErrTransactionsNotSupported = "transactions are not supported"
ErrParametersNotSupported = "query parameters are not supported"
ErrReadQueryStatus = "could not read query status"
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"

Expand Down
36 changes: 36 additions & 0 deletions parameter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package dbsql

import (
"database/sql/driver"
"strconv"
"testing"
"time"

"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/stretchr/testify/assert"
)

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: Decimal}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
assert.Equal(t, string("DECIMAL"), *parameters[4].Type)
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
nithinkdb marked this conversation as resolved.
Show resolved Hide resolved
})
}
func TestParameters_Names(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: Decimal, Value: "6.2"}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, string("1"), *parameters[0].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
assert.Equal(t, string("2"), *parameters[1].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
assert.Equal(t, string("DECIMAL"), *parameters[1].Type)
})
}
150 changes: 150 additions & 0 deletions parameters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package dbsql

import (
"database/sql/driver"
"fmt"
"strconv"
"time"

"github.com/databricks/databricks-sql-go/internal/cli_service"
)

type DBSqlParam struct {
Name string
Type SqlType
Value any
}

type SqlType int64

const (
String SqlType = iota
Date
Timestamp
Float
Decimal
Double
Integer
BigInt
SmallInt
TinyInt
Boolean
IntervalMonth
IntervalDay
)

func (s SqlType) String() string {
switch s {
case String:
return "STRING"
case Date:
return "DATE"
case Timestamp:
return "TIMESTAMP"
case Float:
return "FLOAT"
case Decimal:
return "DECIMAL"
case Double:
return "DOUBLE"
case Integer:
return "INTEGER"
case BigInt:
return "BIGINT"
case SmallInt:
return "SMALLINT"
case TinyInt:
return "TINYINT"
case Boolean:
return "BOOLEAN"
case IntervalMonth:
return "INTERVAL MONTH"
case IntervalDay:
return "INTERVAL DAY"
}
return "unknown"
}

func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
var params []DBSqlParam
for i := range namedValues {
namedValue := namedValues[i]
param := *new(DBSqlParam)
param.Name = namedValue.Name
param.Value = namedValue.Value
params = append(params, param)
}
return params
}

func inferTypes(params []DBSqlParam) {
for i := range params {
param := &params[i]
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = Boolean
case string:
param.Value = value
param.Type = String
case int:
param.Value = strconv.Itoa(value)
param.Type = Integer
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = Float
case time.Time:
param.Value = value.String()
param.Type = Timestamp
case DBSqlParam:
param.Name = value.Name
param.Value = value.Value
param.Type = value.Type
default:
s := fmt.Sprintf("%s", value)
param.Value = s
param.Type = String
}
}
}
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
var sparkParams []*cli_service.TSparkParameter

sqlParams := valuesToDBSQLParams(values)
inferTypes(sqlParams)
for i := range sqlParams {
sqlParam := sqlParams[i]
sparkParamValue := sqlParam.Value.(string)
sparkParamType := sqlParam.Type.String()
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
sparkParams = append(sparkParams, &sparkParam)
}
return sparkParams
}
11 changes: 0 additions & 11 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql/driver"
"testing"
"time"

"github.com/apache/thrift/lib/go/thrift"
"github.com/databricks/databricks-sql-go/internal/cli_service"
Expand Down Expand Up @@ -166,13 +165,3 @@ func TestStmt_QueryContext(t *testing.T) {
assert.Equal(t, testQuery, savedQueryString)
})
}
func TestParameters(t *testing.T) {
t.Run("Parameter casting should be correct", func(t *testing.T) {
values := [3]driver.NamedValue{{Ordinal: 1, Name: "", Value: float32(5)}, {Ordinal: 2, Name: "", Value: time.Now()}, {Ordinal: 3, Name: "", Value: int64(5)}}
parameters := namedValuesToTSparkParams(values[:])
assert.Equal(t, &cli_service.TSparkParameterValue{DoubleValue: thrift.Float64Ptr(5)}, parameters[0].Value)
assert.NotNil(t, parameters[1].Value.StringValue)
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
})
}