Skip to content

Commit

Permalink
Parse numbers on text protocol too (#1452)
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Jul 3, 2023
1 parent 564dee9 commit 5d4a831
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 36 deletions.
87 changes: 59 additions & 28 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,18 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
defer db2.Close()
}

dsn3 := dsn + "&multiStatements=true"
var db3 *sql.DB
if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
db3, err = sql.Open("mysql", dsn3)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db3.Close()
}

dbt := &DBTest{t, db}
dbt2 := &DBTest{t, db2}
dbt3 := &DBTest{t, db3}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
t.Run("default", func(t *testing.T) {
dbt := &DBTest{t, db}
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
})
if db2 != nil {
test(dbt2)
dbt2.db.Exec("DROP TABLE IF EXISTS test")
}
if db3 != nil {
test(dbt3)
dbt3.db.Exec("DROP TABLE IF EXISTS test")
t.Run("interpolateParams", func(t *testing.T) {
dbt2 := &DBTest{t, db2}
test(dbt2)
dbt2.db.Exec("DROP TABLE IF EXISTS test")
})
}
}
}
Expand Down Expand Up @@ -316,6 +305,48 @@ func TestCRUD(t *testing.T) {
})
}

// TestNumbers test that selecting numeric columns.
// Both of textRows and binaryRows should return same type and value.
func TestNumbersToAny(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE `test` (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " +
"i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)")
dbt.mustExec("INSERT INTO `test` VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)")

// Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true.
rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM `test` WHERE id=?", 1)
if !rows.Next() {
dbt.Fatal("no data")
}
var b, i8, i16, i32, i64, f32, f64 any
err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64)
if err != nil {
dbt.Fatal(err)
}
if b.(int64) != 1 {
dbt.Errorf("b != 1")
}
if i8.(int64) != 127 {
dbt.Errorf("i8 != 127")
}
if i16.(int64) != 32767 {
dbt.Errorf("i16 != 32767")
}
if i32.(int64) != 2147483647 {
dbt.Errorf("i32 != 2147483647")
}
if i64.(int64) != 9223372036854775807 {
dbt.Errorf("i64 != 9223372036854775807")
}
if f32.(float32) != 1.25 {
dbt.Errorf("f32 != 1.25")
}
if f64.(float64) != 2.5 {
dbt.Errorf("f64 != 2.5")
}
})
}

func TestMultiQuery(t *testing.T) {
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
// Create Table
Expand Down Expand Up @@ -1808,13 +1839,13 @@ func TestConcurrent(t *testing.T) {
}

runTests(t, dsn, func(dbt *DBTest) {
var version string
if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
dbt.Fatalf("%s", err.Error())
}
if strings.Contains(strings.ToLower(version), "mariadb") {
t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
}
// var version string
// if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil {
// dbt.Fatal(err)
// }
// if strings.Contains(strings.ToLower(version), "mariadb") {
// t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`)
// }

var max int
err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
Expand Down
39 changes: 31 additions & 8 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"fmt"
"io"
"math"
"strconv"
"time"
)

Expand Down Expand Up @@ -834,7 +835,8 @@ func (rows *textRows) readRow(dest []driver.Value) error {

for i := range dest {
// Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
var buf []byte
buf, isNull, n, err = readLengthEncodedString(data[pos:])
pos += n

if err != nil {
Expand All @@ -846,19 +848,40 @@ func (rows *textRows) readRow(dest []driver.Value) error {
continue
}

if !mc.parseTime {
continue
}

// Parse time field
switch rows.rs.columns[i].fieldType {
case fieldTypeTimestamp,
fieldTypeDateTime,
fieldTypeDate,
fieldTypeNewDate:
if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil {
return err
if mc.parseTime {
dest[i], err = parseDateTime(buf, mc.cfg.Loc)
} else {
dest[i] = buf
}

case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong:
dest[i], err = strconv.ParseInt(string(buf), 10, 32)

case fieldTypeLongLong:
if rows.rs.columns[i].flags&flagUnsigned != 0 {
dest[i], err = strconv.ParseUint(string(buf), 10, 64)
} else {
dest[i], err = strconv.ParseInt(string(buf), 10, 64)
}

case fieldTypeFloat:
var d float64
d, err = strconv.ParseFloat(string(buf), 32)
dest[i] = float32(d)

case fieldTypeDouble:
dest[i], err = strconv.ParseFloat(string(buf), 64)

default:
dest[i] = buf
}
if err != nil {
return err
}
}

Expand Down

0 comments on commit 5d4a831

Please sign in to comment.