Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,18 @@ type User struct {
Name string `sql:"name"`
}

// single column, single row:
name, _ := queries.QueryRow[string](ctx, db, "SELECT name FROM users WHERE id = 1")

// single column, multiple rows:
names, _ := queries.Collect(queries.Query[string](ctx, db, "SELECT name FROM users"))

// multiple columns, single row:
user, _ := queries.QueryRow[User](ctx, db, "SELECT id, name FROM users WHERE id = 1")

// multiple columns, multiple rows:
for user, _ := range queries.Query[User](ctx, db, "SELECT id, name FROM users") {
// user.ID, user.Name
// ...
}
```

Expand Down Expand Up @@ -98,7 +108,6 @@ Integration tests cover the following databases and drivers:

## 🚧 TODOs

- Add missing documentation.
- Add more tests for different databases and drivers. See https://go.dev/wiki/SQLDrivers.
- Add examples for tested databases and drivers.
- Add benchmarks.
Expand Down
2 changes: 1 addition & 1 deletion builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Builder struct {
// IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs.
// Always test your queries.
//
// Placeholder verbs to database placeholders:
// Placeholder verbs map to the following database placeholders:
// - MySQL, SQLite: %? -> ?
// - PostgreSQL: %$ -> $N
// - MSSQL: %@ -> @pN
Expand Down
55 changes: 50 additions & 5 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,40 @@ import (
"time"
)

type queryer interface {
// Queryer is an interface implemented by [sql.DB] and [sql.Tx].
type Queryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}

// TODO: document me.
func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] {
// Query executes a query that returns rows, scans each row into a T, and returns an iterator over the Ts.
// If an error occurs, the iterator yields it as the second value, and the caller should then stop the iteration.
// [Queryer] can be either [sql.DB] or [sql.Tx], the rest of the arguments are passed directly to [Queryer.QueryContext].
// Query fully manages the lifecycle of the [sql.Rows] returned by [Queryer.QueryContext], so the caller does not have to.
//
// The following Ts are supported:
// - int (any kind)
// - uint (any kind)
// - float (any kind)
// - bool
// - string
// - time.Time
// - [sql.Scanner] (implemented by [sql.Null] types)
// - any struct
//
// See the [sql.Rows.Scan] documentation for the scanning rules.
// If the query has multiple columns, T must be a struct, other types can only be used for single-column queries.
// The fields of a struct T must have the `sql:"COLUMN"` tag, where COLUMN is the name of the corresponding column in the query.
// Unexported and untagged fields are ignored.
//
// Query panics if:
// - The query has no columns.
// - A non-struct T is specified with a multi-column query.
// - The specified struct T has no field for one of the query columns.
// - An unsupported T is specified.
// - One of the fields in a struct T has an empty `sql` tag.
//
// If the caller prefers the result to be a slice rather than an iterator, Query can be combined with [Collect].
func Query[T any](ctx context.Context, q Queryer, query string, args ...any) iter.Seq2[T, error] {
return func(yield func(T, error) bool) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
Expand Down Expand Up @@ -47,8 +75,12 @@ func Query[T any](ctx context.Context, q queryer, query string, args ...any) ite
}
}

// TODO: document me.
func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) {
// QueryRow is a [Query] variant for queries that are expected to return at most one row,
// so instead of an iterator, it returns a single T.
// Like [sql.DB.QueryRow], QueryRow returns [sql.ErrNoRows] if the query selects no rows,
// otherwise it scans the first row and discards the rest.
// See the [Query] documentation for details on supported Ts.
func QueryRow[T any](ctx context.Context, q Queryer, query string, args ...any) (T, error) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return zero[T](), err
Expand Down Expand Up @@ -78,6 +110,19 @@ func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any)
return t, nil
}

// Collect is a [slices.Collect] variant that collects values from an iter.Seq2[T, error].
// If an error occurs during the collection, Collect stops the iteration and returns the error.
func Collect[T any](seq iter.Seq2[T, error]) ([]T, error) {
var ts []T
for t, err := range seq {
if err != nil {
return nil, err
}
ts = append(ts, t)
}
return ts, nil
}

func zero[T any]() (t T) { return t }

type scanner interface {
Expand Down
23 changes: 23 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,37 @@ package queries
import (
"database/sql"
"errors"
"iter"
"reflect"
"slices"
"testing"
"time"

"go-simpler.org/queries/internal/assert"
. "go-simpler.org/queries/internal/assert/EF"
)

func TestCollect(t *testing.T) {
anErr := errors.New("an error")

tests := map[string]struct {
seq iter.Seq2[int, error]
want []int
wantErr error
}{
"no error": {slices.All([]error{nil, nil}), []int{0, 1}, nil},
"an error": {slices.All([]error{nil, anErr}), nil, anErr},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
got, err := Collect(tt.seq)
assert.IsErr[F](t, err, tt.wantErr)
assert.Equal[E](t, got, tt.want)
})
}
}

func Test_scan(t *testing.T) {
t.Run("no columns", func(t *testing.T) {
fn := func() { _, _ = scan[int](nil, nil) }
Expand Down
14 changes: 1 addition & 13 deletions tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"database/sql/driver"
"iter"
"testing"
"time"

Expand Down Expand Up @@ -91,7 +90,7 @@ func TestIntegration(t *testing.T) {
assert.NoErr[F](t, err)
assert.Equal[E](t, name, TableUsers[0].Name)

names, err := collect(queries.Query[string](ctx, queryer, "SELECT name FROM users"))
names, err := queries.Collect(queries.Query[string](ctx, queryer, "SELECT name FROM users"))
assert.NoErr[F](t, err)
assert.Equal[E](t, names, []string{TableUsers[0].Name, TableUsers[1].Name, TableUsers[2].Name})

Expand Down Expand Up @@ -123,17 +122,6 @@ func namedToAny(values []driver.NamedValue) []any {
return args
}

func collect[T any](seq iter.Seq2[T, error]) ([]T, error) {
var ts []T
for t, err := range seq {
if err != nil {
return nil, err
}
ts = append(ts, t)
}
return ts, nil
}

func migrate(ctx context.Context, db *sql.DB) error {
type migration struct {
query string
Expand Down
Loading