-
Notifications
You must be signed in to change notification settings - Fork 0
/
recovery.go
145 lines (129 loc) · 4.18 KB
/
recovery.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Package recovery can be used to add panic recovery middleware to the server.
package recovery
import (
"context"
"fmt"
"net"
"net/http"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"htdvisser.dev/exp/backbone/server"
"htdvisser.dev/exp/backbone/server/packet"
"htdvisser.dev/exp/backbone/server/stream"
)
// Middleware is middleware for panic recovery.
type Middleware struct {
panicToError func(ctx context.Context, p interface{}) error
errorToHTTPResponse func(w http.ResponseWriter, r *http.Request, err error)
}
// Option is an option for the panic recovery middleware.
type Option interface {
apply(*Middleware)
}
type option func(*Middleware)
func (f option) apply(opts *Middleware) {
f(opts)
}
// WithPanicToError returns an option that sets the function to convert panics to errors.
func WithPanicToError(f func(ctx context.Context, p interface{}) error) Option {
return option(func(opts *Middleware) {
opts.panicToError = f
})
}
// WithErrorToHTTPResponse returns an option that sets the function to write errors to HTTP responses.
func WithErrorToHTTPResponse(f func(w http.ResponseWriter, r *http.Request, err error)) Option {
return option(func(opts *Middleware) {
opts.errorToHTTPResponse = f
})
}
// NewMiddleware returns new middleware for panic recovery.
func NewMiddleware(opts ...Option) (*Middleware, error) {
m := &Middleware{
panicToError: func(_ context.Context, p interface{}) error {
if err, ok := p.(error); ok {
return err
}
return status.Errorf(codes.Internal, "%s", p)
},
errorToHTTPResponse: func(w http.ResponseWriter, _ *http.Request, err error) {
http.Error(w, fmt.Sprintf("%s", err), http.StatusInternalServerError)
},
}
for _, opt := range opts {
opt.apply(m)
}
return m, nil
}
// RecoverUnaryRPC recovers from panics in unary RPCs.
func (m *Middleware) RecoverUnaryRPC(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
defer func() {
if p := recover(); p != nil {
err = m.panicToError(ctx, p)
}
}()
return handler(ctx, req)
}
// RecoverStreamingRPC recovers from panics in streaming RPCs.
func (m *Middleware) RecoverStreamingRPC(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
defer func() {
if p := recover(); p != nil {
ctx := ss.Context()
err = m.panicToError(ctx, p)
}
}()
return handler(srv, ss)
}
// RecoverHTTP recovers from panics in HTTP handlers.
func (m *Middleware) RecoverHTTP(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if p := recover(); p != nil {
ctx := r.Context()
err := m.panicToError(ctx, p)
m.errorToHTTPResponse(w, r, err)
}
}()
next.ServeHTTP(w, r)
})
}
// RecoverStream recovers from panics in stream handlers.
func (m *Middleware) RecoverStream(next stream.HandlerFunc) stream.HandlerFunc {
return stream.HandlerFunc(func(ctx context.Context, conn net.Conn) (err error) {
defer func() {
if p := recover(); p != nil {
err = m.panicToError(ctx, p)
}
}()
return next.HandleStream(ctx, conn)
})
}
// RecoverPacket recovers from panics in packet handlers.
func (m *Middleware) RecoverPacket(next packet.HandlerFunc) packet.HandlerFunc {
return packet.HandlerFunc(func(ctx context.Context, pkt []byte, addr net.Addr, reply func([]byte) error) (err error) {
defer func() {
if p := recover(); p != nil {
err = m.panicToError(ctx, p)
}
}()
return next.HandlePacket(ctx, pkt, addr, reply)
})
}
// Register registers the panic recovery to the server.
func (m *Middleware) Register(s *server.Server) error {
s.GRPC.AddUnaryInterceptor(m.RecoverUnaryRPC)
s.GRPC.AddStreamInterceptor(m.RecoverStreamingRPC)
s.HTTP.AddMiddleware(m.RecoverHTTP)
s.InternalGRPC.AddUnaryInterceptor(m.RecoverUnaryRPC)
s.InternalGRPC.AddStreamInterceptor(m.RecoverStreamingRPC)
s.InternalHTTP.AddMiddleware(m.RecoverHTTP)
return nil
}
// Register registers the panic recovery to the server.
func Register(s *server.Server, opts ...Option) error {
m, err := NewMiddleware(opts...)
if err != nil {
return err
}
return m.Register(s)
}