Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Convenience helpers for working with SQL queries.

## 📦 Install

Go 1.23+
Go 1.24+

```shell
go get go-simpler.org/queries
Expand Down
111 changes: 58 additions & 53 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,101 +12,106 @@ import (
// The zero value is ready to use.
// Do not copy a non-zero Builder.
type Builder struct {
// TODO: prealloc?
query strings.Builder
args []any
counter int
placeholder rune
}

// Appendf formats according to the given format and appends the result to the query.
// It works like [fmt.Appendf], i.e. all rules from the [fmt] package are applied.
// In addition, Appendf supports %?, %$, and %@ verbs, which are automatically expanded to the query placeholders ?, $N, and @pN,
// where N is the auto-incrementing counter.
// The corresponding arguments can then be accessed with the [Builder.Args] method.
// It works like [fmt.Appendf], meaning all the rules from the [fmt] package are applied.
// In addition, Appendf supports special verbs that automatically expand to database placeholders.
//
// IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs.
// ---------------------------------------------
// | Database | Verb | Placeholder |
// |----------------------|------|-------------|
// | MySQL, SQLite | %? | ? |
// | PostgreSQL | %$ | $N |
// | Microsoft SQL Server | %@ | @pN |
// ---------------------------------------------
//
// Placeholder verbs map to the following database placeholders:
// - MySQL, SQLite: %? -> ?
// - PostgreSQL: %$ -> $N
// - MSSQL: %@ -> @pN
// Here, N is an auto-incrementing counter.
// For example, "%$, %$, %$" expands to "$1, $2, $3".
//
// TODO: document slice arguments usage.
func (b *Builder) Appendf(format string, args ...any) {
a := make([]any, len(args))
for i, arg := range args {
a[i] = argument{value: arg, builder: b}
// If a special verb includes the "+" flag, it automatically expands to multiple placeholders.
// For example, given the verb "%+?" and the argument []int{1, 2, 3},
// Appendf writes "?, ?, ?" to the query and appends 1, 2, and 3 to the arguments.
// You may want to use this flag to build "WHERE IN (...)" clauses.
//
// Make sure to always pass arguments from user input with placeholder verbs to avoid SQL injections.
func (b *Builder) Appendf(format string, a ...any) {
fs := make([]any, len(a))
for i := range a {
fs[i] = formatter{arg: a[i], builder: b}
}
fmt.Fprintf(&b.query, format, a...)
fmt.Fprintf(&b.query, format, fs...)
}

// Query returns the query string.
func (b *Builder) Query() string { return b.query.String() }
// Build returns the query and its arguments.
func (b *Builder) Build() (query string, args []any) {
return b.query.String(), b.args
}

// Args returns the query arguments.
func (b *Builder) Args() []any { return b.args }
// Build is a shorthand for a new [Builder] + [Builder.Appendf] + [Builder.Build].
func Build(format string, a ...any) (query string, args []any) {
var b Builder
b.Appendf(format, a...)
return b.Build()
}

type argument struct {
value any
type formatter struct {
arg any
builder *Builder
}

// Format implements [fmt.Formatter].
func (a argument) Format(s fmt.State, verb rune) {
func (f formatter) Format(s fmt.State, verb rune) {
switch verb {
case '?', '$', '@':
if a.builder.placeholder == 0 {
a.builder.placeholder = verb
if f.builder.placeholder == 0 {
f.builder.placeholder = verb
}
if a.builder.placeholder != verb {
if f.builder.placeholder != verb {
panic("unexpected placeholder")
}
if s.Flag('+') {
appendAll(s, f.builder, verb, f.arg)
} else {
appendOne(s, f.builder, verb, f.arg)
}
default:
format := fmt.FormatString(s, verb)
fmt.Fprintf(s, format, a.value)
return
}

if s.Flag('+') {
a.writeSlice(s, verb)
} else {
a.writePlaceholder(s, verb)
a.builder.args = append(a.builder.args, a.value)
fmt.Fprintf(s, format, f.arg)
}
}

func (a argument) writePlaceholder(w io.Writer, verb rune) {
func appendOne(w io.Writer, b *Builder, verb rune, arg any) {
switch verb {
case '?': // MySQL, SQLite
case '?':
fmt.Fprint(w, "?")
case '$': // PostgreSQL
a.builder.counter++
fmt.Fprintf(w, "$%d", a.builder.counter)
case '@': // MSSQL
a.builder.counter++
fmt.Fprintf(w, "@p%d", a.builder.counter)
case '$':
b.counter++
fmt.Fprintf(w, "$%d", b.counter)
case '@':
b.counter++
fmt.Fprintf(w, "@p%d", b.counter)
}
b.args = append(b.args, arg)
}

func (a argument) writeSlice(w io.Writer, verb rune) {
slice := reflect.ValueOf(a.value)
func appendAll(w io.Writer, b *Builder, verb rune, arg any) {
slice := reflect.ValueOf(arg)
if slice.Kind() != reflect.Slice {
panic("non-slice argument")
}

if slice.Len() == 0 {
// TODO: revisit.
// "WHERE IN (NULL)" will always result in an empty result set,
// which may be undesirable in some situations.
fmt.Fprint(w, "NULL")
return
panic("zero-length slice argument")
}

for i := range slice.Len() {
if i > 0 {
fmt.Fprint(w, ", ")
}
a.writePlaceholder(w, verb)
a.builder.args = append(a.builder.args, slice.Index(i).Interface())
appendOne(w, b, verb, slice.Index(i).Interface())
}
}
78 changes: 34 additions & 44 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ func TestBuilder(t *testing.T) {
qb.Appendf(" AND bar = %$", "test")
qb.Appendf(" AND baz = %$", false)

assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE 1=1 AND foo = $1 AND bar = $2 AND baz = $3")
assert.Equal[E](t, qb.Args(), []any{42, "test", false})
query, args := qb.Build()
assert.Equal[E](t, query, "SELECT * FROM tbl WHERE 1=1 AND foo = $1 AND bar = $2 AND baz = $3")
assert.Equal[E](t, args, []any{42, "test", false})
}

func TestBuilder_dialects(t *testing.T) {
Expand All @@ -42,72 +43,61 @@ func TestBuilder_dialects(t *testing.T) {

for name, test := range tests {
t.Run(name, func(t *testing.T) {
var qb queries.Builder
qb.Appendf(test.format, 1, 2, 3)
assert.Equal[E](t, qb.Query(), test.query)
assert.Equal[E](t, qb.Args(), []any{1, 2, 3})
query, args := queries.Build(test.format, 1, 2, 3)
assert.Equal[E](t, query, test.query)
assert.Equal[E](t, args, []any{1, 2, 3})
})
}
}

func TestBuilder_sliceArgument(t *testing.T) {
t.Run("ok", func(t *testing.T) {
var qb queries.Builder
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{1, 2, 3})
assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN ($1, $2, $3)")
assert.Equal[E](t, qb.Args(), []any{1, 2, 3})
})

t.Run("empty", func(t *testing.T) {
var qb queries.Builder
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{})
assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN (NULL)")
assert.Equal[E](t, len(qb.Args()), 0)
})
query, args := queries.Build("SELECT * FROM tbl WHERE foo IN (%+$)", []int{1, 2, 3})
assert.Equal[E](t, query, "SELECT * FROM tbl WHERE foo IN ($1, $2, $3)")
assert.Equal[E](t, args, []any{1, 2, 3})
}

func TestBuilder_badQuery(t *testing.T) {
tests := map[string]struct {
appendf func(*queries.Builder)
query string
format string
args []any
query string
}{
"wrong verb": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %d FROM tbl", "foo")
},
query: "SELECT %!d(string=foo) FROM tbl",
format: "SELECT %d FROM tbl",
args: []any{"foo"},
query: "SELECT %!d(string=foo) FROM tbl",
},
"too few arguments": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %s FROM tbl")
},
query: "SELECT %!s(MISSING) FROM tbl",
format: "SELECT %s FROM tbl",
args: []any{},
query: "SELECT %!s(MISSING) FROM tbl",
},
"too many arguments": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %s FROM tbl", "foo", "bar")
},
query: "SELECT foo FROM tbl%!(EXTRA queries.argument=bar)",
format: "SELECT %s FROM tbl",
args: []any{"foo", "bar"},
query: "SELECT foo FROM tbl%!(EXTRA queries.formatter=bar)",
},
"unexpected placeholder": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT * FROM tbl WHERE foo = %? AND bar = %$", 1, 2)
},
query: "SELECT * FROM tbl WHERE foo = ? AND bar = %!$(PANIC=Format method: unexpected placeholder)",
format: "SELECT * FROM tbl WHERE foo = %? AND bar = %$",
args: []any{1, 2},
query: "SELECT * FROM tbl WHERE foo = ? AND bar = %!$(PANIC=Format method: unexpected placeholder)",
},
"non-slice argument": {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", 1)
},
query: "SELECT * FROM tbl WHERE foo IN (%!$(PANIC=Format method: non-slice argument))",
format: "SELECT * FROM tbl WHERE foo IN (%+$)",
args: []any{1},
query: "SELECT * FROM tbl WHERE foo IN (%!$(PANIC=Format method: non-slice argument))",
},
"zero-length slice argument": {
format: "SELECT * FROM tbl WHERE foo IN (%+$)",
args: []any{[]int{}},
query: "SELECT * FROM tbl WHERE foo IN (%!$(PANIC=Format method: zero-length slice argument))",
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
var qb queries.Builder
test.appendf(&qb)
assert.Equal[E](t, qb.Query(), test.query)
query, _ := queries.Build(test.format, test.args...)
assert.Equal[E](t, query, test.query)
})
}
}
5 changes: 2 additions & 3 deletions tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,8 @@ func migrate(ctx context.Context, db *sql.DB) error {
}

for _, m := range migrations {
var qb queries.Builder
qb.Appendf(m.query, m.args...)
if _, err := db.ExecContext(ctx, qb.Query(), qb.Args()...); err != nil {
query, args := queries.Build(m.query, m.args...)
if _, err := db.ExecContext(ctx, query, args...); err != nil {
return err
}
}
Expand Down
Loading