Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add routing config support to Sagemaker Endpoint Configuration #34777

Merged
merged 9 commits into from
Feb 14, 2024
3 changes: 3 additions & 0 deletions .changelog/34777.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add routing_config support. Enables the specification of `routing_strategy`.
```
50 changes: 50 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ func ResourceEndpointConfiguration() *schema.Resource {
Required: true,
ForceNew: true,
},
"routing_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"routing_strategy": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringMatch(regexache.MustCompile(`^(LEAST_OUTSTANDING_REQUESTS|RANDOM)`), ""),
},
},
},
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -448,6 +463,21 @@ func ResourceEndpointConfiguration() *schema.Resource {
Required: true,
ForceNew: true,
},
"routing_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"routing_strategy": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringMatch(regexache.MustCompile(`^(LEAST_OUTSTANDING_REQUESTS|RANDOM)`), ""),
jar-b marked this conversation as resolved.
Show resolved Hide resolved
},
},
},
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -651,6 +681,10 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV
l.AcceleratorType = aws.String(v)
}

if v, ok := data["routing_config"].([]interface{}); ok && len(v) > 0 {
l.RoutingConfig = expandRoutingConfig(v)
}

if v, ok := data["serverless_config"].([]interface{}); ok && len(v) > 0 {
l.ServerlessConfig = expandServerlessConfig(v)
}
Expand Down Expand Up @@ -916,6 +950,22 @@ func expandEndpointConfigNotificationConfig(configured []interface{}) *sagemaker
return c
}

func expandRoutingConfig(configured []interface{}) *sagemaker.ProductionVariantRoutingConfig {
jar-b marked this conversation as resolved.
Show resolved Hide resolved
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &sagemaker.ProductionVariantRoutingConfig{}

if v, ok := m["routing_strategy"].(string); ok {
c.RoutingStrategy = aws.String(v)
}

return c
}

func expandServerlessConfig(configured []interface{}) *sagemaker.ProductionVariantServerlessConfig {
if len(configured) == 0 {
return nil
Expand Down
48 changes: 48 additions & 0 deletions internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ func TestAccSageMakerEndpointConfiguration_shadowProductionVariants(t *testing.T
})
}

func TestAccSageMakerEndpointConfiguration_ProductionVariants_routing(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_routing(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "production_variants.#", "1"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.routing_config.routing_strategy", "RANDOM"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerEndpointConfiguration_ProductionVariants_serverless(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -775,6 +803,9 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
initial_instance_count = 2
instance_type = "ml.t2.medium"
initial_variant_weight = 1
routing_config {
jar-b marked this conversation as resolved.
Show resolved Hide resolved
routing_strategy = "RANDOM"
}
}
}
`)
Expand Down Expand Up @@ -1198,6 +1229,23 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName))
}

func testAccEndpointConfigurationConfig_routing(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q

production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.test.name

routing_config {
routing_strategy = "RANDOM"
}
}
}
`, rName))
}

func testAccEndpointConfigurationConfig_serverless(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ This resource supports the following arguments:
* `destination_s3_uri` - (Required) The Amazon S3 bucket to send the core dump to.
* `kms_key_id` - (Required) The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that SageMaker uses to encrypt the core dump data at rest using Amazon S3 server-side encryption.

#### routing_config
jar-b marked this conversation as resolved.
Show resolved Hide resolved

* `routing_strategy` - (Required) Sets how the endpoint routes incoming traffic. `LEAST_OUTSTANDING_REQUESTS`: The endpoint routes requests to the specific instances that have more capacity to process them. `RANDOM`: The endpoint routes each request to a randomly chosen instance.

#### serverless_config

* `max_concurrency` - (Required) The maximum number of concurrent invocations your serverless endpoint can process. Valid values are between `1` and `200`.
Expand Down
Loading