Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package queries

import (
"fmt"
"io"
"reflect"
"strings"
)

// Builder is a raw SQL query builder.
// The zero value is ready to use.
// Do not copy a non-zero Builder.
// Do not reuse a single Builder for multiple queries.
type Builder struct {
query strings.Builder
args []any
Expand All @@ -22,13 +23,15 @@ type Builder struct {
// In addition, Appendf supports %?, %$, and %@ verbs, which are automatically expanded to the query placeholders ?, $N, and @pN,
// where N is the auto-incrementing counter.
// The corresponding arguments can then be accessed with the [Builder.Args] method.
//
// IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs.
// Always test your queries.
//
// Placeholder verbs map to the following database placeholders:
// - MySQL, SQLite: %? -> ?
// - PostgreSQL: %$ -> $N
// - MSSQL: %@ -> @pN
//
// TODO: document slice arguments usage.
func (b *Builder) Appendf(format string, args ...any) {
a := make([]any, len(args))
for i, arg := range args {
Expand All @@ -53,7 +56,7 @@ func (b *Builder) Query() string {
return query
}

// Args returns the argument slice.
// Args returns the query arguments.
func (b *Builder) Args() []any { return b.args }

type argument struct {
Expand All @@ -65,26 +68,65 @@ type argument struct {
func (a argument) Format(s fmt.State, verb rune) {
switch verb {
case '?', '$', '@':
a.builder.args = append(a.builder.args, a.value)
if a.builder.placeholder == 0 {
a.builder.placeholder = verb
}
if a.builder.placeholder != verb {
a.builder.placeholder = -1
}
default:
format := fmt.FormatString(s, verb)
fmt.Fprintf(s, format, a.value)
return
}

if s.Flag('+') {
a.writeSlice(s, verb)
} else {
a.writePlaceholder(s, verb)
a.builder.args = append(a.builder.args, a.value)
}
}

func (a argument) writePlaceholder(w io.Writer, verb rune) {
switch verb {
case '?': // MySQL, SQLite
fmt.Fprint(s, "?")
fmt.Fprint(w, "?")
case '$': // PostgreSQL
a.builder.counter++
fmt.Fprintf(s, "$%d", a.builder.counter)
fmt.Fprintf(w, "$%d", a.builder.counter)
case '@': // MSSQL
a.builder.counter++
fmt.Fprintf(s, "@p%d", a.builder.counter)
default:
format := fmt.FormatString(s, verb)
fmt.Fprintf(s, format, a.value)
fmt.Fprintf(w, "@p%d", a.builder.counter)
}
}

func (a argument) writeSlice(w io.Writer, verb rune) {
slice := reflect.ValueOf(a.value)
if slice.Kind() != reflect.Slice {
panic("queries: %+ argument must be a slice")
}

if slice.Len() == 0 {
// TODO: revisit.
// Unlike other errors produced by Builder,
// which are all the result of a programmer's mistake,
// this one may be caused by user input, so panicking is not an option here.
// "WHERE IN (NULL)" will always result in an empty result set,
// which may be undesirable in some situations.
fmt.Fprint(w, "NULL")
return
}

args := reflect.ValueOf(a.builder.args)

for i := range slice.Len() {
if i > 0 {
fmt.Fprint(w, ", ")
}
a.writePlaceholder(w, verb)
args = reflect.Append(args, slice.Index(i))
}

a.builder.args = args.Interface().([]any)
}
40 changes: 31 additions & 9 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestBuilder(t *testing.T) {
assert.Equal[E](t, qb.Args(), []any{42, "test", false})
}

func TestBuilder_placeholders(t *testing.T) {
func TestBuilder_dialects(t *testing.T) {
tests := map[string]struct {
format string
query string
Expand Down Expand Up @@ -50,42 +50,64 @@ func TestBuilder_placeholders(t *testing.T) {
}
}

func TestBuilder_sliceArgument(t *testing.T) {
t.Run("ok", func(t *testing.T) {
var qb queries.Builder
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{1, 2, 3})
assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN ($1, $2, $3)")
assert.Equal[E](t, qb.Args(), []any{1, 2, 3})
})

t.Run("empty", func(t *testing.T) {
var qb queries.Builder
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{})
assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN (NULL)")
assert.Equal[E](t, len(qb.Args()), 0)
})
}

func TestBuilder_badQuery(t *testing.T) {
tests := map[string]struct {
appends func(*queries.Builder)
appendf func(*queries.Builder)
panicMsg string
}{
"bad verb": {
appends: func(qb *queries.Builder) {
"wrong verb": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %d FROM tbl", "foo")
},
panicMsg: "queries: bad query: SELECT %!d(string=foo) FROM tbl",
},
"too few arguments": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %s FROM tbl")
},
panicMsg: "queries: bad query: SELECT %!s(MISSING) FROM tbl",
},
"too many arguments": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %s FROM tbl", "foo", "bar")
},
panicMsg: "queries: bad query: SELECT foo FROM tbl%!(EXTRA queries.argument=bar)",
},
"different placeholders": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT * FROM tbl WHERE foo = %? AND bar = %$ AND baz = %@", 1, 2, 3)
},
panicMsg: "queries: different placeholders used",
},
"non-slice argument": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", 1)
},
panicMsg: "queries: bad query: SELECT * FROM tbl WHERE foo IN (%!$(PANIC=Format method: queries: %+ argument must be a slice))",
},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var qb queries.Builder
tt.appends(&qb)
assert.Panics[E](t, func() { _ = qb.Query() }, tt.panicMsg)
tt.appendf(&qb)
assert.Panics[E](t, func() { qb.Query() }, tt.panicMsg)
})
}
}
Loading