Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

r/sagemaker_notebook_instance - lifecycle_config_name, root_access, and default_code_repository allow updating + refactor tests #15385

Merged
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 37 additions & 0 deletions aws/internal/service/sagemaker/waiter/status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package waiter

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/aws-sdk-go-base/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)

const (
SagemakerNotebookInstanceStatusNotFound = "NotFound"
)

// NotebookInstanceStatus fetches the NotebookInstance and its Status
func NotebookInstanceStatus(conn *sagemaker.SageMaker, notebookName string) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
input := &sagemaker.DescribeNotebookInstanceInput{
NotebookInstanceName: aws.String(notebookName),
}

output, err := conn.DescribeNotebookInstance(input)

if tfawserr.ErrMessageContains(err, "ValidationException", "RecordNotFound") {
return nil, SagemakerNotebookInstanceStatusNotFound, nil
}

if err != nil {
return nil, sagemaker.NotebookInstanceStatusFailed, err
}

if output == nil {
return nil, SagemakerNotebookInstanceStatusNotFound, nil
}

return output, aws.StringValue(output.NotebookInstanceStatus), nil
}
}
78 changes: 78 additions & 0 deletions aws/internal/service/sagemaker/waiter/waiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package waiter

import (
"time"

"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)

const (
NotebookInstanceInServiceTimeout = 10 * time.Minute
NotebookInstanceStoppedTimeout = 10 * time.Minute
NotebookInstanceDeletedTimeout = 10 * time.Minute
)

// NotebookInstanceInService waits for a NotebookInstance to return InService
func NotebookInstanceInService(conn *sagemaker.SageMaker, notebookName string) (*sagemaker.DescribeNotebookInstanceOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{
SagemakerNotebookInstanceStatusNotFound,
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusPending,
sagemaker.NotebookInstanceStatusStopped,
},
Target: []string{sagemaker.NotebookInstanceStatusInService},
Refresh: NotebookInstanceStatus(conn, notebookName),
Timeout: NotebookInstanceInServiceTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.DescribeNotebookInstanceOutput); ok {
return output, err
}

return nil, err
}

// NotebookInstanceStopped waits for a NotebookInstance to return Stopped
func NotebookInstanceStopped(conn *sagemaker.SageMaker, notebookName string) (*sagemaker.DescribeNotebookInstanceOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusStopping,
},
Target: []string{sagemaker.NotebookInstanceStatusStopped},
Refresh: NotebookInstanceStatus(conn, notebookName),
Timeout: NotebookInstanceStoppedTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.DescribeNotebookInstanceOutput); ok {
return output, err
}

return nil, err
}

// NotebookInstanceDeleted waits for a NotebookInstance to return Deleted
func NotebookInstanceDeleted(conn *sagemaker.SageMaker, notebookName string) (*sagemaker.DescribeNotebookInstanceOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusDeleting,
},
Target: []string{},
Refresh: NotebookInstanceStatus(conn, notebookName),
Timeout: NotebookInstanceDeletedTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.DescribeNotebookInstanceOutput); ok {
return output, err
}

return nil, err
}
144 changes: 63 additions & 81 deletions aws/resource_aws_sagemaker_notebook_instance.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package aws

import (
"context"
"fmt"
"log"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/keyvaluetags"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/service/sagemaker/waiter"
)

func resourceAwsSagemakerNotebookInstance() *schema.Resource {
Expand All @@ -22,6 +25,11 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
Importer: &schema.ResourceImporter{
State: schema.ImportStatePassthrough,
},
CustomizeDiff: customdiff.Sequence(
customdiff.ForceNewIfChange("volume_size", func(_ context.Context, old, new, meta interface{}) bool {
return new.(int) < old.(int)
}),
),

Schema: map[string]*schema.Schema{
"arn": {
Expand All @@ -37,13 +45,15 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
},

"role_arn": {
Type: schema.TypeString,
Required: true,
Type: schema.TypeString,
Required: true,
ValidateFunc: validateArn,
},

"instance_type": {
Type: schema.TypeString,
Required: true,
Type: schema.TypeString,
Required: true,
ValidateFunc: validation.StringInSlice(sagemaker.InstanceType_Values(), false),
},

"volume_size": {
Expand Down Expand Up @@ -77,33 +87,26 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
"lifecycle_config_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
},

"root_access": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Default: sagemaker.RootAccessEnabled,
ValidateFunc: validation.StringInSlice(
sagemaker.RootAccess_Values(), false),
Type: schema.TypeString,
Optional: true,
Default: sagemaker.RootAccessEnabled,
ValidateFunc: validation.StringInSlice(sagemaker.RootAccess_Values(), false),
},

"direct_internet_access": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Default: sagemaker.DirectInternetAccessEnabled,
ValidateFunc: validation.StringInSlice([]string{
sagemaker.DirectInternetAccessDisabled,
sagemaker.DirectInternetAccessEnabled,
}, false),
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Default: sagemaker.DirectInternetAccessEnabled,
ValidateFunc: validation.StringInSlice(sagemaker.DirectInternetAccess_Values(), false),
},

