Skip to content

Commit

Permalink
Revert "r/aws_apigatewayv2_stage: Add computed 'api_protocol_type' at…
Browse files Browse the repository at this point in the history
…tribute."

This reverts commit a7eb7cf9976ecabb04696dbe2f39805cc0ec1401.
  • Loading branch information
ewbankkit committed Jul 21, 2020
1 parent eaa6111 commit c755087
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 38 deletions.
109 changes: 73 additions & 36 deletions aws/resource_aws_apigatewayv2_stage.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package aws

import (
"bytes"
"fmt"
"log"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/service/apigatewayv2"
"github.com/hashicorp/terraform-plugin-sdk/helper/hashcode"
"github.com/hashicorp/terraform-plugin-sdk/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/helper/validation"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/keyvaluetags"
Expand Down Expand Up @@ -55,10 +57,6 @@ func resourceAwsApiGatewayV2Stage() *schema.Resource {
Required: true,
ForceNew: true,
},
"api_protocol_type": {
Type: schema.TypeString,
Computed: true,
},
"arn": {
Type: schema.TypeString,
Computed: true,
Expand Down Expand Up @@ -99,7 +97,13 @@ func resourceAwsApiGatewayV2Stage() *schema.Resource {
apigatewayv2.LoggingLevelInfo,
apigatewayv2.LoggingLevelOff,
}, false),
DiffSuppressFunc: suppressIfApigatewayv2ProtocolType(apigatewayv2.ProtocolTypeHttp),
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
// Not set for HTTP APIs.
if d.Id() != "" && old == "" && new == apigatewayv2.LoggingLevelOff {
return true
}
return false
},
},
"throttling_burst_limit": {
Type: schema.TypeInt,
Expand Down Expand Up @@ -160,7 +164,13 @@ func resourceAwsApiGatewayV2Stage() *schema.Resource {
apigatewayv2.LoggingLevelInfo,
apigatewayv2.LoggingLevelOff,
}, false),
DiffSuppressFunc: suppressIfApigatewayv2ProtocolType(apigatewayv2.ProtocolTypeHttp),
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
// Not set for HTTP APIs.
if d.Id() != "" && old == "" && new == apigatewayv2.LoggingLevelOff {
return true
}
return false
},
},
"route_key": {
Type: schema.TypeString,
Expand All @@ -176,6 +186,7 @@ func resourceAwsApiGatewayV2Stage() *schema.Resource {
},
},
},
Set: apiGatewayV2RouteSettingsHash,
},
"stage_variables": {
Type: schema.TypeMap,
Expand Down Expand Up @@ -236,7 +247,6 @@ func resourceAwsApiGatewayV2StageCreate(d *schema.ResourceData, meta interface{}
}

d.SetId(aws.StringValue(resp.StageName))
d.Set("api_protocol_type", protocolType)

return resourceAwsApiGatewayV2StageRead(d, meta)
}
Expand Down Expand Up @@ -293,7 +303,14 @@ func resourceAwsApiGatewayV2StageRead(d *schema.ResourceData, meta interface{})
return fmt.Errorf("error setting tags: %s", err)
}

