Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

F add support for model_data_source attribute in aws sagemaker model #34158

Merged
merged 10 commits into from
Nov 9, 2023
3 changes: 3 additions & 0 deletions .changelog/34158.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/sagemaker: Add model_data_source in model containers
```
148 changes: 148 additions & 0 deletions internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,41 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
"model_data_source": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_data_source": {
Type: schema.TypeList,
Required: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validModelDataURL,
},
"s3_data_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.S3ModelDataType_Values(), false),
},
"compression_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCompressionType_Values(), false),
},
},
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -226,6 +261,41 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
"model_data_source": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_data_source": {
Type: schema.TypeList,
Required: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validModelDataURL,
},
"s3_data_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.S3ModelDataType_Values(), false),
},
"compression_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCompressionType_Values(), false),
},
},
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -438,6 +508,9 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
if v, ok := m["model_package_name"]; ok && v.(string) != "" {
container.ModelPackageName = aws.String(v.(string))
}
if v, ok := m["model_data_source"]; ok {
container.ModelDataSource = expandModelDataSource(v.([]interface{}))
}
if v, ok := m["environment"].(map[string]interface{}); ok && len(v) > 0 {
container.Environment = flex.ExpandStringMap(v)
}
Expand All @@ -449,6 +522,44 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
return &container
}

func expandModelDataSource(l []interface{}) *sagemaker.ModelDataSource {
if len(l) == 0 {
return nil
}

modelDataSource := sagemaker.ModelDataSource{}

m := l[0].(map[string]interface{})

if v, ok := m["s3_data_source"]; ok {
modelDataSource.S3DataSource = expandS3ModelDataSource(v.([]interface{}))
}

return &modelDataSource
}

func expandS3ModelDataSource(l []interface{}) *sagemaker.S3ModelDataSource {
if len(l) == 0 {
return nil
}

s3ModelDataSource := sagemaker.S3ModelDataSource{}

m := l[0].(map[string]interface{})

if v, ok := m["s3_uri"]; ok && v.(string) != "" {
s3ModelDataSource.S3Uri = aws.String(v.(string))
}
if v, ok := m["s3_data_type"]; ok && v.(string) != "" {
s3ModelDataSource.S3DataType = aws.String(v.(string))
}
if v, ok := m["compression_type"]; ok && v.(string) != "" {
s3ModelDataSource.CompressionType = aws.String(v.(string))
}

return &s3ModelDataSource
}

func expandModelImageConfig(l []interface{}) *sagemaker.ImageConfig {
if len(l) == 0 {
return nil
Expand Down Expand Up @@ -512,6 +623,9 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
if container.ModelDataUrl != nil {
cfg["model_data_url"] = aws.StringValue(container.ModelDataUrl)
}
if container.ModelDataSource != nil {
cfg["model_data_source"] = flattenModelDataSource(container.ModelDataSource)
}
if container.ModelPackageName != nil {
cfg["model_package_name"] = aws.StringValue(container.ModelPackageName)
}
Expand All @@ -526,6 +640,40 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
return []interface{}{cfg}
}

func flattenModelDataSource(modelDataSource *sagemaker.ModelDataSource) []interface{} {
if modelDataSource == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

if modelDataSource.S3DataSource != nil {
cfg["s3_data_source"] = flattenS3ModelDataSource(modelDataSource.S3DataSource)
}

return []interface{}{cfg}
}

func flattenS3ModelDataSource(s3ModelDataSource *sagemaker.S3ModelDataSource) []interface{} {
if s3ModelDataSource == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

if s3ModelDataSource.S3Uri != nil {
cfg["s3_uri"] = aws.StringValue(s3ModelDataSource.S3Uri)
}
if s3ModelDataSource.S3DataType != nil {
cfg["s3_data_type"] = aws.StringValue(s3ModelDataSource.S3DataType)
}
if s3ModelDataSource.CompressionType != nil {
cfg["compression_type"] = aws.StringValue(s3ModelDataSource.CompressionType)
}

return []interface{}{cfg}
}

func flattenImageConfig(imageConfig *sagemaker.ImageConfig) []interface{} {
if imageConfig == nil {
return []interface{}{}
Expand Down
107 changes: 107 additions & 0 deletions internal/service/sagemaker/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,34 @@ func TestAccSageMakerModel_primaryContainerModelPackageName(t *testing.T) {
})
}

func TestAccSageMakerModel_primaryContainerModelDataSource(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckModelDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccModelConfig_primaryContainerUncompressedModel(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckModelExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.model_data_source.#", "1"),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.model_data_source.0.s3_data_source.0.s3_data_type", "S3Prefix"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerModel_containers(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -751,6 +779,85 @@ resource "aws_sagemaker_model" "test" {
`, rName))
}

