Skip to content

Commit

Permalink
adds support for custom sqlc.arg(MyParam) params
Browse files Browse the repository at this point in the history
  • Loading branch information
cmoog authored and kyleconroy committed Jan 16, 2020
1 parent 014ce72 commit f9d952d
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 17 deletions.
5 changes: 4 additions & 1 deletion examples/booktest/mysql/query.sql
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ WHERE book_id = ?;

/* name: UpdateBookISBN :exec */
UPDATE books
SET title = ?, tags = ?, isbn = ?
SET title = ?, tags = :book_tags, isbn = ?
WHERE book_id = ?;

/* name: DeleteAuthorBeforeYear :exec */
DELETE FROM books
WHERE yr < sqlc.arg(min_publish_year) AND author_id = ?;
24 changes: 19 additions & 5 deletions examples/booktest/mysql/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 64 additions & 4 deletions internal/mysql/param.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl
return params, nil
}

parseLimitSubExp := func(node sqlparser.Expr) {
parseLimitSubExp := func(node sqlparser.Expr) error {
switch v := node.(type) {
case *sqlparser.SQLVal:
if v.Type == sqlparser.ValArg {
Expand All @@ -33,11 +33,30 @@ func paramsInLimitExpr(limit *sqlparser.Limit, s *Schema, tableAliasMap FromTabl
Typ: "uint32",
})
}
case *sqlparser.FuncExpr:
name, raw, err := matchFuncExpr(v)
if err != nil {
return err
}
if name != "" && raw != "" {
params = append(params, &Param{
OriginalName: raw,
Name: name,
Typ: "uint32",
})
}
}
return nil
}

parseLimitSubExp(limit.Offset)
parseLimitSubExp(limit.Rowcount)
err := parseLimitSubExp(limit.Offset)
if err != nil {
return nil, err
}
err = parseLimitSubExp(limit.Rowcount)
if err != nil {
return nil, err
}

