Skip to content

Commit

Permalink
Merge f5963d6 into 1d0076b
Browse files Browse the repository at this point in the history
  • Loading branch information
Sovianum committed Jun 29, 2018
2 parents 1d0076b + f5963d6 commit 7790107
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 31 deletions.
25 changes: 13 additions & 12 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ import (

// Config is a configuration parsed from a DSN string
type Config struct {
User string
Password string
Scheme string
Host string
Database string
Timeout time.Duration
IdleTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
Location *time.Location
Debug bool
Params map[string]string
User string
Password string
Scheme string
Host string
Database string
Timeout time.Duration
IdleTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
Location *time.Location
Debug bool
UseDBLocation bool
Params map[string]string
}

// NewConfig creates a new config with default values
Expand Down
24 changes: 13 additions & 11 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ import (

// conn implements an interface sql.Conn
type conn struct {
url *url.URL
user *url.Userinfo
location *time.Location
transport *http.Transport
cancel context.CancelFunc
txCtx context.Context
stmts []*stmt
logger *log.Logger
closed int32
url *url.URL
user *url.Userinfo
location *time.Location
useDBLocation bool
transport *http.Transport
cancel context.CancelFunc
txCtx context.Context
stmts []*stmt
logger *log.Logger
closed int32
}

func newConn(cfg *Config) *conn {
Expand All @@ -45,7 +46,8 @@ func newConn(cfg *Config) *conn {
IdleConnTimeout: cfg.IdleTimeout,
ResponseHeaderTimeout: cfg.ReadTimeout,
},
logger: logger,
logger: logger,
useDBLocation: cfg.UseDBLocation,
}
// store userinfo in separate member, we will handle it manually
c.user = c.url.User
Expand Down Expand Up @@ -166,7 +168,7 @@ func (c *conn) query(ctx context.Context, query string, args []driver.Value) (dr
if err != nil {
return nil, err
}
return newTextRows(body, c.location)
return newTextRows(body, c.location, c.useDBLocation)
}

func (c *conn) exec(ctx context.Context, query string, args []driver.Value) (driver.Result, error) {
Expand Down
36 changes: 34 additions & 2 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
Expand All @@ -15,7 +16,8 @@ const (
)

var (
textEncode encoder = new(textEncoder)
textEncode encoder = new(textEncoder)
timeZoneRegexp = regexp.MustCompile("\\\\'.+\\\\'")
)

type encoder interface {
Expand All @@ -30,7 +32,8 @@ type textEncoder struct {
}

type textDecoder struct {
location *time.Location
location *time.Location
useDBLocation bool
}

func (e *textEncoder) Encode(value driver.Value) string {
Expand Down Expand Up @@ -130,6 +133,35 @@ func (d *textDecoder) Decode(t string, value []byte) (driver.Value, error) {
case "String":
return unescape(unquote(v)), nil
}

// got zoned datetime
if strings.HasPrefix(t, "DateTime") {
timeZoneName := timeZoneRegexp.FindString(t)
if timeZoneName == "" {
return nil, fmt.Errorf("time zone not found")
}
var (
loc *time.Location
err error
)

if d.useDBLocation {
timeZoneName = timeZoneName[2 : len(timeZoneName)-2] // remove \' in the beginning and in the end
loc, err = time.LoadLocation(timeZoneName)
if err != nil {
return nil, err
}
} else {
loc = d.location
}

var t time.Time
if t, err = time.ParseInLocation(timeFormat, unquote(v), loc); err != nil {
return t, err
}
return t.In(d.location), nil
}

if strings.HasPrefix(t, "FixedString") {
return unescape(unquote(v)), nil
}
Expand Down
19 changes: 16 additions & 3 deletions encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func TestTextDecoder(t *testing.T) {
{"Float64", "1", float64(1)},
{"Date", "'2012-05-31'", d},
{"DateTime", "'2011-03-06 06:20:00'", dt},
{"DateTime(\\'Europe/Moscow\\')", "'2011-03-06 06:20:00'", dt},
{"String", "'hello'", "hello"},
{"String", `'\\\\\'hello'`, `\\'hello`},
{"FixedString(5)", "'hello'", "hello"},
Expand All @@ -74,11 +75,23 @@ func TestTextDecoder(t *testing.T) {
{"Array(UInt32)", "[]", []uint32{}},
}

dec := &textDecoder{location: time.UTC}
for _, tc := range testCases {
dec := &textDecoder{location: time.UTC, useDBLocation: false}
for i, tc := range testCases {
v, err := dec.Decode(tc.tt, []byte(tc.value))
if assert.NoError(t, err) {
if assert.NoError(t, err, "%d", i) {
assert.Equal(t, tc.expected, v)
}
}
}

func TestDecodeTimeWithLocation(t *testing.T) {
dt := time.Date(2011, 3, 6, 3, 20, 0, 0, time.UTC)
dataType := "DateTime(\\'Europe/Moscow\\')"
dtStr := "'2011-03-06 06:20:00'"
dec := &textDecoder{location: time.UTC, useDBLocation: true}

v, err := dec.Decode(dataType, []byte(dtStr))
if assert.NoError(t, err) {
assert.Equal(t, dt, v)
}
}
4 changes: 2 additions & 2 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (r *textRows) Next(dest []driver.Value) error {
return io.EOF
}

func newTextRows(data []byte, location *time.Location) (*textRows, error) {
func newTextRows(data []byte, location *time.Location, useDBLocation bool) (*textRows, error) {
colCount := numOfColumns(data)
if colCount < 0 {
return nil, ErrMalformed
Expand All @@ -71,5 +71,5 @@ func newTextRows(data []byte, location *time.Location) (*textRows, error) {
types := make([]string, colCount)
data = data[splitTSV(data, columns):]
data = data[splitTSV(data, types):]
return &textRows{columns: columns, types: types, data: data, decode: &textDecoder{location: location}}, nil
return &textRows{columns: columns, types: types, data: data, decode: &textDecoder{location: location, useDBLocation: useDBLocation}}, nil
}
2 changes: 1 addition & 1 deletion rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestTextRows(t *testing.T) {
rows, err := newTextRows([]byte("Number\tText\nInt32\tString\n1\t'hello'\n2\t'world'\n"), time.Local)
rows, err := newTextRows([]byte("Number\tText\nInt32\tString\n1\t'hello'\n2\t'world'\n"), time.Local, false)
if !assert.NoError(t, err) {
return
}
Expand Down

0 comments on commit 7790107

Please sign in to comment.