From e4485e1e4e5da0feb356abb0b9adc553c7722b1d Mon Sep 17 00:00:00 2001 From: James Kwon Date: Wed, 12 Jul 2023 18:12:28 -0400 Subject: [PATCH] Mocking AWS Calls for Unit Testing for API Gateway --- aws/apigateway.go | 60 +++++++------- aws/apigateway_test.go | 168 +++++++++++++++------------------------- aws/apigateway_types.go | 19 +++-- aws/aws.go | 8 +- aws/test_utils.go | 19 +++++ config/config.go | 1 + 6 files changed, 133 insertions(+), 142 deletions(-) create mode 100644 aws/test_utils.go diff --git a/aws/apigateway.go b/aws/apigateway.go index 10be3af0..d29b6667 100644 --- a/aws/apigateway.go +++ b/aws/apigateway.go @@ -1,39 +1,40 @@ package aws import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" - commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" "sync" "time" + commonTelemetry "github.com/gruntwork-io/go-commons/telemetry" + "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/apigateway" "github.com/gruntwork-io/cloud-nuke/config" "github.com/gruntwork-io/cloud-nuke/logging" "github.com/gruntwork-io/cloud-nuke/report" + "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/gruntwork-io/go-commons/errors" "github.com/hashicorp/go-multierror" ) -func getAllAPIGateways(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) { - svc := apigateway.New(session) - - result, err := svc.GetRestApis(&apigateway.GetRestApisInput{}) +func (gateway ApiGateway) getAll( + excludeAfter time.Time, configObj config.Config) ([]*string, error) { + result, err := gateway.Client.GetRestApis(&apigateway.GetRestApisInput{}) if err != nil { return []*string{}, errors.WithStackTrace(err) } - Ids := []*string{} - for _, apigateway := range result.Items { - if shouldIncludeAPIGateway(apigateway, excludeAfter, configObj) { - Ids = append(Ids, apigateway.Id) + + var IDs []*string + for _, api := range result.Items { + if gateway.shouldInclude(api, excludeAfter, configObj) { + IDs = append(IDs, api.Id) } } - return Ids, nil + return IDs, nil } -func shouldIncludeAPIGateway(apigw *apigateway.RestApi, excludeAfter time.Time, configObj config.Config) bool { +func (gateway ApiGateway) shouldInclude( + apigw *apigateway.RestApi, excludeAfter time.Time, configObj config.Config) bool { if apigw == nil { return false } @@ -51,28 +52,25 @@ func shouldIncludeAPIGateway(apigw *apigateway.RestApi, excludeAfter time.Time, ) } -func nukeAllAPIGateways(session *session.Session, identifiers []*string) error { - region := aws.StringValue(session.Config.Region) - - svc := apigateway.New(session) - +func (gateway ApiGateway) nukeAll(identifiers []*string) error { if len(identifiers) == 0 { - logging.Logger.Debugf("No API Gateways (v1) to nuke in region %s", region) + logging.Logger.Debugf("No API Gateways (v1) to nuke in region %s", gateway.Region) } if len(identifiers) > 100 { - logging.Logger.Errorf("Nuking too many API Gateways (v1) at once (100): halting to avoid hitting AWS API rate limiting") + logging.Logger.Errorf("Nuking too many API Gateways (v1) at once (100): " + + "halting to avoid hitting AWS API rate limiting") return TooManyApiGatewayErr{} } // There is no bulk delete Api Gateway API, so we delete the batch of gateways concurrently using goroutines - logging.Logger.Debugf("Deleting Api Gateways (v1) in region %s", region) + logging.Logger.Debugf("Deleting Api Gateways (v1) in region %s", gateway.Region) wg := new(sync.WaitGroup) wg.Add(len(identifiers)) errChans := make([]chan error, len(identifiers)) for i, apigwID := range identifiers { errChans[i] = make(chan error, 1) - go deleteApiGatewayAsync(wg, errChans[i], svc, apigwID, region) + go gateway.nukeAsync(wg, errChans[i], apigwID) } wg.Wait() @@ -81,10 +79,11 @@ func nukeAllAPIGateways(session *session.Session, identifiers []*string) error { if err := <-errChan; err != nil { allErrs = multierror.Append(allErrs, err) logging.Logger.Debugf("[Failed] %s", err) + telemetry.TrackEvent(commonTelemetry.EventContext{ EventName: "Error Nuking API Gateway", }, map[string]interface{}{ - "region": *session.Config.Region, + "region": gateway.Region, }) } } @@ -92,14 +91,16 @@ func nukeAllAPIGateways(session *session.Session, identifiers []*string) error { if finalErr != nil { return errors.WithStackTrace(finalErr) } + return nil } -func deleteApiGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *apigateway.APIGateway, apigwID *string, region string) { +func (gateway ApiGateway) nukeAsync( + wg *sync.WaitGroup, errChan chan error, apigwID *string) { defer wg.Done() input := &apigateway.DeleteRestApiInput{RestApiId: apigwID} - _, err := svc.DeleteRestApi(input) + _, err := gateway.Client.DeleteRestApi(input) errChan <- err // Record status of this resource @@ -111,8 +112,11 @@ func deleteApiGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *apigatew report.Record(e) if err == nil { - logging.Logger.Debugf("[OK] API Gateway (v1) %s deleted in %s", aws.StringValue(apigwID), region) - } else { - logging.Logger.Debugf("[Failed] Error deleting API Gateway (v1) %s in %s", aws.StringValue(apigwID), region) + logging.Logger.Debugf("["+ + "OK] API Gateway (v1) %s deleted in %s", aws.StringValue(apigwID), gateway.Region) + return } + + logging.Logger.Debugf( + "[Failed] Error deleting API Gateway (v1) %s in %s", aws.StringValue(apigwID), gateway.Region) } diff --git a/aws/apigateway_test.go b/aws/apigateway_test.go index 85d2e8a4..4776be41 100644 --- a/aws/apigateway_test.go +++ b/aws/apigateway_test.go @@ -1,153 +1,111 @@ package aws import ( - "github.com/gruntwork-io/cloud-nuke/telemetry" "testing" "time" "github.com/aws/aws-sdk-go/aws" awsgo "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/apigateway" + "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" "github.com/gruntwork-io/cloud-nuke/config" + "github.com/gruntwork-io/cloud-nuke/telemetry" "github.com/gruntwork-io/cloud-nuke/util" - "github.com/gruntwork-io/go-commons/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type TestAPIGateway struct { - ID *string - Name *string +type mockedApiGateway struct { + apigatewayiface.APIGatewayAPI + GetRestApisResp apigateway.GetRestApisOutput + DeleteRestApiResp apigateway.DeleteRestApiOutput } -func createTestAPIGateway(t *testing.T, session *session.Session, name string) (*TestAPIGateway, error) { - svc := apigateway.New(session) - - testGw := &TestAPIGateway{ - Name: aws.String(name), - } - - param := &apigateway.CreateRestApiInput{ - Name: aws.String(name), - } - - output, err := svc.CreateRestApi(param) - if err != nil { - assert.Failf(t, "Could not create test API Gateway: %s", errors.WithStackTrace(err).Error()) - } - - testGw.ID = output.Id - - return testGw, nil +func (m mockedApiGateway) GetRestApis(*apigateway.GetRestApisInput) (*apigateway.GetRestApisOutput, error) { + // Only need to return mocked response output + return &m.GetRestApisResp, nil } -func TestListAPIGateways(t *testing.T) { - telemetry.InitTelemetry("cloud-nuke", "") - t.Parallel() - - region, err := getRandomRegion() - require.NoError(t, err) - session, err := session.NewSession(&awsgo.Config{ - Region: awsgo.String(region), - }, - ) - if err != nil { - assert.Fail(t, errors.WithStackTrace(err).Error()) - } - - apigwName := "aws-nuke-test-" + util.UniqueID() - testGw, createTestGwErr := createTestAPIGateway(t, session, apigwName) - require.NoError(t, createTestGwErr) - // clean up after this test - defer nukeAllAPIGateways(session, []*string{testGw.ID}) - - apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) - if err != nil { - assert.Fail(t, "Unable to fetch list of API Gateways (v1)") - } - - assert.Contains(t, awsgo.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) +func (m mockedApiGateway) DeleteRestApi(*apigateway.DeleteRestApiInput) (*apigateway.DeleteRestApiOutput, error) { + // Only need to return mocked response output + return &m.DeleteRestApiResp, nil } -func TestTimeFilterExclusionNewlyCreatedAPIGateway(t *testing.T) { +func TestAPIGatewayGetAllAndNukeAll(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) - - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) - require.NoError(t, err) - - apigwName := "aws-nuke-test-" + util.UniqueID() - - testGw, createTestGwErr := createTestAPIGateway(t, session, apigwName) - require.NoError(t, createTestGwErr) - defer nukeAllAPIGateways(session, []*string{testGw.ID}) + testApiID := "aws-nuke-test-" + util.UniqueID() + apiGateway := ApiGateway{ + Client: mockedApiGateway{ + GetRestApisResp: apigateway.GetRestApisOutput{ + Items: []*apigateway.RestApi{ + {Id: aws.String(testApiID)}, + }, + }, + DeleteRestApiResp: apigateway.DeleteRestApiOutput{}, + }, + } - // Assert API Gateway is picked up without filters - apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) + apis, err := apiGateway.getAll(time.Now(), config.Config{}) require.NoError(t, err) - assert.Contains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) + require.Contains(t, awsgo.StringValueSlice(apis), testApiID) - // Assert API Gateway doesn't appear when we look at API Gateways older than 1 Hour - olderThan := time.Now().Add(-1 * time.Hour) - apiGwIdsOlder, err := getAllAPIGateways(session, olderThan, config.Config{}) + err = apiGateway.nukeAll([]*string{aws.String(testApiID)}) require.NoError(t, err) - assert.NotContains(t, aws.StringValueSlice(apiGwIdsOlder), aws.StringValue(testGw.ID)) } -func TestNukeAPIGatewayOne(t *testing.T) { +func TestAPIGatewayGetAllTimeFilter(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) + testApiID := "aws-nuke-test-" + util.UniqueID() + now := time.Now() + apiGateway := ApiGateway{ + Client: mockedApiGateway{ + GetRestApisResp: apigateway.GetRestApisOutput{ + Items: []*apigateway.RestApi{{ + Id: aws.String(testApiID), + CreatedDate: aws.Time(now), + }}, + }, + }, + } - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) + // test API is not excluded from the filter + IDs, err := apiGateway.getAll(now.Add(1), config.Config{}) require.NoError(t, err) + assert.Contains(t, aws.StringValueSlice(IDs), testApiID) - apigwName := "aws-nuke-test-" + util.UniqueID() - // We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke. - testGw, createTestErr := createTestAPIGateway(t, session, apigwName) - require.NoError(t, createTestErr) - - nukeErr := nukeAllAPIGateways(session, []*string{testGw.ID}) - require.NoError(t, nukeErr) - - // Make sure the API Gateway was deleted - apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) + // test API being excluded from the filter + apiGwIdsOlder, err := apiGateway.getAll(now.Add(-1), config.Config{}) require.NoError(t, err) - - assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) + assert.NotContains(t, aws.StringValueSlice(apiGwIdsOlder), testApiID) } func TestNukeAPIGatewayMoreThanOne(t *testing.T) { telemetry.InitTelemetry("cloud-nuke", "") t.Parallel() - region, err := getRandomRegion() - require.NoError(t, err) + testApiID1 := "aws-nuke-test-" + util.UniqueID() + testApiID2 := "aws-nuke-test-" + util.UniqueID() + apiGateway := ApiGateway{ + Client: mockedApiGateway{ + GetRestApisResp: apigateway.GetRestApisOutput{ + Items: []*apigateway.RestApi{ + {Id: aws.String(testApiID1)}, + {Id: aws.String(testApiID2)}, + }, + }, + DeleteRestApiResp: apigateway.DeleteRestApiOutput{}, + }, + } - session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) + apis, err := apiGateway.getAll(time.Now(), config.Config{}) require.NoError(t, err) + require.Contains(t, awsgo.StringValueSlice(apis), testApiID1) + require.Contains(t, awsgo.StringValueSlice(apis), testApiID2) - apigwName := "aws-nuke-test-" + util.UniqueID() - apigwName2 := "aws-nuke-test-" + util.UniqueID() - // We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke. - testGw, createTestErr := createTestAPIGateway(t, session, apigwName) - require.NoError(t, createTestErr) - testGw2, createTestErr2 := createTestAPIGateway(t, session, apigwName2) - require.NoError(t, createTestErr2) - - nukeErr := nukeAllAPIGateways(session, []*string{testGw.ID, testGw2.ID}) - require.NoError(t, nukeErr) - - // Make sure the API Gateway was deleted - apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{}) + err = apiGateway.nukeAll([]*string{aws.String(testApiID1), aws.String(testApiID2)}) require.NoError(t, err) - - assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID)) - assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw2.ID)) } diff --git a/aws/apigateway_types.go b/aws/apigateway_types.go index 2e3acd38..7d41c3b3 100644 --- a/aws/apigateway_types.go +++ b/aws/apigateway_types.go @@ -3,29 +3,34 @@ package aws import ( awsgo "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" "github.com/gruntwork-io/go-commons/errors" ) type ApiGateway struct { - Ids []string + Client apigatewayiface.APIGatewayAPI + Region string + Ids []string } -func (apigateway ApiGateway) ResourceName() string { +func (gateway ApiGateway) ResourceName() string { return "apigateway" } -func (apigateway ApiGateway) ResourceIdentifiers() []string { - return apigateway.Ids +func (gateway ApiGateway) ResourceIdentifiers() []string { + return gateway.Ids } -func (apigateway ApiGateway) MaxBatchSize() int { +func (gateway ApiGateway) MaxBatchSize() int { return 10 } -func (apigateway ApiGateway) Nuke(session *session.Session, identifiers []string) error { - if err := nukeAllAPIGateways(session, awsgo.StringSlice(identifiers)); err != nil { +func (gateway ApiGateway) Nuke(session *session.Session, identifiers []string) error { + // TODO(james): stop passing in session argument as it is included as part of the gateway struct. + if err := gateway.nukeAll(awsgo.StringSlice(identifiers)); err != nil { return errors.WithStackTrace(err) } + return nil } diff --git a/aws/aws.go b/aws/aws.go index 1d16fa77..820bc7e0 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -2,6 +2,7 @@ package aws import ( "fmt" + "github.com/aws/aws-sdk-go/service/apigateway" "math/rand" "sort" "strings" @@ -1463,10 +1464,13 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp // End Redshift Clusters // API Gateways (v1) - apiGateways := ApiGateway{} + apiGateways := ApiGateway{ + Client: apigateway.New(cloudNukeSession), + Region: region, + } if IsNukeable(apiGateways.ResourceName(), resourceTypes) { start := time.Now() - gatewayIds, err := getAllAPIGateways(cloudNukeSession, excludeAfter, configObj) + gatewayIds, err := apiGateways.getAll(excludeAfter, configObj) if err != nil { ge := report.GeneralError{ Error: err, diff --git a/aws/test_utils.go b/aws/test_utils.go new file mode 100644 index 00000000..a6638797 --- /dev/null +++ b/aws/test_utils.go @@ -0,0 +1,19 @@ +package aws + +import ( + "testing" + + terratestaws "github.com/gruntwork-io/terratest/modules/aws" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/stretchr/testify/require" +) + +func GetTestSession(t *testing.T, approvedRegions []string, forbiddenRegions []string) *session.Session { + region := terratestaws.GetRandomStableRegion(t, approvedRegions, forbiddenRegions) + session, err := session.NewSession(&aws.Config{Region: aws.String(region)}) + require.NoError(t, err) + + return session +} diff --git a/config/config.go b/config/config.go index a02ccfeb..417558fe 100644 --- a/config/config.go +++ b/config/config.go @@ -68,6 +68,7 @@ type ResourceType struct { type FilterRule struct { NamesRegExp []Expression `yaml:"names_regex"` + // TODO(james): consider adding time filter rule here instead of passing in as function argument in other place. } type Expression struct {