-
Notifications
You must be signed in to change notification settings - Fork 0
/
query.go
211 lines (193 loc) · 6.91 KB
/
query.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
package db
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/domonda/go-sqldb"
)
// Now returns the result of the SQL now()
// function for the current connection.
// Useful for getting the timestamp of a
// SQL transaction for use in Go code.
func Now(ctx context.Context) (time.Time, error) {
return Conn(ctx).Now()
}
// Exec executes a query with optional args.
func Exec(ctx context.Context, query string, args ...any) error {
return Conn(ctx).Exec(query, args...)
}
// QueryRow queries a single row and returns a RowScanner for the results.
func QueryRow(ctx context.Context, query string, args ...any) sqldb.RowScanner {
return Conn(ctx).QueryRow(query, args...)
}
// QueryRows queries multiple rows and returns a RowsScanner for the results.
func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner {
return Conn(ctx).QueryRows(query, args...)
}
// QueryValue queries a single value of type T.
func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) {
err = Conn(ctx).QueryRow(query, args...).Scan(&value)
if err != nil {
return *new(T), err
}
return value, nil
}
// QueryValueReplaceErrNoRows queries a single value of type T.
// In case of an sql.ErrNoRows error, errNoRows will be called
// and its result returned together with the default value for T.
func QueryValueReplaceErrNoRows[T any](ctx context.Context, errNoRows func() error, query string, args ...any) (value T, err error) {
err = Conn(ctx).QueryRow(query, args...).Scan(&value)
if err != nil {
if errors.Is(err, sql.ErrNoRows) && errNoRows != nil {
return *new(T), errNoRows()
}
return *new(T), err
}
return value, nil
}
// QueryValueOr queries a single value of type T
// or returns the passed defaultValue in case of sql.ErrNoRows.
func QueryValueOr[T any](ctx context.Context, defaultValue T, query string, args ...any) (value T, err error) {
err = Conn(ctx).QueryRow(query, args...).Scan(&value)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return defaultValue, nil
}
return *new(T), err
}
return value, err
}
// QueryRowStruct queries a row and scans it as struct.
func QueryRowStruct[S any](ctx context.Context, query string, args ...any) (row *S, err error) {
err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row)
if err != nil {
return nil, err
}
return row, nil
}
// QueryRowStructReplaceErrNoRows queries a row and scans it as struct.
// In case of an sql.ErrNoRows error, errNoRows will be called
// and its result returned as error together with nil as row.
func QueryRowStructReplaceErrNoRows[S any](ctx context.Context, errNoRows func() error, query string, args ...any) (row *S, err error) {
err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) && errNoRows != nil {
return nil, errNoRows()
}
return nil, err
}
return row, nil
}
// QueryRowStructOrNil queries a row and scans it as struct
// or returns nil in case of sql.ErrNoRows.
func QueryRowStructOrNil[S any](ctx context.Context, query string, args ...any) (row *S, err error) {
err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
return row, nil
}
// GetRow uses the passed pkValue+pkValues to query a table row
// and scan it into a struct of type S that must have tagged fields
// with primary key flags to identify the primary key column names
// for the passed pkValue+pkValues and a table name.
func GetRow[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) {
// Using explicit first pkValue value
// to not be able to compile without any value
pkValues = append([]any{pkValue}, pkValues...)
t := reflect.TypeOf(row).Elem()
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("expected struct template type instead of %s", t)
}
conn := Conn(ctx)
table, pkColumns, err := pkColumnsOfStruct(conn, t)
if err != nil {
return nil, err
}
if len(pkColumns) != len(pkValues) {
return nil, fmt.Errorf("got %d primary key values, but struct %s has %d primary key fields", len(pkValues), t, len(pkColumns))
}
var query strings.Builder
fmt.Fprintf(&query, `SELECT * FROM %s WHERE "%s" = $1`, table, pkColumns[0]) //#nosec G104
for i := 1; i < len(pkColumns); i++ {
fmt.Fprintf(&query, ` AND "%s" = $%d`, pkColumns[i], i+1) //#nosec G104
}
err = conn.QueryRow(query.String(), pkValues...).ScanStruct(&row)
if err != nil {
return nil, err
}
return row, nil
}
// GetRowOrNil uses the passed pkValue+pkValues to query a table row
// and scan it into a struct of type S that must have tagged fields
// with primary key flags to identify the primary key column names
// for the passed pkValue+pkValues and a table name.
// Returns nil as row and error if no row could be found with the
// passed pkValue+pkValues.
func GetRowOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) {
row, err = GetRow[S](ctx, pkValue, pkValues...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
return row, nil
}
func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) {
mapper := conn.StructFieldMapper()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldTable, column, flags, ok := mapper.MapStructField(field)
if !ok {
continue
}
if fieldTable != "" && fieldTable != table {
if table != "" {
return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t)
}
table = fieldTable
}
if column == "" {
fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type)
if err != nil {
return "", nil, err
}
if fieldTable != "" && fieldTable != table {
if table != "" {
return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t)
}
table = fieldTable
}
columns = append(columns, columnsEmbed...)
} else if flags.PrimaryKey() {
if err = conn.ValidateColumnName(column); err != nil {
return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name)
}
columns = append(columns, column)
}
}
return table, columns, nil
}
// QueryStructSlice returns queried rows as slice of the generic type S
// which must be a struct or a pointer to a struct.
func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) {
err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows)
if err != nil {
return nil, err
}
return rows, nil
}
// InsertStruct inserts a new row into table using the connection's
// StructFieldMapper to map struct fields to column names.
// Optional ColumnFilter can be passed to ignore mapped columns.
func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error {
return Conn(ctx).InsertStruct(table, rowStruct, ignoreColumns...)
}