-
Notifications
You must be signed in to change notification settings - Fork 0
/
parse.go
154 lines (132 loc) · 3.71 KB
/
parse.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package bindvar
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
)
// Named argument prefix syntax used by the std lib.
// https://pkg.go.dev/database/sql#Named
const naPrefix = "@"
type Parser interface {
// Parse parses all named parameters in a SQL statement, and returns
// a statement with the params converted to bindvars appropriate for
// the engine, e.g. :foo, :bar => $1, $2 (postgres).
//
// Additionally, the positional args are returned in order.
Parse(query string, data any) (q string, args []any, err error)
}
// New creates a new parser.
func New(driver string) Parser {
return &parser{driver}
}
type parser struct {
driver string
}
func (p parser) Parse(query string, data any) (string, []any, error) {
// Convert to rune to handle unicode strings
qt := []rune(query)
// Parse named args
q, nvs := parse(p.driver, qt)
args := []any{}
for _, nv := range nvs {
// Get the named arg values from data
v := value(data, nv.Name)
args = append(args, v)
}
return string(q), args, nil
}
// reArgTerm is the terminating character of a named arg.
var reArgTerm = regexp.MustCompile(`[[:space:]]|;|\)|,`)
// parse parses the named args out of a query and returns a string with
// the correct arg syntax for the driver, and a list of arg names.
func parse(driverName string, query []rune) (s []rune, args []driver.NamedValue) {
var (
a int // Pointer used to seek through the string
op int // The ordinal position of the captured arg
ignore bool // Used to ignore false positives
)
for a < len(query) {
ra := query[a] // the rune at position a
// Ignore characters inside sql string literals.
if string(ra) == "'" {
ignore = !ignore
}
if !ignore && string(ra) == naPrefix {
// We've found an argument! Create second pointer to find end of argument.
b := a
// Find the first terminating character to infer the end of the arg.
for b < len(query) {
rb := query[b]
if reArgTerm.MatchString(string(rb)) {
break
}
b++
}
op++ // Increment the arg's ordinal position.
// Get the name of the arg, ignoring the prefix (@).
a1 := a + 1
n := string(query[a1:b])
// Add the named arg to the list of all found args.
nv := driver.NamedValue{
Ordinal: op,
Name: n,
}
args = append(args, nv)
// Convert the named arg to the correct syntax for the driver.
arg := []rune(argfmt(driverName, nv))
s = append(s, arg...)
a = b // Skip to the end of the arg
continue
}
s = append(s, query[a])
a++
}
return s, args
}
// value gets the value for field (name) in the data object.
func value(data any, name string) any {
if m, ok := data.(map[string]any); ok {
if v, ok := m[name]; ok {
return v
}
return nil
}
// If data is not a simple map, use reflection to get the value.
v := reflect.Indirect(reflect.ValueOf(data))
switch {
case v.Kind() == reflect.Struct: // Struct
if f := v.FieldByName(name); f.IsValid() {
if f.CanInterface() {
return f.Interface()
}
}
case v.Elem().Kind() == reflect.Struct: // Pointer struct
el := v.Elem()
if f := el.FieldByName(name); f.IsValid() {
if f.CanInterface() {
return f.Interface()
}
}
case v.Elem().Kind() == reflect.Map: // Map pointer
if val := v.Elem().MapIndex(reflect.ValueOf(name)); !val.IsZero() {
if val.CanInterface() {
return val.Interface()
}
}
}
return nil
}
// argfmt converts a named arg to the correct syntax for the driver.
// e.g. @Foo => $1 (postgres)
func argfmt(driver string, nv driver.NamedValue) string {
switch driver {
// TODO: support more sql engines
case "postgres":
return fmt.Sprintf("$%d", nv.Ordinal)
case "mssql", "sqlserver":
return fmt.Sprintf("@%s", nv.Name)
default:
return "?"
}
}