-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
idempotency.go
159 lines (129 loc) · 4 KB
/
idempotency.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package idempotency
import (
"fmt"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
localsKeyIsFromCache contextKey = iota //
localsKeyWasPutToCache
)
func IsFromCache(c fiber.Ctx) bool {
return c.Locals(localsKeyIsFromCache) != nil
}
func WasPutToCache(c fiber.Ctx) bool {
return c.Locals(localsKeyWasPutToCache) != nil
}
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
for _, h := range cfg.KeepResponseHeaders {
keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
}
maybeWriteCachedResponse := func(c fiber.Ctx, key string) (bool, error) {
if val, err := cfg.Storage.Get(key); err != nil {
return false, fmt.Errorf("failed to read response: %w", err)
} else if val != nil {
var res response
if _, err := res.UnmarshalMsg(val); err != nil {
return false, fmt.Errorf("failed to unmarshal response: %w", err)
}
_ = c.Status(res.StatusCode)
for header, vals := range res.Headers {
for _, val := range vals {
c.Context().Response.Header.Add(header, val)
}
}
if len(res.Body) != 0 {
if err := c.Send(res.Body); err != nil {
return true, err
}
}
_ = c.Locals(localsKeyIsFromCache, true)
return true, nil
}
return false, nil
}
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Don't execute middleware if the idempotency key is empty
key := utils.CopyString(c.Get(cfg.KeyHeader))
if key == "" {
return c.Next()
}
// Validate key
if err := cfg.KeyHeaderValidate(key); err != nil {
return err
}
// First-pass: if the idempotency key is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
} else if ok {
return nil
}
if err := cfg.Lock.Lock(key); err != nil {
return fmt.Errorf("failed to lock: %w", err)
}
defer func() {
if err := cfg.Lock.Unlock(key); err != nil {
log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", key, err)
}
}()
// Lock acquired. If the idempotency key now is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response while locked: %w", err)
} else if ok {
return nil
}
// Execute the request handler
if err := c.Next(); err != nil {
// If the request handler returned an error, return it and skip idempotency
return err
}
// Construct response
res := &response{
StatusCode: c.Response().StatusCode(),
Body: utils.CopyBytes(c.Response().Body()),
}
{
headers := make(map[string][]string)
if err := c.Bind().RespHeader(headers); err != nil {
return fmt.Errorf("failed to bind to response headers: %w", err)
}
if cfg.KeepResponseHeaders == nil {
// Keep all
res.Headers = headers
} else {
// Filter
res.Headers = make(map[string][]string)
for h := range headers {
if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
res.Headers[h] = headers[h]
}
}
}
}
// Marshal response
bs, err := res.MarshalMsg(nil)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
// Store response
if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
return fmt.Errorf("failed to save response: %w", err)
}
_ = c.Locals(localsKeyWasPutToCache, true)
return nil
}
}