diff --git a/share/p2p/recovery.go b/share/p2p/recovery.go new file mode 100644 index 0000000000..b214969399 --- /dev/null +++ b/share/p2p/recovery.go @@ -0,0 +1,21 @@ +package p2p + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/network" +) + +// RecoveryMiddleware is a middleware that recovers from panics in the handler. +func RecoveryMiddleware(handler network.StreamHandler) network.StreamHandler { + return func(stream network.Stream) { + defer func() { + r := recover() + if r != nil { + err := fmt.Errorf("PANIC while handling request: %s", r) + log.Error(err) + } + }() + handler(stream) + } +} diff --git a/share/p2p/shrexeds/server.go b/share/p2p/shrexeds/server.go index 11b99a3438..15d67d2111 100644 --- a/share/p2p/shrexeds/server.go +++ b/share/p2p/shrexeds/server.go @@ -52,7 +52,10 @@ func NewServer(params *Parameters, host host.Host, store *eds.Store) (*Server, e func (s *Server) Start(context.Context) error { s.ctx, s.cancel = context.WithCancel(context.Background()) - s.host.SetStreamHandler(s.protocolID, s.middleware.RateLimitHandler(s.handleStream)) + handler := s.handleStream + withRateLimit := s.middleware.RateLimitHandler(handler) + withRecovery := p2p.RecoveryMiddleware(withRateLimit) + s.host.SetStreamHandler(s.protocolID, withRecovery) return nil } diff --git a/share/p2p/shrexnd/server.go b/share/p2p/shrexnd/server.go index 33e61ff472..4d540c20ba 100644 --- a/share/p2p/shrexnd/server.go +++ b/share/p2p/shrexnd/server.go @@ -54,7 +54,10 @@ func NewServer(params *Parameters, host host.Host, store *eds.Store) (*Server, e ctx, cancel := context.WithCancel(context.Background()) srv.cancel = cancel - srv.handler = srv.middleware.RateLimitHandler(srv.streamHandler(ctx)) + handler := srv.streamHandler(ctx) + withRateLimit := srv.middleware.RateLimitHandler(handler) + withRecovery := p2p.RecoveryMiddleware(withRateLimit) + srv.handler = withRecovery return srv, nil }