forked from zhufuyi/sponge
/
breaker.go
88 lines (74 loc) · 2.16 KB
/
breaker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package middleware
import (
"net/http"
"github.com/github-tree/sponge/pkg/container/group"
"github.com/github-tree/sponge/pkg/gin/response"
"github.com/github-tree/sponge/pkg/shield/circuitbreaker"
"github.com/gin-gonic/gin"
)
// ErrNotAllowed error not allowed.
var ErrNotAllowed = circuitbreaker.ErrNotAllowed
// CircuitBreakerOption set the circuit breaker circuitBreakerOptions.
type CircuitBreakerOption func(*circuitBreakerOptions)
type circuitBreakerOptions struct {
group *group.Group
// http code for circuit breaker, default already includes 500 and 503
validCodes map[int]struct{}
}
func defaultCircuitBreakerOptions() *circuitBreakerOptions {
return &circuitBreakerOptions{
group: group.NewGroup(func() interface{} {
return circuitbreaker.NewBreaker()
}),
validCodes: map[int]struct{}{
http.StatusInternalServerError: {},
http.StatusServiceUnavailable: {},
},
}
}
func (o *circuitBreakerOptions) apply(opts ...CircuitBreakerOption) {
for _, opt := range opts {
opt(o)
}
}
// WithGroup with circuit breaker group.
// NOTE: implements generics circuitbreaker.CircuitBreaker
func WithGroup(g *group.Group) CircuitBreakerOption {
return func(o *circuitBreakerOptions) {
if g != nil {
o.group = g
}
}
}
// WithValidCode http code to mark failed
func WithValidCode(code ...int) CircuitBreakerOption {
return func(o *circuitBreakerOptions) {
for _, c := range code {
o.validCodes[c] = struct{}{}
}
}
}
// CircuitBreaker a circuit breaker middleware
func CircuitBreaker(opts ...CircuitBreakerOption) gin.HandlerFunc {
o := defaultCircuitBreakerOptions()
o.apply(opts...)
return func(c *gin.Context) {
breaker := o.group.Get(c.FullPath()).(circuitbreaker.CircuitBreaker)
if err := breaker.Allow(); err != nil {
// NOTE: when client reject request locally, keep adding counter let the drop ratio higher.
breaker.MarkFailed()
response.Output(c, http.StatusServiceUnavailable, err.Error())
c.Abort()
return
}
c.Next()
code := c.Writer.Status()
// NOTE: need to check internal and service unavailable error
_, isHit := o.validCodes[code]
if isHit {
breaker.MarkFailed()
} else {
breaker.MarkSuccess()
}
}
}