Skip to content

Commit

Permalink
Merge 581858e into 2ba0fc6
Browse files Browse the repository at this point in the history
  • Loading branch information
dynajoe committed Dec 13, 2019
2 parents 2ba0fc6 + 581858e commit 9ef5106
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 85 deletions.
231 changes: 158 additions & 73 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ package sqlx
import (
"bytes"
"database/sql"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"unicode"
"unicode/utf8"

"github.com/jmoiron/sqlx/reflectx"
)
Expand Down Expand Up @@ -279,90 +280,174 @@ func bindMap(bindType int, query string, args map[string]interface{}) (string, [
// digits and numbers, where '5' is a digit but '五' is not.
var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}

// FIXME: this function isn't safe for unicode named params, as a failing test
// can testify. This is not a regression but a failure of the original code
// as well. It should be modified to range over runes in a string rather than
// bytes, even though this is less convenient and slower. Hopefully the
// addition of the prepared NamedStmt (which will only do this once) will make
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
type parseNamedState int

const (
parseStateConsumingIdent parseNamedState = iota
parseStateQuery
parseStateQuotedIdent
parseStateStringConstant
parseStateLineComment
parseStateBlockComment
parseStateSkipThenTransition
parseStateDollarQuoteLiteral
)

type parseNamedContext struct {
state parseNamedState
data map[string]interface{}
}

const (
colon = ':'
backSlash = '\\'
forwardSlash = '/'
singleQuote = '\''
dash = '-'
star = '*'
newLine = '\n'
dollarSign = '$'
doubleQuote = '"'
)

// compile a NamedQuery into an unbound query (using the '?' bindvar) and
// a list of names.
func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
names = make([]string, 0, 10)
rebound := make([]byte, 0, len(qs))

inName := false
last := len(qs) - 1
currentVar := 1
name := make([]byte, 0, 10)

for i, b := range qs {
// a ':' while we're in a name is an error
if b == ':' {
// if this is the second ':' in a '::' escape sequence, append a ':'
if inName && i > 0 && qs[i-1] == ':' {
rebound = append(rebound, ':')
inName = false
continue
} else if inName {
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
return query, names, err
var result strings.Builder

paramCount := 1
var params []string
addParam := func(paramName string) {
params = append(params, paramName)

switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
result.WriteByte(':')
result.WriteString(paramName)
case QUESTION, UNKNOWN:
result.WriteByte('?')
case DOLLAR:
result.WriteByte('$')
result.WriteString(strconv.Itoa(paramCount))
case AT:
result.WriteString("@p")
result.WriteString(strconv.Itoa(paramCount))
}

paramCount++
}

isRuneStartOfIdent := func(r rune) bool {
return unicode.In(r, unicode.Letter) || r == '_'
}

isRunePartOfIdent := func(r rune) bool {
return isRuneStartOfIdent(r) || unicode.In(r, allowedBindRunes...) || r == '_' || r == '.'
}

ctx := parseNamedContext{state: parseStateQuery}

setState := func(s parseNamedState, d map[string]interface{}) {
ctx.data = d
ctx.state = s
}

var previousRune rune
maxIndex := len(qs)

for byteIndex := 0; byteIndex < maxIndex; {
currentRune, runeWidth := utf8.DecodeRune(qs[byteIndex:])
nextRuneByteIndex := byteIndex + runeWidth

nextRune := utf8.RuneError
if nextRuneByteIndex < maxIndex {
nextRune, _ = utf8.DecodeRune(qs[nextRuneByteIndex:])
}

writeCurrentRune := true
switch ctx.state {
case parseStateQuery:
if currentRune == colon && previousRune != colon && isRuneStartOfIdent(nextRune) {
// :foo
writeCurrentRune = false
setState(parseStateConsumingIdent, map[string]interface{}{
"ident": &strings.Builder{},
})
} else if currentRune == singleQuote && previousRune != backSlash {
// \'
setState(parseStateStringConstant, nil)
} else if currentRune == dash && nextRune == dash {
// -- single line comment
setState(parseStateLineComment, nil)
} else if currentRune == forwardSlash && nextRune == star {
// /*
setState(parseStateSkipThenTransition, map[string]interface{}{
"state": parseStateBlockComment,
"data": map[string]interface{}{
"depth": 1,
},
})
} else if currentRune == dollarSign && previousRune == dollarSign {
// $$
setState(parseStateDollarQuoteLiteral, nil)
} else if currentRune == doubleQuote {
// "foo"."bar"
setState(parseStateQuotedIdent, nil)
}
inName = true
name = []byte{}
} else if inName && i > 0 && b == '=' && len(name) == 0 {
rebound = append(rebound, ':', '=')
inName = false
continue
// if we're in a name, and this is an allowed character, continue
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
// append the byte to the name if we are in a name and not on the last byte
name = append(name, b)
// if we're in a name and it's not an allowed character, the name is done
} else if inName {
inName = false
// if this is the final byte of the string and it is part of the name, then
// make sure to add it to the name
if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
name = append(name, b)
case parseStateConsumingIdent:
if isRunePartOfIdent(currentRune) {
ctx.data["ident"].(*strings.Builder).WriteRune(currentRune)
writeCurrentRune = false
} else {
addParam(ctx.data["ident"].(*strings.Builder).String())
setState(parseStateQuery, nil)
}
// add the string representation to the names list
names = append(names, string(name))
// add a proper bindvar for the bindType
switch bindType {
// oracle only supports named type bind vars even for positional
case NAMED:
rebound = append(rebound, ':')
rebound = append(rebound, name...)
case QUESTION, UNKNOWN:
rebound = append(rebound, '?')
case DOLLAR:
rebound = append(rebound, '$')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
}
currentVar++
case AT:
rebound = append(rebound, '@', 'p')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
case parseStateBlockComment:
if previousRune == star && currentRune == forwardSlash {
newDepth := ctx.data["depth"].(int) - 1
if newDepth == 0 {
setState(parseStateQuery, nil)
} else {
ctx.data["depth"] = newDepth
}
currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
rebound = append(rebound, b)
} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
rebound = append(rebound, b)
case parseStateLineComment:
if currentRune == newLine {
setState(parseStateQuery, nil)
}
case parseStateStringConstant:
if currentRune == singleQuote && previousRune != backSlash {
setState(parseStateQuery, nil)
}
} else {
// this is a normal byte and should just go onto the rebound query
rebound = append(rebound, b)
case parseStateDollarQuoteLiteral:
if currentRune == dollarSign && previousRune != dollarSign {
setState(parseStateQuery, nil)
}
case parseStateQuotedIdent:
if currentRune == doubleQuote {
setState(parseStateQuery, nil)
}
case parseStateSkipThenTransition:
setState(ctx.data["state"].(parseNamedState), ctx.data["data"].(map[string]interface{}))
default:
setState(parseStateQuery, nil)
}

if writeCurrentRune {
result.WriteRune(currentRune)
}

previousRune = currentRune
byteIndex = nextRuneByteIndex
}

// If parsing left off while consuming an ident, add that ident to params
if ctx.state == parseStateConsumingIdent {
addParam(ctx.data["ident"].(*strings.Builder).String())
}

return string(rebound), names, err
return result.String(), params, nil
}

// BindNamed binds a struct or a map to a query with named parameters.
Expand Down
89 changes: 77 additions & 12 deletions named_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ func TestCompileQuery(t *testing.T) {
V: []string{"name1", "name2"},
},
{
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
Q: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
Q: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT 'a:b:c' || first_name, ':ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
V: []string{"first_name", "last_name"},
},
{
Expand All @@ -52,16 +52,39 @@ func TestCompileQuery(t *testing.T) {
T: `SELECT @name := "name", @p1, @p2, @p3`,
V: []string{"age", "first", "last"},
},
/* This unicode awareness test sadly fails, because of our byte-wise worldview.
* We could certainly iterate by Rune instead, though it's a great deal slower,
* it's probably the RightWay(tm)
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: []string{"name", "age", "first", "last"},
N: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
V: []string{"あ", "b", "キコ", "名前"},
},
{
Q: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)",
R: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)",
D: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)",
N: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)",
T: "-- A Line Comment should be ignored for :params\nINSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)",
V: []string{"あ", "b", "キコ", "名前"},
},
{
Q: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
R: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
N: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
T: `/* A Block Comment should be ignored for :params */INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
V: []string{"あ", "b", "キコ", "名前"},
},
// Repeated names are not distinct in the names list
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :name)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3)`,
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :name)`,
V: []string{"name", "age", "name"},
},
*/
}

