Skip to content

Commit

Permalink
Merge ad3802b into 2ba0fc6
Browse files Browse the repository at this point in the history
  • Loading branch information
dynajoe committed Dec 6, 2019
2 parents 2ba0fc6 + ad3802b commit 60527d6
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 85 deletions.
230 changes: 157 additions & 73 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ package sqlx
import (
"bytes"
"database/sql"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"unicode"
"unicode/utf8"

"github.com/jmoiron/sqlx/reflectx"
)
Expand Down Expand Up @@ -279,90 +279,174 @@ func bindMap(bindType int, query string, args map[string]interface{}) (string, [
// digits and numbers, where '5' is a digit but '五' is not.
var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}

// FIXME: this function isn't safe for unicode named params, as a failing test
// can testify. This is not a regression but a failure of the original code
// as well. It should be modified to range over runes in a string rather than
// bytes, even though this is less convenient and slower. Hopefully the
// addition of the prepared NamedStmt (which will only do this once) will make
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
type parseNamedState int

const (
parseStateConsumingIdent parseNamedState = iota
parseStateQuery
parseStateQuotedIdent
parseStateStringConstant
parseStateLineComment
parseStateBlockComment
parseStateSkipThenTransition
parseStateDollarQuoteLiteral
)

type parseNamedContext struct {
state parseNamedState
data map[string]interface{}
}

const (
colon = ':'
backSlash = '\\'
forwardSlash = '/'
singleQuote = '\''
dash = '-'
star = '*'
newLine = '\n'
dollarSign = '$'
doubleQuote = '"'
)

// compile a NamedQuery into an unbound query (using the '?' bindvar) and
// a list of names.
func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))

inName := false
last := len(qs) - 1
currentVar := 1
name := make([]byte, 0, 10)

