Skip to content

Commit

Permalink
Update xsql supports "time.Time", "go_ora.TimeStamp", "*timestamppb.T…
Browse files Browse the repository at this point in the history
…imestamp"
  • Loading branch information
onanying committed May 15, 2024
1 parent 1f1133a commit e5fe8b4
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 39 deletions.
92 changes: 66 additions & 26 deletions src/xsql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"database/sql"
"errors"
"fmt"
ora "github.com/sijms/go-ora/v2"
"google.golang.org/protobuf/types/known/timestamppb"
"reflect"
"strings"
"time"
Expand Down Expand Up @@ -321,6 +323,28 @@ func (t *executor) Exec(query string, args []interface{}, opts *sqlOptions) (sql
return res, err
}

func (t *executor) isTime(typ string) bool {
switch typ {
case "time.Time", "go_ora.TimeStamp", "*timestamppb.Timestamp":
return true
default:
return false
}
}

func (t *executor) formatTime(typ string, v interface{}, opts *sqlOptions) string {
switch typ {
case "time.Time":
return v.(time.Time).Format(opts.TimeLayout)
case "go_ora.TimeStamp":
return time.Time(v.(ora.TimeStamp)).Format(opts.TimeLayout)
case "*timestamppb.Timestamp":
return v.(*timestamppb.Timestamp).AsTime().Format(opts.TimeLayout)
default:
return ""
}
}

func (t *executor) foreachInsert(value reflect.Value, typ reflect.Type, opts *sqlOptions) (fields, vars []string, bindArgs []interface{}) {
for n := 0; n < value.NumField(); n++ {
fieldValue := value.Field(n)
Expand All @@ -336,32 +360,34 @@ func (t *executor) foreachInsert(value reflect.Value, typ reflect.Type, opts *sq
if !value.Field(n).CanInterface() {
continue
}
isTime := value.Field(n).Type().String() == "time.Time"

tag := value.Type().Field(n).Tag.Get(opts.Tag)
if tag == "" || tag == "-" || tag == "_" {
if tag == "" || tag == "-" {
continue
}

fields = append(fields, tag)

v := ""
vTyp := value.Field(n).Type().String()
var v string
if opts.Placeholder == "?" {
v = opts.Placeholder
} else {
v = fmt.Sprintf(opts.Placeholder, n)
}
isTime := t.isTime(vTyp)
if isTime {
vars = append(vars, opts.TimeFunc(v))
} else {
vars = append(vars, v)
v = opts.TimeFunc(v)
}
vars = append(vars, v)

var a interface{}
if isTime {
ti := value.Field(n).Interface().(time.Time)
bindArgs = append(bindArgs, ti.Format(opts.TimeLayout))
a = t.formatTime(vTyp, value.Field(n).Interface(), opts)
} else {
bindArgs = append(bindArgs, value.Field(n).Interface())
a = value.Field(n).Interface()
}
bindArgs = append(bindArgs, a)
}
return
}
Expand All @@ -381,7 +407,7 @@ func (t *executor) foreachBatchInsertFields(value reflect.Value, typ reflect.Typ
}

tag := value.Type().Field(n).Tag.Get(opts.Tag)
if tag == "" || tag == "-" || tag == "_" {
if tag == "" || tag == "-" {
continue
}

Expand All @@ -406,24 +432,31 @@ func (t *executor) foreachBatchInsertValues(ai int, value reflect.Value, typ ref
}

tag := value.Type().Field(n).Tag.Get(opts.Tag)
if tag == "" || tag == "-" || tag == "_" {
if tag == "" || tag == "-" {
continue
}

vTyp := value.Field(n).Type().String()
var v string
if opts.Placeholder == "?" {
vars = append(vars, opts.Placeholder)
v = opts.Placeholder
} else {
vars = append(vars, fmt.Sprintf(opts.Placeholder, ai))
v = fmt.Sprintf(opts.Placeholder, ai)
ai += 1
}
isTime := t.isTime(vTyp)
if isTime {
v = opts.TimeFunc(v)
}
vars = append(vars, v)

// time特殊处理
if value.Field(n).Type().String() == "time.Time" {
ti := value.Field(n).Interface().(time.Time)
bindArgs = append(bindArgs, ti.Format(opts.TimeLayout))
var a interface{}
if isTime {
a = t.formatTime(vTyp, value.Field(n).Interface(), opts)
} else {
bindArgs = append(bindArgs, value.Field(n).Interface())
a = value.Field(n).Interface()
}
bindArgs = append(bindArgs, a)
}
return
}
Expand All @@ -444,23 +477,30 @@ func (t *executor) foreachUpdate(value reflect.Value, typ reflect.Type, opts *sq
}

tag := value.Type().Field(n).Tag.Get(opts.Tag)
if tag == "" || tag == "-" || tag == "_" {
if tag == "" || tag == "-" {
continue
}

vTyp := value.Field(n).Type().String()
var v string
if opts.Placeholder == "?" {
set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, opts.Placeholder))
v = opts.Placeholder
} else {
set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, fmt.Sprintf(opts.Placeholder, n)))
v = fmt.Sprintf(opts.Placeholder, n)
}
isTime := t.isTime(vTyp)
if isTime {
v = opts.TimeFunc(v)
}
set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, v))

