Skip to content

Commit

Permalink
Fix formatting of *float64 parameters (#215)
Browse files Browse the repository at this point in the history
Attempt to fix #214

Signed-off-by: Esdras Beleza <esdras@esdrasbeleza.com>
  • Loading branch information
esdrasbeleza committed Apr 12, 2024
1 parent e82880f commit 7c4ada8
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 49 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Release History

- Fix formatting of *float64 parameters

## v1.5.4 (2024-04-10)

- Added OAuth support for GCP (databricks/databricks-sql-go#189 by @rcypher-databricks)
Expand Down
4 changes: 4 additions & 0 deletions driver_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,10 @@ func strPtr(s string) *string {
return &s
}

func float64Ptr(f float64) *float64 {
return &f
}

func loadTestData(t *testing.T, name string, v any) {
if f, err := os.ReadFile(fmt.Sprintf("testdata/%s", name)); err != nil {
t.Errorf("could not read data from: %s", name)
Expand Down
10 changes: 9 additions & 1 deletion parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@ import (

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}}}
values := [6]driver.NamedValue{
{Name: "", Value: float32(5.1)},
{Name: "", Value: time.Now()},
{Name: "", Value: int64(5)},
{Name: "", Value: true},
{Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}},
{Name: "", Value: Parameter{Value: float64Ptr(6.2), Type: SqlUnkown}},
}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
Expand All @@ -21,6 +28,7 @@ func TestParameter_Inference(t *testing.T) {
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[4].Type)
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[5].Value)
})
}
func TestParameters_Names(t *testing.T) {
Expand Down
111 changes: 63 additions & 48 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dbsql
import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -90,57 +91,71 @@ func inferTypes(params []Parameter) {
for i := range params {
param := &params[i]
if param.Type == SqlUnkown {
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = SqlBoolean
case string:
param.Value = value
param.Type = SqlString
case int:
param.Value = strconv.Itoa(value)
param.Type = SqlInteger
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = SqlFloat
case time.Time:
param.Value = value.Format(time.RFC3339Nano)
param.Type = SqlTimestamp
default:
s := fmt.Sprintf("%s", param.Value)
param.Value = s
param.Type = SqlString
}
inferType(param)
}
}
}

func inferType(param *Parameter) {
if param.Value != nil && reflect.ValueOf(param.Value).Kind() == reflect.Ptr {
param.Value = reflect.ValueOf(param.Value).Elem().Interface()
inferType(param)
return
}

switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = SqlBoolean
case string:
param.Value = value
param.Type = SqlString
case int:
param.Value = strconv.Itoa(value)
param.Type = SqlInteger
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = SqlInteger
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = SqlInteger
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = SqlFloat
case float64:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 64)
param.Type = SqlFloat
case time.Time:
param.Value = value.Format(time.RFC3339Nano)
param.Type = SqlTimestamp
default:
s := fmt.Sprintf("%s", param.Value)
param.Value = s
param.Type = SqlString
}
}

func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
var sparkParams []*cli_service.TSparkParameter

Expand Down

0 comments on commit 7c4ada8

Please sign in to comment.