diff --git a/builder.go b/builder.go index dd01f38..6db2bae 100644 --- a/builder.go +++ b/builder.go @@ -3,13 +3,14 @@ package queries import ( "fmt" + "io" + "reflect" "strings" ) // Builder is a raw SQL query builder. // The zero value is ready to use. // Do not copy a non-zero Builder. -// Do not reuse a single Builder for multiple queries. type Builder struct { query strings.Builder args []any @@ -22,13 +23,15 @@ type Builder struct { // In addition, Appendf supports %?, %$, and %@ verbs, which are automatically expanded to the query placeholders ?, $N, and @pN, // where N is the auto-incrementing counter. // The corresponding arguments can then be accessed with the [Builder.Args] method. +// // IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs. -// Always test your queries. // // Placeholder verbs map to the following database placeholders: // - MySQL, SQLite: %? -> ? // - PostgreSQL: %$ -> $N // - MSSQL: %@ -> @pN +// +// TODO: document slice arguments usage. func (b *Builder) Appendf(format string, args ...any) { a := make([]any, len(args)) for i, arg := range args { @@ -53,7 +56,7 @@ func (b *Builder) Query() string { return query } -// Args returns the argument slice. +// Args returns the query arguments. func (b *Builder) Args() []any { return b.args } type argument struct { @@ -65,26 +68,65 @@ type argument struct { func (a argument) Format(s fmt.State, verb rune) { switch verb { case '?', '$', '@': - a.builder.args = append(a.builder.args, a.value) if a.builder.placeholder == 0 { a.builder.placeholder = verb } if a.builder.placeholder != verb { a.builder.placeholder = -1 } + default: + format := fmt.FormatString(s, verb) + fmt.Fprintf(s, format, a.value) + return } + if s.Flag('+') { + a.writeSlice(s, verb) + } else { + a.writePlaceholder(s, verb) + a.builder.args = append(a.builder.args, a.value) + } +} + +func (a argument) writePlaceholder(w io.Writer, verb rune) { switch verb { case '?': // MySQL, SQLite - fmt.Fprint(s, "?") + fmt.Fprint(w, "?") case '$': // PostgreSQL a.builder.counter++ - fmt.Fprintf(s, "$%d", a.builder.counter) + fmt.Fprintf(w, "$%d", a.builder.counter) case '@': // MSSQL a.builder.counter++ - fmt.Fprintf(s, "@p%d", a.builder.counter) - default: - format := fmt.FormatString(s, verb) - fmt.Fprintf(s, format, a.value) + fmt.Fprintf(w, "@p%d", a.builder.counter) + } +} + +func (a argument) writeSlice(w io.Writer, verb rune) { + slice := reflect.ValueOf(a.value) + if slice.Kind() != reflect.Slice { + panic("queries: %+ argument must be a slice") + } + + if slice.Len() == 0 { + // TODO: revisit. + // Unlike other errors produced by Builder, + // which are all the result of a programmer's mistake, + // this one may be caused by user input, so panicking is not an option here. + // "WHERE IN (NULL)" will always result in an empty result set, + // which may be undesirable in some situations. + fmt.Fprint(w, "NULL") + return } + + args := reflect.ValueOf(a.builder.args) + + for i := range slice.Len() { + if i > 0 { + fmt.Fprint(w, ", ") + } + a.writePlaceholder(w, verb) + args = reflect.Append(args, slice.Index(i)) + } + + a.builder.args = args.Interface().([]any) } diff --git a/builder_test.go b/builder_test.go index 4ded17e..85e8366 100644 --- a/builder_test.go +++ b/builder_test.go @@ -21,7 +21,7 @@ func TestBuilder(t *testing.T) { assert.Equal[E](t, qb.Args(), []any{42, "test", false}) } -func TestBuilder_placeholders(t *testing.T) { +func TestBuilder_dialects(t *testing.T) { tests := map[string]struct { format string query string @@ -50,42 +50,64 @@ func TestBuilder_placeholders(t *testing.T) { } } +func TestBuilder_sliceArgument(t *testing.T) { + t.Run("ok", func(t *testing.T) { + var qb queries.Builder + qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{1, 2, 3}) + assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN ($1, $2, $3)") + assert.Equal[E](t, qb.Args(), []any{1, 2, 3}) + }) + + t.Run("empty", func(t *testing.T) { + var qb queries.Builder + qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{}) + assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN (NULL)") + assert.Equal[E](t, len(qb.Args()), 0) + }) +} + func TestBuilder_badQuery(t *testing.T) { tests := map[string]struct { - appends func(*queries.Builder) + appendf func(*queries.Builder) panicMsg string }{ - "bad verb": { - appends: func(qb *queries.Builder) { + "wrong verb": { + appendf: func(qb *queries.Builder) { qb.Appendf("SELECT %d FROM tbl", "foo") }, panicMsg: "queries: bad query: SELECT %!d(string=foo) FROM tbl", }, "too few arguments": { - appends: func(qb *queries.Builder) { + appendf: func(qb *queries.Builder) { qb.Appendf("SELECT %s FROM tbl") }, panicMsg: "queries: bad query: SELECT %!s(MISSING) FROM tbl", }, "too many arguments": { - appends: func(qb *queries.Builder) { + appendf: func(qb *queries.Builder) { qb.Appendf("SELECT %s FROM tbl", "foo", "bar") }, panicMsg: "queries: bad query: SELECT foo FROM tbl%!(EXTRA queries.argument=bar)", }, "different placeholders": { - appends: func(qb *queries.Builder) { + appendf: func(qb *queries.Builder) { qb.Appendf("SELECT * FROM tbl WHERE foo = %? AND bar = %$ AND baz = %@", 1, 2, 3) }, panicMsg: "queries: different placeholders used", }, + "non-slice argument": { + appendf: func(qb *queries.Builder) { + qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", 1) + }, + panicMsg: "queries: bad query: SELECT * FROM tbl WHERE foo IN (%!$(PANIC=Format method: queries: %+ argument must be a slice))", + }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { var qb queries.Builder - tt.appends(&qb) - assert.Panics[E](t, func() { _ = qb.Query() }, tt.panicMsg) + tt.appendf(&qb) + assert.Panics[E](t, func() { qb.Query() }, tt.panicMsg) }) } }