Skip to content

Commit

Permalink
Merge pull request #17 from knocknote/feature/add_cases_time_and_nil_…
Browse files Browse the repository at this point in the history
…to_replace_insert_value_from_val_arg_function

Add cases time.Time and nil to replaceInsertValueFromValArg function
  • Loading branch information
goccy committed Aug 28, 2019
2 parents 7ece3ea + 3399bd8 commit 27424a2
Show file tree
Hide file tree
Showing 2 changed files with 627 additions and 45 deletions.
200 changes: 182 additions & 18 deletions sqlparser/sqlparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"reflect"
"regexp"
"strconv"
"time"

vtparser "github.com/knocknote/vitess-sqlparser/sqlparser"
"github.com/pkg/errors"
Expand All @@ -26,6 +27,10 @@ var (
replaceCharSetParam = regexp.MustCompile("charset=[A-Za-z-_0-9]+")
)

var (
ErrShardingKeyNotAllowNil = errors.New("sharding key does not allow nil")
)

func (p *Parser) shardColumnName(tableName string) string {
return p.cfg.ShardColumnName(tableName)
}
Expand Down Expand Up @@ -197,38 +202,154 @@ func (p *Parser) replaceInsertValueFromValArg(query *InsertQuery, colIndex int,
queryArg := query.Args[index-1]
switch arg := queryArg.(type) {
case string:
query.ColumnValues[colIndex] = func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.StrVal,
Val: []byte(arg),
query.ColumnValues[colIndex] = createSQLStringTypeVal(arg)
case *string:
if arg == nil {
query.ColumnValues[colIndex] = createSQLNilTypeVal()
} else {
query.ColumnValues[colIndex] = createSQLStringTypeVal(*arg)
}
case int:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case int8:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case int16:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case int32:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case int64:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case *int:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case int, int8, int16, int32, int64:
if colName == p.shardKeyColumnName(query.TableName) {
query.ShardKeyID = Identifier(arg.(int64))
case *int8:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
query.ColumnValues[colIndex] = func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.IntVal,
Val: []byte(fmt.Sprintf("%d", arg)),
case *int16:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case uint, uint8, uint16, uint32, uint64:
if colName == p.shardKeyColumnName(query.TableName) {
query.ShardKeyID = Identifier(int64(arg.(uint64)))
case *int32:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
query.ColumnValues[colIndex] = func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.IntVal,
Val: []byte(fmt.Sprintf("%d", arg)),
case *int64:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case uint:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case uint8:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case uint16:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case uint32:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case uint64:
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(arg))
case *uint:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case *uint8:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case *uint16:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case *uint32:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case *uint64:
if arg == nil {
if err := p.replaceInsertValueFromValArgCaseIntNilPtr(query, colIndex, colName); err != nil {
return errors.WithStack(err)
}
} else {
p.replaceInsertValueFromValArgCaseInt(query, colIndex, colName, int64(*arg))
}
case bool:
val := convertBoolToInt8(arg)
query.ColumnValues[colIndex] = createSQLIntTypeVal(val)
case *bool:
if arg == nil {
query.ColumnValues[colIndex] = createSQLNilTypeVal()
} else {
val := convertBoolToInt8(*arg)
query.ColumnValues[colIndex] = createSQLIntTypeVal(val)
}
case time.Time:
query.ColumnValues[colIndex] = createSQLTimeTypeVal(arg)
case *time.Time:
if arg == nil {
query.ColumnValues[colIndex] = createSQLNilTypeVal()
} else {
query.ColumnValues[colIndex] = createSQLTimeTypeVal(*arg)
}
case nil:
query.ColumnValues[colIndex] = createSQLNilTypeVal()
default:
debug.Printf("arg type = %s", reflect.TypeOf(arg))
}
return nil
}

func (p *Parser) replaceInsertValueFromValArgCaseInt(query *InsertQuery, colIndex int, colName string, arg int64) {
if colName == p.shardKeyColumnName(query.TableName) {
query.ShardKeyID = Identifier(arg)
}
query.ColumnValues[colIndex] = createSQLIntTypeVal(arg)
}

func (p *Parser) replaceInsertValueFromValArgCaseIntNilPtr(query *InsertQuery, colIndex int, colName string) error {
if colName == p.shardKeyColumnName(query.TableName) {
return errors.WithStack(ErrShardingKeyNotAllowNil)
}
query.ColumnValues[colIndex] = createSQLNilTypeVal()
return nil
}

func (p *Parser) replaceInsertValue(query *InsertQuery, colIndex int, colName string) error {
if colName == p.shardColumnName(query.TableName) {
query.ColumnValues[colIndex] = func() *vtparser.SQLVal {
Expand Down Expand Up @@ -472,3 +593,46 @@ func New() (*Parser, error) {

return &Parser{cfg: cfg}, nil
}

func createSQLIntTypeVal(val interface{}) func() *vtparser.SQLVal {
return func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.IntVal,
Val: []byte(fmt.Sprintf("%d", val)),
}
}
}

func createSQLStringTypeVal(val string) func() *vtparser.SQLVal {
return func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.StrVal,
Val: []byte(val),
}
}
}

func createSQLTimeTypeVal(val time.Time) func() *vtparser.SQLVal {
return func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.StrVal,
Val: []byte(val.Format("2006-01-02 15:04:05")),
}
}
}

func createSQLNilTypeVal() func() *vtparser.SQLVal {
return func() *vtparser.SQLVal {
return &vtparser.SQLVal{
Type: vtparser.IntVal,
Val: []byte("null"),
}
}
}

func convertBoolToInt8(val bool) (res int8) {
if val {
res = 1
}
return res
}
Loading

0 comments on commit 27424a2

Please sign in to comment.