Skip to content

Commit

Permalink
Merge pull request #34158 from wahi-com/f-aws_sagemaker_model_datasource
Browse files Browse the repository at this point in the history
F add support for model_data_source attribute in aws sagemaker model
  • Loading branch information
ewbankkit committed Nov 9, 2023
2 parents 9289064 + d2ed4cf commit ab5eaf0
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .changelog/34158.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/sagemaker: Add model_data_source in model containers
```
148 changes: 148 additions & 0 deletions internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,41 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
"model_data_source": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_data_source": {
Type: schema.TypeList,
Required: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validModelDataURL,
},
"s3_data_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.S3ModelDataType_Values(), false),
},
"compression_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCompressionType_Values(), false),
},
},
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -226,6 +261,41 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
"model_data_source": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_data_source": {
Type: schema.TypeList,
Required: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validModelDataURL,
},
"s3_data_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.S3ModelDataType_Values(), false),
},
"compression_type": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ModelCompressionType_Values(), false),
},
},
},
},
},
},
},
},
},
},
Expand Down Expand Up @@ -438,6 +508,9 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
if v, ok := m["model_package_name"]; ok && v.(string) != "" {
container.ModelPackageName = aws.String(v.(string))
}
if v, ok := m["model_data_source"]; ok {
container.ModelDataSource = expandModelDataSource(v.([]interface{}))
}
if v, ok := m["environment"].(map[string]interface{}); ok && len(v) > 0 {
container.Environment = flex.ExpandStringMap(v)
}
Expand All @@ -449,6 +522,44 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
return &container
}

func expandModelDataSource(l []interface{}) *sagemaker.ModelDataSource {
if len(l) == 0 {
return nil
}

modelDataSource := sagemaker.ModelDataSource{}

m := l[0].(map[string]interface{})

if v, ok := m["s3_data_source"]; ok {
modelDataSource.S3DataSource = expandS3ModelDataSource(v.([]interface{}))
}

return &modelDataSource
}

func expandS3ModelDataSource(l []interface{}) *sagemaker.S3ModelDataSource {
if len(l) == 0 {
return nil
}

s3ModelDataSource := sagemaker.S3ModelDataSource{}

m := l[0].(map[string]interface{})

if v, ok := m["s3_uri"]; ok && v.(string) != "" {
s3ModelDataSource.S3Uri = aws.String(v.(string))
}
if v, ok := m["s3_data_type"]; ok && v.(string) != "" {
s3ModelDataSource.S3DataType = aws.String(v.(string))
}
if v, ok := m["compression_type"]; ok && v.(string) != "" {
s3ModelDataSource.CompressionType = aws.String(v.(string))
}

return &s3ModelDataSource
}

func expandModelImageConfig(l []interface{}) *sagemaker.ImageConfig {
if len(l) == 0 {
return nil
Expand Down Expand Up @@ -512,6 +623,9 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
if container.ModelDataUrl != nil {
cfg["model_data_url"] = aws.StringValue(container.ModelDataUrl)
}
if container.ModelDataSource != nil {
cfg["model_data_source"] = flattenModelDataSource(container.ModelDataSource)
}
if container.ModelPackageName != nil {
cfg["model_package_name"] = aws.StringValue(container.ModelPackageName)
}
Expand All @@ -526,6 +640,40 @@ func flattenContainer(container *sagemaker.ContainerDefinition) []interface{} {
return []interface{}{cfg}
}

func flattenModelDataSource(modelDataSource *sagemaker.ModelDataSource) []interface{} {
if modelDataSource == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

if modelDataSource.S3DataSource != nil {
cfg["s3_data_source"] = flattenS3ModelDataSource(modelDataSource.S3DataSource)
}

return []interface{}{cfg}
}

func flattenS3ModelDataSource(s3ModelDataSource *sagemaker.S3ModelDataSource) []interface{} {
if s3ModelDataSource == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

if s3ModelDataSource.S3Uri != nil {
cfg["s3_uri"] = aws.StringValue(s3ModelDataSource.S3Uri)
}
if s3ModelDataSource.S3DataType != nil {
cfg["s3_data_type"] = aws.StringValue(s3ModelDataSource.S3DataType)
}
if s3ModelDataSource.CompressionType != nil {
cfg["compression_type"] = aws.StringValue(s3ModelDataSource.CompressionType)
}

return []interface{}{cfg}
}

func flattenImageConfig(imageConfig *sagemaker.ImageConfig) []interface{} {
if imageConfig == nil {
return []interface{}{}
Expand Down
107 changes: 107 additions & 0 deletions internal/service/sagemaker/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,34 @@ func TestAccSageMakerModel_primaryContainerModelPackageName(t *testing.T) {
})
}

func TestAccSageMakerModel_primaryContainerModelDataSource(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckModelDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccModelConfig_primaryContainerUncompressedModel(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckModelExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.model_data_source.#", "1"),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.model_data_source.0.s3_data_source.0.s3_data_type", "S3Prefix"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerModel_containers(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -751,6 +779,85 @@ resource "aws_sagemaker_model" "test" {
`, rName))
}

