diff --git a/x-pack/filebeat/input/awss3/input.go b/x-pack/filebeat/input/awss3/input.go index 855403e5dc4..2c0372fe561 100644 --- a/x-pack/filebeat/input/awss3/input.go +++ b/x-pack/filebeat/input/awss3/input.go @@ -13,6 +13,7 @@ import ( "time" awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/smithy-go" @@ -21,7 +22,6 @@ import ( v2 "github.com/elastic/beats/v7/filebeat/input/v2" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/feature" - "github.com/elastic/beats/v7/libbeat/statestore" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" conf "github.com/elastic/elastic-agent-libs/config" "github.com/elastic/go-concert/unison" @@ -99,21 +99,6 @@ func (in *s3Input) Test(ctx v2.TestContext) error { } func (in *s3Input) Run(inputContext v2.Context, pipeline beat.Pipeline) error { - var err error - - persistentStore, err := in.store.Access() - if err != nil { - return fmt.Errorf("can not access persistent store: %w", err) - } - - defer persistentStore.Close() - - states := newStates(inputContext) - err = states.readStatesFrom(persistentStore) - if err != nil { - return fmt.Errorf("can not start persistent store: %w", err) - } - // Wrap input Context's cancellation Done channel a context.Context. This // goroutine stops with the parent closes the Done channel. ctx, cancelInputCtx := context.WithCancel(context.Background()) @@ -168,8 +153,20 @@ func (in *s3Input) Run(inputContext v2.Context, pipeline beat.Pipeline) error { } defer client.Close() + // Connect to the registry and create our states lookup + persistentStore, err := in.store.Access() + if err != nil { + return fmt.Errorf("can not access persistent store: %w", err) + } + defer persistentStore.Close() + + states, err := newStates(inputContext, persistentStore) + if err != nil { + return fmt.Errorf("can not start persistent store: %w", err) + } + // Create S3 receiver and S3 notification processor. - poller, err := in.createS3Lister(inputContext, ctx, client, persistentStore, states) + poller, err := in.createS3Lister(inputContext, ctx, client, states) if err != nil { return fmt.Errorf("failed to initialize s3 poller: %w", err) } @@ -240,7 +237,7 @@ func (n nonAWSBucketResolver) ResolveEndpoint(region string, options s3.Endpoint return awssdk.Endpoint{URL: n.endpoint, SigningRegion: region, HostnameImmutable: true, Source: awssdk.EndpointSourceCustom}, nil } -func (in *s3Input) createS3Lister(ctx v2.Context, cancelCtx context.Context, client beat.Client, persistentStore *statestore.Store, states *states) (*s3Poller, error) { +func (in *s3Input) createS3Lister(ctx v2.Context, cancelCtx context.Context, client beat.Client, states *states) (*s3Poller, error) { var bucketName string var bucketID string if in.config.NonAWSBucketName != "" { @@ -260,6 +257,12 @@ func (in *s3Input) createS3Lister(ctx v2.Context, cancelCtx context.Context, cli o.EndpointOptions.UseFIPSEndpoint = awssdk.FIPSEndpointStateEnabled } o.UsePathStyle = in.config.PathStyle + + o.Retryer = retry.NewStandard(func(so *retry.StandardOptions) { + so.MaxAttempts = 5 + // Recover quickly when requests start working again + so.NoRetryIncrement = 100 + }) }) regionName, err := getRegionForBucket(cancelCtx, s3Client, bucketName) if err != nil { @@ -305,7 +308,6 @@ func (in *s3Input) createS3Lister(ctx v2.Context, cancelCtx context.Context, cli client, s3EventHandlerFactory, states, - persistentStore, bucketID, in.config.BucketListPrefix, in.awsConfig.Region, diff --git a/x-pack/filebeat/input/awss3/input_benchmark_test.go b/x-pack/filebeat/input/awss3/input_benchmark_test.go index e05e5b461ca..5d22d141168 100644 --- a/x-pack/filebeat/input/awss3/input_benchmark_test.go +++ b/x-pack/filebeat/input/awss3/input_benchmark_test.go @@ -8,7 +8,6 @@ import ( "context" "errors" "fmt" - "io/ioutil" "os" "path/filepath" "runtime" @@ -16,6 +15,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/elastic/beats/v7/libbeat/statestore" "github.com/elastic/beats/v7/libbeat/statestore/storetest" @@ -132,7 +133,7 @@ type constantS3 struct { var _ s3API = (*constantS3)(nil) func newConstantS3(t testing.TB) *constantS3 { - data, err := ioutil.ReadFile(cloudtrailTestFile) + data, err := os.ReadFile(cloudtrailTestFile) if err != nil { t.Fatal(err) } @@ -342,14 +343,11 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult return } - err = store.Set(awsS3WriteCommitPrefix+"bucket"+listPrefix, &commitWriteState{time.Time{}}) - if err != nil { - errChan <- err - return - } + states, err := newStates(inputCtx, store) + assert.NoError(t, err, "states creation should succeed") s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, config.FileSelectors, backupConfig{}, numberOfWorkers) - s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, client, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", listPrefix, "region", "provider", numberOfWorkers, time.Second) + s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, client, s3EventHandlerFactory, states, "bucket", listPrefix, "region", "provider", numberOfWorkers, time.Second) if err := s3Poller.Poll(ctx); err != nil { if !errors.Is(err, context.DeadlineExceeded) { diff --git a/x-pack/filebeat/input/awss3/s3.go b/x-pack/filebeat/input/awss3/s3.go index 5aa8d31e95d..8909f78bb39 100644 --- a/x-pack/filebeat/input/awss3/s3.go +++ b/x-pack/filebeat/input/awss3/s3.go @@ -11,34 +11,22 @@ import ( "sync" "time" - "github.com/gofrs/uuid" - "go.uber.org/multierr" + "github.com/aws/aws-sdk-go-v2/aws/ratelimit" "github.com/elastic/beats/v7/libbeat/beat" - "github.com/elastic/beats/v7/libbeat/statestore" + "github.com/elastic/beats/v7/libbeat/common/backoff" awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/go-concert/timed" ) -const maxCircuitBreaker = 5 - -type commitWriteState struct { - time.Time -} - -type s3ObjectInfo struct { - name string - key string - etag string - lastModified time.Time - listingID string -} +// var instead of const so it can be reduced during unit tests (instead of waiting +// through 10 minutes of retry backoff) +var readerLoopMaxCircuitBreaker = 10 type s3ObjectPayload struct { s3ObjectHandler s3ObjectHandler - s3ObjectInfo s3ObjectInfo - s3ObjectEvent s3EventV2 + objectState state } type s3Poller struct { @@ -48,15 +36,12 @@ type s3Poller struct { region string provider string bucketPollInterval time.Duration - workerSem *awscommon.Sem s3 s3API log *logp.Logger metrics *inputMetrics client beat.Client s3ObjectHandler s3ObjectHandlerFactory states *states - store *statestore.Store - workersListingMap *sync.Map workersProcessingMap *sync.Map } @@ -66,7 +51,6 @@ func newS3Poller(log *logp.Logger, client beat.Client, s3ObjectHandler s3ObjectHandlerFactory, states *states, - store *statestore.Store, bucket string, listPrefix string, awsRegion string, @@ -85,41 +69,17 @@ func newS3Poller(log *logp.Logger, region: awsRegion, provider: provider, bucketPollInterval: bucketPollInterval, - workerSem: awscommon.NewSem(numberOfWorkers), s3: s3, log: log, metrics: metrics, client: client, s3ObjectHandler: s3ObjectHandler, states: states, - store: store, - workersListingMap: new(sync.Map), workersProcessingMap: new(sync.Map), } } -func (p *s3Poller) handlePurgingLock(info s3ObjectInfo, isStored bool) { - id := stateID(info.name, info.key, info.etag, info.lastModified) - previousState := p.states.FindPreviousByID(id) - if !previousState.IsEmpty() { - if isStored { - previousState.MarkAsStored() - } else { - previousState.MarkAsError() - } - - p.states.Update(previousState, info.listingID) - } - - // Manage locks for purging. - if p.states.IsListingFullyStored(info.listingID) { - // locked on processing we unlock when all the object were ACKed - lock, _ := p.workersListingMap.Load(info.listingID) - lock.(*sync.Mutex).Unlock() - } -} - -func (p *s3Poller) createS3ObjectProcessor(ctx context.Context, state state) (s3ObjectHandler, s3EventV2) { +func (p *s3Poller) createS3ObjectProcessor(ctx context.Context, state state) s3ObjectHandler { event := s3EventV2{} event.AWSRegion = p.region event.Provider = p.provider @@ -129,275 +89,126 @@ func (p *s3Poller) createS3ObjectProcessor(ctx context.Context, state state) (s3 acker := awscommon.NewEventACKTracker(ctx) - return p.s3ObjectHandler.Create(ctx, p.log, p.client, acker, event), event + return p.s3ObjectHandler.Create(ctx, p.log, p.client, acker, event) } -func (p *s3Poller) ProcessObject(s3ObjectPayloadChan <-chan *s3ObjectPayload) error { - var errs []error +func (p *s3Poller) workerLoop(ctx context.Context, s3ObjectPayloadChan <-chan *s3ObjectPayload) { + rateLimitWaiter := backoff.NewEqualJitterBackoff(ctx.Done(), 1, 120) for s3ObjectPayload := range s3ObjectPayloadChan { - // Process S3 object (download, parse, create events). - err := s3ObjectPayload.s3ObjectHandler.ProcessS3Object() + objHandler := s3ObjectPayload.s3ObjectHandler + state := s3ObjectPayload.objectState - // Wait for all events to be ACKed before proceeding. - s3ObjectPayload.s3ObjectHandler.Wait() + // Process S3 object (download, parse, create events). + err := objHandler.ProcessS3Object() + if errors.Is(err, errS3DownloadFailed) { + // Download errors are ephemeral. Add a backoff delay, then skip to the + // next iteration so we don't mark the object as permanently failed. + rateLimitWaiter.Wait() + continue + } + // Reset the rate limit delay on results that aren't download errors. + rateLimitWaiter.Reset() - info := s3ObjectPayload.s3ObjectInfo + // Wait for downloaded objects to be ACKed. + objHandler.Wait() if err != nil { - event := s3ObjectPayload.s3ObjectEvent - errs = append(errs, - fmt.Errorf( - fmt.Sprintf("failed processing S3 event for object key %q in bucket %q: %%w", - event.S3.Object.Key, event.S3.Bucket.Name), - err)) - - p.handlePurgingLock(info, false) - continue + p.log.Errorf("failed processing S3 event for object key %q in bucket %q: %v", + state.Key, state.Bucket, err.Error()) + + // Non-retryable error. + state.Failed = true + } else { + state.Stored = true } - p.handlePurgingLock(info, true) + // Persist the result + p.states.AddState(state) // Metrics p.metrics.s3ObjectsAckedTotal.Inc() } - - return multierr.Combine(errs...) } -func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<- *s3ObjectPayload) { +func (p *s3Poller) readerLoop(ctx context.Context, s3ObjectPayloadChan chan<- *s3ObjectPayload) { defer close(s3ObjectPayloadChan) bucketName := getBucketNameFromARN(p.bucket) + errorBackoff := backoff.NewEqualJitterBackoff(ctx.Done(), 1, 120) circuitBreaker := 0 paginator := p.s3.ListObjectsPaginator(bucketName, p.listPrefix) for paginator.HasMorePages() { page, err := paginator.NextPage(ctx) - if err != nil { - if !paginator.HasMorePages() { - break - } + if err != nil { p.log.Warnw("Error when paginating listing.", "error", err) - circuitBreaker++ - if circuitBreaker >= maxCircuitBreaker { - p.log.Warnw(fmt.Sprintf("%d consecutive error when paginating listing, breaking the circuit.", circuitBreaker), "error", err) - break + // QuotaExceededError is client-side rate limiting in the AWS sdk, + // don't include it in the circuit breaker count + if !errors.As(err, &ratelimit.QuotaExceededError{}) { + circuitBreaker++ + if circuitBreaker >= readerLoopMaxCircuitBreaker { + p.log.Warnw(fmt.Sprintf("%d consecutive error when paginating listing, breaking the circuit.", circuitBreaker), "error", err) + break + } } + // add a backoff delay and try again + errorBackoff.Wait() continue } + // Reset the circuit breaker and the error backoff if a read is successful + circuitBreaker = 0 + errorBackoff.Reset() - listingID, err := uuid.NewV4() - if err != nil { - p.log.Warnw("Error generating UUID for listing page.", "error", err) - continue - } - - // lock for the listing page and state in workersListingMap - // this map is shared with the storedOp and will be unlocked there - lock := new(sync.Mutex) - lock.Lock() - p.workersListingMap.Store(listingID.String(), lock) - - totProcessableObjects := 0 totListedObjects := len(page.Contents) - s3ObjectPayloadChanByPage := make(chan *s3ObjectPayload, totListedObjects) // Metrics p.metrics.s3ObjectsListedTotal.Add(uint64(totListedObjects)) for _, object := range page.Contents { - state := newState(bucketName, *object.Key, *object.ETag, p.listPrefix, *object.LastModified) - if p.states.MustSkip(state, p.store) { + state := newState(bucketName, *object.Key, *object.ETag, *object.LastModified) + if p.states.IsProcessed(state) { p.log.Debugw("skipping state.", "state", state) continue } - // we have no previous state or the previous state - // is not stored: refresh the state - previousState := p.states.FindPrevious(state) - if previousState.IsEmpty() || !previousState.IsProcessed() { - p.states.Update(state, "") - } - - s3Processor, event := p.createS3ObjectProcessor(ctx, state) + s3Processor := p.createS3ObjectProcessor(ctx, state) if s3Processor == nil { p.log.Debugw("empty s3 processor.", "state", state) continue } - totProcessableObjects++ - - s3ObjectPayloadChanByPage <- &s3ObjectPayload{ + s3ObjectPayloadChan <- &s3ObjectPayload{ s3ObjectHandler: s3Processor, - s3ObjectInfo: s3ObjectInfo{ - name: bucketName, - key: *object.Key, - etag: *object.ETag, - lastModified: *object.LastModified, - listingID: listingID.String(), - }, - s3ObjectEvent: event, - } - } - - if totProcessableObjects == 0 { - p.log.Debugw("0 processable objects on bucket pagination.", "bucket", p.bucket, "listPrefix", p.listPrefix, "listingID", listingID) - // nothing to be ACKed, unlock here - p.states.DeleteListing(listingID.String()) - lock.Unlock() - } else { - listingInfo := &listingInfo{totObjects: totProcessableObjects} - p.states.AddListing(listingID.String(), listingInfo) - - // Metrics - p.metrics.s3ObjectsProcessedTotal.Add(uint64(totProcessableObjects)) - } - - close(s3ObjectPayloadChanByPage) - for s3ObjectPayload := range s3ObjectPayloadChanByPage { - s3ObjectPayloadChan <- s3ObjectPayload - } - } -} - -func (p *s3Poller) Purge(ctx context.Context) { - listingIDs := p.states.GetListingIDs() - p.log.Debugw("purging listing.", "listingIDs", listingIDs) - for _, listingID := range listingIDs { - // we lock here in order to process the purge only after - // full listing page is ACKed by all the workers - lock, loaded := p.workersListingMap.Load(listingID) - if !loaded { - // purge calls can overlap, GetListingIDs can return - // an outdated snapshot with listing already purged - p.states.DeleteListing(listingID) - p.log.Debugw("deleting already purged listing from states.", "listingID", listingID) - continue - } - - lock.(*sync.Mutex).Lock() - - states := map[string]*state{} - latestStoredTimeByBucketAndListPrefix := make(map[string]time.Time, 0) - - listingStates := p.states.GetStatesByListingID(listingID) - for i, state := range listingStates { - // it is not stored, keep - if !state.IsProcessed() { - p.log.Debugw("state not stored or with error, skip purge", "state", state) - continue + objectState: state, } - var latestStoredTime time.Time - states[state.ID] = &listingStates[i] - latestStoredTime, ok := latestStoredTimeByBucketAndListPrefix[state.Bucket+state.ListPrefix] - if !ok { - var commitWriteState commitWriteState - err := p.store.Get(awsS3WriteCommitPrefix+state.Bucket+state.ListPrefix, &commitWriteState) - if err == nil { - // we have no entry in the map, and we have no entry in the store - // set zero time - latestStoredTime = time.Time{} - p.log.Debugw("last stored time is zero time", "bucket", state.Bucket, "listPrefix", state.ListPrefix) - } else { - latestStoredTime = commitWriteState.Time - p.log.Debugw("last stored time is commitWriteState", "commitWriteState", commitWriteState, "bucket", state.Bucket, "listPrefix", state.ListPrefix) - } - } else { - p.log.Debugw("last stored time from memory", "latestStoredTime", latestStoredTime, "bucket", state.Bucket, "listPrefix", state.ListPrefix) - } - - if state.LastModified.After(latestStoredTime) { - p.log.Debugw("last stored time updated", "state.LastModified", state.LastModified, "bucket", state.Bucket, "listPrefix", state.ListPrefix) - latestStoredTimeByBucketAndListPrefix[state.Bucket+state.ListPrefix] = state.LastModified - } - } - - for key := range states { - p.states.Delete(key) - } - - if err := p.states.writeStates(p.store); err != nil { - p.log.Errorw("Failed to write states to the registry", "error", err) - } - - for bucketAndListPrefix, latestStoredTime := range latestStoredTimeByBucketAndListPrefix { - if err := p.store.Set(awsS3WriteCommitPrefix+bucketAndListPrefix, commitWriteState{latestStoredTime}); err != nil { - p.log.Errorw("Failed to write commit time to the registry", "error", err) - } - } - - // purge is done, we can unlock and clean - lock.(*sync.Mutex).Unlock() - p.workersListingMap.Delete(listingID) - p.states.DeleteListing(listingID) - - // Listing is removed from all states, we can finalize now - for _, state := range states { - processor, _ := p.createS3ObjectProcessor(ctx, *state) - if err := processor.FinalizeS3Object(); err != nil { - p.log.Errorw("Failed to finalize S3 object", "key", state.Key, "error", err) - } + p.metrics.s3ObjectsProcessedTotal.Inc() } } } func (p *s3Poller) Poll(ctx context.Context) error { - // This loop tries to keep the workers busy as much as possible while - // honoring the number in config opposed to a simpler loop that does one - // listing, sequentially processes every object and then does another listing - workerWg := new(sync.WaitGroup) for ctx.Err() == nil { - // Determine how many S3 workers are available. - workers, err := p.workerSem.AcquireContext(p.numberOfWorkers, ctx) - if err != nil { - break - } - - if workers == 0 { - continue - } + var workerWg sync.WaitGroup + workChan := make(chan *s3ObjectPayload) - s3ObjectPayloadChan := make(chan *s3ObjectPayload) - - workerWg.Add(1) - go func() { - defer func() { - workerWg.Done() - }() - - p.GetS3Objects(ctx, s3ObjectPayloadChan) - p.Purge(ctx) - }() - - workerWg.Add(workers) - for i := 0; i < workers; i++ { + // Start the worker goroutines to listen on the work channel + for i := 0; i < p.numberOfWorkers; i++ { + workerWg.Add(1) go func() { - defer func() { - workerWg.Done() - p.workerSem.Release(1) - }() - if err := p.ProcessObject(s3ObjectPayloadChan); err != nil { - p.log.Warnw("Failed processing S3 listing.", "error", err) - } + defer workerWg.Done() + p.workerLoop(ctx, workChan) }() } - err = timed.Wait(ctx, p.bucketPollInterval) - if err != nil { - if errors.Is(err, context.Canceled) { - // A canceled context is a normal shutdown. - return nil - } + // Start reading data and wait for its processing to be done + p.readerLoop(ctx, workChan) + workerWg.Wait() - return err - } + _ = timed.Wait(ctx, p.bucketPollInterval) } - // Wait for all workers to finish. - workerWg.Wait() - if errors.Is(ctx.Err(), context.Canceled) { // A canceled context is a normal shutdown. return nil diff --git a/x-pack/filebeat/input/awss3/s3_objects.go b/x-pack/filebeat/input/awss3/s3_objects.go index 32911778336..21dfa2243e7 100644 --- a/x-pack/filebeat/input/awss3/s3_objects.go +++ b/x-pack/filebeat/input/awss3/s3_objects.go @@ -43,6 +43,11 @@ type s3ObjectProcessorFactory struct { backupConfig backupConfig } +// errS3DownloadFailed reports problems downloading an S3 object. Download errors +// should never treated as permanent, they are just an indication to apply a +// retry backoff until the connection is healthy again. +var errS3DownloadFailed = errors.New("S3 download failure") + func newS3ObjectProcessorFactory(log *logp.Logger, metrics *inputMetrics, s3 s3API, sel []fileSelectorConfig, backupConfig backupConfig, maxWorkers int) *s3ObjectProcessorFactory { if metrics == nil { // Metrics are optional. Initialize a stub. @@ -135,8 +140,9 @@ func (p *s3ObjectProcessor) ProcessS3Object() error { // Request object (download). contentType, meta, body, err := p.download() if err != nil { - return fmt.Errorf("failed to get s3 object (elapsed_time_ns=%d): %w", - time.Since(start).Nanoseconds(), err) + // Wrap downloadError in the result so the caller knows it's not a + // permanent failure. + return fmt.Errorf("%w: %w", errS3DownloadFailed, err) } defer body.Close() p.s3Metadata = meta @@ -434,10 +440,7 @@ func (p *s3ObjectProcessor) FinalizeS3Object() error { if bucketName == "" { return nil } - backupKey := p.s3Obj.S3.Object.Key - if p.backupConfig.BackupToBucketPrefix != "" { - backupKey = fmt.Sprintf("%s%s", p.backupConfig.BackupToBucketPrefix, backupKey) - } + backupKey := p.backupConfig.BackupToBucketPrefix + p.s3Obj.S3.Object.Key _, err := p.s3.CopyObject(p.ctx, p.s3Obj.S3.Bucket.Name, bucketName, p.s3Obj.S3.Object.Key, backupKey) if err != nil { return fmt.Errorf("failed to copy object to backup bucket: %w", err) diff --git a/x-pack/filebeat/input/awss3/s3_objects_test.go b/x-pack/filebeat/input/awss3/s3_objects_test.go index 6732c12e057..28e8f4f42a5 100644 --- a/x-pack/filebeat/input/awss3/s3_objects_test.go +++ b/x-pack/filebeat/input/awss3/s3_objects_test.go @@ -8,7 +8,8 @@ import ( "bytes" "context" "errors" - "io/ioutil" + "io" + "os" "path/filepath" "strings" "testing" @@ -27,7 +28,7 @@ import ( ) func newS3Object(t testing.TB, filename, contentType string) (s3EventV2, *s3.GetObjectOutput) { - data, err := ioutil.ReadFile(filename) + data, err := os.ReadFile(filename) if err != nil { t.Fatal(err) } @@ -39,7 +40,7 @@ func newS3GetObjectResponse(filename string, data []byte, contentType string) *s r := bytes.NewReader(data) getObjectOutput := s3.GetObjectOutput{} getObjectOutput.ContentLength = int64(r.Len()) - getObjectOutput.Body = ioutil.NopCloser(r) + getObjectOutput.Body = io.NopCloser(r) if contentType != "" { getObjectOutput.ContentType = &contentType } @@ -157,7 +158,7 @@ func TestS3ObjectProcessor(t *testing.T) { ack := awscommon.NewEventACKTracker(ctx) err := s3ObjProc.Create(ctx, logp.NewLogger(inputName), mockPublisher, ack, s3Event).ProcessS3Object() require.Error(t, err) - assert.True(t, errors.Is(err, errFakeConnectivityFailure), "expected errFakeConnectivityFailure error") + assert.True(t, errors.Is(err, errS3DownloadFailed), "expected errS3DownloadFailed") }) t.Run("no error empty result in download", func(t *testing.T) { diff --git a/x-pack/filebeat/input/awss3/s3_test.go b/x-pack/filebeat/input/awss3/s3_test.go index b94ba7cfb09..be1d65b796e 100644 --- a/x-pack/filebeat/input/awss3/s3_test.go +++ b/x-pack/filebeat/input/awss3/s3_test.go @@ -13,7 +13,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/elastic/beats/v7/libbeat/statestore" @@ -134,12 +133,16 @@ func TestS3Poller(t *testing.T) { Return(nil, errFakeConnectivityFailure) s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockAPI, nil, backupConfig{}, numberOfWorkers) - receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, s3ObjProc, newStates(inputCtx), store, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) + states, err := newStates(inputCtx, store) + require.NoError(t, err, "states creation must succeed") + receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, s3ObjProc, states, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) require.Error(t, context.DeadlineExceeded, receiver.Poll(ctx)) - assert.Equal(t, numberOfWorkers, receiver.workerSem.Available()) }) - t.Run("retry after Poll error", func(t *testing.T) { + t.Run("restart bucket scan after paging errors", func(t *testing.T) { + // Change the restart limit to 2 consecutive errors, so the test doesn't + // take too long to run + readerLoopMaxCircuitBreaker = 2 storeReg := statestore.NewRegistry(storetest.NewMemoryStoreBackend()) store, err := storeReg.Get("test") if err != nil { @@ -176,13 +179,13 @@ func TestS3Poller(t *testing.T) { // Initial Next gets an error. mockPagerFirst.EXPECT(). HasMorePages(). - Times(10). + Times(2). DoAndReturn(func() bool { return true }) mockPagerFirst.EXPECT(). NextPage(gomock.Any()). - Times(5). + Times(2). DoAndReturn(func(_ context.Context, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { return nil, errFakeConnectivityFailure }) @@ -257,8 +260,9 @@ func TestS3Poller(t *testing.T) { Return(nil, errFakeConnectivityFailure) s3ObjProc := newS3ObjectProcessorFactory(logp.NewLogger(inputName), nil, mockAPI, nil, backupConfig{}, numberOfWorkers) - receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, s3ObjProc, newStates(inputCtx), store, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) + states, err := newStates(inputCtx, store) + require.NoError(t, err, "states creation must succeed") + receiver := newS3Poller(logp.NewLogger(inputName), nil, mockAPI, mockPublisher, s3ObjProc, states, bucket, "key", "region", "provider", numberOfWorkers, pollInterval) require.Error(t, context.DeadlineExceeded, receiver.Poll(ctx)) - assert.Equal(t, numberOfWorkers, receiver.workerSem.Available()) }) } diff --git a/x-pack/filebeat/input/awss3/state.go b/x-pack/filebeat/input/awss3/state.go index 97fb8d538cd..4b7e09f9e7f 100644 --- a/x-pack/filebeat/input/awss3/state.go +++ b/x-pack/filebeat/input/awss3/state.go @@ -5,84 +5,52 @@ package awss3 import ( - "fmt" "time" ) // state is used to communicate the publishing state of a s3 object type state struct { - // ID is used to identify the state in the store, and it is composed by - // Bucket + Key + Etag + LastModified.String(): changing this value or how it is - // composed will break backward compatibilities with entries already in the store. - ID string `json:"id" struct:"id"` Bucket string `json:"bucket" struct:"bucket"` Key string `json:"key" struct:"key"` Etag string `json:"etag" struct:"etag"` LastModified time.Time `json:"last_modified" struct:"last_modified"` - // ListPrefix is used for unique of the key in the store for awsS3WriteCommitPrefix - ListPrefix string `json:"list_prefix" struct:"list_prefix"` - // A state has Stored = true when all events are ACKed. Stored bool `json:"stored" struct:"stored"` - // A state has Error = true when ProcessS3Object returned an error - Error bool `json:"error" struct:"error"` + + // Failed is true when ProcessS3Object returned an error other than + // s3DownloadError. + // Before 8.14, this field was called "error". However, that field was + // set for many ephemeral reasons including client-side rate limiting + // (see https://github.com/elastic/beats/issues/39114). Now that we + // don't treat download errors as permanent, the field name was changed + // so that users upgrading from old versions aren't prevented from + // retrying old download failures. + Failed bool `json:"failed" struct:"failed"` } +// ID is used to identify the state in the store, and it is composed by +// Bucket + Key + Etag + LastModified.String(): changing this value or how it is +// composed will break backward compatibilities with entries already in the store. func stateID(bucket, key, etag string, lastModified time.Time) string { return bucket + key + etag + lastModified.String() } // newState creates a new s3 object state -func newState(bucket, key, etag, listPrefix string, lastModified time.Time) state { - s := state{ +func newState(bucket, key, etag string, lastModified time.Time) state { + return state{ Bucket: bucket, Key: key, LastModified: lastModified, Etag: etag, - ListPrefix: listPrefix, - Stored: false, - Error: false, } - - s.ID = stateID(s.Bucket, s.Key, s.Etag, s.LastModified) - - return s } -// MarkAsStored set the stored flag to true -func (s *state) MarkAsStored() { - s.Stored = true -} - -// MarkAsError set the error flag to true -func (s *state) MarkAsError() { - s.Error = true -} - -// IsProcessed checks if the state is either Stored or Error -func (s *state) IsProcessed() bool { - return s.Stored || s.Error +func (s *state) ID() string { + return stateID(s.Bucket, s.Key, s.Etag, s.LastModified) } // IsEqual checks if the two states point to the same s3 object. func (s *state) IsEqual(c *state) bool { return s.Bucket == c.Bucket && s.Key == c.Key && s.Etag == c.Etag && s.LastModified.Equal(c.LastModified) } - -// IsEmpty checks if the state is empty -func (s *state) IsEmpty() bool { - c := state{} - return s.Bucket == c.Bucket && s.Key == c.Key && s.Etag == c.Etag && s.LastModified.Equal(c.LastModified) -} - -// String returns string representation of the struct -func (s *state) String() string { - return fmt.Sprintf( - "{ID: %v, Bucket: %v, Key: %v, Etag: %v, LastModified: %v}", - s.ID, - s.Bucket, - s.Key, - s.Etag, - s.LastModified) -} diff --git a/x-pack/filebeat/input/awss3/state_test.go b/x-pack/filebeat/input/awss3/state_test.go index 24a5e9d81b4..375a44ce79e 100644 --- a/x-pack/filebeat/input/awss3/state_test.go +++ b/x-pack/filebeat/input/awss3/state_test.go @@ -61,7 +61,7 @@ func TestStateIsEqual(t *testing.T) { Key: "/key/to/this/file/1", Etag: "etag", LastModified: lastModifed, - Error: true, + Failed: true, }, { Bucket: "bucket a", diff --git a/x-pack/filebeat/input/awss3/states.go b/x-pack/filebeat/input/awss3/states.go index 449219a867f..edbbcc73793 100644 --- a/x-pack/filebeat/input/awss3/states.go +++ b/x-pack/filebeat/input/awss3/states.go @@ -15,278 +15,64 @@ import ( "github.com/elastic/beats/v7/libbeat/statestore" ) -const ( - awsS3ObjectStatePrefix = "filebeat::aws-s3::state::" - awsS3WriteCommitPrefix = "filebeat::aws-s3::writeCommit::" -) - -type listingInfo struct { - totObjects int - - mu sync.Mutex - storedObjects int - errorObjects int - finalCheck bool -} +const awsS3ObjectStatePrefix = "filebeat::aws-s3::state::" // states handles list of s3 object state. One must use newStates to instantiate a // file states registry. Using the zero-value is not safe. type states struct { - sync.RWMutex - log *logp.Logger - // states store - states []state - - // idx maps state IDs to state indexes for fast lookup and modifications. - idx map[string]int + // Completed S3 object states, indexed by state ID. + // statesLock must be held to access states. + states map[string]state + statesLock sync.Mutex - listingIDs map[string]struct{} - listingInfo *sync.Map - statesByListingID map[string][]state + // The store used to persist state changes to the registry. + // storeLock must be held to access store. + store *statestore.Store + storeLock sync.Mutex } // newStates generates a new states registry. -func newStates(ctx v2.Context) *states { - return &states{ - log: ctx.Logger.Named("states"), - states: nil, - idx: map[string]int{}, - listingInfo: new(sync.Map), - listingIDs: map[string]struct{}{}, - statesByListingID: map[string][]state{}, - } -} - -func (s *states) MustSkip(state state, store *statestore.Store) bool { - if !s.IsNew(state) { - s.log.Debugw("not new state in must skip", "state", state) - return true - } - - previousState := s.FindPrevious(state) - - // status is forgotten. if there is no previous state and - // the state.LastModified is before the last cleanStore - // write commit we can remove - var commitWriteState commitWriteState - err := store.Get(awsS3WriteCommitPrefix+state.Bucket+state.ListPrefix, &commitWriteState) - if err == nil && previousState.IsEmpty() && - (state.LastModified.Before(commitWriteState.Time) || state.LastModified.Equal(commitWriteState.Time)) { - s.log.Debugw("state.LastModified older than writeCommitState in must skip", "state", state, "commitWriteState", commitWriteState) - return true - } - - // the previous state is stored or has error: let's skip - if !previousState.IsEmpty() && previousState.IsProcessed() { - s.log.Debugw("previous state is stored or has error", "state", state) - return true - } - - return false -} - -func (s *states) Delete(id string) { - s.Lock() - defer s.Unlock() - - index := s.findPrevious(id) - if index >= 0 { - last := len(s.states) - 1 - s.states[last], s.states[index] = s.states[index], s.states[last] - s.states = s.states[:last] - - s.idx = map[string]int{} - for i, state := range s.states { - s.idx[state.ID] = i - } - } -} - -// IsListingFullyStored check if listing if fully stored -// After first time the condition is met it will always return false -func (s *states) IsListingFullyStored(listingID string) bool { - info, ok := s.listingInfo.Load(listingID) - if !ok { - return false - } - listingInfo, ok := info.(*listingInfo) - if !ok { - return false - } - - listingInfo.mu.Lock() - defer listingInfo.mu.Unlock() - if listingInfo.finalCheck { - return false - } - - listingInfo.finalCheck = (listingInfo.storedObjects + listingInfo.errorObjects) == listingInfo.totObjects - - if (listingInfo.storedObjects + listingInfo.errorObjects) > listingInfo.totObjects { - s.log.Warnf("unexepected mixmatch between storedObjects (%d), errorObjects (%d) and totObjects (%d)", - listingInfo.storedObjects, listingInfo.errorObjects, listingInfo.totObjects) - } - - return listingInfo.finalCheck -} - -// AddListing add listing info -func (s *states) AddListing(listingID string, listingInfo *listingInfo) { - s.Lock() - defer s.Unlock() - s.listingIDs[listingID] = struct{}{} - s.listingInfo.Store(listingID, listingInfo) -} - -// DeleteListing delete listing info -func (s *states) DeleteListing(listingID string) { - s.Lock() - defer s.Unlock() - delete(s.listingIDs, listingID) - delete(s.statesByListingID, listingID) - s.listingInfo.Delete(listingID) -} - -// Update updates a state. If previous state didn't exist, new one is created -func (s *states) Update(newState state, listingID string) { - s.Lock() - defer s.Unlock() - - id := newState.ID - index := s.findPrevious(id) - - if index >= 0 { - s.states[index] = newState - } else { - // No existing state found, add new one - s.idx[id] = len(s.states) - s.states = append(s.states, newState) - s.log.Debug("New state added for ", newState.ID) - } - - if listingID == "" || !newState.IsProcessed() { - return - } - - // here we increase the number of stored object - info, ok := s.listingInfo.Load(listingID) - if !ok { - return - } - listingInfo, ok := info.(*listingInfo) - if !ok { - return - } - - listingInfo.mu.Lock() - - if newState.Stored { - listingInfo.storedObjects++ - } - - if newState.Error { - listingInfo.errorObjects++ - } - - listingInfo.mu.Unlock() - - if _, ok := s.statesByListingID[listingID]; !ok { - s.statesByListingID[listingID] = make([]state, 0) +func newStates(ctx v2.Context, store *statestore.Store) (*states, error) { + states := &states{ + log: ctx.Logger.Named("states"), + states: map[string]state{}, + store: store, } - - s.statesByListingID[listingID] = append(s.statesByListingID[listingID], newState) + return states, states.loadFromRegistry() } -// FindPrevious lookups a registered state, that matching the new state. -// Returns a zero-state if no match is found. -func (s *states) FindPrevious(newState state) state { - s.RLock() - defer s.RUnlock() - id := newState.ID - i := s.findPrevious(id) - if i < 0 { - return state{} - } - return s.states[i] +func (s *states) IsProcessed(state state) bool { + s.statesLock.Lock() + defer s.statesLock.Unlock() + // Our in-memory table only stores completed objects + _, ok := s.states[state.ID()] + return ok } -// FindPreviousByID lookups a registered state, that matching the id. -// Returns a zero-state if no match is found. -func (s *states) FindPreviousByID(id string) state { - s.RLock() - defer s.RUnlock() - i := s.findPrevious(id) - if i < 0 { - return state{} - } - return s.states[i] -} - -func (s *states) IsNew(state state) bool { - s.RLock() - defer s.RUnlock() - id := state.ID - i := s.findPrevious(id) - - if i < 0 { - return true - } +func (s *states) AddState(state state) { - return !s.states[i].IsEqual(&state) -} + id := state.ID() + // Update in-memory copy + s.statesLock.Lock() + s.states[id] = state + s.statesLock.Unlock() -// findPrevious returns the previous state for the file. -// In case no previous state exists, index -1 is returned -func (s *states) findPrevious(id string) int { - if i, exists := s.idx[id]; exists { - return i + // Persist to the registry + s.storeLock.Lock() + key := awsS3ObjectStatePrefix + id + if err := s.store.Set(key, state); err != nil { + s.log.Errorw("Failed to write states to the registry", "error", err) } - return -1 -} - -// GetStates creates copy of the file states. -func (s *states) GetStates() []state { - s.RLock() - defer s.RUnlock() - - newStates := make([]state, len(s.states)) - copy(newStates, s.states) - - return newStates -} - -// GetListingIDs return a of the listing IDs -func (s *states) GetListingIDs() []string { - s.RLock() - defer s.RUnlock() - listingIDs := make([]string, 0, len(s.listingIDs)) - for listingID := range s.listingIDs { - listingIDs = append(listingIDs, listingID) - } - - return listingIDs -} - -// GetStatesByListingID return a copy of the states by listing ID -func (s *states) GetStatesByListingID(listingID string) []state { - s.RLock() - defer s.RUnlock() - - if _, ok := s.statesByListingID[listingID]; !ok { - return nil - } - - newStates := make([]state, len(s.statesByListingID[listingID])) - copy(newStates, s.statesByListingID[listingID]) - return newStates + s.storeLock.Unlock() } -func (s *states) readStatesFrom(store *statestore.Store) error { - var states []state +func (s *states) loadFromRegistry() error { + states := map[string]state{} - err := store.Each(func(key string, dec statestore.ValueDecoder) (bool, error) { + s.storeLock.Lock() + err := s.store.Each(func(key string, dec statestore.ValueDecoder) (bool, error) { if !strings.HasPrefix(key, awsS3ObjectStatePrefix) { return true, nil } @@ -294,78 +80,30 @@ func (s *states) readStatesFrom(store *statestore.Store) error { // try to decode. Ignore faulty/incompatible values. var st state if err := dec.Decode(&st); err != nil { - // XXX: Do we want to log here? In case we start to store other - // state types in the registry, then this operation will likely fail - // quite often, producing some false-positives in the logs... - return false, err + // Skip this key but continue iteration + s.log.Warnf("invalid S3 state loading object key %v", key) + //nolint:nilerr // One bad object shouldn't stop iteration + return true, nil + } + if !st.Stored && !st.Failed { + // This is from an older version where state could be stored in the + // registry even if the object wasn't processed, or if it encountered + // ephemeral download errors. We don't add these to the in-memory cache, + // so if we see them during a bucket scan we will still retry them. + return true, nil } - st.ID = key[len(awsS3ObjectStatePrefix):] - states = append(states, st) + states[st.ID()] = st return true, nil }) + s.storeLock.Unlock() if err != nil { return err } - states = fixStates(states) - - for _, state := range states { - s.Update(state, "") - } - - return nil -} - -// fixStates cleans up the registry states when updating from an older version -// of filebeat potentially writing invalid entries. -func fixStates(states []state) []state { - if len(states) == 0 { - return states - } - - // we use a map of states here, so to identify and merge duplicate entries. - idx := map[string]*state{} - for i := range states { - state := &states[i] - - old, exists := idx[state.ID] - if !exists { - idx[state.ID] = state - } else { - mergeStates(old, state) // overwrite the entry in 'old' - } - } - - if len(idx) == len(states) { - return states - } - - i := 0 - newStates := make([]state, len(idx)) - for _, state := range idx { - newStates[i] = *state - i++ - } - return newStates -} - -// mergeStates merges 2 states by trying to determine the 'newer' state. -// The st state is overwritten with the updated fields. -func mergeStates(st, other *state) { - // update file meta-data. As these are updated concurrently by the - // inputs, select the newer state based on the update timestamp. - if st.LastModified.Before(other.LastModified) { - st.LastModified = other.LastModified - } -} + s.statesLock.Lock() + s.states = states + s.statesLock.Unlock() -func (s *states) writeStates(store *statestore.Store) error { - for _, state := range s.GetStates() { - key := awsS3ObjectStatePrefix + state.ID - if err := store.Set(key, state); err != nil { - return err - } - } return nil } diff --git a/x-pack/filebeat/input/awss3/states_test.go b/x-pack/filebeat/input/awss3/states_test.go index 39dc4cf82e6..2f8bbf58fdf 100644 --- a/x-pack/filebeat/input/awss3/states_test.go +++ b/x-pack/filebeat/input/awss3/states_test.go @@ -14,6 +14,7 @@ import ( "github.com/elastic/beats/v7/libbeat/statestore/storetest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" v2 "github.com/elastic/beats/v7/filebeat/input/v2" "github.com/elastic/elastic-agent-libs/logp" @@ -46,287 +47,92 @@ var inputCtx = v2.Context{ Cancelation: context.Background(), } -func TestStatesIsNewAndMustSkip(t *testing.T) { +func TestStatesAddStateAndIsProcessed(t *testing.T) { type stateTestCase struct { - states func() *states - state state - mustBeNew bool - persistentStoreKV map[string]interface{} - expectedMustSkip bool - expectedIsNew bool + // An initialization callback to invoke on the (initially empty) states. + statesEdit func(states *states) + + // The state to call IsProcessed on and the expected result + state state + expectedIsProcessed bool + + // If true, the test will run statesEdit, then create a new states + // object from the same persistent store before calling IsProcessed + // (to test persistence between restarts). + shouldReload bool } lastModified := time.Date(2022, time.June, 30, 14, 13, 00, 0, time.UTC) + testState1 := newState("bucket", "key", "etag", lastModified) + testState2 := newState("bucket1", "key1", "etag1", lastModified) tests := map[string]stateTestCase{ "with empty states": { - states: func() *states { - return newStates(inputCtx) - }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified), - expectedMustSkip: false, - expectedIsNew: true, + state: testState1, + expectedIsProcessed: false, }, "not existing state": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key", "etag", "listPrefix", lastModified), "") - return states + statesEdit: func(states *states) { + states.AddState(testState2) }, - state: newState("bucket1", "key1", "etag1", "listPrefix1", lastModified), - expectedMustSkip: false, - expectedIsNew: true, + state: testState1, + expectedIsProcessed: false, }, "existing state": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key", "etag", "listPrefix", lastModified), "") - return states - }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified), - expectedMustSkip: true, - expectedIsNew: false, - }, - "with different etag": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key", "etag1", "listPrefix", lastModified), "") - return states - }, - state: newState("bucket", "key", "etag2", "listPrefix", lastModified), - expectedMustSkip: false, - expectedIsNew: true, - }, - "with different lastmodified": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key", "etag", "listPrefix", lastModified), "") - return states - }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified.Add(1*time.Second)), - expectedMustSkip: false, - expectedIsNew: true, - }, - "with stored state": { - states: func() *states { - states := newStates(inputCtx) - aState := newState("bucket", "key", "etag", "listPrefix", lastModified) - aState.Stored = true - states.Update(aState, "") - return states + statesEdit: func(states *states) { + states.AddState(testState1) }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified), - mustBeNew: true, - expectedMustSkip: true, - expectedIsNew: true, + state: testState1, + expectedIsProcessed: true, }, - "with error state": { - states: func() *states { - states := newStates(inputCtx) - aState := newState("bucket", "key", "etag", "listPrefix", lastModified) - aState.Error = true - states.Update(aState, "") - return states + "existing stored state is persisted": { + statesEdit: func(states *states) { + state := testState1 + state.Stored = true + states.AddState(state) }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified), - mustBeNew: true, - expectedMustSkip: true, - expectedIsNew: true, + state: testState1, + shouldReload: true, + expectedIsProcessed: true, }, - "before commit write": { - states: func() *states { - return newStates(inputCtx) + "existing failed state is persisted": { + statesEdit: func(states *states) { + state := testState1 + state.Failed = true + states.AddState(state) }, - persistentStoreKV: map[string]interface{}{ - awsS3WriteCommitPrefix + "bucket" + "listPrefix": &commitWriteState{lastModified}, - }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified.Add(-1*time.Second)), - expectedMustSkip: true, - expectedIsNew: true, + state: testState1, + shouldReload: true, + expectedIsProcessed: true, }, - "same commit write": { - states: func() *states { - return newStates(inputCtx) - }, - persistentStoreKV: map[string]interface{}{ - awsS3WriteCommitPrefix + "bucket" + "listPrefix": &commitWriteState{lastModified}, + "existing unprocessed state is not persisted": { + statesEdit: func(states *states) { + states.AddState(testState1) }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified), - expectedMustSkip: true, - expectedIsNew: true, - }, - "after commit write": { - states: func() *states { - return newStates(inputCtx) - }, - persistentStoreKV: map[string]interface{}{ - awsS3WriteCommitPrefix + "bucket" + "listPrefix": &commitWriteState{lastModified}, - }, - state: newState("bucket", "key", "etag", "listPrefix", lastModified.Add(time.Second)), - expectedMustSkip: false, - expectedIsNew: true, + state: testState1, + shouldReload: true, + expectedIsProcessed: false, }, } for name, test := range tests { test := test t.Run(name, func(t *testing.T) { - states := test.states() store := openTestStatestore() persistentStore, err := store.Access() if err != nil { t.Fatalf("unexpected err: %v", err) } - for key, value := range test.persistentStoreKV { - _ = persistentStore.Set(key, value) + states, err := newStates(inputCtx, persistentStore) + require.NoError(t, err, "states creation must succeed") + if test.statesEdit != nil { + test.statesEdit(states) } - - if test.mustBeNew { - test.state.LastModified = test.state.LastModified.Add(1 * time.Second) + if test.shouldReload { + states, err = newStates(inputCtx, persistentStore) + require.NoError(t, err, "states creation must succeed") } - isNew := states.IsNew(test.state) - assert.Equal(t, test.expectedIsNew, isNew) - - mustSkip := states.MustSkip(test.state, persistentStore) - assert.Equal(t, test.expectedMustSkip, mustSkip) - }) - } -} - -func TestStatesDelete(t *testing.T) { - type stateTestCase struct { - states func() *states - deleteID string - expected []state - } - - lastModified := time.Date(2021, time.July, 22, 18, 38, 00, 0, time.UTC) - tests := map[string]stateTestCase{ - "delete empty states": { - states: func() *states { - return newStates(inputCtx) - }, - deleteID: "an id", - expected: []state{}, - }, - "delete not existing state": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key", "etag", "listPrefix", lastModified), "") - return states - }, - deleteID: "an id", - expected: []state{ - { - ID: stateID("bucket", "key", "etag", lastModified), - Bucket: "bucket", - Key: "key", - Etag: "etag", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - }, - }, - "delete only one existing": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key", "etag", "listPrefix", lastModified), "") - return states - }, - deleteID: stateID("bucket", "key", "etag", lastModified), - expected: []state{}, - }, - "delete first": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key1", "etag1", "listPrefix", lastModified), "") - states.Update(newState("bucket", "key2", "etag2", "listPrefix", lastModified), "") - states.Update(newState("bucket", "key3", "etag3", "listPrefix", lastModified), "") - return states - }, - deleteID: "bucketkey1etag1" + lastModified.String(), - expected: []state{ - { - ID: stateID("bucket", "key3", "etag3", lastModified), - Bucket: "bucket", - Key: "key3", - Etag: "etag3", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - { - ID: stateID("bucket", "key2", "etag2", lastModified), - Bucket: "bucket", - Key: "key2", - Etag: "etag2", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - }, - }, - "delete last": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key1", "etag1", "listPrefix", lastModified), "") - states.Update(newState("bucket", "key2", "etag2", "listPrefix", lastModified), "") - states.Update(newState("bucket", "key3", "etag3", "listPrefix", lastModified), "") - return states - }, - deleteID: "bucketkey3etag3" + lastModified.String(), - expected: []state{ - { - ID: stateID("bucket", "key1", "etag1", lastModified), - Bucket: "bucket", - Key: "key1", - Etag: "etag1", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - { - ID: stateID("bucket", "key2", "etag2", lastModified), - Bucket: "bucket", - Key: "key2", - Etag: "etag2", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - }, - }, - "delete any": { - states: func() *states { - states := newStates(inputCtx) - states.Update(newState("bucket", "key1", "etag1", "listPrefix", lastModified), "") - states.Update(newState("bucket", "key2", "etag2", "listPrefix", lastModified), "") - states.Update(newState("bucket", "key3", "etag3", "listPrefix", lastModified), "") - return states - }, - deleteID: "bucketkey2etag2" + lastModified.String(), - expected: []state{ - { - ID: stateID("bucket", "key1", "etag1", lastModified), - Bucket: "bucket", - Key: "key1", - Etag: "etag1", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - { - ID: stateID("bucket", "key3", "etag3", lastModified), - Bucket: "bucket", - Key: "key3", - Etag: "etag3", - ListPrefix: "listPrefix", - LastModified: lastModified, - }, - }, - }, - } - - for name, test := range tests { - test := test - t.Run(name, func(t *testing.T) { - states := test.states() - states.Delete(test.deleteID) - assert.Equal(t, test.expected, states.GetStates()) + isProcessed := states.IsProcessed(test.state) + assert.Equal(t, test.expectedIsProcessed, isProcessed) }) } }