return params, nil
}
Expand Down Expand Up @@ -115,13 +134,26 @@ func paramInComparison(cond *sqlparser.ComparisonExpr, s *Schema, tableAliasMap
if v.Type == sqlparser.ValArg {
p.OriginalName = string(v.Val)
}
case *sqlparser.FuncExpr:
name, raw, err := matchFuncExpr(v)
if err != nil {
return false, err
}
if name != "" && raw != "" {
p.OriginalName = raw
p.Name = name
}
return false, nil
}
return true, nil
}
err := sqlparser.Walk(walker, cond)
if err != nil {
return nil, false, err
}
if p.Name != "" {
return p, true, nil
}
if p.OriginalName != "" && p.Typ != "" {
p.Name = paramName(colIdent, p.OriginalName)
return p, true, nil
Expand All @@ -143,11 +175,39 @@ func paramName(col sqlparser.ColIdent, originalName string) string {

func replaceParamStrs(query string, params []*Param) (string, error) {
for _, p := range params {
re, err := regexp.Compile(fmt.Sprintf("(%v)", p.OriginalName))
re, err := regexp.Compile(fmt.Sprintf("(%v)", regexp.QuoteMeta(p.OriginalName)))
if err != nil {
return "", err
}
query = re.ReplaceAllString(query, "?")
}
return query, nil
}

func matchFuncExpr(v *sqlparser.FuncExpr) (name string, raw string, err error) {
namespace := "sqlc"
fakeFunc := "arg"
if v.Qualifier.String() == namespace {
if v.Name.String() == fakeFunc {
if expr, ok := v.Exprs[0].(*sqlparser.AliasedExpr); ok {
if colName, ok := expr.Expr.(*sqlparser.ColName); ok {
customName := colName.Name.String()
return customName, fmt.Sprintf("%s.%s(%s)", namespace, fakeFunc, customName), nil
}
return "", "", fmt.Errorf("invalid custom argument value \"%s.%s(%s)\"", namespace, fakeFunc, replaceVParamExprs(sqlparser.String(v.Exprs[0])))
}
return "", "", fmt.Errorf("invalid custom argument value \"%s.%s(%s)\"", namespace, fakeFunc, replaceVParamExprs(sqlparser.String(v.Exprs[0])))
}
return "", "", fmt.Errorf("invalid function call \"%s.%s\", did you mean \"%s.%s\"?", namespace, v.Name.String(), namespace, fakeFunc)
}
return "", "", nil
}

func replaceVParamExprs(sql string) string {
/*
the sqlparser replaces "?" with ":v1"
to display a helpful error message, these should be replaced back to "?"
*/
matcher := regexp.MustCompile(":v[0-9]*")
return matcher.ReplaceAllString(sql, "?")
}
18 changes: 14 additions & 4 deletions internal/mysql/param_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ func TestSelectParamSearcher(t *testing.T) {
},
},
},
{
input: "select first_name, id FROM users LIMIT sqlc.arg(UsersLimit)",
output: []*Param{
&Param{
OriginalName: "sqlc.arg(UsersLimit)",
Name: "UsersLimit",
Typ: "uint32",
},
},
},
}
for _, tCase := range tests {
tree, err := sqlparser.Parse(tCase.input)
Expand Down Expand Up @@ -118,20 +128,20 @@ func TestInsertParamSearcher(t *testing.T) {

tests := []testCase{
testCase{
input: "/* name: InsertNewUser :exec */\nINSERT INTO users (first_name, last_name) VALUES (?, ?)",
input: "/* name: InsertNewUser :exec */\nINSERT INTO users (first_name, last_name) VALUES (?, sqlc.arg(user_last_name))",
output: []*Param{
&Param{
OriginalName: ":v1",
Name: "first_name",
Typ: "string",
},
&Param{
OriginalName: ":v2",
Name: "last_name",
OriginalName: "sqlc.arg(user_last_name)",
Name: "user_last_name",
Typ: "sql.NullString",
},
},
expectedNames: []string{"first_name", "last_name"},
expectedNames: []string{"first_name", "user_last_name"},
},
}
for _, tCase := range tests {
Expand Down
23 changes: 22 additions & 1 deletion internal/mysql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,30 @@ func parseInsert(node *sqlparser.Insert, query string, s *Schema, settings dinos
}
params = append(params, p)
}
case *sqlparser.FuncExpr:
name, raw, err := matchFuncExpr(v)

if err != nil {
return nil, err
}
if name == "" || raw == "" {
continue
}
colName := cols[colIx].String()
colDfn, err := s.schemaLookup(tableName, colName)
p := &Param{
OriginalName: raw,
}
if err == nil {
p.Name = name
p.Typ = goTypeCol(colDfn, settings)
} else {
p.Name = "Unknown"
p.Typ = "interface{}"
}
params = append(params, p)
default:
panic("Error occurred in parsing INSERT statement")
return nil, fmt.Errorf("failed to parse insert query value")
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions internal/mysql/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Schema struct {
// returns a deep copy of the column definition for using as a query return type or param type
func (s *Schema) getColType(col *sqlparser.ColName, tableAliasMap FromTables, defaultTableName string) (*sqlparser.ColumnDefinition, error) {
realTable, err := tableColReferences(col, defaultTableName, tableAliasMap)
if err != nil {
return nil, err
}

colDfn, err := s.schemaLookup(realTable.TrueName, col.Name.String())
if err != nil {
Expand All @@ -38,13 +41,13 @@ func tableColReferences(col *sqlparser.ColName, defaultTable string, tableAliasM
var table FromTable
if col.Qualifier.IsEmpty() {
if defaultTable == "" {
return FromTable{}, fmt.Errorf("Column reference [%v] is ambiguous -- Add a qualifier", col.Name.String())
return FromTable{}, fmt.Errorf("column reference \"%s\" is ambiguous, add a qualifier", col.Name.String())
}
table = FromTable{defaultTable, false}
} else {
fromTable, ok := tableAliasMap[col.Qualifier.Name.String()]
if !ok {
return FromTable{}, fmt.Errorf("Column qualifier [%v] not found in table alias map", col.Qualifier.Name.String())
return FromTable{}, fmt.Errorf("column qualifier \"%s\" is not in schema or is an invalid alias", col.Qualifier.Name.String())
}
return fromTable, nil
}
Expand Down

0 comments on commit f9d952d

Please sign in to comment.