Skip to content

Commit

Permalink
Propery handle nil panics (#281)
Browse files Browse the repository at this point in the history
* Add test case for exhibiting nil panic behavior.

* Implement proper handling of nil panics.
  • Loading branch information
misberner committed Mar 31, 2020
1 parent 3ce3d51 commit 4705cb3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
18 changes: 13 additions & 5 deletions recovery/interceptors.go
Expand Up @@ -22,33 +22,41 @@ type RecoveryHandlerFuncContext func(ctx context.Context, p interface{}) (err er
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOptions(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
panicked := true

defer func() {
if r := recover(); r != nil {
if r := recover(); r != nil || panicked {
err = recoverFrom(ctx, r, o.recoveryHandlerFunc)
}
}()

return handler(ctx, req)
resp, err := handler(ctx, req)
panicked = false
return resp, err
}
}

// StreamServerInterceptor returns a new streaming server interceptor for panic recovery.
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOptions(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
panicked := true

defer func() {
if r := recover(); r != nil {
if r := recover(); r != nil || panicked {
err = recoverFrom(stream.Context(), r, o.recoveryHandlerFunc)
}
}()

return handler(srv, stream)
err = handler(srv, stream)
panicked = false
return err
}
}

func recoverFrom(ctx context.Context, p interface{}, r RecoveryHandlerFuncContext) error {
if r == nil {
return status.Errorf(codes.Internal, "%s", p)
return status.Errorf(codes.Internal, "%v", p)
}
return r(ctx, p)
}
33 changes: 28 additions & 5 deletions recovery/interceptors_test.go
Expand Up @@ -7,9 +7,9 @@ import (
"context"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/grpc-ecosystem/go-grpc-middleware/testing"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -20,8 +20,9 @@ import (
)

var (
goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
panicPing = &pb_testproto.PingRequest{Value: "panic", SleepTimeMs: 9999}
goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
panicPing = &pb_testproto.PingRequest{Value: "panic", SleepTimeMs: 9999}
nilPanicPing = &pb_testproto.PingRequest{Value: "nilpanic", SleepTimeMs: 9999}
)

type recoveryAssertService struct {
Expand All @@ -32,13 +33,19 @@ func (s *recoveryAssertService) Ping(ctx context.Context, ping *pb_testproto.Pin
if ping.Value == "panic" {
panic("very bad thing happened")
}
if ping.Value == "nilpanic" {
panic(nil)
}
return s.TestServiceServer.Ping(ctx, ping)
}

func (s *recoveryAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
if ping.Value == "panic" {
panic("very bad thing happened")
}
if ping.Value == "nilpanic" {
panic(nil)
}
return s.TestServiceServer.PingList(ping, stream)
}

Expand Down Expand Up @@ -73,6 +80,13 @@ func (s *RecoverySuite) TestUnary_PanickingRequest() {
assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message")
}

func (s *RecoverySuite) TestUnary_NilPanickingRequest() {
_, err := s.Client.Ping(s.SimpleCtx(), nilPanicPing)
require.Error(s.T(), err, "there must be an error")
assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
assert.Equal(s.T(), "<nil>", status.Convert(err).Message(), "must error with <nil>")
}

func (s *RecoverySuite) TestStream_SuccessfulReceive() {
stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
Expand All @@ -90,6 +104,15 @@ func (s *RecoverySuite) TestStream_PanickingReceive() {
assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message")
}

func (s *RecoverySuite) TestStream_NilPanickingReceive() {
stream, err := s.Client.PingList(s.SimpleCtx(), nilPanicPing)
require.NoError(s.T(), err, "should not fail on establishing the stream")
_, err = stream.Recv()
require.Error(s.T(), err, "there must be an error")
assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
assert.Equal(s.T(), "<nil>", status.Convert(err).Message(), "must error with <nil>")
}

func TestRecoveryOverrideSuite(t *testing.T) {
opts := []grpc_recovery.Option{
grpc_recovery.WithRecoveryHandler(func(p interface{}) (err error) {
Expand Down

0 comments on commit 4705cb3

Please sign in to comment.