diff --git a/src/xsql/executor.go b/src/xsql/executor.go index 2061a3d..d143661 100644 --- a/src/xsql/executor.go +++ b/src/xsql/executor.go @@ -4,6 +4,8 @@ import ( "database/sql" "errors" "fmt" + ora "github.com/sijms/go-ora/v2" + "google.golang.org/protobuf/types/known/timestamppb" "reflect" "strings" "time" @@ -321,6 +323,28 @@ func (t *executor) Exec(query string, args []interface{}, opts *sqlOptions) (sql return res, err } +func (t *executor) isTime(typ string) bool { + switch typ { + case "time.Time", "go_ora.TimeStamp", "*timestamppb.Timestamp": + return true + default: + return false + } +} + +func (t *executor) formatTime(typ string, v interface{}, opts *sqlOptions) string { + switch typ { + case "time.Time": + return v.(time.Time).Format(opts.TimeLayout) + case "go_ora.TimeStamp": + return time.Time(v.(ora.TimeStamp)).Format(opts.TimeLayout) + case "*timestamppb.Timestamp": + return v.(*timestamppb.Timestamp).AsTime().Format(opts.TimeLayout) + default: + return "" + } +} + func (t *executor) foreachInsert(value reflect.Value, typ reflect.Type, opts *sqlOptions) (fields, vars []string, bindArgs []interface{}) { for n := 0; n < value.NumField(); n++ { fieldValue := value.Field(n) @@ -336,32 +360,34 @@ func (t *executor) foreachInsert(value reflect.Value, typ reflect.Type, opts *sq if !value.Field(n).CanInterface() { continue } - isTime := value.Field(n).Type().String() == "time.Time" tag := value.Type().Field(n).Tag.Get(opts.Tag) - if tag == "" || tag == "-" || tag == "_" { + if tag == "" || tag == "-" { continue } + fields = append(fields, tag) - v := "" + vTyp := value.Field(n).Type().String() + var v string if opts.Placeholder == "?" { v = opts.Placeholder } else { v = fmt.Sprintf(opts.Placeholder, n) } + isTime := t.isTime(vTyp) if isTime { - vars = append(vars, opts.TimeFunc(v)) - } else { - vars = append(vars, v) + v = opts.TimeFunc(v) } + vars = append(vars, v) + var a interface{} if isTime { - ti := value.Field(n).Interface().(time.Time) - bindArgs = append(bindArgs, ti.Format(opts.TimeLayout)) + a = t.formatTime(vTyp, value.Field(n).Interface(), opts) } else { - bindArgs = append(bindArgs, value.Field(n).Interface()) + a = value.Field(n).Interface() } + bindArgs = append(bindArgs, a) } return } @@ -381,7 +407,7 @@ func (t *executor) foreachBatchInsertFields(value reflect.Value, typ reflect.Typ } tag := value.Type().Field(n).Tag.Get(opts.Tag) - if tag == "" || tag == "-" || tag == "_" { + if tag == "" || tag == "-" { continue } @@ -406,24 +432,31 @@ func (t *executor) foreachBatchInsertValues(ai int, value reflect.Value, typ ref } tag := value.Type().Field(n).Tag.Get(opts.Tag) - if tag == "" || tag == "-" || tag == "_" { + if tag == "" || tag == "-" { continue } + vTyp := value.Field(n).Type().String() + var v string if opts.Placeholder == "?" { - vars = append(vars, opts.Placeholder) + v = opts.Placeholder } else { - vars = append(vars, fmt.Sprintf(opts.Placeholder, ai)) + v = fmt.Sprintf(opts.Placeholder, ai) ai += 1 } + isTime := t.isTime(vTyp) + if isTime { + v = opts.TimeFunc(v) + } + vars = append(vars, v) - // time特殊处理 - if value.Field(n).Type().String() == "time.Time" { - ti := value.Field(n).Interface().(time.Time) - bindArgs = append(bindArgs, ti.Format(opts.TimeLayout)) + var a interface{} + if isTime { + a = t.formatTime(vTyp, value.Field(n).Interface(), opts) } else { - bindArgs = append(bindArgs, value.Field(n).Interface()) + a = value.Field(n).Interface() } + bindArgs = append(bindArgs, a) } return } @@ -444,23 +477,30 @@ func (t *executor) foreachUpdate(value reflect.Value, typ reflect.Type, opts *sq } tag := value.Type().Field(n).Tag.Get(opts.Tag) - if tag == "" || tag == "-" || tag == "_" { + if tag == "" || tag == "-" { continue } + vTyp := value.Field(n).Type().String() + var v string if opts.Placeholder == "?" { - set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, opts.Placeholder)) + v = opts.Placeholder } else { - set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, fmt.Sprintf(opts.Placeholder, n))) + v = fmt.Sprintf(opts.Placeholder, n) + } + isTime := t.isTime(vTyp) + if isTime { + v = opts.TimeFunc(v) } + set = append(set, fmt.Sprintf("%s = %s", opts.ColumnQuotes+tag+opts.ColumnQuotes, v)) - // time特殊处理 - if value.Field(n).Type().String() == "time.Time" { - ti := value.Field(n).Interface().(time.Time) - bindArgs = append(bindArgs, ti.Format(opts.TimeLayout)) + var a interface{} + if isTime { + a = t.formatTime(vTyp, value.Field(n).Interface(), opts) } else { - bindArgs = append(bindArgs, value.Field(n).Interface()) + a = value.Field(n).Interface() } + bindArgs = append(bindArgs, a) } return } diff --git a/src/xsql/fetcher.go b/src/xsql/fetcher.go index 69c2ce8..c0960a4 100644 --- a/src/xsql/fetcher.go +++ b/src/xsql/fetcher.go @@ -4,7 +4,8 @@ import ( "database/sql" "errors" "fmt" - ora "github.com/sijms/go-ora/v2" + "github.com/sijms/go-ora/v2" + "google.golang.org/protobuf/types/known/timestamppb" "reflect" "strconv" "time" @@ -263,8 +264,11 @@ func (t *RowResult) Time() time.Time { if typ == "time.Time" { return t.v.(time.Time) } - if typ == "ora.TimeStamp" { - return time.Time(t.v.(ora.TimeStamp)) + if typ == "go_ora.TimeStamp" { + return time.Time(t.v.(go_ora.TimeStamp)) + } + if typ == "*timestamppb.Timestamp" { + return t.v.(*timestamppb.Timestamp).AsTime() } return time.Time{} } @@ -346,17 +350,34 @@ func (t *Fetcher) mapped(row *Row, tag string, value reflect.Value, typ reflect. v = res.Int() == 1 break default: - if !res.Empty() && - typ.String() == "time.Time" && - reflect.ValueOf(v).Type().String() != "time.Time" { - if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { - v = t - } else { - return fmt.Errorf("time parse fail for field %s: %v", tag, e) + if !res.Empty() { + vTyp := reflect.ValueOf(v).Type().String() + // 如果结构体是time.Time类型,执行转换 + if typ.String() == "time.Time" { + if vTyp == "time.Time" { + // parseTime=true + v = res.Value() + } else { + // parseTime=false + if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { + v = t + } else { + return fmt.Errorf("time parse fail for field %s: %v", tag, e) + } + } + } + // 如果结构体是*timestamppb.Timestamp类型,执行转换 + if typ.String() == "*timestamppb.Timestamp" { + if vTyp != "*timestamppb.Timestamp" { + if t, e := time.ParseInLocation(t.options.TimeLayout, res.String(), t.options.TimeLocation); e == nil { + v = timestamppb.New(t) + } else { + return fmt.Errorf("time parse fail for field %s: %v", tag, e) + } + } } } } - // 追加异常信息 defer func() { if e := recover(); e != nil { diff --git a/src/xsql/go.mod b/src/xsql/go.mod index d77bd3c..5167535 100644 --- a/src/xsql/go.mod +++ b/src/xsql/go.mod @@ -6,6 +6,7 @@ require ( github.com/go-sql-driver/mysql v1.6.0 github.com/sijms/go-ora/v2 v2.5.2 github.com/stretchr/testify v1.7.1 + google.golang.org/protobuf v1.34.1 ) require ( diff --git a/src/xsql/go.sum b/src/xsql/go.sum index 20a5da4..d35cbc8 100644 --- a/src/xsql/go.sum +++ b/src/xsql/go.sum @@ -2,6 +2,9 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sijms/go-ora/v2 v2.5.2 h1:8ACnYT4rOI7vjCIXQuGopiClXrXt4AnmSrv+nyMxELQ= @@ -9,6 +12,11 @@ github.com/sijms/go-ora/v2 v2.5.2/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/xsql/options.go b/src/xsql/options.go index 52b0b88..55acdec 100644 --- a/src/xsql/options.go +++ b/src/xsql/options.go @@ -14,7 +14,7 @@ func newDefaultOptions() sqlOptions { TableKey: "${TABLE}", Placeholder: "?", ColumnQuotes: "`", - TimeLayout: "2006-01-02 15:04:05", + TimeLayout: "2006-01-02 15:04:05.000000", TimeLocation: time.Local, TimeFunc: func(placeholder string) string { return placeholder @@ -41,7 +41,7 @@ type sqlOptions struct { // For oracle, can be configured as " ColumnQuotes string - // Default: 2006-01-02 15:04:05 + // Default: 2006-01-02 15:04:05.000000 TimeLayout string // Default: time.Local