Skip to content

Commit

Permalink
Merge pull request #34777 from curtisallen/feat-sagemakerendpointconf…
Browse files Browse the repository at this point in the history
…ig-lor

Add routing config support to Sagemaker Endpoint Configuration
  • Loading branch information
jar-b committed Feb 14, 2024
2 parents f76769d + 0074a44 commit 8b5bf36
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .changelog/34777.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add `routing_config` argument. Enables the specification of a `routing_strategy`.
```
68 changes: 68 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ func ResourceEndpointConfiguration() *schema.Resource {
Required: true,
ForceNew: true,
},
"routing_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"routing_strategy": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RoutingStrategy_Values(), false),
},
},
},
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -448,6 +463,21 @@ func ResourceEndpointConfiguration() *schema.Resource {
Required: true,
ForceNew: true,
},
"routing_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"routing_strategy": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RoutingStrategy_Values(), false),
},
},
},
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
Expand Down Expand Up @@ -651,6 +681,10 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV
l.AcceleratorType = aws.String(v)
}

if v, ok := data["routing_config"].([]interface{}); ok && len(v) > 0 {
l.RoutingConfig = expandRoutingConfig(v)
}

if v, ok := data["serverless_config"].([]interface{}); ok && len(v) > 0 {
l.ServerlessConfig = expandServerlessConfig(v)
}
Expand Down Expand Up @@ -700,6 +734,10 @@ func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string
l["instance_type"] = aws.StringValue(i.InstanceType)
}

if i.RoutingConfig != nil {
l["routing_config"] = flattenRoutingConfig(i.RoutingConfig)
}

if i.ServerlessConfig != nil {
l["serverless_config"] = flattenServerlessConfig(i.ServerlessConfig)
}
Expand Down Expand Up @@ -916,6 +954,22 @@ func expandEndpointConfigNotificationConfig(configured []interface{}) *sagemaker
return c
}

func expandRoutingConfig(configured []interface{}) *sagemaker.ProductionVariantRoutingConfig {
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &sagemaker.ProductionVariantRoutingConfig{}

if v, ok := m["routing_strategy"].(string); ok && v != "" {
c.RoutingStrategy = aws.String(v)
}

return c
}

func expandServerlessConfig(configured []interface{}) *sagemaker.ProductionVariantServerlessConfig {
if len(configured) == 0 {
return nil
Expand Down Expand Up @@ -1038,6 +1092,20 @@ func flattenEndpointConfigNotificationConfig(config *sagemaker.AsyncInferenceNot
return []map[string]interface{}{cfg}
}

func flattenRoutingConfig(config *sagemaker.ProductionVariantRoutingConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

cfg := map[string]interface{}{}

if config.RoutingStrategy != nil {
cfg["routing_strategy"] = aws.StringValue(config.RoutingStrategy)
}

return []map[string]interface{}{cfg}
}

func flattenServerlessConfig(config *sagemaker.ProductionVariantServerlessConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
Expand Down
48 changes: 48 additions & 0 deletions internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,35 @@ func TestAccSageMakerEndpointConfiguration_shadowProductionVariants(t *testing.T
})
}

func TestAccSageMakerEndpointConfiguration_ProductionVariants_routing(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_routing(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "production_variants.#", "1"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.routing_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "production_variants.0.routing_config.0.routing_strategy", "RANDOM"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerEndpointConfiguration_ProductionVariants_serverless(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -1198,6 +1227,25 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName))
}

func testAccEndpointConfigurationConfig_routing(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q
production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.test.name
initial_instance_count = 2
instance_type = "ml.t2.medium"
routing_config {
routing_strategy = "RANDOM"
}
}
}
`, rName))
}

func testAccEndpointConfigurationConfig_serverless(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ This resource supports the following arguments:
* `tags` - (Optional) A mapping of tags to assign to the resource. If configured with a provider [`default_tags` configuration block](https://registry.terraform.io/providers/hashicorp/aws/latest/docs#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level.
* `data_capture_config` - (Optional) Specifies the parameters to capture input/output of SageMaker models endpoints. Fields are documented below.
* `async_inference_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference.
* `shadow_production_variants` - (Optional) Array of ProductionVariant objects. There is one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants.If you use this field, you can only specify one variant for ProductionVariants and one variant for ShadowProductionVariants. Fields are documented below.
* `shadow_production_variants` - (Optional) Array of ProductionVariant objects. There is one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. If you use this field, you can only specify one variant for ProductionVariants and one variant for ShadowProductionVariants. Fields are documented below.

### production_variants

Expand All @@ -55,6 +55,7 @@ This resource supports the following arguments:
* `initial_variant_weight` - (Optional) Determines initial traffic distribution among all of the models that you specify in the endpoint configuration. If unspecified, it defaults to `1.0`.
* `model_data_download_timeout_in_seconds` - (Optional) The timeout value, in seconds, to download and extract the model that you want to host from Amazon S3 to the individual inference instance associated with this production variant. Valid values between `60` and `3600`.
* `model_name` - (Required) The name of the model to use.
* `routing_config` - (Optional) Sets how the endpoint routes incoming traffic. See [routing_config](#routing_config) below.
* `serverless_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference.
* `variant_name` - (Optional) The name of the variant. If omitted, Terraform will assign a random, unique name.
* `volume_size_in_gb` - (Optional) The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Valid values between `1` and `512`.
Expand All @@ -64,6 +65,10 @@ This resource supports the following arguments:
* `destination_s3_uri` - (Required) The Amazon S3 bucket to send the core dump to.
* `kms_key_id` - (Required) The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that SageMaker uses to encrypt the core dump data at rest using Amazon S3 server-side encryption.

#### routing_config

* `routing_strategy` - (Required) Sets how the endpoint routes incoming traffic. Valid values are `LEAST_OUTSTANDING_REQUESTS` and `RANDOM`. `LEAST_OUTSTANDING_REQUESTS` routes requests to the specific instances that have more capacity to process them. `RANDOM` routes each request to a randomly chosen instance.

#### serverless_config

* `max_concurrency` - (Required) The maximum number of concurrent invocations your serverless endpoint can process. Valid values are between `1` and `200`.
Expand Down

0 comments on commit 8b5bf36

Please sign in to comment.