/
utils.go
146 lines (124 loc) · 3.5 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package sql
import (
"database/sql/driver"
"fmt"
"math"
"net"
"reflect"
"strings"
"time"
"github.com/google/uuid"
"github.com/bytehouse-cloud/driver-go/errors"
)
// bindArgsToQuery binds query's question marks to args
// This function should only be used for select queries
// e.g. (SELECT ?, "Goo") -> "SELECT 'Goo'"
// Returns err if there are not as many ? as args
func bindArgsToQuery(query string, args []driver.Value) (string, error) {
if len(args) == 0 {
return query, nil
}
var sb strings.Builder
// Preallocate memory (assume args 10 bytes each)
sb.Grow(len(query) + len(args)*10)
var index int
var openSingleQuote bool
var openBacktick bool
var openDoubleQuote bool
for i, value := range query {
// If open single quote ', skip question mark replace until next close single quote '
if value == '\'' {
// Ignore quotes escaped by \\', check by looking behind
if i <= 1 || (query[i-1] != '\\' && query[i-2] != '\\') {
openSingleQuote = !openSingleQuote
}
}
// If open backtick `, skip question mark replace until next close backtick `
if value == '`' {
openBacktick = !openBacktick
}
// If open double quote ", skip question mark replace until next close double quote "
if value == '"' {
openDoubleQuote = !openDoubleQuote
}
if value != '?' || openSingleQuote || openBacktick || openDoubleQuote {
sb.WriteRune(value)
continue
}
// Replace question mark with arg
if index >= len(args) {
return "", errors.ErrorfWithCaller("less args then query's ? sign")
}
sb.WriteString(quote(args[index]))
index++
}
// index should have advanced past all args
if index != len(args) {
return "", errors.ErrorfWithCaller("more args then query's ? sign")
}
return sb.String(), nil
}
// quote converts driver.Value into a string used in a sql statement depending on it's type
// this function is copied from clickhouse_go
func quote(v driver.Value) string {
switch v := v.(type) {
case string:
return "'" + strings.NewReplacer(`\`, `\\`, `'`, `\'`).Replace(v) + "'"
case time.Time:
return formatTime(v)
case net.IP:
return v.String()
case uuid.UUID:
return v.String()
}
switch v := reflect.ValueOf(v); v.Kind() {
case reflect.Slice:
var sb strings.Builder
sb.WriteRune('[')
for i := 0; ; i++ {
if i == v.Len()-1 {
sb.WriteString(quote(v.Index(v.Len() - 1).Interface()))
break
}
sb.WriteString(quote(v.Index(i).Interface()))
sb.WriteRune(',')
}
sb.WriteRune(']')
return sb.String()
case reflect.Map:
var sb strings.Builder
sb.WriteRune('{')
iter := v.MapRange()
if iter.Next() {
for {
sb.WriteString(quote(iter.Key()))
sb.WriteRune(':')
sb.WriteString(quote(iter.Value()))
if !iter.Next() {
break
}
sb.WriteRune(',')
}
}
sb.WriteRune('}')
return sb.String()
}
return fmt.Sprint(v)
}
func formatTime(value time.Time) string {
// toDate() overflows after 65535 days, but toDateTime() only overflows when time.Time overflows (after 9223372036854775807 seconds)
if days := value.Unix() / 24 / 3600; days <= math.MaxUint16 && (value.Hour()+value.Minute()+value.Second()+value.Nanosecond()) == 0 {
return fmt.Sprintf("toDate(%d)", days)
}
return fmt.Sprintf("toDateTime(%d)", value.Unix())
}
func namedArgsToArgs(namedArgs []driver.NamedValue) ([]driver.Value, error) {
args := make([]driver.Value, len(namedArgs))
for n, param := range namedArgs {
if len(param.Name) > 0 {
return nil, errors.Errorf("named params not supported")
}
args[n] = param.Value
}
return args, nil
}