for _, test := range table {
Expand Down Expand Up @@ -98,6 +121,48 @@ func TestCompileQuery(t *testing.T) {
}
}

func TestNamedQueryWithoutParams(t *testing.T) {
var queries []string = []string{
// Array Slice Syntax
`SELECT schedule[1:2][1:1] FROM sal_emp WHERE name = 'Bill';`,
`SELECT f1[1][-2][3] AS e1, f1[1][-1][5] AS e2 FROM (SELECT '[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}'::int[] AS f1) AS ss;`,
`SELECT array_dims(1 || '[0:1]={2,3}'::int[]);`,
// String Constant Syntax
`'Dianne'':not_a_parameter horse'`,
`'Dianne'''':not_a_parameter horse'`,
`SELECT ':not_an_parameter'`,
`$$Dia:not_an_parameter's horse$$`,
`$$Dianne's horse$$`,
`SELECT 'foo'
'bar';`,
`E'user\'s log'`,
`$$escape ' with ''$$`,
// Quoted Ident Syntax
`SELECT "addr:city" FROM "location";`,
// Type Cast Syntax
`select '1' :: numeric;`, `select '1' :: text :: numeric;`,
// Nested Block Quotes
`SELECT * FROM users
/* Ignore all things who aren't after a certain :date
* More lines /* nested block comment
*/*/
WHERE some_text LIKE 'foo -- bar'`,
}

for _, q := range queries {
qr, names, err := compileNamedQuery([]byte(q), QUESTION)
if err != nil {
t.Error(err)
}
if qr != q {
t.Errorf("expected query to be unaltered\nexpected: %s\ngot:%s", q, qr)
}
if len(names) > 0 {
t.Errorf("expected params to be empty got: %v", names)
}
}
}

type Test struct {
t *testing.T
}
Expand Down Expand Up @@ -143,7 +208,7 @@ func TestNamedQueries(t *testing.T) {
test.Error(err)

ns, err = db.PrepareNamed(`
SELECT first_name, last_name, email
SELECT first_name, last_name, email
FROM person WHERE first_name=:first_name AND email=:email`)
test.Error(err)

Expand Down

0 comments on commit 9ef5106

Please sign in to comment.