func testAccModelConfig_primaryContainerUncompressedModel(rName string) string {
return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
primary_container {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
model_data_source {
s3_data_source {
s3_data_type = "S3Prefix"
s3_uri = "s3://${aws_s3_object.test.bucket}/model/"
compression_type = "None"
}
}
}
}
resource "aws_iam_policy" "test" {
name = %[1]q
description = "Allow SageMaker to create model"
policy = data.aws_iam_policy_document.policy.json
}
data "aws_iam_policy_document" "policy" {
statement {
effect = "Allow"
actions = [
"cloudwatch:PutMetricData",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:CreateLogGroup",
"logs:DescribeLogStreams",
"ecr:GetAuthorizationToken",
"ecr:BatchCheckLayerAvailability",
"ecr:GetDownloadUrlForLayer",
"ecr:BatchGetImage",
]
resources = [
"*",
]
}
statement {
effect = "Allow"
actions = [
"s3:GetObject",
"s3:ListBucket",
]
resources = [
"${aws_s3_bucket.test.arn}",
"${aws_s3_bucket.test.arn}/*",
]
}
}
resource "aws_iam_role_policy_attachment" "test" {
role = aws_iam_role.test.name
policy_arn = aws_iam_policy.test.arn
}
resource "aws_s3_bucket" "test" {
bucket = %[1]q
force_destroy = true
}
resource "aws_s3_object" "test" {
bucket = aws_s3_bucket.test.bucket
key = "model/inference.py"
content = "some-data"
}
`, rName))
}

func testAccModelConfig_containers(rName string) string {
return acctest.ConfigCompose(testAccModelConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
Expand Down
11 changes: 11 additions & 0 deletions website/docs/r/sagemaker_model.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The `primary_container` and `container` block both support:
* `mode` - (Optional) The container hosts value `SingleModel/MultiModel`. The default value is `SingleModel`.
* `model_data_url` - (Optional) The URL for the S3 location where model artifacts are stored.
* `model_package_name` - (Optional) The Amazon Resource Name (ARN) of the model package to use to create the model.
* `model_data_source` - (Optional) The location of model data to deploy. Use this for uncompressed model deployment. For information about how to deploy an uncompressed model, see [Deploying uncompressed models](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html) in the _AWS SageMaker Developer Guide_.
* `container_hostname` - (Optional) The DNS host name for the container.
* `environment` - (Optional) Environment variables for the Docker container.
A list of key value pairs.
Expand All @@ -77,6 +78,16 @@ The `primary_container` and `container` block both support:

* `repository_credentials_provider_arn` - (Required) The Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your model image is hosted. For information about how to create an AWS Lambda function, see [Create a Lambda function with the console](https://docs.aws.amazon.com/lambda/latest/dg/getting-started-create-function.html) in the _AWS Lambda Developer Guide_.

### Model Data Source

* `s3_data_source` - (Required) The S3 location of model data to deploy.

#### S3 Data Source

* `compression_type` - (Required) How the model data is prepared. Allowed values are: `None` and `Gzip`.
* `s3_data_type` - (Required) The type of model data to deploy. Allowed values are: `S3Object` and `S3Prefix`.
* `s3_uri` - (Required) The S3 path of model data to deploy.

## Inference Execution Config

* `mode` - (Required) How containers in a multi-container are run. The following values are valid `Serial` and `Direct`.
Expand Down

0 comments on commit ab5eaf0

Please sign in to comment.