Skip to content

Commit

Permalink
fix: Do not wrap io.EOF intercepted by stream Sends (#37647)
Browse files Browse the repository at this point in the history
* Verify that intercepted stream Sends wrap io.EOF

* fix: Do not wrap io.EOF intercepted by stream Sends

* Use a helper func, fix duplicate Send/Recv calls

* Fix typo
  • Loading branch information
codingllama committed Feb 1, 2024
1 parent da9272e commit c8f1187
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
18 changes: 10 additions & 8 deletions api/utils/grpc/interceptors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,24 @@ type grpcClientStreamWrapper struct {

// SendMsg wraps around ClientStream.SendMsg
func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error {
if err := s.ClientStream.SendMsg(m); err != nil {
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m)))}
}
return nil
return wrapStreamErr(s.ClientStream.SendMsg(m))
}

// RecvMsg wraps around ClientStream.RecvMsg
func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error {
switch err := s.ClientStream.RecvMsg(m); {
return wrapStreamErr(s.ClientStream.RecvMsg(m))
}

func wrapStreamErr(err error) error {
switch {
case err == nil:
return nil
case errors.Is(err, io.EOF):
// Do not wrap io.EOF errors, they are often used as stop guards for streams.
return err
case err != nil:
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m)))}
default:
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))}
}
return nil
}

// GRPCServerUnaryErrorInterceptor is a gRPC unary server interceptor that
Expand Down
17 changes: 11 additions & 6 deletions api/utils/grpc/interceptors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ package interceptors

import (
"context"
"errors"
"io"
"net"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -87,14 +87,19 @@ func TestGRPCErrorWrapping(t *testing.T) {
stream, err := client.AddMFADevice(context.Background())
require.NoError(t, err)

// Give the server time to close the stream. This allows us to more
// consistently hit the io.EOF error.
time.Sleep(100 * time.Millisecond)

//nolint:staticcheck // SA1019. The specific stream used here doesn't matter.
sendErr := stream.Send(&proto.AddMFADeviceRequest{})

// io.EOF means the server closed the stream, which can
// happen depending in timing. In either case, it is
// still safe to recv from the stream and check for
// Expect either a success (unlikely because of the Sleep) or an unwrapped
// io.EOF error (meaning the server errored and closed the stream).
// In either case, it is still safe to recv from the stream and check for
// the already exists error.
if sendErr != nil && !errors.Is(sendErr, io.EOF) {
t.Fatalf("Unexpected error: %v", sendErr)
if sendErr != nil && sendErr != io.EOF /* == error comparison on purpose! */ {
t.Fatalf("Unexpected error: %q (%T)", sendErr, sendErr)
}

_, err = stream.Recv()
Expand Down

0 comments on commit c8f1187

Please sign in to comment.