Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mocking AWS Calls for Unit Testing for API Gateway #497

Merged
merged 1 commit into from Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 32 additions & 28 deletions aws/apigateway.go
@@ -1,39 +1,40 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"
"sync"
"time"

commonTelemetry "github.com/gruntwork-io/go-commons/telemetry"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigateway"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/cloud-nuke/report"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/gruntwork-io/go-commons/errors"
"github.com/hashicorp/go-multierror"
)

func getAllAPIGateways(session *session.Session, excludeAfter time.Time, configObj config.Config) ([]*string, error) {
svc := apigateway.New(session)

result, err := svc.GetRestApis(&apigateway.GetRestApisInput{})
func (gateway ApiGateway) getAll(
excludeAfter time.Time, configObj config.Config) ([]*string, error) {
result, err := gateway.Client.GetRestApis(&apigateway.GetRestApisInput{})
if err != nil {
return []*string{}, errors.WithStackTrace(err)
}
Ids := []*string{}
for _, apigateway := range result.Items {
if shouldIncludeAPIGateway(apigateway, excludeAfter, configObj) {
Ids = append(Ids, apigateway.Id)

var IDs []*string
for _, api := range result.Items {
if gateway.shouldInclude(api, excludeAfter, configObj) {
IDs = append(IDs, api.Id)
}
}

return Ids, nil
return IDs, nil
}

func shouldIncludeAPIGateway(apigw *apigateway.RestApi, excludeAfter time.Time, configObj config.Config) bool {
func (gateway ApiGateway) shouldInclude(
apigw *apigateway.RestApi, excludeAfter time.Time, configObj config.Config) bool {
if apigw == nil {
return false
}
Expand All @@ -51,28 +52,25 @@ func shouldIncludeAPIGateway(apigw *apigateway.RestApi, excludeAfter time.Time,
)
}

func nukeAllAPIGateways(session *session.Session, identifiers []*string) error {
region := aws.StringValue(session.Config.Region)

svc := apigateway.New(session)

func (gateway ApiGateway) nukeAll(identifiers []*string) error {
if len(identifiers) == 0 {
logging.Logger.Debugf("No API Gateways (v1) to nuke in region %s", region)
logging.Logger.Debugf("No API Gateways (v1) to nuke in region %s", gateway.Region)
}

if len(identifiers) > 100 {
logging.Logger.Errorf("Nuking too many API Gateways (v1) at once (100): halting to avoid hitting AWS API rate limiting")
logging.Logger.Errorf("Nuking too many API Gateways (v1) at once (100): " +
"halting to avoid hitting AWS API rate limiting")
return TooManyApiGatewayErr{}
}

// There is no bulk delete Api Gateway API, so we delete the batch of gateways concurrently using goroutines
logging.Logger.Debugf("Deleting Api Gateways (v1) in region %s", region)
logging.Logger.Debugf("Deleting Api Gateways (v1) in region %s", gateway.Region)
wg := new(sync.WaitGroup)
wg.Add(len(identifiers))
errChans := make([]chan error, len(identifiers))
for i, apigwID := range identifiers {
errChans[i] = make(chan error, 1)
go deleteApiGatewayAsync(wg, errChans[i], svc, apigwID, region)
go gateway.nukeAsync(wg, errChans[i], apigwID)
}
wg.Wait()

Expand All @@ -81,25 +79,28 @@ func nukeAllAPIGateways(session *session.Session, identifiers []*string) error {
if err := <-errChan; err != nil {
allErrs = multierror.Append(allErrs, err)
logging.Logger.Debugf("[Failed] %s", err)

telemetry.TrackEvent(commonTelemetry.EventContext{
EventName: "Error Nuking API Gateway",
}, map[string]interface{}{
"region": *session.Config.Region,
"region": gateway.Region,
})
}
}
finalErr := allErrs.ErrorOrNil()
if finalErr != nil {
return errors.WithStackTrace(finalErr)
}

return nil
}

func deleteApiGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *apigateway.APIGateway, apigwID *string, region string) {
func (gateway ApiGateway) nukeAsync(
wg *sync.WaitGroup, errChan chan error, apigwID *string) {
defer wg.Done()

input := &apigateway.DeleteRestApiInput{RestApiId: apigwID}
_, err := svc.DeleteRestApi(input)
_, err := gateway.Client.DeleteRestApi(input)
errChan <- err

// Record status of this resource
Expand All @@ -111,8 +112,11 @@ func deleteApiGatewayAsync(wg *sync.WaitGroup, errChan chan error, svc *apigatew
report.Record(e)

if err == nil {
logging.Logger.Debugf("[OK] API Gateway (v1) %s deleted in %s", aws.StringValue(apigwID), region)
} else {
logging.Logger.Debugf("[Failed] Error deleting API Gateway (v1) %s in %s", aws.StringValue(apigwID), region)
logging.Logger.Debugf("["+
"OK] API Gateway (v1) %s deleted in %s", aws.StringValue(apigwID), gateway.Region)
return
}

logging.Logger.Debugf(
"[Failed] Error deleting API Gateway (v1) %s in %s", aws.StringValue(apigwID), gateway.Region)
}
168 changes: 63 additions & 105 deletions aws/apigateway_test.go
@@ -1,153 +1,111 @@
package aws

import (
"github.com/gruntwork-io/cloud-nuke/telemetry"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigateway"
"github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface"
"github.com/gruntwork-io/cloud-nuke/config"
"github.com/gruntwork-io/cloud-nuke/telemetry"
"github.com/gruntwork-io/cloud-nuke/util"
"github.com/gruntwork-io/go-commons/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type TestAPIGateway struct {
ID *string
Name *string
type mockedApiGateway struct {
apigatewayiface.APIGatewayAPI
GetRestApisResp apigateway.GetRestApisOutput
DeleteRestApiResp apigateway.DeleteRestApiOutput
}

func createTestAPIGateway(t *testing.T, session *session.Session, name string) (*TestAPIGateway, error) {
svc := apigateway.New(session)

testGw := &TestAPIGateway{
Name: aws.String(name),
}

param := &apigateway.CreateRestApiInput{
Name: aws.String(name),
}

output, err := svc.CreateRestApi(param)
if err != nil {
assert.Failf(t, "Could not create test API Gateway: %s", errors.WithStackTrace(err).Error())
}

testGw.ID = output.Id

return testGw, nil
func (m mockedApiGateway) GetRestApis(*apigateway.GetRestApisInput) (*apigateway.GetRestApisOutput, error) {
// Only need to return mocked response output
return &m.GetRestApisResp, nil
}

func TestListAPIGateways(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)
session, err := session.NewSession(&awsgo.Config{
Region: awsgo.String(region),
},
)
if err != nil {
assert.Fail(t, errors.WithStackTrace(err).Error())
}

apigwName := "aws-nuke-test-" + util.UniqueID()
testGw, createTestGwErr := createTestAPIGateway(t, session, apigwName)
require.NoError(t, createTestGwErr)
// clean up after this test
defer nukeAllAPIGateways(session, []*string{testGw.ID})

apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{})
if err != nil {
assert.Fail(t, "Unable to fetch list of API Gateways (v1)")
}

assert.Contains(t, awsgo.StringValueSlice(apigwIds), aws.StringValue(testGw.ID))
func (m mockedApiGateway) DeleteRestApi(*apigateway.DeleteRestApiInput) (*apigateway.DeleteRestApiOutput, error) {
// Only need to return mocked response output
return &m.DeleteRestApiResp, nil
}

func TestTimeFilterExclusionNewlyCreatedAPIGateway(t *testing.T) {
func TestAPIGatewayGetAllAndNukeAll(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)

session, err := session.NewSession(&aws.Config{Region: aws.String(region)})
require.NoError(t, err)

apigwName := "aws-nuke-test-" + util.UniqueID()

testGw, createTestGwErr := createTestAPIGateway(t, session, apigwName)
require.NoError(t, createTestGwErr)
defer nukeAllAPIGateways(session, []*string{testGw.ID})
testApiID := "aws-nuke-test-" + util.UniqueID()
apiGateway := ApiGateway{
Client: mockedApiGateway{
GetRestApisResp: apigateway.GetRestApisOutput{
Items: []*apigateway.RestApi{
{Id: aws.String(testApiID)},
},
},
DeleteRestApiResp: apigateway.DeleteRestApiOutput{},
},
}

// Assert API Gateway is picked up without filters
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{})
apis, err := apiGateway.getAll(time.Now(), config.Config{})
require.NoError(t, err)
assert.Contains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID))
require.Contains(t, awsgo.StringValueSlice(apis), testApiID)

// Assert API Gateway doesn't appear when we look at API Gateways older than 1 Hour
olderThan := time.Now().Add(-1 * time.Hour)
apiGwIdsOlder, err := getAllAPIGateways(session, olderThan, config.Config{})
err = apiGateway.nukeAll([]*string{aws.String(testApiID)})
require.NoError(t, err)
assert.NotContains(t, aws.StringValueSlice(apiGwIdsOlder), aws.StringValue(testGw.ID))
}

func TestNukeAPIGatewayOne(t *testing.T) {
func TestAPIGatewayGetAllTimeFilter(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)
testApiID := "aws-nuke-test-" + util.UniqueID()
now := time.Now()
apiGateway := ApiGateway{
Client: mockedApiGateway{
GetRestApisResp: apigateway.GetRestApisOutput{
Items: []*apigateway.RestApi{{
Id: aws.String(testApiID),
CreatedDate: aws.Time(now),
}},
},
},
}

session, err := session.NewSession(&aws.Config{Region: aws.String(region)})
// test API is not excluded from the filter
IDs, err := apiGateway.getAll(now.Add(1), config.Config{})
require.NoError(t, err)
assert.Contains(t, aws.StringValueSlice(IDs), testApiID)

apigwName := "aws-nuke-test-" + util.UniqueID()
// We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke.
testGw, createTestErr := createTestAPIGateway(t, session, apigwName)
require.NoError(t, createTestErr)

nukeErr := nukeAllAPIGateways(session, []*string{testGw.ID})
require.NoError(t, nukeErr)

// Make sure the API Gateway was deleted
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{})
// test API being excluded from the filter
apiGwIdsOlder, err := apiGateway.getAll(now.Add(-1), config.Config{})
require.NoError(t, err)

assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID))
assert.NotContains(t, aws.StringValueSlice(apiGwIdsOlder), testApiID)
}

func TestNukeAPIGatewayMoreThanOne(t *testing.T) {
telemetry.InitTelemetry("cloud-nuke", "")
t.Parallel()

region, err := getRandomRegion()
require.NoError(t, err)
testApiID1 := "aws-nuke-test-" + util.UniqueID()
testApiID2 := "aws-nuke-test-" + util.UniqueID()
apiGateway := ApiGateway{
Client: mockedApiGateway{
GetRestApisResp: apigateway.GetRestApisOutput{
Items: []*apigateway.RestApi{
{Id: aws.String(testApiID1)},
{Id: aws.String(testApiID2)},
},
},
DeleteRestApiResp: apigateway.DeleteRestApiOutput{},
},
}

session, err := session.NewSession(&aws.Config{Region: aws.String(region)})
apis, err := apiGateway.getAll(time.Now(), config.Config{})
require.NoError(t, err)
require.Contains(t, awsgo.StringValueSlice(apis), testApiID1)
require.Contains(t, awsgo.StringValueSlice(apis), testApiID2)

apigwName := "aws-nuke-test-" + util.UniqueID()
apigwName2 := "aws-nuke-test-" + util.UniqueID()
// We ignore errors in the delete call here, because it is intended to be a stop gap in case there is a bug in nuke.
testGw, createTestErr := createTestAPIGateway(t, session, apigwName)
require.NoError(t, createTestErr)
testGw2, createTestErr2 := createTestAPIGateway(t, session, apigwName2)
require.NoError(t, createTestErr2)

nukeErr := nukeAllAPIGateways(session, []*string{testGw.ID, testGw2.ID})
require.NoError(t, nukeErr)

// Make sure the API Gateway was deleted
apigwIds, err := getAllAPIGateways(session, time.Now(), config.Config{})
err = apiGateway.nukeAll([]*string{aws.String(testApiID1), aws.String(testApiID2)})
require.NoError(t, err)

assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw.ID))
assert.NotContains(t, aws.StringValueSlice(apigwIds), aws.StringValue(testGw2.ID))
}
19 changes: 12 additions & 7 deletions aws/apigateway_types.go
Expand Up @@ -3,29 +3,34 @@ package aws
import (
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface"
"github.com/gruntwork-io/go-commons/errors"
)

type ApiGateway struct {
Ids []string
Client apigatewayiface.APIGatewayAPI
Region string
Ids []string
}

func (apigateway ApiGateway) ResourceName() string {
func (gateway ApiGateway) ResourceName() string {
return "apigateway"
}

func (apigateway ApiGateway) ResourceIdentifiers() []string {
return apigateway.Ids
func (gateway ApiGateway) ResourceIdentifiers() []string {
return gateway.Ids
}

func (apigateway ApiGateway) MaxBatchSize() int {
func (gateway ApiGateway) MaxBatchSize() int {
return 10
}

func (apigateway ApiGateway) Nuke(session *session.Session, identifiers []string) error {
if err := nukeAllAPIGateways(session, awsgo.StringSlice(identifiers)); err != nil {
func (gateway ApiGateway) Nuke(session *session.Session, identifiers []string) error {
// TODO(james): stop passing in session argument as it is included as part of the gateway struct.
if err := gateway.nukeAll(awsgo.StringSlice(identifiers)); err != nil {
return errors.WithStackTrace(err)
}

return nil
}

Expand Down