-
Notifications
You must be signed in to change notification settings - Fork 8
/
middleware.go
420 lines (359 loc) · 13.8 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
/*
Copyright 2023 eatmoreapple
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package juice
import (
"context"
"crypto/md5"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/eatmoreapple/juice/cache"
"github.com/eatmoreapple/juice/internal/reflectlite"
"log"
"reflect"
"strconv"
"strings"
"time"
"unicode"
)
// Middleware is a wrapper of QueryHandler and ExecHandler.
type Middleware interface {
// QueryContext wraps the QueryHandler.
QueryContext(stmt Statement, next QueryHandler) QueryHandler
// ExecContext wraps the ExecHandler.
ExecContext(stmt Statement, next ExecHandler) ExecHandler
}
// ensure MiddlewareGroup implements Middleware.
var _ Middleware = MiddlewareGroup(nil) // compile time check
// MiddlewareGroup is a group of Middleware.
type MiddlewareGroup []Middleware
// QueryContext implements Middleware.
// Call QueryContext will call all the QueryContext of the middlewares in the group.
func (m MiddlewareGroup) QueryContext(stmt Statement, next QueryHandler) QueryHandler {
for _, middleware := range m {
next = middleware.QueryContext(stmt, next)
}
return next
}
// ExecContext implements Middleware.
// Call ExecContext will call all the ExecContext of the middlewares in the group.
func (m MiddlewareGroup) ExecContext(stmt Statement, next ExecHandler) ExecHandler {
for _, middleware := range m {
next = middleware.ExecContext(stmt, next)
}
return next
}
// logger is a default logger for debug.
var logger = log.New(log.Writer(), "[juice] ", log.Flags())
// ensure DebugMiddleware implements Middleware.
var _ Middleware = (*DebugMiddleware)(nil) // compile time check
// DebugMiddleware is a middleware that prints the sql xmlSQLStatement and the execution time.
type DebugMiddleware struct{}
// QueryContext implements Middleware.
// QueryContext will print the sql xmlSQLStatement and the execution time.
func (m *DebugMiddleware) QueryContext(stmt Statement, next QueryHandler) QueryHandler {
if !m.isDeBugMode(stmt) {
return next
}
// wrapper QueryHandler
return func(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
start := time.Now()
rows, err := next(ctx, query, args...)
spent := time.Since(start)
logger.Printf("\x1b[33m[%s]\x1b[0m \x1b[32m %s\x1b[0m \x1b[34m %v\x1b[0m \x1b[31m %v\x1b[0m\n", stmt.Name(), query, args, spent)
return rows, err
}
}
// ExecContext implements Middleware.
// ExecContext will print the sql xmlSQLStatement and the execution time.
func (m *DebugMiddleware) ExecContext(stmt Statement, next ExecHandler) ExecHandler {
if !m.isDeBugMode(stmt) {
return next
}
// wrapper ExecContext
return func(ctx context.Context, query string, args ...any) (sql.Result, error) {
start := time.Now()
rows, err := next(ctx, query, args...)
spent := time.Since(start)
logger.Printf("\x1b[33m[%s]\x1b[0m \x1b[32m %s\x1b[0m \x1b[34m %v\x1b[0m \x1b[31m %v\x1b[0m\n", stmt.Name(), query, args, spent)
return rows, err
}
}
// isDeBugMode returns true if the debug mode is on.
// Default debug mode is on.
// You can turn off the debug mode by setting the debug tag to false in the mapper xmlSQLStatement attribute or the configuration.
func (m *DebugMiddleware) isDeBugMode(stmt Statement) bool {
// try to one the bug mode from the xmlSQLStatement
debug := stmt.Attribute("debug")
// if the bug mode is not set, try to one the bug mode from the Context
if debug == "false" {
return false
}
if cfg := stmt.Configuration(); cfg.Settings().Get("debug") == "false" {
return false
}
return true
}
// ensure TimeoutMiddleware implements Middleware
var _ Middleware = (*TimeoutMiddleware)(nil) // compile time check
// TimeoutMiddleware is a middleware that sets the timeout for the sql xmlSQLStatement.
type TimeoutMiddleware struct{}
// QueryContext implements Middleware.
// QueryContext will set the timeout for the sql xmlSQLStatement.
func (t TimeoutMiddleware) QueryContext(stmt Statement, next QueryHandler) QueryHandler {
timeout := t.getTimeout(stmt)
if timeout <= 0 {
return next
}
return func(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
defer cancel()
return next(ctx, query, args...)
}
}
// ExecContext implements Middleware.
// ExecContext will set the timeout for the sql xmlSQLStatement.
func (t TimeoutMiddleware) ExecContext(stmt Statement, next ExecHandler) ExecHandler {
timeout := t.getTimeout(stmt)
if timeout <= 0 {
return next
}
return func(ctx context.Context, query string, args ...any) (sql.Result, error) {
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Millisecond)
defer cancel()
return next(ctx, query, args...)
}
}
// getTimeout returns the timeout from the xmlSQLStatement.
func (t TimeoutMiddleware) getTimeout(stmt Statement) (timeout int64) {
timeoutAttr := stmt.Attribute("timeout")
if timeoutAttr == "" {
return
}
timeout, _ = strconv.ParseInt(timeoutAttr, 10, 64)
return
}
// ensure useGeneratedKeysMiddleware implements Middleware
var _ Middleware = (*useGeneratedKeysMiddleware)(nil) // compile time check
// useGeneratedKeysMiddleware is a middleware that set the last insert id to the struct.
type useGeneratedKeysMiddleware struct{}
// QueryContext implements Middleware.
// return the result directly and do nothing.
func (m *useGeneratedKeysMiddleware) QueryContext(_ Statement, next QueryHandler) QueryHandler {
return next
}
// ExecContext implements Middleware.
// ExecContext will set the last insert id to the struct.
func (m *useGeneratedKeysMiddleware) ExecContext(stmt Statement, next ExecHandler) ExecHandler {
if !(stmt.Action() == Insert) {
return next
}
// If the useGeneratedKeys is not set or false, return the result directly.
useGeneratedKeys := stmt.Attribute("useGeneratedKeys") == "true" ||
// If the useGeneratedKeys is not set, but the global useGeneratedKeys is set and true.
stmt.Configuration().Settings().Get("useGeneratedKeys") == "true"
if !useGeneratedKeys {
return next
}
return func(ctx context.Context, query string, args ...any) (sql.Result, error) {
result, err := next(ctx, query, args...)
if err != nil {
return nil, err
}
// try to get param from context
// ParamCtxInjectorExecutor is already set in middlewares, so the param should be in the context.
param := ParamFromContext(ctx)
if param == nil {
return nil, errors.New("useGeneratedKeys is true, but the param is nil")
}
// checkout the input param
rv := reflect.ValueOf(param)
// If the useGeneratedKeys is set and true but the param is not a pointer.
if rv.Kind() != reflect.Ptr {
return nil, errors.New("useGeneratedKeys is true, but the param is not a pointer")
}
rv = reflect.Indirect(rv)
// If the useGeneratedKeys is set and true but the param is not a struct pointer.
// NOTE: batch insert does not support useGeneratedKeys yet.
// TODO: support batch insert useGeneratedKeys.
if rv.Kind() != reflect.Struct {
return nil, errors.New("useGeneratedKeys is true, but the param is not a struct pointer")
}
var field reflect.Value
// keyProperty is the name of the field that will be set the generated key.
keyProperty := stmt.Attribute("keyProperty")
if len(keyProperty) == 0 {
// try to find the field by default behavior.
field = reflectlite.From(rv).FindFieldFromTag("autoincr", "true").Value
} else {
keyProperties := strings.Split(keyProperty, ".")
// try to find the field from the given struct.
// if isPublic is true, then it means the following keyProperties are the field names.
// otherwise, the following keyProperties are the tag names.
isPublic := unicode.IsUpper(rune(keyProperty[0]))
loopValue := rv
for i := 0; i < len(keyProperties); i++ {
value := reflectlite.From(loopValue)
if ik := value.IndirectKind(); ik != reflect.Struct {
return nil, fmt.Errorf("expect struct, but got %s", ik)
}
// if the keyProperty is public, find the field by name.
// otherwise, find the field by tag.
if isPublic {
loopValue = value.FieldByName(keyProperties[i])
} else {
loopValue = value.FindFieldFromTag("column", keyProperties[i]).Value
}
// we can not find the field, return directly.
if !loopValue.IsValid() {
return nil, fmt.Errorf("the keyProperty %s is not found", keyProperty)
}
}
// reset the field
field = loopValue
}
if !field.IsValid() {
return nil, fmt.Errorf("the keyProperty %s is not found or not field has the autoincr tag", keyProperty)
}
// If the field is not an int, return the result directly.
if !field.CanInt() {
return nil, fmt.Errorf("the keyProperty %s is not a int", keyProperty)
}
// get the last insert id
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
// set the id to the field
field.SetInt(id)
return result, nil
}
}
// GenericMiddleware defines the middleware interface for the generic execution.
type GenericMiddleware[T any] interface {
// QueryContext wraps the GenericQueryHandler.
// The GenericQueryHandler is a function that accepts a context.Context, a query string and a slice of arguments.
QueryContext(stmt Statement, next GenericQueryHandler[T]) GenericQueryHandler[T]
// ExecContext wraps the ExecHandler.
// The ExecHandler is a function that accepts a context.Context, a query string and a slice of arguments.
ExecContext(stmt Statement, next ExecHandler) ExecHandler
}
// ensure GenericMiddlewareGroup implements GenericMiddleware
var _ GenericMiddleware[any] = (GenericMiddlewareGroup[any])(nil) // compile time check
// GenericMiddlewareGroup is a group of GenericMiddleware.
// It implements the GenericMiddleware interface.
type GenericMiddlewareGroup[T any] []GenericMiddleware[T]
// QueryContext implements GenericMiddleware.
func (m GenericMiddlewareGroup[T]) QueryContext(stmt Statement, next GenericQueryHandler[T]) GenericQueryHandler[T] {
for _, middleware := range m {
next = middleware.QueryContext(stmt, next)
}
return next
}
// ExecContext implements GenericMiddleware.
func (m GenericMiddlewareGroup[T]) ExecContext(stmt Statement, next ExecHandler) ExecHandler {
for _, middleware := range m {
next = middleware.ExecContext(stmt, next)
}
return next
}
// ensure GenericMiddlewareGroup implements GenericMiddleware
var _ GenericMiddleware[any] = (*CacheMiddleware[any])(nil) // compile time check
// cacheKeyFunc defines the function which is used to generate the scopeCache key.
type cacheKeyFunc func(stmt Statement, query string, args []any) (string, error)
// CacheKeyFunc is the function which is used to generate the scopeCache key.
// default is the md5 of the query and args.
// reset the CacheKeyFunc variable to change the default behavior.
var CacheKeyFunc cacheKeyFunc = func(stmt Statement, query string, args []any) (string, error) {
// only same xmlSQLStatement same query same args can get the same scopeCache key
writer := md5.New()
writer.Write([]byte(stmt.ID() + query))
if len(args) > 0 {
item, err := json.Marshal(args)
if err != nil {
return "", err
}
writer.Write(item)
}
return hex.EncodeToString(writer.Sum(nil)), nil
}
// CacheMiddleware is a middleware that caches the result of the sql query.
type CacheMiddleware[T any] struct {
scopeCache cache.ScopeCache
}
// QueryContext implements Middleware.
func (c *CacheMiddleware[T]) QueryContext(stmt Statement, next GenericQueryHandler[T]) GenericQueryHandler[T] {
// If the scopeCache is nil or the useCache is false, return the result directly.
if c.scopeCache == nil || stmt.Attribute("useCache") == "false" {
return next
}
return func(ctx context.Context, query string, args ...any) (result T, err error) {
// cached this function incase the CacheKeyFunc is changed by other goroutines.
keyFunc := CacheKeyFunc
// check the keyFunc variable
if keyFunc == nil {
err = errors.New("CacheKeyFunc is nil")
return
}
// cacheKey is the key which is used to get the result and put the result to the scopeCache.
var cacheKey string
// CacheKeyFunc is the function which is used to generate the scopeCache key.
// default is the md5 of the query and args.
// reset the CacheKeyFunc variable to change the default behavior.
cacheKey, err = keyFunc(stmt, query, args)
if err != nil {
return
}
// try to get the result from the scopeCache
instance, err := c.scopeCache.Get(ctx, cacheKey)
if err != nil {
// ErrCacheNotFound means the scopeCache is not found,
// we should continue to query the database.
if !errors.Is(err, cache.ErrCacheNotFound) {
return
}
err = nil
}
// try to convert the instance to the result type.
var ok bool
result, ok = instance.(T)
if ok {
return
}
// if the instance can not be converted to the result type, continue with the next handler.
// call the next handler
result, err = next(ctx, query, args...)
if err != nil {
return
}
err = c.scopeCache.Set(ctx, cacheKey, result)
return
}
}
// ExecContext implements Middleware.
func (c *CacheMiddleware[T]) ExecContext(stmt Statement, next ExecHandler) ExecHandler {
// if the scopeCache is enabled and flushCache is not disabled in this xmlSQLStatement.
if stmt.Attribute("flushCache") == "false" || c.scopeCache == nil {
return next
}
return func(ctx context.Context, query string, args ...any) (sql.Result, error) {
// call the next handler
result, err := next(ctx, query, args...)
if err == nil {
err = c.scopeCache.Flush(ctx)
}
return result, err
}
}