Skip to content

Commit

Permalink
Merge pull request #34127 from hashicorp/s3/b-s3-checksum
Browse files Browse the repository at this point in the history
backend/s3: Adds parameter `skip_s3_checksum` to skip checksum on upload
  • Loading branch information
gdavison committed Oct 24, 2023
2 parents 752e5a1 + aab15a9 commit 57e2fc9
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 10 deletions.
9 changes: 8 additions & 1 deletion internal/backend/remote-state/s3/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type Backend struct {
kmsKeyID string
ddbTable string
workspaceKeyPrefix string
skipS3Checksum bool
}

// ConfigSchema returns a description of the expected configuration
Expand Down Expand Up @@ -183,7 +184,7 @@ func (b *Backend) ConfigSchema() *configschema.Block {
"skip_credentials_validation": {
Type: cty.Bool,
Optional: true,
Description: "Skip the credentials validation via STS API.",
Description: "Skip the credentials validation via STS API. Useful for testing and for AWS API implementations that do not have STS available.",
},
"skip_requesting_account_id": {
Type: cty.Bool,
Expand All @@ -200,6 +201,11 @@ func (b *Backend) ConfigSchema() *configschema.Block {
Optional: true,
Description: "Skip static validation of region name.",
},
"skip_s3_checksum": {
Type: cty.Bool,
Optional: true,
Description: "Do not include checksum when uploading S3 Objects. Useful for some S3-Compatible APIs.",
},
"sse_customer_key": {
Type: cty.String,
Optional: true,
Expand Down Expand Up @@ -903,6 +909,7 @@ func (b *Backend) Configure(obj cty.Value) tfdiags.Diagnostics {
b.serverSideEncryption = boolAttr(obj, "encrypt")
b.kmsKeyID = stringAttr(obj, "kms_key_id")
b.ddbTable = stringAttr(obj, "dynamodb_table")
b.skipS3Checksum = boolAttr(obj, "skip_s3_checksum")

if _, ok := stringAttrOk(obj, "kms_key_id"); ok {
if customerKey := os.Getenv("AWS_SSE_CUSTOMER_KEY"); customerKey != "" {
Expand Down
1 change: 1 addition & 0 deletions internal/backend/remote-state/s3/backend_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ func (b *Backend) remoteClient(name string) (*RemoteClient, error) {
acl: b.acl,
kmsKeyID: b.kmsKeyID,
ddbTable: b.ddbTable,
skipS3Checksum: b.skipS3Checksum,
}

return client, nil
Expand Down
9 changes: 9 additions & 0 deletions internal/backend/remote-state/s3/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ func TestBackend_impl(t *testing.T) {
var _ backend.Backend = new(Backend)
}

func TestBackend_InternalValidate(t *testing.T) {
b := New()

schema := b.ConfigSchema()
if err := schema.InternalValidate(); err != nil {
t.Fatalf("failed InternalValidate: %s", err)
}
}

func TestBackendConfig_original(t *testing.T) {
testACC(t)

Expand Down
25 changes: 17 additions & 8 deletions internal/backend/remote-state/s3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type RemoteClient struct {
acl string
kmsKeyID string
ddbTable string
skipS3Checksum bool
}

var (
Expand Down Expand Up @@ -182,6 +183,10 @@ func (c *RemoteClient) get(ctx context.Context) (*remote.Payload, error) {
}

func (c *RemoteClient) Put(data []byte) error {
return c.put(data)
}

func (c *RemoteClient) put(data []byte, optFns ...func(*s3.Options)) error {
ctx := context.TODO()
log := c.logger(operationClientPut)

Expand All @@ -193,11 +198,13 @@ func (c *RemoteClient) Put(data []byte) error {
sum := md5.Sum(data)

input := &s3.PutObjectInput{
ContentType: aws.String(contentType),
Body: bytes.NewReader(data),
Bucket: aws.String(c.bucketName),
Key: aws.String(c.path),
ChecksumAlgorithm: s3types.ChecksumAlgorithmSha256,
ContentType: aws.String(contentType),
Body: bytes.NewReader(data),
Bucket: aws.String(c.bucketName),
Key: aws.String(c.path),
}
if !c.skipS3Checksum {
input.ChecksumAlgorithm = s3types.ChecksumAlgorithmSha256
}

if c.serverSideEncryption {
Expand All @@ -219,16 +226,18 @@ func (c *RemoteClient) Put(data []byte) error {

log.Info("Uploading remote state")

uploader := manager.NewUploader(c.s3Client)
uploader := manager.NewUploader(c.s3Client, func(u *manager.Uploader) {
u.ClientOptions = optFns
})
_, err := uploader.Upload(ctx, input)
if err != nil {
return fmt.Errorf("failed to upload state: %s", err)
return fmt.Errorf("failed to upload state: %w", err)
}

if err := c.putMD5(ctx, sum[:]); err != nil {
// if this errors out, we unfortunately have to error out altogether,
// since the next Get will inevitably fail.
return fmt.Errorf("failed to store state MD5: %s", err)
return fmt.Errorf("failed to store state MD5: %w", err)
}

return nil
Expand Down
108 changes: 108 additions & 0 deletions internal/backend/remote-state/s3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@ import (
"bytes"
"context"
"crypto/md5"
"errors"
"fmt"
"io"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/hashicorp/terraform/internal/backend"
"github.com/hashicorp/terraform/internal/states/remote"
"github.com/hashicorp/terraform/internal/states/statefile"
"github.com/hashicorp/terraform/internal/states/statemgr"
"golang.org/x/exp/maps"
)

func TestRemoteClient_impl(t *testing.T) {
Expand Down Expand Up @@ -383,3 +388,106 @@ func (b neverEnding) Read(p []byte) (n int, err error) {
}
return len(p), nil
}

func TestRemoteClientSkipS3Checksum(t *testing.T) {
testACC(t)

ctx := context.TODO()

bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"

testcases := map[string]struct {
config map[string]any
expected string
}{
"default": {
config: map[string]any{},
expected: string(s3types.ChecksumAlgorithmSha256),
},
"true": {
config: map[string]any{
"skip_s3_checksum": true,
},
expected: "",
},
"false": {
config: map[string]any{
"skip_s3_checksum": false,
},
expected: string(s3types.ChecksumAlgorithmSha256),
},
}

for name, testcase := range testcases {
t.Run(name, func(t *testing.T) {
config := map[string]interface{}{
"bucket": bucketName,
"key": keyName,
}
maps.Copy(config, testcase.config)
b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend)

createS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region)
defer deleteS3Bucket(ctx, t, b.s3Client, bucketName, b.awsConfig.Region)

state, err := b.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}

c := state.(*remote.State).Client
client := c.(*RemoteClient)

s := statemgr.TestFullInitialState()
sf := &statefile.File{State: s}
var stateBuf bytes.Buffer
if err := statefile.Write(sf, &stateBuf); err != nil {
t.Fatal(err)
}

var checksum string
err = client.put(stateBuf.Bytes(), func(opts *s3.Options) {
opts.APIOptions = append(opts.APIOptions,
addRetrieveChecksumHeaderMiddleware(t, &checksum),
addCancelRequestMiddleware(),
)
})
if err == nil {
t.Fatal("Expected an error, got none")
} else if !errors.Is(err, errCancelOperation) {
t.Fatalf("Unexpected error: %s", err)
}

if a, e := checksum, testcase.expected; a != e {
t.Fatalf("expected %q, got %q", e, a)
}
})
}
}

func addRetrieveChecksumHeaderMiddleware(t *testing.T, checksum *string) func(*middleware.Stack) error {
return func(stack *middleware.Stack) error {
return stack.Finalize.Add(
retrieveChecksumHeaderMiddleware(t, checksum),
middleware.After,
)
}
}

func retrieveChecksumHeaderMiddleware(t *testing.T, checksum *string) middleware.FinalizeMiddleware {
return middleware.FinalizeMiddlewareFunc(
"Test: Retrieve Stuff",
func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) {
t.Helper()

request, ok := in.Request.(*smithyhttp.Request)
if !ok {
t.Fatalf("Expected *github.com/aws/smithy-go/transport/http.Request, got %s", fullTypeName(in.Request))
}

*checksum = request.Header.Get("x-amz-sdk-checksum-algorithm")

return next.HandleFinalize(ctx, in)
})
}
6 changes: 5 additions & 1 deletion website/docs/language/settings/backends/s3.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ The following configuration is optional:
* `shared_credentials_file` - (Optional, **Deprecated**, use `shared_credentials_files` instead) Path to the AWS shared credentials file. Defaults to `~/.aws/credentials`.
* `shared_credentials_files` - (Optional) List of paths to AWS shared credentials files. Defaults to `~/.aws/credentials`.
* `skip_credentials_validation` - (Optional) Skip credentials validation via the STS API.
Useful for testing and for AWS API implementations that do not have STS available.
* `skip_region_validation` - (Optional) Skip validation of provided region name.
* `skip_requesting_account_id` - (Optional) Whether to skip requesting the account ID. Useful for AWS API implementations that do not have the IAM, STS API, or metadata API.
* `skip_requesting_account_id` - (Optional) Whether to skip requesting the account ID.
Useful for AWS API implementations that do not have the IAM, STS API, or metadata API.
* `skip_metadata_api_check` - (Optional) Skip usage of EC2 Metadata API.
* `skip_s3_checksum` - (Optional) Do not include checksum when uploading S3 Objects.
Useful for some S3-Compatible APIs.
* `sts_endpoint` - (Optional, **Deprecated**) Custom endpoint URL for the AWS Security Token Service (STS) API.
Use `endpoints.sts` instead.
* `sts_region` - (Optional) AWS region for STS. If unset, AWS will use the same region for STS as other non-STS operations.
Expand Down

0 comments on commit 57e2fc9

Please sign in to comment.