/
slices.go
106 lines (87 loc) · 2.37 KB
/
slices.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package query
import (
"context"
"database/sql"
"fmt"
"strings"
)
// SelectStrings executes a statement which must yield rows with a single string
// column. It returns the list of column values.
func SelectStrings(ctx context.Context, tx *sql.Tx, query string, args ...any) ([]string, error) {
values := []string{}
scan := func(rows *sql.Rows) error {
var value string
err := rows.Scan(&value)
if err != nil {
return err
}
values = append(values, value)
return nil
}
err := scanSingleColumn(ctx, tx, query, args, "TEXT", scan)
if err != nil {
return nil, err
}
return values, nil
}
// SelectIntegers executes a statement which must yield rows with a single integer
// column. It returns the list of column values.
func SelectIntegers(ctx context.Context, tx *sql.Tx, query string, args ...any) ([]int, error) {
values := []int{}
scan := func(rows *sql.Rows) error {
var value int
err := rows.Scan(&value)
if err != nil {
return err
}
values = append(values, value)
return nil
}
err := scanSingleColumn(ctx, tx, query, args, "INTEGER", scan)
if err != nil {
return nil, err
}
return values, nil
}
// InsertStrings inserts a new row for each of the given strings, using the
// given insert statement template, which must define exactly one insertion
// column and one substitution placeholder for the values. For example:
// InsertStrings(tx, "INSERT INTO foo(name) VALUES %s", []string{"bar"}).
func InsertStrings(tx *sql.Tx, stmt string, values []string) error {
n := len(values)
if n == 0 {
return nil
}
params := make([]string, n)
args := make([]any, n)
for i, value := range values {
params[i] = "(?)"
args[i] = value
}
stmt = fmt.Sprintf(stmt, strings.Join(params, ", "))
_, err := tx.Exec(stmt, args...)
return err
}
// Execute the given query and ensure that it yields rows with a single column
// of the given database type. For every row yielded, execute the given
// scanner.
func scanSingleColumn(ctx context.Context, tx *sql.Tx, query string, args []any, typeName string, scan scanFunc) error {
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
err := scan(rows)
if err != nil {
return err
}
}
err = rows.Err()
if err != nil {
return err
}
return nil
}
// Function to scan a single row.
type scanFunc func(*sql.Rows) error