Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"iter"
"reflect"
"sync"
"time"
)

type queryer interface {
Expand Down Expand Up @@ -84,29 +85,58 @@ type scanner interface {
}

func scan[T any](s scanner, columns []string) (T, error) {
var t T
v := reflect.ValueOf(&t).Elem()
if v.Kind() != reflect.Struct {
panic("queries: T must be a struct")
if len(columns) == 0 {
panic("queries: no columns specified") // valid in PostgreSQL (for some reason).
}

indexes := parseStruct(v.Type())
var t T
v := reflect.ValueOf(&t).Elem()
args := make([]any, len(columns))

for i, column := range columns {
idx, ok := indexes[column]
if !ok {
panic(fmt.Sprintf("queries: no field for column %q", column))
switch {
case scannable(v):
if len(columns) > 1 {
panic("queries: T must be a struct if len(columns) > 1")
}
args[0] = v.Addr().Interface()
case v.Kind() == reflect.Struct:
indexes := parseStruct(v.Type())
for i, column := range columns {
idx, ok := indexes[column]
if !ok {
panic(fmt.Sprintf("queries: no field for column %q", column))
}
args[i] = v.Field(idx).Addr().Interface()
}
args[i] = v.Field(idx).Addr().Interface()
default:
panic(fmt.Sprintf("queries: unsupported T %T", t))
}

if err := s.Scan(args...); err != nil {
return zero[T](), err
}

return t, nil
}

func scannable(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
return true
}
if v.Type() == reflect.TypeFor[time.Time]() {
return true
}
if v.Addr().Type().Implements(reflect.TypeFor[sql.Scanner]()) {
return true
}
return false
}

var cache sync.Map // map[reflect.Type]map[string]int

// parseStruct parses the given struct type and returns a map of column names to field indexes.
Expand Down
109 changes: 93 additions & 16 deletions query_test.go
Original file line number Diff line number Diff line change
@@ -1,62 +1,133 @@
package queries

import (
"database/sql"
"errors"
"reflect"
"testing"
"time"

"go-simpler.org/queries/internal/assert"
. "go-simpler.org/queries/internal/assert/EF"
)

func Test_scan(t *testing.T) {
t.Run("non-struct T", func(t *testing.T) {
t.Run("no columns", func(t *testing.T) {
fn := func() { _, _ = scan[int](nil, nil) }
assert.Panics[E](t, fn, "queries: T must be a struct")
assert.Panics[E](t, fn, "queries: no columns specified")
})

t.Run("unsupported T", func(t *testing.T) {
columns := []string{"foo", "bar"}

fn := func() { _, _ = scan[complex64](nil, columns) }
assert.Panics[E](t, fn, "queries: unsupported T complex64")
})

t.Run("non-struct T with len(columns) > 1", func(t *testing.T) {
columns := []string{"foo", "bar"}

fn := func() { _, _ = scan[int](nil, columns) }
assert.Panics[E](t, fn, "queries: T must be a struct if len(columns) > 1")
})

t.Run("empty tag", func(t *testing.T) {
columns := []string{"foo", "bar"}

type row struct {
Foo int `sql:""`
Foo int `sql:"foo"`
Bar string `sql:""`
}
fn := func() { _, _ = scan[row](nil, nil) }
assert.Panics[E](t, fn, "queries: field Foo has an empty `sql` tag")
fn := func() { _, _ = scan[row](nil, columns) }
assert.Panics[E](t, fn, "queries: field Bar has an empty `sql` tag")
})

t.Run("missing field", func(t *testing.T) {
columns := []string{"foo", "bar"}

type row struct {
Foo int `sql:"foo"`
Bar string
}
fn := func() { _, _ = scan[row](nil, []string{"foo", "bar"}) }
fn := func() { _, _ = scan[row](nil, columns) }
assert.Panics[E](t, fn, `queries: no field for column "bar"`)
})

t.Run("scan error", func(t *testing.T) {
columns := []string{"foo"}
columns := []string{"foo", "bar"}
s := mockScanner{err: errors.New("an error")}

type row struct {
Foo int `sql:"foo"`
Foo int `sql:"foo"`
Bar string `sql:"bar"`
}
_, err := scan[row](&s, columns)
assert.IsErr[E](t, err, s.err)
})

t.Run("ok", func(t *testing.T) {
t.Run("struct T", func(t *testing.T) {
columns := []string{"foo", "bar"}
s := mockScanner{values: []any{1, "A"}}
s := mockScanner{values: []any{1, "test"}}

type row struct {
Foo int `sql:"foo"`
Bar string `sql:"bar"`
unexported bool
}
r, err := scan[row](&s, columns)
v, err := scan[row](&s, columns)
assert.NoErr[F](t, err)
assert.Equal[E](t, v.Foo, 1)
assert.Equal[E](t, v.Bar, "test")
assert.Equal[E](t, v.unexported, false)
})

t.Run("struct T with len(columns) == 1", func(t *testing.T) {
columns := []string{"foo"}
s := mockScanner{values: []any{1}}

type row struct {
Foo int `sql:"foo"`
}
v, err := scan[row](&s, columns)
assert.NoErr[F](t, err)
assert.Equal[E](t, v.Foo, 1)
})

t.Run("non-struct T with len(columns) == 1", func(t *testing.T) {
columns := []string{"foo"}

tests := []struct {
scan func(scanner) (any, error)
value any
}{
{func(s scanner) (any, error) { return scan[bool](s, columns) }, true},
{func(s scanner) (any, error) { return scan[int](s, columns) }, int(-1)},
{func(s scanner) (any, error) { return scan[int8](s, columns) }, int8(-8)},
{func(s scanner) (any, error) { return scan[int16](s, columns) }, int16(-16)},
{func(s scanner) (any, error) { return scan[int32](s, columns) }, int32(-32)},
{func(s scanner) (any, error) { return scan[int64](s, columns) }, int64(-64)},
{func(s scanner) (any, error) { return scan[uint](s, columns) }, uint(1)},
{func(s scanner) (any, error) { return scan[uint8](s, columns) }, uint8(8)},
{func(s scanner) (any, error) { return scan[uint16](s, columns) }, uint16(16)},
{func(s scanner) (any, error) { return scan[uint32](s, columns) }, uint32(32)},
{func(s scanner) (any, error) { return scan[uint64](s, columns) }, uint64(64)},
{func(s scanner) (any, error) { return scan[float32](s, columns) }, float32(0.32)},
{func(s scanner) (any, error) { return scan[float64](s, columns) }, float64(0.64)},
{func(s scanner) (any, error) { return scan[string](s, columns) }, "test"},
{func(s scanner) (any, error) { return scan[time.Time](s, columns) }, time.Now()},
}
for _, tt := range tests {
s := mockScanner{values: []any{tt.value}}
v, err := tt.scan(&s)
assert.NoErr[F](t, err)
assert.Equal[E](t, v, tt.value)
}

// sql.Scanner implementation:
s := mockScanner{values: []any{"test"}}
v, err := scan[sql.Null[string]](&s, columns)
assert.NoErr[F](t, err)
assert.Equal[E](t, r.Foo, 1)
assert.Equal[E](t, r.Bar, "A")
assert.Equal[E](t, r.unexported, false)
assert.Equal[E](t, v, sql.Null[string]{V: "test", Valid: true})
})
}

Expand All @@ -70,8 +141,14 @@ func (s *mockScanner) Scan(dst ...any) error {
return s.err
}
for i := range dst {
v := reflect.ValueOf(s.values[i])
reflect.ValueOf(dst[i]).Elem().Set(v)
if sc, ok := dst[i].(sql.Scanner); ok {
if err := sc.Scan(s.values[i]); err != nil {
return err
}
} else {
v := reflect.ValueOf(s.values[i])
reflect.ValueOf(dst[i]).Elem().Set(v)
}
}
return nil
}
29 changes: 28 additions & 1 deletion tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"iter"
"testing"
"time"

Expand Down Expand Up @@ -38,13 +39,18 @@ func TestIntegration(t *testing.T) {
ctx := t.Context()

for name, database := range DBs {
var execCalls int
var queryCalls int

interceptor := queries.Interceptor{
Driver: database.driver,
ExecContext: func(ctx context.Context, query string, args []driver.NamedValue, execer driver.ExecerContext) (driver.Result, error) {
execCalls++
t.Logf("[%s] ExecContext: %s %v", name, query, namedToAny(args))
return execer.ExecContext(ctx, query, args)
},
QueryContext: func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error) {
queryCalls++
t.Logf("[%s] QueryContext: %s %v", name, query, namedToAny(args))
return queryer.QueryContext(ctx, query, args)
},
Expand Down Expand Up @@ -78,9 +84,17 @@ func TestIntegration(t *testing.T) {
for _, queryer := range []interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}{db, tx} {
_, err := queries.QueryRow[User](ctx, queryer, "SELECT id, name, created_at FROM users WHERE id = 0")
_, err := queries.QueryRow[string](ctx, queryer, "SELECT name FROM users WHERE id = 0")
assert.IsErr[E](t, err, sql.ErrNoRows)

name, err := queries.QueryRow[string](ctx, queryer, "SELECT name FROM users WHERE id = 1")
assert.NoErr[F](t, err)
assert.Equal[E](t, name, TableUsers[0].Name)

names, err := collect(queries.Query[string](ctx, queryer, "SELECT name FROM users"))
assert.NoErr[F](t, err)
assert.Equal[E](t, names, []string{TableUsers[0].Name, TableUsers[1].Name, TableUsers[2].Name})

user, err := queries.QueryRow[User](ctx, queryer, "SELECT id, name, created_at FROM users WHERE id = 1")
assert.NoErr[F](t, err)
assert.Equal[E](t, user.ID, TableUsers[0].ID)
Expand All @@ -96,6 +110,8 @@ func TestIntegration(t *testing.T) {
}

assert.NoErr[F](t, tx.Commit())
assert.Equal[E](t, execCalls, 2)
assert.Equal[E](t, queryCalls, 5*2)
}
}

Expand All @@ -107,6 +123,17 @@ func namedToAny(values []driver.NamedValue) []any {
return args
}

func collect[T any](seq iter.Seq2[T, error]) ([]T, error) {
var ts []T
for t, err := range seq {
if err != nil {
return nil, err
}
ts = append(ts, t)
}
return ts, nil
}

func migrate(ctx context.Context, db *sql.DB) error {
type migration struct {
query string
Expand Down
Loading