From 0b33659eba8809569f7aadbfc7d7cfdec9a81143 Mon Sep 17 00:00:00 2001 From: Nik <73077675+tmzane@users.noreply.github.com> Date: Sat, 12 Apr 2025 17:38:22 +0300 Subject: [PATCH 1/2] feat(scanner): add docs for Query/QueryRow, implement Collect --- builder.go | 2 +- query.go | 55 +++++++++++++++++++++++++++++++++++---- query_test.go | 23 ++++++++++++++++ tests/integration_test.go | 14 +--------- 4 files changed, 75 insertions(+), 19 deletions(-) diff --git a/builder.go b/builder.go index dd00e7d..dd01f38 100644 --- a/builder.go +++ b/builder.go @@ -25,7 +25,7 @@ type Builder struct { // IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs. // Always test your queries. // -// Placeholder verbs to database placeholders: +// Placeholder verbs map to the following database placeholders: // - MySQL, SQLite: %? -> ? // - PostgreSQL: %$ -> $N // - MSSQL: %@ -> @pN diff --git a/query.go b/query.go index ce89018..80e5c0a 100644 --- a/query.go +++ b/query.go @@ -10,12 +10,40 @@ import ( "time" ) -type queryer interface { +// Queryer is an interface implemented by [sql.DB] and [sql.Tx]. +type Queryer interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } -// TODO: document me. -func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] { +// Query executes a query that returns rows, scans each row into a T, and returns an iterator over the Ts. +// If an error occurs, the iterator yields it as the second value, and the caller should then stop the iteration. +// [Queryer] can be either [sql.DB] or [sql.Tx], the rest of the arguments are passed directly to [Queryer.QueryContext]. +// Query fully manages the lifecycle of the [sql.Rows] returned by [Queryer.QueryContext], so the caller does not have to. +// +// The following Ts are supported: +// - int (any kind) +// - uint (any kind) +// - float (any kind) +// - bool +// - string +// - time.Time +// - [sql.Scanner] (implemented by [sql.Null] types) +// - any struct +// +// See the [sql.Rows.Scan] documentation for the scanning rules. +// If the query has multiple columns, T must be a struct, other types can only be used for single-column queries. +// The fields of a struct T must have the `sql:"COLUMN"` tag, where COLUMN is the name of the corresponding column in the query. +// Unexported and untagged fields are ignored. +// +// Query panics if: +// - The query has no columns. +// - A non-struct T is specified with a multi-column query. +// - The specified struct T has no field for one of the query columns. +// - An unsupported T is specified. +// - One of the fields in a struct T has an empty `sql` tag. +// +// If the caller prefers the result to be a slice rather than an iterator, Query can be combined with [Collect]. +func Query[T any](ctx context.Context, q Queryer, query string, args ...any) iter.Seq2[T, error] { return func(yield func(T, error) bool) { rows, err := q.QueryContext(ctx, query, args...) if err != nil { @@ -47,8 +75,12 @@ func Query[T any](ctx context.Context, q queryer, query string, args ...any) ite } } -// TODO: document me. -func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) { +// QueryRow is a [Query] variant for queries that are expected to return at most one row, +// so instead of an iterator, it returns a single T. +// Like [sql.DB.QueryRowContext], QueryRow returns [sql.ErrNoRows] if the query selects no rows, +// otherwise it scans the first row and discards the rest. +// See the [Query] documentation for details on supported Ts. +func QueryRow[T any](ctx context.Context, q Queryer, query string, args ...any) (T, error) { rows, err := q.QueryContext(ctx, query, args...) if err != nil { return zero[T](), err @@ -78,6 +110,19 @@ func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) return t, nil } +// Collect is a [slices.Collect] variant that collects values from an iter.Seq2[T, error]. +// If an error occurs during the collection, Collect stops the iteration and returns the error. +func Collect[T any](seq iter.Seq2[T, error]) ([]T, error) { + var ts []T + for t, err := range seq { + if err != nil { + return nil, err + } + ts = append(ts, t) + } + return ts, nil +} + func zero[T any]() (t T) { return t } type scanner interface { diff --git a/query_test.go b/query_test.go index 7b1a76e..9d964ba 100644 --- a/query_test.go +++ b/query_test.go @@ -3,7 +3,9 @@ package queries import ( "database/sql" "errors" + "iter" "reflect" + "slices" "testing" "time" @@ -11,6 +13,27 @@ import ( . "go-simpler.org/queries/internal/assert/EF" ) +func TestCollect(t *testing.T) { + anErr := errors.New("an error") + + tests := map[string]struct { + seq iter.Seq2[int, error] + want []int + wantErr error + }{ + "no error": {slices.All([]error{nil, nil}), []int{0, 1}, nil}, + "an error": {slices.All([]error{nil, anErr}), nil, anErr}, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + got, err := Collect(tt.seq) + assert.IsErr[F](t, err, tt.wantErr) + assert.Equal[E](t, got, tt.want) + }) + } +} + func Test_scan(t *testing.T) { t.Run("no columns", func(t *testing.T) { fn := func() { _, _ = scan[int](nil, nil) } diff --git a/tests/integration_test.go b/tests/integration_test.go index d2ca093..2c6ddeb 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "database/sql/driver" - "iter" "testing" "time" @@ -91,7 +90,7 @@ func TestIntegration(t *testing.T) { assert.NoErr[F](t, err) assert.Equal[E](t, name, TableUsers[0].Name) - names, err := collect(queries.Query[string](ctx, queryer, "SELECT name FROM users")) + names, err := queries.Collect(queries.Query[string](ctx, queryer, "SELECT name FROM users")) assert.NoErr[F](t, err) assert.Equal[E](t, names, []string{TableUsers[0].Name, TableUsers[1].Name, TableUsers[2].Name}) @@ -123,17 +122,6 @@ func namedToAny(values []driver.NamedValue) []any { return args } -func collect[T any](seq iter.Seq2[T, error]) ([]T, error) { - var ts []T - for t, err := range seq { - if err != nil { - return nil, err - } - ts = append(ts, t) - } - return ts, nil -} - func migrate(ctx context.Context, db *sql.DB) error { type migration struct { query string From 079dca42fde81af7fdc5258456c005cf71047421 Mon Sep 17 00:00:00 2001 From: Nik <73077675+tmzane@users.noreply.github.com> Date: Sat, 12 Apr 2025 19:05:22 +0300 Subject: [PATCH 2/2] update README --- README.md | 13 +++++++++++-- query.go | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fa32155..852ecfd 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,18 @@ type User struct { Name string `sql:"name"` } +// single column, single row: +name, _ := queries.QueryRow[string](ctx, db, "SELECT name FROM users WHERE id = 1") + +// single column, multiple rows: +names, _ := queries.Collect(queries.Query[string](ctx, db, "SELECT name FROM users")) + +// multiple columns, single row: +user, _ := queries.QueryRow[User](ctx, db, "SELECT id, name FROM users WHERE id = 1") + +// multiple columns, multiple rows: for user, _ := range queries.Query[User](ctx, db, "SELECT id, name FROM users") { - // user.ID, user.Name + // ... } ``` @@ -98,7 +108,6 @@ Integration tests cover the following databases and drivers: ## 🚧 TODOs -- Add missing documentation. - Add more tests for different databases and drivers. See https://go.dev/wiki/SQLDrivers. - Add examples for tested databases and drivers. - Add benchmarks. diff --git a/query.go b/query.go index 80e5c0a..ae17ebc 100644 --- a/query.go +++ b/query.go @@ -77,7 +77,7 @@ func Query[T any](ctx context.Context, q Queryer, query string, args ...any) ite // QueryRow is a [Query] variant for queries that are expected to return at most one row, // so instead of an iterator, it returns a single T. -// Like [sql.DB.QueryRowContext], QueryRow returns [sql.ErrNoRows] if the query selects no rows, +// Like [sql.DB.QueryRow], QueryRow returns [sql.ErrNoRows] if the query selects no rows, // otherwise it scans the first row and discards the rest. // See the [Query] documentation for details on supported Ts. func QueryRow[T any](ctx context.Context, q Queryer, query string, args ...any) (T, error) {