From 65368d9707fdd7e0fa3e25d304008f13440a77c6 Mon Sep 17 00:00:00 2001 From: zihengCat Date: Thu, 15 Jul 2021 15:22:18 +0800 Subject: [PATCH] map time.Duration param into MySQL TIME type --- connection.go | 7 +++ utils.go | 74 +++++++++++++++++++++++++++++ utils_test.go | 127 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 208 insertions(+) diff --git a/connection.go b/connection.go index 835f89729..f27309a76 100644 --- a/connection.go +++ b/connection.go @@ -252,6 +252,13 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } buf = append(buf, '\'') } + case time.Duration: + buf = append(buf, '\'') + buf, err = appendTime(buf, v) + if err != nil { + return "", err + } + buf = append(buf, '\'') case json.RawMessage: buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { diff --git a/utils.go b/utils.go index bcdee1b46..88ae18cb4 100644 --- a/utils.go +++ b/utils.go @@ -276,6 +276,80 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } +func appendTime(buf []byte, td time.Duration) ([]byte, error) { + if td == 0 { + return append(buf, "00:00:00"...), nil + } + + // See: https://dev.mysql.com/doc/refman/8.0/en/time.html + tBuf := make([]byte, 0, len("-838:59:59.000000")) + + // Time could be negative + if td < 0 { + td *= -1 + tBuf = append(tBuf, '-') + } + + var ns time.Duration + + hour := int64(td / time.Hour) + ns = td % time.Hour + + min := int64(ns / time.Minute) + ns = ns % time.Minute + + sec := int64(ns / time.Second) + ns = ns % time.Second + + msec := int64(ns / 1000) + + // nsec := int64(ns) + + // hour + if hour >= 0 && hour < 839 { + if hour < 10 { + tBuf = append(tBuf, '0') + } + tBuf = strconv.AppendInt(tBuf, hour, 10) + } else { + return buf, errors.New("hour is not in the range [-838, 838]") + } + tBuf = append(tBuf, ':') + + // minute + if min >= 0 && min < 10 { + tBuf = append(tBuf, '0') + } + tBuf = strconv.AppendInt(tBuf, min, 10) + tBuf = append(tBuf, ':') + + // second + if sec >= 0 && sec < 10 { + tBuf = append(tBuf, '0') + } + tBuf = strconv.AppendInt(tBuf, sec, 10) + + // microsecond + if msec > 0 { + tBuf = append(tBuf, '.') + switch { + case msec < 10: + tBuf = append(tBuf, "00000"...) + case msec >= 10 && msec < 100: + tBuf = append(tBuf, "0000"...) + case msec >= 100 && msec < 1000: + tBuf = append(tBuf, "000"...) + case msec >= 1000 && msec < 10000: + tBuf = append(tBuf, "00"...) + case msec >= 1000 && msec < 10000: + tBuf = append(tBuf, "0"...) + } + tBuf = strconv.AppendInt(tBuf, msec, 10) + } + + return append(buf, tBuf...), nil +} + func appendDateTime(buf []byte, t time.Time) ([]byte, error) { year, month, day := t.Date() hour, min, sec := t.Clock() diff --git a/utils_test.go b/utils_test.go index b0069251e..9ab726b58 100644 --- a/utils_test.go +++ b/utils_test.go @@ -13,6 +13,7 @@ import ( "database/sql" "database/sql/driver" "encoding/binary" + "errors" "testing" "time" ) @@ -295,6 +296,132 @@ func TestIsolationLevelMapping(t *testing.T) { } } +func TestAppendTime(t *testing.T) { + tests := []struct { + td time.Duration + str string + }{ + // hour + { + td: 1 * time.Hour, + str: "01:00:00", + }, + { + td: 10 * time.Hour, + str: "10:00:00", + }, + { + td: 11 * time.Hour, + str: "11:00:00", + }, + { + td: 23 * time.Hour, + str: "23:00:00", + }, + // minute + { + td: 1 * time.Minute, + str: "00:01:00", + }, + { + td: 10 * time.Minute, + str: "00:10:00", + }, + { + td: 59 * time.Minute, + str: "00:59:00", + }, + // second + { + td: 1 * time.Second, + str: "00:00:01", + }, + { + td: 10 * time.Second, + str: "00:00:10", + }, + { + td: 59 * time.Second, + str: "00:00:59", + }, + { + td: 60 * time.Second, + str: "00:01:00", + }, + // hour + minute + second + { + td: 1*time.Hour + 2*time.Minute + 3*time.Second, + str: "01:02:03", + }, + { + td: 23*time.Hour + 59*time.Minute + 59*time.Second, + str: "23:59:59", + }, + // microsecond + { + td: 1 * time.Microsecond, + str: "00:00:00.000001", + }, + { + td: 10 * time.Microsecond, + str: "00:00:00.000010", + }, + { + td: 100 * time.Microsecond, + str: "00:00:00.000100", + }, + { + td: 10 * 100000 * time.Microsecond, + str: "00:00:01", + }, + { + td: 1*time.Second + 1*time.Microsecond, + str: "00:00:01.000001", + }, + // > 23:59:59.999999 + { + td: 24 * time.Hour, + str: "24:00:00", + }, + { + td: 100*time.Hour + 20*time.Minute + 30*time.Second + 123456*time.Microsecond, + str: "100:20:30.123456", + }, + // upper bound / lower bound + { + td: 838*time.Hour + 59*time.Minute + 59*time.Second, + str: "838:59:59", + }, + { + td: -1 * (838*time.Hour + 59*time.Minute + 59*time.Second), + str: "-838:59:59", + }, + } + for _, v := range tests { + buf := make([]byte, 0, 32) + buf, _ = appendTime(buf, v.td) + if str := string(buf); str != v.str { + t.Errorf("appendTime(%v), have: %s, want: %s", v.td, str, v.str) + } + } + + // hour out of range [-838, 838] + failedCases := []time.Duration{ + 839 * time.Hour, + -839 * time.Hour, + 1000 * time.Hour, + -1000 * time.Hour, + } + outOfRangeErr := errors.New("hour is not in the range [-838, 838]") + + for _, v := range failedCases { + buf := make([]byte, 0, 32) + newBuf, err := appendTime(buf, v) + if !bytes.Equal(buf, newBuf) || err == nil { + t.Errorf("appendTime(%v), have: %v, want: %v", v, err, outOfRangeErr) + } + } +} func TestAppendDateTime(t *testing.T) { tests := []struct { t time.Time