Permalink
Browse files

Fix handling of named parameters. Closes #190.

	modified:   Makefile
	modified:   all_test.go
	modified:   driver1.8.go
	modified:   go1.8_test.go
	modified:   parser.go
	modified:   ql.y
	modified:   scanner.go
	modified:   scanner.l
  • Loading branch information...
cznic committed Nov 22, 2017
1 parent 11f89f7 commit 3f53e147d722f949b627631bc771623ab9bdb396
Showing with 1,099 additions and 987 deletions.
  1. +2 −2 Makefile
  2. +17 −0 all_test.go
  3. +92 −24 driver1.8.go
  4. +12 −8 go1.8_test.go
  5. +3 −1 parser.go
  6. +1 −1 ql.y
  7. +967 −950 scanner.go
  8. +5 −1 scanner.l
View
@@ -29,7 +29,7 @@ coerce.go: helper/helper.go
go run helper/helper.go | gofmt > $@
cover:
t=$(shell tempfile) ; go test -coverprofile $$t && go tool cover -html $$t && unlink $$t
t=$(shell mktemp) ; go test -coverprofile $$t && go tool cover -html $$t && unlink $$t
cpu: clean
go test -run @ -bench . -cpuprofile cpu.out
@@ -58,7 +58,7 @@ nuke: clean
go clean -i
parser.go: parser.y
a=$(shell tempfile) ; \
a=$(shell mktemp) ; \
goyacc -o /dev/null -xegen $$a $< ; \
goyacc -cr -o $@ -xe $$a $< ; \
rm -f $$a
View
@@ -3444,6 +3444,23 @@ func TestIssue142(t *testing.T) {
}
}
func TestTokenize(t *testing.T) {
toks, err := tokenize("\"a$1\" `a$2` $3 $x $x_Yřa 'z' 3+6 -- foo\nbar")
if err != nil {
t.Fatal(err)
}
exp := []string{"\"a$1\"", "`a$2`", "$3", "$x", "$x_Yřa", "'z'", "3", "+", "6", "bar"}
if g, e := len(toks), len(exp); g != e {
t.Fatalf("\ngot %q\nexp %q", toks, exp)
}
for i, g := range toks {
if e := exp[i]; g != e {
t.Fatalf("\not %q\nexp %q", toks, exp)
}
}
}
// Both of the UPDATEs _should_ work but the 2nd one results in a _type missmatch_ error at the time of writing.
// see https://github.com/cznic/ql/issues/190
func TestIssue190(t *testing.T) {
View
@@ -6,53 +6,99 @@ import (
"context"
"database/sql/driver"
"fmt"
"strconv"
"strings"
"regexp"
)
const prefix = "$"
func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return c.Exec(replaceNamed(query, args))
query, vals, err := replaceNamed(query, args)
if err != nil {
return nil, err
}
return c.Exec(query, vals)
}
func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value) {
func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value, error) {
toks, err := tokenize(query)
if err != nil {
return "", nil, err
}
a := make([]driver.Value, len(args))
for k, v := range args {
if v.Name != "" {
query = strings.Replace(query, prefix+v.Name, fmt.Sprintf("%s%d", prefix, v.Ordinal), -1)
m := map[string]int{}
for _, v := range args {
m[v.Name] = v.Ordinal
a[v.Ordinal-1] = v.Value
}
for i, v := range toks {
if len(v) > 1 && strings.HasPrefix(v, prefix) {
if v[1] >= '1' && v[1] <= '9' {
continue
}
nm := v[1:]
k, ok := m[nm]
if !ok {
return query, nil, fmt.Errorf("unknown named parameter %s", nm)
}
toks[i] = fmt.Sprintf("$%d", k)
}
a[k] = v.Value
}
return query, a
return strings.Join(toks, " "), a, nil
}
func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return c.Query(replaceNamed(query, args))
query, vals, err := replaceNamed(query, args)
if err != nil {
return nil, err
}
return c.Query(query, vals)
}
func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.Prepare(filterNamedArgs(query))
}
query, err := filterNamedArgs(query)
if err != nil {
return nil, err
}
var re = regexp.MustCompile(`^\w+`)
return c.Prepare(query)
}
func filterNamedArgs(q string) string {
c := strings.Count(q, prefix)
if c == 0 || c == len(q) {
return q
func filterNamedArgs(query string) (string, error) {
toks, err := tokenize(query)
if err != nil {
return "", err
}
pc := strings.Split(q, prefix)
for k, v := range pc {
if k == 0 {
continue
n := 0
for _, v := range toks {
if len(v) > 1 && strings.HasPrefix(v, prefix) && v[1] >= '1' && v[1] <= '9' {
m, err := strconv.ParseUint(v[1:], 10, 31)
if err != nil {
return "", err
}
if int(m) > n {
n = int(m)
}
}
if v != "" {
pc[k] = re.ReplaceAllString(v, fmt.Sprint(k))
}
for i, v := range toks {
if len(v) > 1 && strings.HasPrefix(v, prefix) {
if v[1] >= '1' && v[1] <= '9' {
continue
}
n++
toks[i] = fmt.Sprintf("$%d", n)
}
}
return strings.Join(pc, prefix)
return strings.Join(toks, " "), nil
}
func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
@@ -70,3 +116,25 @@ func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
}
return s.Query(a)
}
func tokenize(s string) (r []string, _ error) {
lx, err := newLexer(s)
if err != nil {
return nil, err
}
var lval yySymType
for lx.Lex(&lval) != 0 {
s := string(lx.TokenBytes(nil))
if s != "" {
switch s[len(s)-1] {
case '"':
s = "\"" + s
case '`':
s = "`" + s
}
}
r = append(r, s)
}
return r, nil
}
View
@@ -84,9 +84,9 @@ func TestNamedArgs(t *testing.T) {
rows, err := db.QueryContext(
context.Background(),
`select $one;select $two;select $three;`,
sql.Named("one", 1),
sql.Named("two", 2),
`select $two;select $one;select $three;`,
sql.Named("one", 2),
sql.Named("two", 1),
sql.Named("three", 3),
)
if err != nil {
@@ -119,21 +119,25 @@ func TestNamedArgs(t *testing.T) {
}{
{
`select $one;select $two;select $three;`,
`select $1;select $2;select $3;`,
`select $1 ; select $2 ; select $3 ;`,
},
{
`select * from foo where t=$1`,
`select * from foo where t=$1`,
`select * from foo where t = $1`,
},
{
`select * from foo where t=$1&&name=$name`,
`select * from foo where t=$1&&name=$2`,
`select * from foo where t = $1 && name = $2`,
},
}
for _, s := range samples {
e := filterNamedArgs(s.src)
e, err := filterNamedArgs(s.src)
if err != nil {
t.Fatal(err)
}
if e != s.exp {
t.Errorf("expected %s got %s", s.exp, e)
t.Errorf("\nexpected %q\n got %q", s.exp, e)
}
}
View

Some generated files are not rendered by default. Learn more.

Oops, something went wrong.
View
2 ql.y
@@ -3,7 +3,7 @@
//TODO Put your favorite license here
// yacc source generated by ebnf2y[1]
// at 2017-08-31 17:43:07.227157474 +0200 CEST m=+0.001846399
// at 2017-11-22 13:44:30.7008477 +0100 CET m=+0.004756809
//
// $ ebnf2y -o ql.y -oe ql.ebnf -start StatementList -pkg ql -p _
//
Oops, something went wrong.

0 comments on commit 3f53e14

Please sign in to comment.