From bc15c9efa724e102c46bd76b207dab1a0f67cd90 Mon Sep 17 00:00:00 2001 From: Nik <73077675+tmzane@users.noreply.github.com> Date: Sun, 6 Apr 2025 00:59:07 +0300 Subject: [PATCH] feat(scanner): support non-struct T for single-column queries --- query.go | 50 +++++++++++++---- query_test.go | 109 ++++++++++++++++++++++++++++++++------ tests/integration_test.go | 29 +++++++++- 3 files changed, 161 insertions(+), 27 deletions(-) diff --git a/query.go b/query.go index 3064191..ce89018 100644 --- a/query.go +++ b/query.go @@ -7,6 +7,7 @@ import ( "iter" "reflect" "sync" + "time" ) type queryer interface { @@ -84,22 +85,33 @@ type scanner interface { } func scan[T any](s scanner, columns []string) (T, error) { - var t T - v := reflect.ValueOf(&t).Elem() - if v.Kind() != reflect.Struct { - panic("queries: T must be a struct") + if len(columns) == 0 { + panic("queries: no columns specified") // valid in PostgreSQL (for some reason). } - indexes := parseStruct(v.Type()) + var t T + v := reflect.ValueOf(&t).Elem() args := make([]any, len(columns)) - for i, column := range columns { - idx, ok := indexes[column] - if !ok { - panic(fmt.Sprintf("queries: no field for column %q", column)) + switch { + case scannable(v): + if len(columns) > 1 { + panic("queries: T must be a struct if len(columns) > 1") + } + args[0] = v.Addr().Interface() + case v.Kind() == reflect.Struct: + indexes := parseStruct(v.Type()) + for i, column := range columns { + idx, ok := indexes[column] + if !ok { + panic(fmt.Sprintf("queries: no field for column %q", column)) + } + args[i] = v.Field(idx).Addr().Interface() } - args[i] = v.Field(idx).Addr().Interface() + default: + panic(fmt.Sprintf("queries: unsupported T %T", t)) } + if err := s.Scan(args...); err != nil { return zero[T](), err } @@ -107,6 +119,24 @@ func scan[T any](s scanner, columns []string) (T, error) { return t, nil } +func scannable(v reflect.Value) bool { + switch v.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String: + return true + } + if v.Type() == reflect.TypeFor[time.Time]() { + return true + } + if v.Addr().Type().Implements(reflect.TypeFor[sql.Scanner]()) { + return true + } + return false +} + var cache sync.Map // map[reflect.Type]map[string]int // parseStruct parses the given struct type and returns a map of column names to field indexes. diff --git a/query_test.go b/query_test.go index ae00253..7b1a76e 100644 --- a/query_test.go +++ b/query_test.go @@ -1,62 +1,133 @@ package queries import ( + "database/sql" "errors" "reflect" "testing" + "time" "go-simpler.org/queries/internal/assert" . "go-simpler.org/queries/internal/assert/EF" ) func Test_scan(t *testing.T) { - t.Run("non-struct T", func(t *testing.T) { + t.Run("no columns", func(t *testing.T) { fn := func() { _, _ = scan[int](nil, nil) } - assert.Panics[E](t, fn, "queries: T must be a struct") + assert.Panics[E](t, fn, "queries: no columns specified") + }) + + t.Run("unsupported T", func(t *testing.T) { + columns := []string{"foo", "bar"} + + fn := func() { _, _ = scan[complex64](nil, columns) } + assert.Panics[E](t, fn, "queries: unsupported T complex64") + }) + + t.Run("non-struct T with len(columns) > 1", func(t *testing.T) { + columns := []string{"foo", "bar"} + + fn := func() { _, _ = scan[int](nil, columns) } + assert.Panics[E](t, fn, "queries: T must be a struct if len(columns) > 1") }) t.Run("empty tag", func(t *testing.T) { + columns := []string{"foo", "bar"} + type row struct { - Foo int `sql:""` + Foo int `sql:"foo"` + Bar string `sql:""` } - fn := func() { _, _ = scan[row](nil, nil) } - assert.Panics[E](t, fn, "queries: field Foo has an empty `sql` tag") + fn := func() { _, _ = scan[row](nil, columns) } + assert.Panics[E](t, fn, "queries: field Bar has an empty `sql` tag") }) t.Run("missing field", func(t *testing.T) { + columns := []string{"foo", "bar"} + type row struct { Foo int `sql:"foo"` Bar string } - fn := func() { _, _ = scan[row](nil, []string{"foo", "bar"}) } + fn := func() { _, _ = scan[row](nil, columns) } assert.Panics[E](t, fn, `queries: no field for column "bar"`) }) t.Run("scan error", func(t *testing.T) { - columns := []string{"foo"} + columns := []string{"foo", "bar"} s := mockScanner{err: errors.New("an error")} type row struct { - Foo int `sql:"foo"` + Foo int `sql:"foo"` + Bar string `sql:"bar"` } _, err := scan[row](&s, columns) assert.IsErr[E](t, err, s.err) }) - t.Run("ok", func(t *testing.T) { + t.Run("struct T", func(t *testing.T) { columns := []string{"foo", "bar"} - s := mockScanner{values: []any{1, "A"}} + s := mockScanner{values: []any{1, "test"}} type row struct { Foo int `sql:"foo"` Bar string `sql:"bar"` unexported bool } - r, err := scan[row](&s, columns) + v, err := scan[row](&s, columns) + assert.NoErr[F](t, err) + assert.Equal[E](t, v.Foo, 1) + assert.Equal[E](t, v.Bar, "test") + assert.Equal[E](t, v.unexported, false) + }) + + t.Run("struct T with len(columns) == 1", func(t *testing.T) { + columns := []string{"foo"} + s := mockScanner{values: []any{1}} + + type row struct { + Foo int `sql:"foo"` + } + v, err := scan[row](&s, columns) + assert.NoErr[F](t, err) + assert.Equal[E](t, v.Foo, 1) + }) + + t.Run("non-struct T with len(columns) == 1", func(t *testing.T) { + columns := []string{"foo"} + + tests := []struct { + scan func(scanner) (any, error) + value any + }{ + {func(s scanner) (any, error) { return scan[bool](s, columns) }, true}, + {func(s scanner) (any, error) { return scan[int](s, columns) }, int(-1)}, + {func(s scanner) (any, error) { return scan[int8](s, columns) }, int8(-8)}, + {func(s scanner) (any, error) { return scan[int16](s, columns) }, int16(-16)}, + {func(s scanner) (any, error) { return scan[int32](s, columns) }, int32(-32)}, + {func(s scanner) (any, error) { return scan[int64](s, columns) }, int64(-64)}, + {func(s scanner) (any, error) { return scan[uint](s, columns) }, uint(1)}, + {func(s scanner) (any, error) { return scan[uint8](s, columns) }, uint8(8)}, + {func(s scanner) (any, error) { return scan[uint16](s, columns) }, uint16(16)}, + {func(s scanner) (any, error) { return scan[uint32](s, columns) }, uint32(32)}, + {func(s scanner) (any, error) { return scan[uint64](s, columns) }, uint64(64)}, + {func(s scanner) (any, error) { return scan[float32](s, columns) }, float32(0.32)}, + {func(s scanner) (any, error) { return scan[float64](s, columns) }, float64(0.64)}, + {func(s scanner) (any, error) { return scan[string](s, columns) }, "test"}, + {func(s scanner) (any, error) { return scan[time.Time](s, columns) }, time.Now()}, + } + for _, tt := range tests { + s := mockScanner{values: []any{tt.value}} + v, err := tt.scan(&s) + assert.NoErr[F](t, err) + assert.Equal[E](t, v, tt.value) + } + + // sql.Scanner implementation: + s := mockScanner{values: []any{"test"}} + v, err := scan[sql.Null[string]](&s, columns) assert.NoErr[F](t, err) - assert.Equal[E](t, r.Foo, 1) - assert.Equal[E](t, r.Bar, "A") - assert.Equal[E](t, r.unexported, false) + assert.Equal[E](t, v, sql.Null[string]{V: "test", Valid: true}) }) } @@ -70,8 +141,14 @@ func (s *mockScanner) Scan(dst ...any) error { return s.err } for i := range dst { - v := reflect.ValueOf(s.values[i]) - reflect.ValueOf(dst[i]).Elem().Set(v) + if sc, ok := dst[i].(sql.Scanner); ok { + if err := sc.Scan(s.values[i]); err != nil { + return err + } + } else { + v := reflect.ValueOf(s.values[i]) + reflect.ValueOf(dst[i]).Elem().Set(v) + } } return nil } diff --git a/tests/integration_test.go b/tests/integration_test.go index 9fd2296..d2ca093 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "iter" "testing" "time" @@ -38,13 +39,18 @@ func TestIntegration(t *testing.T) { ctx := t.Context() for name, database := range DBs { + var execCalls int + var queryCalls int + interceptor := queries.Interceptor{ Driver: database.driver, ExecContext: func(ctx context.Context, query string, args []driver.NamedValue, execer driver.ExecerContext) (driver.Result, error) { + execCalls++ t.Logf("[%s] ExecContext: %s %v", name, query, namedToAny(args)) return execer.ExecContext(ctx, query, args) }, QueryContext: func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error) { + queryCalls++ t.Logf("[%s] QueryContext: %s %v", name, query, namedToAny(args)) return queryer.QueryContext(ctx, query, args) }, @@ -78,9 +84,17 @@ func TestIntegration(t *testing.T) { for _, queryer := range []interface { QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) }{db, tx} { - _, err := queries.QueryRow[User](ctx, queryer, "SELECT id, name, created_at FROM users WHERE id = 0") + _, err := queries.QueryRow[string](ctx, queryer, "SELECT name FROM users WHERE id = 0") assert.IsErr[E](t, err, sql.ErrNoRows) + name, err := queries.QueryRow[string](ctx, queryer, "SELECT name FROM users WHERE id = 1") + assert.NoErr[F](t, err) + assert.Equal[E](t, name, TableUsers[0].Name) + + names, err := collect(queries.Query[string](ctx, queryer, "SELECT name FROM users")) + assert.NoErr[F](t, err) + assert.Equal[E](t, names, []string{TableUsers[0].Name, TableUsers[1].Name, TableUsers[2].Name}) + user, err := queries.QueryRow[User](ctx, queryer, "SELECT id, name, created_at FROM users WHERE id = 1") assert.NoErr[F](t, err) assert.Equal[E](t, user.ID, TableUsers[0].ID) @@ -96,6 +110,8 @@ func TestIntegration(t *testing.T) { } assert.NoErr[F](t, tx.Commit()) + assert.Equal[E](t, execCalls, 2) + assert.Equal[E](t, queryCalls, 5*2) } } @@ -107,6 +123,17 @@ func namedToAny(values []driver.NamedValue) []any { return args } +func collect[T any](seq iter.Seq2[T, error]) ([]T, error) { + var ts []T + for t, err := range seq { + if err != nil { + return nil, err + } + ts = append(ts, t) + } + return ts, nil +} + func migrate(ctx context.Context, db *sql.DB) error { type migration struct { query string