for i, b := range qs {
// a ':' while we're in a name is an error
if b == ':' {
// if this is the second ':' in a '::' escape sequence, append a ':'
if inName && i > 0 && qs[i-1] == ':' {
rebound = append(rebound, ':')
inName = false
continue
} else if inName {
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
return query, names, err
result := make([]byte, 0, len(qs))

addRuneToResult := func(r rune) {
result = append(result, []byte(string(r))...)
}

paramCount := 1
var params []string
addParam := func(paramName string) {
params = append(params, paramName)

switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
result = append(result, ':')
result = append(result, paramName...)
case QUESTION, UNKNOWN:
result = append(result, '?')
case DOLLAR:
result = append(result, '$')
result = append(result, []byte(strconv.Itoa(paramCount))...)
case AT:
result = append(result, '@', 'p')
result = append(result, []byte(strconv.Itoa(paramCount))...)
}

paramCount++
}

isRunePartOfIdent := func(r rune) bool {
return unicode.In(r, allowedBindRunes...) || r == '_' || r == '.'
}

source := string(qs)

ctx := parseNamedContext{state: parseStateQuery}

setState := func(s parseNamedState, d map[string]interface{}) {
ctx.data = d
ctx.state = s
}

var previousRune rune

for byteIndex, currentRune := range source {
nextRuneByteIndex := byteIndex + utf8.RuneLen(currentRune)

var remainingBytes []byte
if nextRuneByteIndex < len(source) {
remainingBytes = []byte(source[nextRuneByteIndex:])
}

nextRune, _ := utf8.DecodeRune(remainingBytes)
addCurrentRune := true
switch ctx.state {
case parseStateQuery:
if currentRune == colon && previousRune != colon && isRunePartOfIdent(nextRune) {
// :foo
addCurrentRune = false
setState(parseStateConsumingIdent, map[string]interface{}{
"ident": []byte{},
})
} else if currentRune == singleQuote && previousRune != backSlash {
// \'
setState(parseStateStringConstant, nil)
} else if currentRune == dash && nextRune == dash {
// -- single line comment
setState(parseStateLineComment, nil)
} else if currentRune == forwardSlash && nextRune == star {
// /*
setState(parseStateSkipThenTransition, map[string]interface{}{
"state": parseStateBlockComment,
"data": map[string]interface{}{
"depth": 1,
},
})
} else if currentRune == dollarSign && previousRune == dollarSign {
// $$
setState(parseStateDollarQuoteLiteral, nil)
} else if currentRune == doubleQuote {
// "foo"."bar"
setState(parseStateQuotedIdent, nil)
}
inName = true
name = []byte{}
} else if inName && i > 0 && b == '=' && len(name) == 0 {
rebound = append(rebound, ':', '=')
inName = false
continue
// if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
// append the byte to the name if we are in a name and not on the last byte
name = append(name, b)
// if we're in a name and it's not an allowed character, the name is done
} else if inName {
inName = false
// if this is the final byte of the string and it is part of the name, then
// make sure to add it to the name
if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
name = append(name, b)
case parseStateConsumingIdent:
if isRunePartOfIdent(currentRune) {
ctx.data["ident"] = append(ctx.data["ident"].([]byte), []byte(string(currentRune))...)
addCurrentRune = false
} else {
addParam(string(ctx.data["ident"].([]byte)))
setState(parseStateQuery, nil)
}
// add the string representation to the names list
names = append(names, string(name))
// add a proper bindvar for the bindType
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
rebound = append(rebound, ':')
rebound = append(rebound, name...)
case QUESTION, UNKNOWN:
rebound = append(rebound, '?')
case DOLLAR:
rebound = append(rebound, '$')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
case parseStateBlockComment:
if previousRune == star && currentRune == forwardSlash {
newDepth := ctx.data["depth"].(int) - 1
if newDepth == 0 {
setState(parseStateQuery, nil)
} else {
ctx.data["depth"] = newDepth
}
currentVar++
case AT:
rebound = append(rebound, '@', 'p')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
}
currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
rebound = append(rebound, b)
} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
rebound = append(rebound, b)
case parseStateLineComment:
if currentRune == newLine {
setState(parseStateQuery, nil)
}
case parseStateStringConstant:
if currentRune == singleQuote && previousRune != backSlash {
setState(parseStateQuery, nil)
}
} else {
// this is a normal byte and should just go onto the rebound query
rebound = append(rebound, b)
case parseStateDollarQuoteLiteral:
if currentRune == dollarSign && previousRune != dollarSign {
setState(parseStateQuery, nil)
}
case parseStateQuotedIdent:
if currentRune == doubleQuote {
setState(parseStateQuery, nil)
}
case parseStateSkipThenTransition:
setState(ctx.data["state"].(parseNamedState), ctx.data["data"].(map[string]interface{}))
default:
setState(parseStateQuery, nil)
}

if addCurrentRune {
addRuneToResult(currentRune)
}

previousRune = currentRune
}

// If parsing left off while consuming an ident, add that ident to params
if ctx.state == parseStateConsumingIdent {
addParam(string(ctx.data["ident"].([]byte)))
}

return string(rebound), names, err
return string(result), params, nil
}

// BindNamed binds a struct or a map to a query with named parameters.
Expand Down
47 changes: 35 additions & 12 deletions named_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ func TestCompileQuery(t *testing.T) {
V: []string{"name1", "name2"},
},
{
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
Q: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
Q: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
V: []string{"first_name", "last_name"},
},
{
Expand All @@ -52,16 +52,39 @@ func TestCompileQuery(t *testing.T) {
T: `SELECT @name := "name", @p1, @p2, @p3`,
V: []string{"age", "first", "last"},
},
/* This unicode awareness test sadly fails, because of our byte-wise worldview.
* We could certainly iterate by Rune instead, though it's a great deal slower,
* it's probably the RightWay(tm)
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: []string{"name", "age", "first", "last"},
N: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
V: []string{"あ", "b", "キコ", "名前"},
},
{
Q: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)",
R: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)",
D: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)",
N: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)",
T: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)",
V: []string{"あ", "b", "キコ", "名前"},
},
{
Q: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
R: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
T: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
V: []string{"あ", "b", "キコ", "名前"},
},
// Repeated names are not distinct in the names list
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :name)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3)`,
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :name)`,
V: []string{"name", "age", "name"},
},
*/
}

for _, test := range table {
Expand Down Expand Up @@ -143,7 +166,7 @@ func TestNamedQueries(t *testing.T) {
test.Error(err)

ns, err = db.PrepareNamed(`
SELECT first_name, last_name, email
SELECT first_name, last_name, email
FROM person WHERE first_name=:first_name AND email=:email`)
test.Error(err)

Expand Down

0 comments on commit 60527d6

Please sign in to comment.