diff --git a/v2/client/client.go b/v2/client/client.go index ea8fbfbb..452304ff 100644 --- a/v2/client/client.go +++ b/v2/client/client.go @@ -98,6 +98,7 @@ type ceClient struct { eventDefaulterFns []EventDefaulter pollGoroutines int blockingCallback bool + ackMalformedEvent bool } func (c *ceClient) applyOptions(opts ...Option) error { @@ -202,7 +203,13 @@ func (c *ceClient) StartReceiver(ctx context.Context, fn interface{}) error { return fmt.Errorf("client already has a receiver") } - invoker, err := newReceiveInvoker(fn, c.observabilityService, c.inboundContextDecorators, c.eventDefaulterFns...) + invoker, err := newReceiveInvoker( + fn, + c.observabilityService, + c.inboundContextDecorators, + c.eventDefaulterFns, + c.ackMalformedEvent, + ) if err != nil { return err } diff --git a/v2/client/client_test.go b/v2/client/client_test.go index 2c4d70a3..39f705a3 100644 --- a/v2/client/client_test.go +++ b/v2/client/client_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" + "github.com/cloudevents/sdk-go/v2/binding" "github.com/cloudevents/sdk-go/v2/client" "github.com/cloudevents/sdk-go/v2/event" "github.com/cloudevents/sdk-go/v2/protocol" @@ -399,6 +400,71 @@ func TestClientContext(t *testing.T) { wg.Wait() } +func TestClientStartReceiverWithAckMalformedEvent(t *testing.T) { + testCases := []struct { + name string + opts []client.Option + expectedAck bool + }{ + { + name: "without ack", + }, + { + name: "with ack", + opts: []client.Option{client.WithAckMalformedEvent()}, + expectedAck: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // make sure the receiver goroutine is closed + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + receiver := &mockReceiver{ + finished: make(chan struct{}), + } + + // only need 1 goroutine to exercise this + tc.opts = append(tc.opts, client.WithPollGoroutines(1)) + + c, err := client.New(receiver, tc.opts...) + if err != nil { + t.Errorf("failed to construct client: %v", err) + } + + go c.StartReceiver(ctx, func(ctx context.Context, e event.Event) protocol.Result { + t.Error("receiver callback called unexpectedly") + return nil + }) + + // wait for receive to occur + time.Sleep(time.Millisecond) + + ctx, cancelTimeout := context.WithTimeout(ctx, time.Second) + defer cancelTimeout() + + select { + case <-receiver.finished: + // continue to rest of the test + case <-ctx.Done(): + t.Errorf("timeoued out waiting for receiver to complete") + } + + if tc.expectedAck { + if protocol.IsNACK(receiver.result) { + t.Errorf("receiver did not receive ACK: %v", receiver.result) + } + } else { + if protocol.IsACK(receiver.result) { + t.Errorf("receiver did not receive NACK: %v", receiver.result) + } + } + }) + } +} + type requestValidation struct { Host string Headers http.Header @@ -488,3 +554,38 @@ func isImportantHeader(h string) bool { } return true } + +type mockMessage struct{} + +func (m *mockMessage) ReadEncoding() binding.Encoding { + return binding.EncodingUnknown +} + +func (m *mockMessage) ReadStructured(ctx context.Context, writer binding.StructuredWriter) error { + return nil +} +func (m *mockMessage) ReadBinary(ctx context.Context, writer binding.BinaryWriter) error { return nil } +func (m *mockMessage) Finish(err error) error { return nil } + +type mockReceiver struct { + mu sync.Mutex + count int + result error + finished chan struct{} +} + +func (m *mockReceiver) Receive(ctx context.Context) (binding.Message, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.count > 0 { + return nil, io.EOF + } + + m.count++ + + return binding.WithFinish(&mockMessage{}, func(err error) { + m.result = err + close(m.finished) + }), nil +} diff --git a/v2/client/http_receiver.go b/v2/client/http_receiver.go index 94a4b4e6..672581b5 100644 --- a/v2/client/http_receiver.go +++ b/v2/client/http_receiver.go @@ -14,7 +14,7 @@ import ( ) func NewHTTPReceiveHandler(ctx context.Context, p *thttp.Protocol, fn interface{}) (*EventReceiver, error) { - invoker, err := newReceiveInvoker(fn, noopObservabilityService{}, nil) //TODO(slinkydeveloper) maybe not nil? + invoker, err := newReceiveInvoker(fn, noopObservabilityService{}, nil, nil, false) //TODO(slinkydeveloper) maybe not nil? if err != nil { return nil, err } diff --git a/v2/client/invoker.go b/v2/client/invoker.go index 403fb0f5..a3080b00 100644 --- a/v2/client/invoker.go +++ b/v2/client/invoker.go @@ -23,11 +23,18 @@ type Invoker interface { var _ Invoker = (*receiveInvoker)(nil) -func newReceiveInvoker(fn interface{}, observabilityService ObservabilityService, inboundContextDecorators []func(context.Context, binding.Message) context.Context, fns ...EventDefaulter) (Invoker, error) { +func newReceiveInvoker( + fn interface{}, + observabilityService ObservabilityService, + inboundContextDecorators []func(context.Context, binding.Message) context.Context, + fns []EventDefaulter, + ackMalformedEvent bool, +) (Invoker, error) { r := &receiveInvoker{ eventDefaulterFns: fns, observabilityService: observabilityService, inboundContextDecorators: inboundContextDecorators, + ackMalformedEvent: ackMalformedEvent, } if fn, err := receiver(fn); err != nil { @@ -44,6 +51,7 @@ type receiveInvoker struct { observabilityService ObservabilityService eventDefaulterFns []EventDefaulter inboundContextDecorators []func(context.Context, binding.Message) context.Context + ackMalformedEvent bool } func (r *receiveInvoker) Invoke(ctx context.Context, m binding.Message, respFn protocol.ResponseFn) (err error) { @@ -58,13 +66,13 @@ func (r *receiveInvoker) Invoke(ctx context.Context, m binding.Message, respFn p switch { case eventErr != nil && r.fn.hasEventIn: r.observabilityService.RecordReceivedMalformedEvent(ctx, eventErr) - return respFn(ctx, nil, protocol.NewReceipt(false, "failed to convert Message to Event: %w", eventErr)) + return respFn(ctx, nil, protocol.NewReceipt(r.ackMalformedEvent, "failed to convert Message to Event: %w", eventErr)) case r.fn != nil: // Check if event is valid before invoking the receiver function if e != nil { if validationErr := e.Validate(); validationErr != nil { r.observabilityService.RecordReceivedMalformedEvent(ctx, validationErr) - return respFn(ctx, nil, protocol.NewReceipt(false, "validation error in incoming event: %w", validationErr)) + return respFn(ctx, nil, protocol.NewReceipt(r.ackMalformedEvent, "validation error in incoming event: %w", validationErr)) } } diff --git a/v2/client/options.go b/v2/client/options.go index 93847816..44394be3 100644 --- a/v2/client/options.go +++ b/v2/client/options.go @@ -126,3 +126,16 @@ func WithBlockingCallback() Option { return nil } } + +// WithAckMalformedevents causes malformed events received within StartReceiver to be acknowledged +// rather than being permanently not-acknowledged. This can be useful when a protocol does not +// provide a responder implementation and would otherwise cause the receiver to be partially or +// fully stuck. +func WithAckMalformedEvent() Option { + return func(i interface{}) error { + if c, ok := i.(*ceClient); ok { + c.ackMalformedEvent = true + } + return nil + } +} diff --git a/v2/client/options_test.go b/v2/client/options_test.go index d712496e..a8224221 100644 --- a/v2/client/options_test.go +++ b/v2/client/options_test.go @@ -136,3 +136,31 @@ func TestWith_Defaulters(t *testing.T) { }) } } + +func TestWithAckMalformedEvent(t *testing.T) { + testCases := []struct { + name string + opts []Option + expected bool + }{ + { + name: "unset", + }, + { + name: "set", + opts: []Option{WithAckMalformedEvent()}, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := &ceClient{} + client.applyOptions(tc.opts...) + + if client.ackMalformedEvent != tc.expected { + t.Errorf("unexpected ackMalformedEvent; want: %t; got: %t", tc.expected, client.ackMalformedEvent) + } + }) + } +}