This repository has been archived by the owner on Jun 6, 2022. It is now read-only.
generated from datewu/project-lib
/
middleware.go
127 lines (115 loc) · 3.32 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package toushi
import (
"expvar"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
// Middleware is a function that takes a http.HandlerFunc and returns a http.HandlerFunc.
type Middleware func(http.HandlerFunc) http.HandlerFunc
func (ro *router) enabledCORS(next http.HandlerFunc) http.HandlerFunc {
middle := func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Vary", "Origin")
// Add the "Vary: Access-Control-Request-Method" header.
w.Header().Add("Vary", "Access-Control-Request-Method")
origin := r.Header.Get("Origin")
if origin != "" {
for i := range ro.config.CORS.TrustedOrigins {
if origin == ro.config.CORS.TrustedOrigins[i] {
w.Header().Set("Access-Control-Allow-Origin", origin)
// Check if the request has the HTTP method OPTIONS and contains the
// "Access-Control-Request-Method" header. If it does, then we treat
// it as a preflight request.
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
// Set the necessary preflight response headers, as discussed
// previously.
w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, PUT, PATCH, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
// Write the headers along with a 200 OK status and return from
// the middleware with no further action.
w.WriteHeader(http.StatusOK)
return
}
}
}
}
next(w, r)
}
return middle
}
func (ro *router) rateLimit(next http.HandlerFunc) http.HandlerFunc {
type client struct {
limiter *rate.Limiter
lastSeen time.Time
}
var (
clients = make(map[string]*client)
mu sync.Mutex
)
delOld := func(interval time.Duration) {
for {
time.Sleep(interval)
mu.Lock()
for k, v := range clients {
if time.Since(v.lastSeen) > 3*time.Minute {
delete(clients, k)
}
}
mu.Unlock()
}
}
go delOld(time.Minute)
middle := func(w http.ResponseWriter, r *http.Request) {
h := NewHandleHelper(w, r)
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
h.ServerErr(err)
return
}
mu.Lock()
if _, existed := clients[ip]; !existed {
clients[ip] = &client{
limiter: rate.NewLimiter(rate.Limit(ro.config.Limiter.Rps),
ro.config.Limiter.Burst),
}
}
clients[ip].lastSeen = time.Now()
if !clients[ip].limiter.Allow() {
mu.Unlock()
h.RateLimitExceede()
return
}
mu.Unlock()
next(w, r)
}
return middle
}
func (ro *router) metrics(next http.HandlerFunc) http.HandlerFunc {
totalRequestReceived := expvar.NewInt("total_requests_received")
totalResponsesSend := expvar.NewInt("total_responses_send")
totalProcessingTimeMicroseconds := expvar.NewInt("total_processing_time_us")
middle := func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
totalRequestReceived.Add(1)
next(w, r)
totalResponsesSend.Add(1)
duration := time.Since(start).Microseconds()
totalProcessingTimeMicroseconds.Add(duration)
}
return middle
}
func recoverPanic(next http.HandlerFunc) http.HandlerFunc {
middle := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
w.Header().Set("Connection", "close")
WriteJSON(w, http.StatusInternalServerError, Envelope{"recover": err}, nil)
return
}
}()
next(w, r)
}
return middle
}