"default_code_repository": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
},

"tags": tagsSchema(),
Expand Down Expand Up @@ -164,19 +167,8 @@ func resourceAwsSagemakerNotebookInstanceCreate(d *schema.ResourceData, meta int
d.SetId(name)
log.Printf("[INFO] sagemaker notebook instance ID: %s", d.Id())

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusPending,
sagemaker.NotebookInstanceStatusStopped,
},
Target: []string{sagemaker.NotebookInstanceStatusInService},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to create: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceInService(conn, d.Id()); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to create: %w", d.Id(), err)
}

return resourceAwsSagemakerNotebookInstanceRead(d, meta)
Expand Down Expand Up @@ -289,6 +281,29 @@ func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta int
hasChanged = true
}

if d.HasChange("lifecycle_config_name") {
if v, ok := d.GetOk("lifecycle_config_name"); ok {
updateOpts.LifecycleConfigName = aws.String(v.(string))
} else {
updateOpts.DisassociateLifecycleConfig = aws.Bool(true)
}
hasChanged = true
}

if d.HasChange("default_code_repository") {
if v, ok := d.GetOk("default_code_repository"); ok {
updateOpts.DefaultCodeRepository = aws.String(v.(string))
} else {
updateOpts.DisassociateDefaultCodeRepository = aws.Bool(true)
}
hasChanged = true
}

if d.HasChange("root_access") {
updateOpts.RootAccess = aws.String(d.Get("root_access").(string))
hasChanged = true
}

if hasChanged {

// Stop notebook
Expand All @@ -303,17 +318,8 @@ func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta int
return fmt.Errorf("error updating sagemaker notebook instance: %s", err)
}

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
},
Target: []string{sagemaker.NotebookInstanceStatusStopped},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err := stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to update: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceStopped(conn, d.Id()); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %w", d.Id(), err)
}

// Restart if needed
Expand Down Expand Up @@ -356,19 +362,8 @@ func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta int
return fmt.Errorf("Error waiting for sagemaker notebook instance to start: %s", err)
}

stateConf = &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusPending,
sagemaker.NotebookInstanceStatusStopped,
},
Target: []string{sagemaker.NotebookInstanceStatusInService},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to start after update: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceInService(conn, d.Id()); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to to start after update: %w", d.Id(), err)
}
}
}
Expand All @@ -389,7 +384,9 @@ func resourceAwsSagemakerNotebookInstanceDelete(d *schema.ResourceData, meta int
}
return fmt.Errorf("unable to find sagemaker notebook instance to delete (%s): %s", d.Id(), err)
}
if *notebook.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusFailed && *notebook.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusStopped {

if aws.StringValue(notebook.NotebookInstanceStatus) != sagemaker.NotebookInstanceStatusFailed &&
aws.StringValue(notebook.NotebookInstanceStatus) != sagemaker.NotebookInstanceStatusStopped {
if err := stopSagemakerNotebookInstance(conn, d.Id()); err != nil {
return err
}
Expand All @@ -403,17 +400,11 @@ func resourceAwsSagemakerNotebookInstanceDelete(d *schema.ResourceData, meta int
return fmt.Errorf("error trying to delete sagemaker notebook instance (%s): %s", d.Id(), err)
}

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusDeleting,
},
Target: []string{""},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to delete: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceDeleted(conn, d.Id()); err != nil {
if isAWSErr(err, "ValidationException", "RecordNotFound") {
return nil
}
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to delete: %w", d.Id(), err)
}

return nil
Expand All @@ -430,7 +421,7 @@ func stopSagemakerNotebookInstance(conn *sagemaker.SageMaker, id string) error {
}
return fmt.Errorf("unable to find sagemaker notebook instance (%s): %s", id, err)
}
if *notebook.NotebookInstanceStatus == sagemaker.NotebookInstanceStatusStopped {
if aws.StringValue(notebook.NotebookInstanceStatus) == sagemaker.NotebookInstanceStatusStopped {
return nil
}

Expand All @@ -442,17 +433,8 @@ func stopSagemakerNotebookInstance(conn *sagemaker.SageMaker, id string) error {
return fmt.Errorf("Error stopping sagemaker notebook instance: %s", err)
}

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusStopping,
},
Target: []string{sagemaker.NotebookInstanceStatusStopped},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, id),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %s", id, err)
if _, err := waiter.NotebookInstanceStopped(conn, id); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %w", id, err)
}

return nil
Expand Down