switch d.Get("api_protocol_type").(string) {
apiOutput, err := conn.GetApi(&apigatewayv2.GetApiInput{
ApiId: aws.String(apiId),
})
if err != nil {
return fmt.Errorf("error reading API Gateway v2 API (%s): %s", apiId, err)
}

switch aws.StringValue(apiOutput.ProtocolType) {
case apigatewayv2.ProtocolTypeWebsocket:
executionArn := arn.ARN{
Partition: meta.(*AWSClient).partition,
Expand Down Expand Up @@ -322,10 +339,19 @@ func resourceAwsApiGatewayV2StageUpdate(d *schema.ResourceData, meta interface{}
if d.HasChanges("access_log_settings", "auto_deploy", "client_certificate_id",
"default_route_settings", "deployment_id", "description",
"route_settings", "stage_variables") {
protocolType := d.Get("api_protocol_type").(string)
apiId := d.Get("api_id").(string)

apiOutput, err := conn.GetApi(&apigatewayv2.GetApiInput{
ApiId: aws.String(apiId),
})
if err != nil {
return fmt.Errorf("error reading API Gateway v2 API (%s): %s", apiId, err)
}

protocolType := aws.StringValue(apiOutput.ProtocolType)

req := &apigatewayv2.UpdateStageInput{
ApiId: aws.String(d.Get("api_id").(string)),
ApiId: aws.String(apiId),
StageName: aws.String(d.Id()),
}
if d.HasChange("access_log_settings") {
Expand Down Expand Up @@ -364,7 +390,7 @@ func resourceAwsApiGatewayV2StageUpdate(d *schema.ResourceData, meta interface{}
}

log.Printf("[DEBUG] Updating API Gateway v2 stage: %s", req)
_, err := conn.UpdateStage(req)
_, err = conn.UpdateStage(req)
if err != nil {
return fmt.Errorf("error updating API Gateway v2 stage (%s): %s", d.Id(), err)
}
Expand Down Expand Up @@ -421,16 +447,8 @@ func resourceAwsApiGatewayV2StageImport(d *schema.ResourceData, meta interface{}
return nil, fmt.Errorf("API Gateway v2 stage (%s) was created via quick create", stageName)
}

apiOutput, err := conn.GetApi(&apigatewayv2.GetApiInput{
ApiId: aws.String(apiId),
})
if err != nil {
return nil, fmt.Errorf("error reading API Gateway v2 API (%s): %s", apiId, err)
}

d.SetId(stageName)
d.Set("api_id", apiId)
d.Set("api_protocol_type", apiOutput.ProtocolType)

return []*schema.ResourceData{d}, nil
}
Expand Down Expand Up @@ -513,20 +531,20 @@ func expandApiGatewayV2RouteSettings(vSettings *schema.Set, protocolType string)

mSettings := v.(map[string]interface{})

if vDataTraceEnabled, ok := mSettings["data_trace_enabled"].(bool); ok && protocolType == apigatewayv2.ProtocolTypeWebsocket {
routeSettings.DataTraceEnabled = aws.Bool(vDataTraceEnabled)
if v, ok := mSettings["data_trace_enabled"].(bool); ok && protocolType == apigatewayv2.ProtocolTypeWebsocket {
routeSettings.DataTraceEnabled = aws.Bool(v)
}
if vDetailedMetricsEnabled, ok := mSettings["detailed_metrics_enabled"].(bool); ok {
routeSettings.DetailedMetricsEnabled = aws.Bool(vDetailedMetricsEnabled)
if v, ok := mSettings["detailed_metrics_enabled"].(bool); ok {
routeSettings.DetailedMetricsEnabled = aws.Bool(v)
}
if vLoggingLevel, ok := mSettings["logging_level"].(string); ok && vLoggingLevel != "" && protocolType == apigatewayv2.ProtocolTypeWebsocket {
routeSettings.LoggingLevel = aws.String(vLoggingLevel)
if v, ok := mSettings["logging_level"].(string); ok && v != "" && protocolType == apigatewayv2.ProtocolTypeWebsocket {
routeSettings.LoggingLevel = aws.String(v)
}
if vThrottlingBurstLimit, ok := mSettings["throttling_burst_limit"].(int); ok {
routeSettings.ThrottlingBurstLimit = aws.Int64(int64(vThrottlingBurstLimit))
if v, ok := mSettings["throttling_burst_limit"].(int); ok {
routeSettings.ThrottlingBurstLimit = aws.Int64(int64(v))
}
if vThrottlingRateLimit, ok := mSettings["throttling_rate_limit"].(float64); ok {
routeSettings.ThrottlingRateLimit = aws.Float64(vThrottlingRateLimit)
if v, ok := mSettings["throttling_rate_limit"].(float64); ok {
routeSettings.ThrottlingRateLimit = aws.Float64(v)
}

settings[mSettings["route_key"].(string)] = routeSettings
Expand All @@ -535,7 +553,7 @@ func expandApiGatewayV2RouteSettings(vSettings *schema.Set, protocolType string)
return settings
}

func flattenApiGatewayV2RouteSettings(settings map[string]*apigatewayv2.RouteSettings) []interface{} {
func flattenApiGatewayV2RouteSettings(settings map[string]*apigatewayv2.RouteSettings) *schema.Set {
vSettings := []interface{}{}

for k, routeSetting := range settings {
Expand All @@ -549,13 +567,32 @@ func flattenApiGatewayV2RouteSettings(settings map[string]*apigatewayv2.RouteSet
})
}

return vSettings
return schema.NewSet(apiGatewayV2RouteSettingsHash, vSettings)
}

// suppressIfApigatewayv2ProtocolType suppresses attribute differences
// if the API protocol type is the specified value.
func suppressIfApigatewayv2ProtocolType(t string) schema.SchemaDiffSuppressFunc {
return func(k, old, new string, d *schema.ResourceData) bool {
return d.Get("api_protocol_type").(string) == t
func apiGatewayV2RouteSettingsHash(vSettings interface{}) int {
var buf bytes.Buffer

mSettings := vSettings.(map[string]interface{})

if v, ok := mSettings["route_key"].(string); ok {
buf.WriteString(fmt.Sprintf("%s-", v))
}
if v, ok := mSettings["data_trace_enabled"].(bool); ok {
buf.WriteString(fmt.Sprintf("%t-", v))
}
if v, ok := mSettings["detailed_metrics_enabled"].(bool); ok {
buf.WriteString(fmt.Sprintf("%t-", v))
}
if v, ok := mSettings["logging_level"].(string); ok {
buf.WriteString(fmt.Sprintf("%s-", v))
}
if v, ok := mSettings["throttling_burst_limit"].(int); ok {
buf.WriteString(fmt.Sprintf("%d-", v))
}
if v, ok := mSettings["throttling_rate_limit"].(float64); ok {
buf.WriteString(fmt.Sprintf("%g-", v))
}

return hashcode.String(buf.String())
}
2 changes: 0 additions & 2 deletions aws/resource_aws_apigatewayv2_stage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func TestAccAWSAPIGatewayV2Stage_basicWebSocket(t *testing.T) {
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSAPIGatewayV2StageExists(resourceName, &apiId, &v),
resource.TestCheckResourceAttr(resourceName, "access_log_settings.#", "0"),
resource.TestCheckResourceAttr(resourceName, "api_protocol_type", "WEBSOCKET"),
testAccMatchResourceAttrRegionalARNNoAccount(resourceName, "arn", "apigateway", regexp.MustCompile(fmt.Sprintf("/apis/.+/stages/%s", rName))),
resource.TestCheckResourceAttr(resourceName, "auto_deploy", "false"),
resource.TestCheckResourceAttr(resourceName, "client_certificate_id", ""),
Expand Down Expand Up @@ -76,7 +75,6 @@ func TestAccAWSAPIGatewayV2Stage_basicHttp(t *testing.T) {
Check: resource.ComposeTestCheckFunc(
testAccCheckAWSAPIGatewayV2StageExists(resourceName, &apiId, &v),
resource.TestCheckResourceAttr(resourceName, "access_log_settings.#", "0"),
resource.TestCheckResourceAttr(resourceName, "api_protocol_type", "HTTP"),
testAccMatchResourceAttrRegionalARNNoAccount(resourceName, "arn", "apigateway", regexp.MustCompile(fmt.Sprintf("/apis/.+/stages/%s", rName))),
resource.TestCheckResourceAttr(resourceName, "auto_deploy", "false"),
resource.TestCheckResourceAttr(resourceName, "client_certificate_id", ""),
Expand Down

0 comments on commit c755087

Please sign in to comment.