Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add routing config support to Sagemaker Endpoint Configuration #34777

Merged
merged 9 commits into from
Feb 14, 2024
3 changes: 3 additions & 0 deletions .changelog/34777.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add routing_config support. Enables the specification of `routing_strategy`.
```
68 changes: 68 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ func ResourceEndpointConfiguration() *schema.Resource {
Required: true,
ForceNew: true,
},
"routing_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"routing_strategy": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringMatch(regexache.MustCompile(`^(LEAST_OUTSTANDING_REQUESTS|RANDOM)`), ""),
},
},
},
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -448,6 +463,21 @@ func ResourceEndpointConfiguration() *schema.Resource {
Required: true,
ForceNew: true,
},
"routing_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"routing_strategy": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RoutingStrategy_Values(), false),
},
},
},
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -651,6 +681,10 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV
l.AcceleratorType = aws.String(v)
}

if v, ok := data["routing_config"].([]interface{}); ok && len(v) > 0 {
l.RoutingConfig = expandRoutingConfig(v)
}

if v, ok := data["serverless_config"].([]interface{}); ok && len(v) > 0 {
l.ServerlessConfig = expandServerlessConfig(v)
}
Expand Down Expand Up @@ -700,6 +734,10 @@ func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string
l["instance_type"] = aws.StringValue(i.InstanceType)
}

if i.RoutingConfig != nil {
l["routing_config"] = flattenRoutingConfig(i.RoutingConfig)
}

if i.ServerlessConfig != nil {
l["serverless_config"] = flattenServerlessConfig(i.ServerlessConfig)
}
Expand Down Expand Up @@ -916,6 +954,22 @@ func expandEndpointConfigNotificationConfig(configured []interface{}) *sagemaker
return c
}

func expandRoutingConfig(configured []interface{}) *sagemaker.ProductionVariantRoutingConfig {
jar-b marked this conversation as resolved.
Show resolved Hide resolved
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &sagemaker.ProductionVariantRoutingConfig{}

if v, ok := m["routing_strategy"].(string); ok && v != "" {
c.RoutingStrategy = aws.String(v)
}

return c
}

func expandServerlessConfig(configured []interface{}) *sagemaker.ProductionVariantServerlessConfig {
if len(configured) == 0 {
return nil
Expand Down Expand Up @@ -1038,6 +1092,20 @@ func flattenEndpointConfigNotificationConfig(config *sagemaker.AsyncInferenceNot
return []map[string]interface{}{cfg}
}

func flattenRoutingConfig(config *sagemaker.ProductionVariantRoutingConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

cfg := map[string]interface{}{}

if config.RoutingStrategy != nil {
cfg["routing_strategy"] = aws.StringValue(config.RoutingStrategy)
}

return []map[string]interface{}{cfg}
}

func flattenServerlessConfig(config *sagemaker.ProductionVariantServerlessConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
Expand Down
45 changes: 45 additions & 0 deletions internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ func TestAccSageMakerEndpointConfiguration_shadowProductionVariants(t *testing.T
})
}

func TestAccSageMakerEndpointConfiguration_ProductionVariants_routing(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_routing(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "production_variants.#", "1"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.routing_config.routing_strategy", "RANDOM"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerEndpointConfiguration_ProductionVariants_serverless(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -1198,6 +1226,23 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName))
}

func testAccEndpointConfigurationConfig_routing(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q

production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.test.name

routing_config {
routing_strategy = "RANDOM"
}
}
}
`, rName))
}

func testAccEndpointConfigurationConfig_serverless(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ This resource supports the following arguments:
* `initial_variant_weight` - (Optional) Determines initial traffic distribution among all of the models that you specify in the endpoint configuration. If unspecified, it defaults to `1.0`.
* `model_data_download_timeout_in_seconds` - (Optional) The timeout value, in seconds, to download and extract the model that you want to host from Amazon S3 to the individual inference instance associated with this production variant. Valid values between `60` and `3600`.
* `model_name` - (Required) The name of the model to use.
* `routing_strategy` - (Optional) Sets how the endpoint routes incoming traffic.
* `serverless_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference.
* `variant_name` - (Optional) The name of the variant. If omitted, Terraform will assign a random, unique name.
* `volume_size_in_gb` - (Optional) The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Valid values between `1` and `512`.
Expand All @@ -64,6 +65,10 @@ This resource supports the following arguments:
* `destination_s3_uri` - (Required) The Amazon S3 bucket to send the core dump to.
* `kms_key_id` - (Required) The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that SageMaker uses to encrypt the core dump data at rest using Amazon S3 server-side encryption.

#### routing_config
jar-b marked this conversation as resolved.
Show resolved Hide resolved

* `routing_strategy` - (Required) Sets how the endpoint routes incoming traffic. `LEAST_OUTSTANDING_REQUESTS`: The endpoint routes requests to the specific instances that have more capacity to process them. `RANDOM`: The endpoint routes each request to a randomly chosen instance.

#### serverless_config

* `max_concurrency` - (Required) The maximum number of concurrent invocations your serverless endpoint can process. Valid values are between `1` and `200`.
Expand Down
Loading