func testAccModelConfig_primaryContainerUncompressedModel(rName string) string {
return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
primary_container {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
model_data_source {
s3_data_source {
s3_data_type = "S3Prefix"
s3_uri = "s3://${aws_s3_object.test.bucket}/model/"
compression_type = "None"
}
}
}
}
resource "aws_iam_policy" "test" {
name = %[1]q
description = "Allow SageMaker to create model"
policy = data.aws_iam_policy_document.policy.json
}
data "aws_iam_policy_document" "policy" {
statement {
effect = "Allow"
actions = [
"cloudwatch:PutMetricData",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:CreateLogGroup",
"logs:DescribeLogStreams",
"ecr:GetAuthorizationToken",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:BatchGetImage",
]
resources = [
"*",
]
}
statement {
effect = "Allow"
actions = [
"s3:GetObject",
"s3:ListBucket",
]
resources = [
"${aws_s3_bucket.test.arn}",
"${aws_s3_bucket.test.arn}/*",
]
}
}
resource "aws_iam_role_policy_attachment" "test" {
role = aws_iam_role.test.name
policy_arn = aws_iam_policy.test.arn
}
resource "aws_s3_bucket" "test" {
bucket = %[1]q
force_destroy = true
}
resource "aws_s3_object" "test" {
bucket = aws_s3_bucket.test.bucket
key = "model/inference.py"
content = "some-data"
}
`, rName))
}

func testAccModelConfig_containers(rName string) string {
return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
Expand Down
11 changes: 11 additions & 0 deletions website/docs/r/sagemaker_model.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The `primary_container` and `container` block both support:
* `mode` - (Optional) The container hosts value `SingleModel/MultiModel`. The default value is `SingleModel`.
* `model_data_url` - (Optional) The URL for the S3 location where model artifacts are stored.
* `model_package_name` - (Optional) The Amazon Resource Name (ARN) of the model package to use to create the model.
* `model_data_source` - (Optional) The location of model data to deploy. Use this for uncompressed model deployment. For information about how to deploy an uncompressed model, see [Deploying uncompressed models](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html) in the _AWS SageMaker Developer Guide_.
* `container_hostname` - (Optional) The DNS host name for the container.
* `environment` - (Optional) Environment variables for the Docker container.
A list of key value pairs.
Expand All @@ -77,6 +78,16 @@ The `primary_container` and `container` block both support:

* `repository_credentials_provider_arn` - (Required) The Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your model image is hosted. For information about how to create an AWS Lambda function, see [Create a Lambda function with the console](https://docs.aws.amazon.com/lambda/latest/dg/getting-started-create-function.html) in the _AWS Lambda Developer Guide_.

### Model Data Source

* `s3_data_source` - (Required) The S3 location of model data to deploy.

#### S3 Data Source

* `compression_type` - (Required) How the model data is prepared. Allowed values are: `None` and `Gzip`.
* `s3_data_type` - (Required) The type of model data to deploy. Allowed values are: `S3Object` and `S3Prefix`.
* `s3_uri` - (Required) The S3 path of model data to deploy.

## Inference Execution Config

* `mode` - (Required) How containers in a multi-container are run. The following values are valid `Serial` and `Direct`.
Expand Down