From ae78fa09c6a0e7c2276841afcdc62ec73c3f3e97 Mon Sep 17 00:00:00 2001 From: Andrew Kroh Date: Sun, 13 Nov 2022 14:53:20 -0500 Subject: [PATCH] aws-s3 - create beat.Client for each SQS worker To address mutex connection in the single beat.Client used for publishing S3 events, create a unique beat.Client for each worker goroutine that is processing an SQS message. The beat.Client is used for all S3 objects contained within the SQS message (they are processed serially). After all events are ACKed the beat.Client is closed. --- CHANGELOG.next.asciidoc | 1 + x-pack/filebeat/input/awss3/input.go | 31 ++++++----- .../input/awss3/input_benchmark_test.go | 49 ++++++++++++----- x-pack/filebeat/input/awss3/interfaces.go | 5 +- .../input/awss3/mock_interfaces_test.go | 9 +-- .../input/awss3/mock_publisher_test.go | 55 ++++++++++++++++++- x-pack/filebeat/input/awss3/s3.go | 9 ++- x-pack/filebeat/input/awss3/s3_objects.go | 8 +-- .../filebeat/input/awss3/s3_objects_test.go | 16 +++--- x-pack/filebeat/input/awss3/s3_test.go | 8 +-- x-pack/filebeat/input/awss3/sqs_s3_event.go | 29 +++++++++- .../filebeat/input/awss3/sqs_s3_event_test.go | 46 +++++++++++----- 12 files changed, 198 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index 42685e384ec..cd3903fce5b 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -156,6 +156,7 @@ https://github.com/elastic/beats/compare/v8.2.0\...main[Check the HEAD diff] - Improve httpjson documentation for split processor. {pull}33473[33473] - Added separation of transform context object inside httpjson. Introduced new clause `.parent_last_response.*` {pull}33499[33499] - Cloud Foundry input uses server-side filtering when retrieving logs. {pull}33456[33456] +- Modified `aws-s3` input to reduce mutex contention when multiple SQS message are being processed concurrently. {pull}33658[33658] *Auditbeat* diff --git a/x-pack/filebeat/input/awss3/input.go b/x-pack/filebeat/input/awss3/input.go index e93ff421693..4533d3144bf 100644 --- a/x-pack/filebeat/input/awss3/input.go +++ b/x-pack/filebeat/input/awss3/input.go @@ -108,16 +108,6 @@ func (in *s3Input) Run(inputContext v2.Context, pipeline beat.Pipeline) error { }() defer cancelInputCtx() - // Create client for publishing events and receive notification of their ACKs. - client, err := pipeline.ConnectWith(beat.ClientConfig{ - CloseRef: inputContext.Cancelation, - ACKHandler: awscommon.NewEventACKHandler(), - }) - if err != nil { - return fmt.Errorf("failed to create pipeline client: %w", err) - } - defer client.Close() - if in.config.QueueURL != "" { regionName, err := getRegionFromQueueURL(in.config.QueueURL, in.config.AWSConfig.Endpoint) if err != nil { @@ -127,7 +117,7 @@ func (in *s3Input) Run(inputContext v2.Context, pipeline beat.Pipeline) error { in.awsConfig.Region = regionName // Create SQS receiver and S3 notification processor. - receiver, err := in.createSQSReceiver(inputContext, client) + receiver, err := in.createSQSReceiver(inputContext, pipeline) if err != nil { return fmt.Errorf("failed to initialize sqs receiver: %w", err) } @@ -139,6 +129,16 @@ func (in *s3Input) Run(inputContext v2.Context, pipeline beat.Pipeline) error { } if in.config.BucketARN != "" || in.config.NonAWSBucketName != "" { + // Create client for publishing events and receive notification of their ACKs. + client, err := pipeline.ConnectWith(beat.ClientConfig{ + CloseRef: inputContext.Cancelation, + ACKHandler: awscommon.NewEventACKHandler(), + }) + if err != nil { + return fmt.Errorf("failed to create pipeline client: %w", err) + } + defer client.Close() + // Create S3 receiver and S3 notification processor. poller, err := in.createS3Lister(inputContext, ctx, client, persistentStore, states) if err != nil { @@ -154,7 +154,7 @@ func (in *s3Input) Run(inputContext v2.Context, pipeline beat.Pipeline) error { return nil } -func (in *s3Input) createSQSReceiver(ctx v2.Context, client beat.Client) (*sqsReader, error) { +func (in *s3Input) createSQSReceiver(ctx v2.Context, pipeline beat.Pipeline) (*sqsReader, error) { sqsAPI := &awsSQSAPI{ client: sqs.NewFromConfig(in.awsConfig, func(o *sqs.Options) { if in.config.AWSConfig.FIPSEnabled { @@ -192,8 +192,8 @@ func (in *s3Input) createSQSReceiver(ctx v2.Context, client beat.Client) (*sqsRe if err != nil { return nil, err } - s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, fileSelectors) - sqsMessageHandler := newSQSS3EventProcessor(log.Named("sqs_s3_event"), metrics, sqsAPI, script, in.config.VisibilityTimeout, in.config.SQSMaxReceiveCount, s3EventHandlerFactory) + s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, fileSelectors) + sqsMessageHandler := newSQSS3EventProcessor(log.Named("sqs_s3_event"), metrics, sqsAPI, script, in.config.VisibilityTimeout, in.config.SQSMaxReceiveCount, pipeline, s3EventHandlerFactory) sqsReader := newSQSReader(log.Named("sqs"), metrics, sqsAPI, in.config.MaxNumberOfMessages, sqsMessageHandler) return sqsReader, nil @@ -267,10 +267,11 @@ func (in *s3Input) createS3Lister(ctx v2.Context, cancelCtx context.Context, cli if len(in.config.FileSelectors) == 0 { fileSelectors = []fileSelectorConfig{{ReaderConfig: in.config.ReaderConfig}} } - s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, fileSelectors) + s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, fileSelectors) s3Poller := newS3Poller(log.Named("s3_poller"), metrics, s3API, + client, s3EventHandlerFactory, states, persistentStore, diff --git a/x-pack/filebeat/input/awss3/input_benchmark_test.go b/x-pack/filebeat/input/awss3/input_benchmark_test.go index b1e652ef635..b8d9f29ce36 100644 --- a/x-pack/filebeat/input/awss3/input_benchmark_test.go +++ b/x-pack/filebeat/input/awss3/input_benchmark_test.go @@ -148,6 +148,37 @@ func (c constantS3) ListObjectsPaginator(bucket, prefix string) s3Pager { return c.pagerConstant } +var _ beat.Pipeline = (*fakePipeline)(nil) + +// fakePipeline returns new ackClients. +type fakePipeline struct{} + +func (c *fakePipeline) ConnectWith(clientConfig beat.ClientConfig) (beat.Client, error) { + return &ackClient{}, nil +} + +func (c *fakePipeline) Connect() (beat.Client, error) { + panic("Connect() is not implemented.") +} + +var _ beat.Client = (*ackClient)(nil) + +// ackClient is a fake beat.Client that ACKs the published messages. +type ackClient struct{} + +func (c *ackClient) Close() error { return nil } + +func (c *ackClient) Publish(event beat.Event) { + // Fake the ACK handling. + event.Private.(*awscommon.EventACKTracker).ACK() +} + +func (c *ackClient) PublishAll(event []beat.Event) { + for _, e := range event { + c.Publish(e) + } +} + func makeBenchmarkConfig(t testing.TB) config { cfg := conf.MustNewConfigFrom(`--- queue_url: foo @@ -171,21 +202,13 @@ func benchmarkInputSQS(t *testing.T, maxMessagesInflight int) testing.BenchmarkR metrics := newInputMetrics(metricRegistry, "test_id") sqsAPI := newConstantSQS() s3API := newConstantS3(t) - client := pubtest.NewChanClient(100) - defer close(client.Channel) + pipeline := &fakePipeline{} conf := makeBenchmarkConfig(t) - s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, conf.FileSelectors) - sqsMessageHandler := newSQSS3EventProcessor(log.Named("sqs_s3_event"), metrics, sqsAPI, nil, time.Minute, 5, s3EventHandlerFactory) + s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, conf.FileSelectors) + sqsMessageHandler := newSQSS3EventProcessor(log.Named("sqs_s3_event"), metrics, sqsAPI, nil, time.Minute, 5, pipeline, s3EventHandlerFactory) sqsReader := newSQSReader(log.Named("sqs"), metrics, sqsAPI, maxMessagesInflight, sqsMessageHandler) - go func() { - for event := range client.Channel { - // Fake the ACK handling that's not implemented in pubtest. - event.Private.(*awscommon.EventACKTracker).ACK() - } - }() - ctx, cancel := context.WithCancel(context.Background()) b.Cleanup(cancel) @@ -313,8 +336,8 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult return } - s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, config.FileSelectors) - s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", listPrefix, "region", "provider", numberOfWorkers, time.Second) + s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, config.FileSelectors) + s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, client, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", listPrefix, "region", "provider", numberOfWorkers, time.Second) if err := s3Poller.Poll(ctx); err != nil { if !errors.Is(err, context.DeadlineExceeded) { diff --git a/x-pack/filebeat/input/awss3/interfaces.go b/x-pack/filebeat/input/awss3/interfaces.go index 0196f831af9..77c51890b8b 100644 --- a/x-pack/filebeat/input/awss3/interfaces.go +++ b/x-pack/filebeat/input/awss3/interfaces.go @@ -15,6 +15,7 @@ import ( "github.com/aws/smithy-go/middleware" + "github.com/elastic/beats/v7/libbeat/beat" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" awssdk "github.com/aws/aws-sdk-go-v2/aws" @@ -28,7 +29,7 @@ import ( // Run 'go generate' to create mocks that are used in tests. //go:generate go install github.com/golang/mock/mockgen@v1.6.0 //go:generate mockgen -source=interfaces.go -destination=mock_interfaces_test.go -package awss3 -mock_names=sqsAPI=MockSQSAPI,sqsProcessor=MockSQSProcessor,s3API=MockS3API,s3Pager=MockS3Pager,s3ObjectHandlerFactory=MockS3ObjectHandlerFactory,s3ObjectHandler=MockS3ObjectHandler -//go:generate mockgen -destination=mock_publisher_test.go -package=awss3 -mock_names=Client=MockBeatClient github.com/elastic/beats/v7/libbeat/beat Client +//go:generate mockgen -destination=mock_publisher_test.go -package=awss3 -mock_names=Client=MockBeatClient,Pipeline=MockBeatPipeline github.com/elastic/beats/v7/libbeat/beat Client,Pipeline // ------ // SQS interfaces @@ -88,7 +89,7 @@ type s3ObjectHandlerFactory interface { // Create returns a new s3ObjectHandler that can be used to process the // specified S3 object. If the handler is not configured to process the // given S3 object (based on key name) then it will return nil. - Create(ctx context.Context, log *logp.Logger, acker *awscommon.EventACKTracker, obj s3EventV2) s3ObjectHandler + Create(ctx context.Context, log *logp.Logger, client beat.Client, acker *awscommon.EventACKTracker, obj s3EventV2) s3ObjectHandler } type s3ObjectHandler interface { diff --git a/x-pack/filebeat/input/awss3/mock_interfaces_test.go b/x-pack/filebeat/input/awss3/mock_interfaces_test.go index 63d26918302..0044c87da89 100644 --- a/x-pack/filebeat/input/awss3/mock_interfaces_test.go +++ b/x-pack/filebeat/input/awss3/mock_interfaces_test.go @@ -17,6 +17,7 @@ import ( types "github.com/aws/aws-sdk-go-v2/service/sqs/types" gomock "github.com/golang/mock/gomock" + beat "github.com/elastic/beats/v7/libbeat/beat" aws "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" logp "github.com/elastic/elastic-agent-libs/logp" ) @@ -444,17 +445,17 @@ func (m *MockS3ObjectHandlerFactory) EXPECT() *MockS3ObjectHandlerFactoryMockRec } // Create mocks base method. -func (m *MockS3ObjectHandlerFactory) Create(ctx context.Context, log *logp.Logger, acker *aws.EventACKTracker, obj s3EventV2) s3ObjectHandler { +func (m *MockS3ObjectHandlerFactory) Create(ctx context.Context, log *logp.Logger, client beat.Client, acker *aws.EventACKTracker, obj s3EventV2) s3ObjectHandler { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Create", ctx, log, acker, obj) + ret := m.ctrl.Call(m, "Create", ctx, log, client, acker, obj) ret0, _ := ret[0].(s3ObjectHandler) return ret0 } // Create indicates an expected call of Create. -func (mr *MockS3ObjectHandlerFactoryMockRecorder) Create(ctx, log, acker, obj interface{}) *gomock.Call { +func (mr *MockS3ObjectHandlerFactoryMockRecorder) Create(ctx, log, client, acker, obj interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockS3ObjectHandlerFactory)(nil).Create), ctx, log, acker, obj) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockS3ObjectHandlerFactory)(nil).Create), ctx, log, client, acker, obj) } // MockS3ObjectHandler is a mock of s3ObjectHandler interface. diff --git a/x-pack/filebeat/input/awss3/mock_publisher_test.go b/x-pack/filebeat/input/awss3/mock_publisher_test.go index 7fa935496aa..efbd5bcef97 100644 --- a/x-pack/filebeat/input/awss3/mock_publisher_test.go +++ b/x-pack/filebeat/input/awss3/mock_publisher_test.go @@ -3,7 +3,7 @@ // you may not use this file except in compliance with the Elastic License. // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/elastic/beats/v7/libbeat/beat (interfaces: Client) +// Source: github.com/elastic/beats/v7/libbeat/beat (interfaces: Client,Pipeline) // Package awss3 is a generated GoMock package. package awss3 @@ -76,3 +76,56 @@ func (mr *MockBeatClientMockRecorder) PublishAll(arg0 interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishAll", reflect.TypeOf((*MockBeatClient)(nil).PublishAll), arg0) } + +// MockBeatPipeline is a mock of Pipeline interface. +type MockBeatPipeline struct { + ctrl *gomock.Controller + recorder *MockBeatPipelineMockRecorder +} + +// MockBeatPipelineMockRecorder is the mock recorder for MockBeatPipeline. +type MockBeatPipelineMockRecorder struct { + mock *MockBeatPipeline +} + +// NewMockBeatPipeline creates a new mock instance. +func NewMockBeatPipeline(ctrl *gomock.Controller) *MockBeatPipeline { + mock := &MockBeatPipeline{ctrl: ctrl} + mock.recorder = &MockBeatPipelineMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBeatPipeline) EXPECT() *MockBeatPipelineMockRecorder { + return m.recorder +} + +// Connect mocks base method. +func (m *MockBeatPipeline) Connect() (beat.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect") + ret0, _ := ret[0].(beat.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Connect indicates an expected call of Connect. +func (mr *MockBeatPipelineMockRecorder) Connect() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockBeatPipeline)(nil).Connect)) +} + +// ConnectWith mocks base method. +func (m *MockBeatPipeline) ConnectWith(arg0 beat.ClientConfig) (beat.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ConnectWith", arg0) + ret0, _ := ret[0].(beat.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ConnectWith indicates an expected call of ConnectWith. +func (mr *MockBeatPipelineMockRecorder) ConnectWith(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectWith", reflect.TypeOf((*MockBeatPipeline)(nil).ConnectWith), arg0) +} diff --git a/x-pack/filebeat/input/awss3/s3.go b/x-pack/filebeat/input/awss3/s3.go index 349d5f7cfdd..5b1187e4317 100644 --- a/x-pack/filebeat/input/awss3/s3.go +++ b/x-pack/filebeat/input/awss3/s3.go @@ -15,6 +15,7 @@ import ( "github.com/gofrs/uuid" "go.uber.org/multierr" + "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/statestore" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" @@ -53,6 +54,7 @@ type s3Poller struct { s3 s3API log *logp.Logger metrics *inputMetrics + client beat.Client s3ObjectHandler s3ObjectHandlerFactory states *states store *statestore.Store @@ -63,6 +65,7 @@ type s3Poller struct { func newS3Poller(log *logp.Logger, metrics *inputMetrics, s3 s3API, + client beat.Client, s3ObjectHandler s3ObjectHandlerFactory, states *states, store *statestore.Store, @@ -71,7 +74,8 @@ func newS3Poller(log *logp.Logger, awsRegion string, provider string, numberOfWorkers int, - bucketPollInterval time.Duration) *s3Poller { + bucketPollInterval time.Duration, +) *s3Poller { if metrics == nil { metrics = newInputMetrics(monitoring.NewRegistry(), "") } @@ -86,6 +90,7 @@ func newS3Poller(log *logp.Logger, s3: s3, log: log, metrics: metrics, + client: client, s3ObjectHandler: s3ObjectHandler, states: states, store: store, @@ -214,7 +219,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<- acker := awscommon.NewEventACKTracker(ctx) - s3Processor := p.s3ObjectHandler.Create(ctx, p.log, acker, event) + s3Processor := p.s3ObjectHandler.Create(ctx, p.log, p.client, acker, event) if s3Processor == nil { p.log.Debugw("empty s3 processor.", "state", state) continue diff --git a/x-pack/filebeat/input/awss3/s3_objects.go b/x-pack/filebeat/input/awss3/s3_objects.go index a1d70c604c2..0f273828dc4 100644 --- a/x-pack/filebeat/input/awss3/s3_objects.go +++ b/x-pack/filebeat/input/awss3/s3_objects.go @@ -42,11 +42,10 @@ type s3ObjectProcessorFactory struct { log *logp.Logger metrics *inputMetrics s3 s3Getter - publisher beat.Client fileSelectors []fileSelectorConfig } -func newS3ObjectProcessorFactory(log *logp.Logger, metrics *inputMetrics, s3 s3Getter, publisher beat.Client, sel []fileSelectorConfig) *s3ObjectProcessorFactory { +func newS3ObjectProcessorFactory(log *logp.Logger, metrics *inputMetrics, s3 s3Getter, sel []fileSelectorConfig) *s3ObjectProcessorFactory { if metrics == nil { metrics = newInputMetrics(monitoring.NewRegistry(), "") } @@ -59,7 +58,6 @@ func newS3ObjectProcessorFactory(log *logp.Logger, metrics *inputMetrics, s3 s3G log: log, metrics: metrics, s3: s3, - publisher: publisher, fileSelectors: sel, } } @@ -75,7 +73,7 @@ func (f *s3ObjectProcessorFactory) findReaderConfig(key string) *readerConfig { // Create returns a new s3ObjectProcessor. It returns nil when no file selectors // match the S3 object key. -func (f *s3ObjectProcessorFactory) Create(ctx context.Context, log *logp.Logger, ack *awscommon.EventACKTracker, obj s3EventV2) s3ObjectHandler { +func (f *s3ObjectProcessorFactory) Create(ctx context.Context, log *logp.Logger, client beat.Client, ack *awscommon.EventACKTracker, obj s3EventV2) s3ObjectHandler { log = log.With( "bucket_arn", obj.S3.Bucket.Name, "object_key", obj.S3.Object.Key) @@ -90,6 +88,7 @@ func (f *s3ObjectProcessorFactory) Create(ctx context.Context, log *logp.Logger, s3ObjectProcessorFactory: f, log: log, ctx: ctx, + publisher: client, acker: ack, readerConfig: readerConfig, s3Obj: obj, @@ -102,6 +101,7 @@ type s3ObjectProcessor struct { log *logp.Logger ctx context.Context + publisher beat.Client acker *awscommon.EventACKTracker // ACKer tied to the SQS message (multiple S3 readers share an ACKer when the S3 notification event contains more than one S3 object). readerConfig *readerConfig // Config about how to process the object. s3Obj s3EventV2 // S3 object information. diff --git a/x-pack/filebeat/input/awss3/s3_objects_test.go b/x-pack/filebeat/input/awss3/s3_objects_test.go index 4541874303b..a3a9168d967 100644 --- a/x-pack/filebeat/input/awss3/s3_objects_test.go +++ b/x-pack/filebeat/input/awss3/s3_objects_test.go @@ -144,9 +144,9 @@ func TestS3ObjectProcessor(t *testing.T) { GetObject(gomock.Any(), gomock.Eq(s3Event.S3.Bucket.Name), gomock.Eq(s3Event.S3.Object.Key)). Return(nil, errFakeConnectivityFailure) - s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, mockPublisher, nil) + s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, nil) ack := awscommon.NewEventACKTracker(ctx) - err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), ack, s3Event).ProcessS3Object() + err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), mockPublisher, ack, s3Event).ProcessS3Object() require.Error(t, err) assert.True(t, errors.Is(err, errFakeConnectivityFailure), "expected errFakeConnectivityFailure error") }) @@ -166,9 +166,9 @@ func TestS3ObjectProcessor(t *testing.T) { GetObject(gomock.Any(), gomock.Eq(s3Event.S3.Bucket.Name), gomock.Eq(s3Event.S3.Object.Key)). Return(nil, nil) - s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, mockPublisher, nil) + s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, nil) ack := awscommon.NewEventACKTracker(ctx) - err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), ack, s3Event).ProcessS3Object() + err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), mockPublisher, ack, s3Event).ProcessS3Object() require.Error(t, err) }) @@ -193,9 +193,9 @@ func TestS3ObjectProcessor(t *testing.T) { Times(2), ) - s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, mockPublisher, nil) + s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, nil) ack := awscommon.NewEventACKTracker(ctx) - err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), ack, s3Event).ProcessS3Object() + err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), mockPublisher, ack, s3Event).ProcessS3Object() require.NoError(t, err) }) } @@ -231,9 +231,9 @@ func _testProcessS3Object(t testing.TB, file, contentType string, numEvents int, Times(numEvents), ) - s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, mockPublisher, selectors) + s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockS3API, selectors) ack := awscommon.NewEventACKTracker(ctx) - err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), ack, s3Event).ProcessS3Object() + err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), mockPublisher, ack, s3Event).ProcessS3Object() if !expectErr { require.NoError(t, err) diff --git a/x-pack/filebeat/input/awss3/s3_test.go b/x-pack/filebeat/input/awss3/s3_test.go index 367f707b183..674e700fd7b 100644 --- a/x-pack/filebeat/input/awss3/s3_test.go +++ b/x-pack/filebeat/input/awss3/s3_test.go @@ -125,8 +125,8 @@ func TestS3Poller(t *testing.T) { GetObject(gomock.Any(), gomock.Eq(bucket), gomock.Eq("key5")). Return(nil, errFakeConnectivityFailure) - s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, nil) - receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, s3ObjProc, newStates(inputCtx), store, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) + s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockAPI, nil) + receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, s3ObjProc, newStates(inputCtx), store, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) require.Error(t, context.DeadlineExceeded, receiver.Poll(ctx)) assert.Equal(t, numberOfWorkers, receiver.workerSem.Available()) }) @@ -248,8 +248,8 @@ func TestS3Poller(t *testing.T) { GetObject(gomock.Any(), gomock.Eq(bucket), gomock.Eq("key5")). Return(nil, errFakeConnectivityFailure) - s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, nil) - receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, s3ObjProc, newStates(inputCtx), store, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) + s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockAPI, nil) + receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, s3ObjProc, newStates(inputCtx), store, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) require.Error(t, context.DeadlineExceeded, receiver.Poll(ctx)) assert.Equal(t, numberOfWorkers, receiver.workerSem.Available()) }) diff --git a/x-pack/filebeat/input/awss3/sqs_s3_event.go b/x-pack/filebeat/input/awss3/sqs_s3_event.go index 72489b1550c..dba18898d1b 100644 --- a/x-pack/filebeat/input/awss3/sqs_s3_event.go +++ b/x-pack/filebeat/input/awss3/sqs_s3_event.go @@ -17,6 +17,7 @@ import ( "github.com/aws/smithy-go" + "github.com/elastic/beats/v7/libbeat/beat" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/aws/aws-sdk-go-v2/service/sqs/types" @@ -89,13 +90,23 @@ type sqsS3EventProcessor struct { sqsVisibilityTimeout time.Duration maxReceiveCount int sqs sqsAPI + pipeline beat.Pipeline // Pipeline creates clients for publishing events. log *logp.Logger warnOnce sync.Once metrics *inputMetrics script *script } -func newSQSS3EventProcessor(log *logp.Logger, metrics *inputMetrics, sqs sqsAPI, script *script, sqsVisibilityTimeout time.Duration, maxReceiveCount int, s3 s3ObjectHandlerFactory) *sqsS3EventProcessor { +func newSQSS3EventProcessor( + log *logp.Logger, + metrics *inputMetrics, + sqs sqsAPI, + script *script, + sqsVisibilityTimeout time.Duration, + maxReceiveCount int, + pipeline beat.Pipeline, + s3 s3ObjectHandlerFactory, +) *sqsS3EventProcessor { if metrics == nil { metrics = newInputMetrics(monitoring.NewRegistry(), "") } @@ -104,6 +115,7 @@ func newSQSS3EventProcessor(log *logp.Logger, metrics *inputMetrics, sqs sqsAPI, sqsVisibilityTimeout: sqsVisibilityTimeout, maxReceiveCount: maxReceiveCount, sqs: sqs, + pipeline: pipeline, log: log, metrics: metrics, script: script, @@ -277,13 +289,26 @@ func (p *sqsS3EventProcessor) processS3Events(ctx context.Context, log *logp.Log log.Debugf("SQS message contained %d S3 event notifications.", len(s3Events)) defer log.Debug("End processing SQS S3 event notifications.") + if len(s3Events) == 0 { + return nil + } + + // Create a pipeline client scoped to this goroutine. + client, err := p.pipeline.ConnectWith(beat.ClientConfig{ + ACKHandler: awscommon.NewEventACKHandler(), + }) + if err != nil { + return err + } + defer client.Close() + // Wait for all events to be ACKed before proceeding. acker := awscommon.NewEventACKTracker(ctx) defer acker.Wait() var errs []error for i, event := range s3Events { - s3Processor := p.s3ObjectHandler.Create(ctx, log, acker, event) + s3Processor := p.s3ObjectHandler.Create(ctx, log, client, acker, event) if s3Processor == nil { continue } diff --git a/x-pack/filebeat/input/awss3/sqs_s3_event_test.go b/x-pack/filebeat/input/awss3/sqs_s3_event_test.go index 98312fc9401..886fdfe1711 100644 --- a/x-pack/filebeat/input/awss3/sqs_s3_event_test.go +++ b/x-pack/filebeat/input/awss3/sqs_s3_event_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/elastic/beats/v7/libbeat/beat" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/go-concert/timed" @@ -39,13 +40,17 @@ func TestSQSS3EventProcessor(t *testing.T) { defer ctrl.Finish() mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) + mockClient := NewMockBeatClient(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) gomock.InOrder( - mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil), + mockBeatPipeline.EXPECT().ConnectWith(gomock.Any()).Return(mockClient, nil), + mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil), + mockClient.EXPECT().Close(), mockAPI.EXPECT().DeleteMessage(gomock.Any(), gomock.Eq(&msg)).Return(nil), ) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockBeatPipeline, mockS3HandlerFactory) require.NoError(t, p.ProcessSQS(ctx, &msg)) }) @@ -57,6 +62,7 @@ func TestSQSS3EventProcessor(t *testing.T) { defer ctrl.Finish() mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) invalidBodyMsg := newSQSMessage(newS3Event("log.json")) body := *invalidBodyMsg.Body @@ -67,7 +73,7 @@ func TestSQSS3EventProcessor(t *testing.T) { mockAPI.EXPECT().DeleteMessage(gomock.Any(), gomock.Eq(&invalidBodyMsg)).Return(nil), ) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockBeatPipeline, mockS3HandlerFactory) err := p.ProcessSQS(ctx, &invalidBodyMsg) require.Error(t, err) t.Log(err) @@ -81,6 +87,7 @@ func TestSQSS3EventProcessor(t *testing.T) { defer ctrl.Finish() mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) emptyRecordsMsg := newSQSMessage([]s3EventV2{}...) @@ -88,7 +95,7 @@ func TestSQSS3EventProcessor(t *testing.T) { mockAPI.EXPECT().DeleteMessage(gomock.Any(), gomock.Eq(&emptyRecordsMsg)).Return(nil), ) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockBeatPipeline, mockS3HandlerFactory) require.NoError(t, p.ProcessSQS(ctx, &emptyRecordsMsg)) }) @@ -103,19 +110,23 @@ func TestSQSS3EventProcessor(t *testing.T) { mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) mockS3Handler := NewMockS3ObjectHandler(ctrl) + mockClient := NewMockBeatClient(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) mockAPI.EXPECT().ChangeMessageVisibility(gomock.Any(), gomock.Eq(&msg), gomock.Eq(visibilityTimeout)).AnyTimes().Return(nil) gomock.InOrder( - mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Do(func(ctx context.Context, _ *logp.Logger, _ *awscommon.EventACKTracker, _ s3EventV2) { + mockBeatPipeline.EXPECT().ConnectWith(gomock.Any()).Return(mockClient, nil), + mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Do(func(ctx context.Context, _ *logp.Logger, _ beat.Client, _ *awscommon.EventACKTracker, _ s3EventV2) { require.NoError(t, timed.Wait(ctx, 5*visibilityTimeout)) }).Return(mockS3Handler), mockS3Handler.EXPECT().ProcessS3Object().Return(nil), + mockClient.EXPECT().Close(), mockAPI.EXPECT().DeleteMessage(gomock.Any(), gomock.Eq(&msg)).Return(nil), ) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, visibilityTimeout, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, visibilityTimeout, 5, mockBeatPipeline, mockS3HandlerFactory) require.NoError(t, p.ProcessSQS(ctx, &msg)) }) @@ -128,13 +139,17 @@ func TestSQSS3EventProcessor(t *testing.T) { mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) mockS3Handler := NewMockS3ObjectHandler(ctrl) + mockClient := NewMockBeatClient(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) gomock.InOrder( - mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Handler), + mockBeatPipeline.EXPECT().ConnectWith(gomock.Any()).Return(mockClient, nil), + mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Handler), mockS3Handler.EXPECT().ProcessS3Object().Return(errors.New("fake connectivity problem")), + mockClient.EXPECT().Close(), ) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockBeatPipeline, mockS3HandlerFactory) err := p.ProcessSQS(ctx, &msg) t.Log(err) require.Error(t, err) @@ -149,6 +164,8 @@ func TestSQSS3EventProcessor(t *testing.T) { mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) mockS3Handler := NewMockS3ObjectHandler(ctrl) + mockClient := NewMockBeatClient(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) msg := msg msg.Attributes = map[string]string{ @@ -156,12 +173,14 @@ func TestSQSS3EventProcessor(t *testing.T) { } gomock.InOrder( - mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Handler), + mockBeatPipeline.EXPECT().ConnectWith(gomock.Any()).Return(mockClient, nil), + mockS3HandlerFactory.EXPECT().Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockS3Handler), mockS3Handler.EXPECT().ProcessS3Object().Return(errors.New("fake connectivity problem")), + mockClient.EXPECT().Close(), mockAPI.EXPECT().DeleteMessage(gomock.Any(), gomock.Eq(&msg)).Return(nil), ) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, time.Minute, 5, mockBeatPipeline, mockS3HandlerFactory) err := p.ProcessSQS(ctx, &msg) t.Log(err) require.Error(t, err) @@ -202,11 +221,12 @@ func TestSqsProcessor_keepalive(t *testing.T) { defer ctrl.Finish() mockAPI := NewMockSQSAPI(ctrl) mockS3HandlerFactory := NewMockS3ObjectHandlerFactory(ctrl) + mockBeatPipeline := NewMockBeatPipeline(ctrl) mockAPI.EXPECT().ChangeMessageVisibility(gomock.Any(), gomock.Eq(&msg), gomock.Eq(visibilityTimeout)). Times(1).Return(tc.Err) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, visibilityTimeout, 5, mockS3HandlerFactory) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, mockAPI, nil, visibilityTimeout, 5, mockBeatPipeline, mockS3HandlerFactory) var wg sync.WaitGroup wg.Add(1) p.keepalive(ctx, p.log, &wg, &msg) @@ -218,7 +238,7 @@ func TestSqsProcessor_keepalive(t *testing.T) { func TestSqsProcessor_getS3Notifications(t *testing.T) { require.NoError(t, logp.TestingSetup()) - p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, nil, nil, time.Minute, 5, nil) + p := newSQSS3EventProcessor(logp.NewLogger(inputName), nil, nil, nil, time.Minute, 5, nil, nil) t.Run("s3 key is url unescaped", func(t *testing.T) { msg := newSQSMessage(newS3Event("Happy+Face.jpg"))