// time特殊处理
if value.Field(n).Type().String() == "time.Time" {
ti := value.Field(n).Interface().(time.Time)
bindArgs = append(bindArgs, ti.Format(opts.TimeLayout))
var a interface{}
if isTime {
a = t.formatTime(vTyp, value.Field(n).Interface(), opts)
} else {
bindArgs = append(bindArgs, value.Field(n).Interface())
a = value.Field(n).Interface()
}
bindArgs = append(bindArgs, a)
}
return
}
43 changes: 32 additions & 11 deletions src/xsql/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import (
"database/sql"
"errors"
"fmt"
ora "github.com/sijms/go-ora/v2"
"github.com/sijms/go-ora/v2"
"google.golang.org/protobuf/types/known/timestamppb"
"reflect"
"strconv"
"time"
Expand Down Expand Up @@ -263,8 +264,11 @@ func (t *RowResult) Time() time.Time {
if typ == "time.Time" {
return t.v.(time.Time)
}
if typ == "ora.TimeStamp" {
return time.Time(t.v.(ora.TimeStamp))
if typ == "go_ora.TimeStamp" {
return time.Time(t.v.(go_ora.TimeStamp))
}
if typ == "*timestamppb.Timestamp" {
return t.v.(*timestamppb.Timestamp).AsTime()
}
return time.Time{}
}
Expand Down Expand Up @@ -346,17 +350,34 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect.
v = res.Int() == 1
break
default:
if !res.Empty() &&
typ.String() == "time.Time" &&
reflect.ValueOf(v).Type().String() != "time.Time" {
if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil {
v = t
} else {
return fmt.Errorf("time parse fail for field %s: %v", tag, e)
if !res.Empty() {
vTyp := reflect.ValueOf(v).Type().String()
// 如果结构体是time.Time类型,执行转换
if typ.String() == "time.Time" {
if vTyp == "time.Time" {
// parseTime=true
v = res.Value()
} else {
// parseTime=false
if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil {
v = t
} else {
return fmt.Errorf("time parse fail for field %s: %v", tag, e)
}
}
}
// 如果结构体是*timestamppb.Timestamp类型,执行转换
if typ.String() == "*timestamppb.Timestamp" {
if vTyp != "*timestamppb.Timestamp" {
if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil {
v = timestamppb.New(t)
} else {
return fmt.Errorf("time parse fail for field %s: %v", tag, e)
}
}
}
}
}

// 追加异常信息
defer func() {
if e := recover(); e != nil {
Expand Down
1 change: 1 addition & 0 deletions src/xsql/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/go-sql-driver/mysql v1.6.0
github.com/sijms/go-ora/v2 v2.5.2
github.com/stretchr/testify v1.7.1
google.golang.org/protobuf v1.34.1
)

require (
Expand Down
8 changes: 8 additions & 0 deletions src/xsql/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,21 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sijms/go-ora/v2 v2.5.2 h1:8ACnYT4rOI7vjCIXQuGopiClXrXt4AnmSrv+nyMxELQ=
github.com/sijms/go-ora/v2 v2.5.2/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
4 changes: 2 additions & 2 deletions src/xsql/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func newDefaultOptions() sqlOptions {
TableKey: "${TABLE}",
Placeholder: "?",
ColumnQuotes: "`",
TimeLayout: "2006-01-02 15:04:05",
TimeLayout: "2006-01-02 15:04:05.000000",
TimeLocation: time.Local,
TimeFunc: func(placeholder string) string {
return placeholder
Expand All @@ -41,7 +41,7 @@ type sqlOptions struct {
// For oracle, can be configured as "
ColumnQuotes string

// Default: 2006-01-02 15:04:05
// Default: 2006-01-02 15:04:05.000000
TimeLayout string

// Default: time.Local
Expand Down

0 comments on commit e5fe8b4

Please sign in to comment.