From 49012571a8355d5a48c230fd9219de08f5c9f124 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 16 Feb 2022 14:19:34 -0800 Subject: [PATCH 1/5] Adds tests for `AWS_CA_BUNDLE`. Cannot use AWS SDK v2 HTTP client with AWS SDK v1 when setting CA bundle --- .semgrep/imports.yml | 2 + aws_config.go | 3 +- aws_config_test.go | 141 +++++++++++++++++ .../http_client.go => http_client.go | 4 +- ...http_client_test.go => http_client_test.go | 22 +-- servicemocks/pem_file.go | 41 +++++ v2/awsv1shim/go.mod | 1 + v2/awsv1shim/go.sum | 2 + v2/awsv1shim/http_client.go | 41 +++++ v2/awsv1shim/http_client_test.go | 71 +++++++++ v2/awsv1shim/session.go | 12 +- v2/awsv1shim/session_test.go | 144 ++++++++++++++++++ 12 files changed, 461 insertions(+), 23 deletions(-) rename internal/httpclient/http_client.go => http_client.go (88%) rename internal/httpclient/http_client_test.go => http_client_test.go (62%) create mode 100644 servicemocks/pem_file.go create mode 100644 v2/awsv1shim/http_client.go create mode 100644 v2/awsv1shim/http_client_test.go diff --git a/.semgrep/imports.yml b/.semgrep/imports.yml index 526ab116..1609457d 100644 --- a/.semgrep/imports.yml +++ b/.semgrep/imports.yml @@ -28,4 +28,6 @@ rules: - metavariable-regex: metavariable: "$X" regex: '^"github.com/aws/aws-sdk-go-v2/.+"$' + - pattern-not: | + import ("github.com/aws/aws-sdk-go-v2/aws/transport/http") severity: ERROR diff --git a/aws_config.go b/aws_config.go index 5849f52b..8bbe394d 100644 --- a/aws_config.go +++ b/aws_config.go @@ -19,7 +19,6 @@ import ( "github.com/aws/smithy-go/middleware" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" "github.com/hashicorp/aws-sdk-go-base/v2/internal/endpoints" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" "github.com/hashicorp/go-multierror" "github.com/mitchellh/go-homedir" ) @@ -140,7 +139,7 @@ func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, c *C } func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { - httpClient, err := httpclient.DefaultHttpClient(c) + httpClient, err := defaultHttpClient(c) if err != nil { return nil, err } diff --git a/aws_config_test.go b/aws_config_test.go index 7dd76c3e..c3ed1a4c 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "os" "reflect" "runtime" @@ -2019,6 +2020,146 @@ ec2_metadata_service_endpoint_mode = IPv4 } } +func TestCustomCABundle(t *testing.T) { + testCases := map[string]struct { + Config *Config + SetEnvironmentVariable bool + SetSharedConfigurationFile bool + ExpectTLSClientConfigRootCAsSet bool + }{ + "no configuration": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectTLSClientConfigRootCAsSet: false, + }, + + // "config": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, + // }, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + "envvar": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnvironmentVariable: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 + // "shared configuration file": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // SetSharedConfigurationFile: true, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + // "config overrides envvar": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, + // }, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + // "envvar overrides shared configuration": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, + // }, + // SharedConfigurationFile: ` + // [default] + // ec2_metadata_service_endpoint_mode = IPv4 + // `, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + pemFile, err := servicemocks.TempPEMFile() + defer os.Remove(pemFile) + t.Logf("PEM file name: %s", pemFile) + if err != nil { + t.Fatalf("error creating PEM file: %s", err) + } + + if testCase.SetEnvironmentVariable { + os.Setenv("AWS_CA_BUNDLE", pemFile) + } + + if testCase.SetSharedConfigurationFile { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile( + file.Name(), + []byte(fmt.Sprintf(` +[default] +ca_bundle = %s +`, pemFile)), + 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err) + } + + type transportGetter interface { + GetTransport() *http.Transport + } + + trGetter := awsConfig.HTTPClient.(transportGetter) + tr := trGetter.GetTransport() + + if a, e := tr.TLSClientConfig.RootCAs != nil, testCase.ExpectTLSClientConfigRootCAsSet; a != e { + t.Errorf("expected(%t) CA Bundle, got: %t", e, a) + } + }) + } +} + func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) diff --git a/internal/httpclient/http_client.go b/http_client.go similarity index 88% rename from internal/httpclient/http_client.go rename to http_client.go index 2741c331..6fac1421 100644 --- a/internal/httpclient/http_client.go +++ b/http_client.go @@ -1,4 +1,4 @@ -package httpclient +package awsbase import ( "fmt" @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" ) -func DefaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) { +func defaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) { var err error httpClient := awshttp.NewBuildableClient(). diff --git a/internal/httpclient/http_client_test.go b/http_client_test.go similarity index 62% rename from internal/httpclient/http_client_test.go rename to http_client_test.go index 7e9ea8a7..2529fe20 100644 --- a/internal/httpclient/http_client_test.go +++ b/http_client_test.go @@ -1,40 +1,42 @@ -package httpclient_test +package awsbase import ( "crypto/tls" "testing" - "github.com/aws/aws-sdk-go-v2/aws/transport/http" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" ) func TestHTTPClientConfiguration_basic(t *testing.T) { - client, err := httpclient.DefaultHttpClient(&config.Config{}) + client, err := defaultHttpClient(&config.Config{}) if err != nil { t.Fatalf("unexpected error: %s", err) } transport := client.GetTransport() - if a, e := transport.MaxIdleConns, http.DefaultHTTPTransportMaxIdleConns; a != e { + if a, e := transport.MaxIdleConns, awshttp.DefaultHTTPTransportMaxIdleConns; a != e { t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) } - if a, e := transport.MaxIdleConnsPerHost, http.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { + if a, e := transport.MaxIdleConnsPerHost, awshttp.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) } - if a, e := transport.IdleConnTimeout, http.DefaultHTTPTransportIdleConnTimeout; a != e { + if a, e := transport.IdleConnTimeout, awshttp.DefaultHTTPTransportIdleConnTimeout; a != e { t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) } - if a, e := transport.TLSHandshakeTimeout, http.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { + if a, e := transport.TLSHandshakeTimeout, awshttp.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) } - if a, e := transport.ExpectContinueTimeout, http.DefaultHTTPTransportExpectContinueTimeout; a != e { + if a, e := transport.ExpectContinueTimeout, awshttp.DefaultHTTPTransportExpectContinueTimeout; a != e { t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) } if !transport.ForceAttemptHTTP2 { t.Error("expected ForceAttemptHTTP2 to be true, got false") } + if transport.DisableKeepAlives { + t.Error("expected DisableKeepAlives to be false, got true") + } tlsConfig := transport.TLSClientConfig if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { @@ -46,7 +48,7 @@ func TestHTTPClientConfiguration_basic(t *testing.T) { } func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { - client, err := httpclient.DefaultHttpClient(&config.Config{ + client, err := defaultHttpClient(&config.Config{ Insecure: true, }) if err != nil { diff --git a/servicemocks/pem_file.go b/servicemocks/pem_file.go new file mode 100644 index 00000000..fea9672f --- /dev/null +++ b/servicemocks/pem_file.go @@ -0,0 +1,41 @@ +package servicemocks + +import ( + "io/ioutil" +) + +func TempPEMFile() (string, error) { + file, err := ioutil.TempFile(".", "bundle-*.pem") + if err != nil { + return "", err + } + defer file.Close() + + _, err = file.Write(TLSBundleCA) + if err != nil { + return "", err + } + + return file.Name(), nil +} + +var ( + // TLSBundleCA ca.crt + TLSBundleCA = []byte(`-----BEGIN CERTIFICATE----- +MIICiTCCAfKgAwIBAgIJAJ5X1olt05XjMA0GCSqGSIb3DQEBCwUAMDgxCzAJBgNV +BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD +QTAeFw0xNzAzMDkwMDAyMDZaFw0yNzAzMDcwMDAyMDZaMDgxCzAJBgNVBAYTAkdP +MQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBDQTCBnzAN +BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAw/8DN+t9XQR60jx42rsQ2WE2Dx85rb3n +GQxnKZZLNddsT8rDyxJNP18aFalbRbFlyln5fxWxZIblu9Xkm/HRhOpbSimSqo1y +uDx21NVZ1YsOvXpHby71jx3gPrrhSc/t/zikhi++6D/C6m1CiIGuiJ0GBiJxtrub +UBMXT0QtI2ECAwEAAaOBmjCBlzAdBgNVHQ4EFgQU8XG3X/YHBA6T04kdEkq6+4GV +YykwaAYDVR0jBGEwX4AU8XG3X/YHBA6T04kdEkq6+4GVYymhPKQ6MDgxCzAJBgNV +BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD +QYIJAJ5X1olt05XjMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADgYEAeILv +z49+uxmPcfOZzonuOloRcpdvyjiXblYxbzz6ch8GsE7Q886FTZbvwbgLhzdwSVgG +G8WHkodDUsymVepdqAamS3f8PdCUk8xIk9mop8LgaB9Ns0/TssxDvMr3sOD2Grb3 +xyWymTWMcj6uCiEBKtnUp4rPiefcvCRYZ17/hLE= +-----END CERTIFICATE----- +`) +) diff --git a/v2/awsv1shim/go.mod b/v2/awsv1shim/go.mod index 202640c2..dd0bd8a1 100644 --- a/v2/awsv1shim/go.mod +++ b/v2/awsv1shim/go.mod @@ -5,6 +5,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.13.0 github.com/google/go-cmp v0.5.7 github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.7 + github.com/hashicorp/go-cleanhttp v0.5.2 ) go 1.16 diff --git a/v2/awsv1shim/go.sum b/v2/awsv1shim/go.sum index 509fea30..6088d0a9 100644 --- a/v2/awsv1shim/go.sum +++ b/v2/awsv1shim/go.sum @@ -32,6 +32,8 @@ github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= diff --git a/v2/awsv1shim/http_client.go b/v2/awsv1shim/http_client.go new file mode 100644 index 00000000..af18b020 --- /dev/null +++ b/v2/awsv1shim/http_client.go @@ -0,0 +1,41 @@ +package awsv1shim + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/go-cleanhttp" +) + +func defaultHttpClient(c *config.Config) (*http.Client, error) { + httpClient := cleanhttp.DefaultPooledClient() + transport := httpClient.Transport.(*http.Transport) + + transport.MaxIdleConnsPerHost = awshttp.DefaultHTTPTransportMaxIdleConnsPerHost + + tlsConfig := transport.TLSClientConfig + if tlsConfig == nil { + tlsConfig = &tls.Config{} + transport.TLSClientConfig = tlsConfig + } + tlsConfig.MinVersion = tls.VersionTLS12 + + if c.Insecure { + tlsConfig.InsecureSkipVerify = true + } + + if c.HTTPProxy != "" { + proxyUrl, err := url.Parse(c.HTTPProxy) + if err != nil { + return nil, fmt.Errorf("error parsing HTTP proxy URL: %w", err) + } + + transport.Proxy = http.ProxyURL(proxyUrl) + } + + return httpClient, nil +} diff --git a/v2/awsv1shim/http_client_test.go b/v2/awsv1shim/http_client_test.go new file mode 100644 index 00000000..98129d97 --- /dev/null +++ b/v2/awsv1shim/http_client_test.go @@ -0,0 +1,71 @@ +package awsv1shim + +import ( + "crypto/tls" + "net/http" + "testing" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" +) + +func TestHTTPClientConfiguration_basic(t *testing.T) { + client, err := defaultHttpClient(&config.Config{}) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) + } + + if a, e := transport.MaxIdleConns, awshttp.DefaultHTTPTransportMaxIdleConns; a != e { + t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) + } + if a, e := transport.MaxIdleConnsPerHost, awshttp.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { + t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) + } + if a, e := transport.IdleConnTimeout, awshttp.DefaultHTTPTransportIdleConnTimeout; a != e { + t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) + } + if a, e := transport.TLSHandshakeTimeout, awshttp.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { + t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) + } + if a, e := transport.ExpectContinueTimeout, awshttp.DefaultHTTPTransportExpectContinueTimeout; a != e { + t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) + } + if !transport.ForceAttemptHTTP2 { + t.Error("expected ForceAttemptHTTP2 to be true, got false") + } + if transport.DisableKeepAlives { + t.Error("expected DisableKeepAlives to be false, got true") + } + + tlsConfig := transport.TLSClientConfig + if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { + t.Errorf("expected tlsConfig.MinVersion to be %d, got %d", e, a) + } + if tlsConfig.InsecureSkipVerify { + t.Error("expected InsecureSkipVerify to be false, got true") + } +} + +func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { + client, err := defaultHttpClient(&config.Config{ + Insecure: true, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) + } + + tlsConfig := transport.TLSClientConfig + if !tlsConfig.InsecureSkipVerify { + t.Error("expected InsecureSkipVerify to be true, got false") + } +} diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 89bce01a..0a57fa39 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -4,7 +4,6 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim "context" "fmt" "log" - "net/http" "os" awsv2 "github.com/aws/aws-sdk-go-v2/aws" @@ -17,7 +16,6 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim "github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr" "github.com/hashicorp/aws-sdk-go-base/v2/internal/awsconfig" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" ) // getSessionOptions attempts to return valid AWS Go SDK session authentication @@ -39,13 +37,9 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, return nil, fmt.Errorf("error resolving configuration: %w", err) } - httpClient, ok := awsC.HTTPClient.(*http.Client) - if !ok { // This is unlikely, but technically possible - client, err := httpclient.DefaultHttpClient(c) - if err != nil { - return nil, err - } - httpClient = client.Freeze().(*http.Client) + httpClient, err := defaultHttpClient(c) + if err != nil { + return nil, err } options := &session.Options{ Config: aws.Config{ diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index ccc2c58b..d70bb767 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "os" "runtime" "testing" @@ -1549,6 +1550,149 @@ func DualStackEndpointStateString(state endpoints.DualStackEndpointState) string return fmt.Sprintf("unknown endpoints.FIPSEndpointStateUnset (%d)", state) } +func TestCustomCABundle(t *testing.T) { + testCases := map[string]struct { + Config *awsbase.Config + SetEnvironmentVariable bool + SetSharedConfigurationFile bool + ExpectTLSClientConfigRootCAsSet bool + }{ + "no configuration": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectTLSClientConfigRootCAsSet: false, + }, + + // "config": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, + // }, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + "envvar": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnvironmentVariable: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 + // "shared configuration file": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // SetSharedConfigurationFile: true, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + // "config overrides envvar": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, + // }, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + // "envvar overrides shared configuration": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, + // }, + // SharedConfigurationFile: ` + // [default] + // ec2_metadata_service_endpoint_mode = IPv4 + // `, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + pemFile, err := servicemocks.TempPEMFile() + defer os.Remove(pemFile) + t.Logf("PEM file name: %s", pemFile) + if err != nil { + t.Fatalf("error creating PEM file: %s", err) + } + + if testCase.SetEnvironmentVariable { + os.Setenv("AWS_CA_BUNDLE", pemFile) + } + + if testCase.SetSharedConfigurationFile { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile( + file.Name(), + []byte(fmt.Sprintf(` +[default] +ca_bundle = %s +`, pemFile)), + 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("GetAwsConfig() returned error: %s", err) + } + actualSession, err := GetSession(&awsConfig, testCase.Config) + if err != nil { + t.Fatalf("error in GetSession() '%[1]T': %[1]s", err) + } + + roundTripper := actualSession.Config.HTTPClient.Transport + tr, ok := roundTripper.(*http.Transport) + if !ok { + t.Fatalf("Unexpected type for HTTP client transport: %T", roundTripper) + } + + if a, e := tr.TLSClientConfig.RootCAs != nil, testCase.ExpectTLSClientConfigRootCAsSet; a != e { + t.Errorf("expected(%t) CA Bundle, got: %t", e, a) + } + }) + } +} + func TestSessionRetryHandlers(t *testing.T) { const maxRetries = 25 From 7c77742b6ff6ca505422c2b2e5a8a85dfbe25a67 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 16 Feb 2022 14:37:11 -0800 Subject: [PATCH 2/5] Centralizes HTTP client configuration tests --- http_client_test.go | 38 ++----------------------- internal/test/http_client.go | 48 ++++++++++++++++++++++++++++++++ v2/awsv1shim/http_client_test.go | 38 ++----------------------- 3 files changed, 54 insertions(+), 70 deletions(-) create mode 100644 internal/test/http_client.go diff --git a/http_client_test.go b/http_client_test.go index 2529fe20..aac32985 100644 --- a/http_client_test.go +++ b/http_client_test.go @@ -1,11 +1,10 @@ package awsbase import ( - "crypto/tls" "testing" - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" ) func TestHTTPClientConfiguration_basic(t *testing.T) { @@ -16,35 +15,7 @@ func TestHTTPClientConfiguration_basic(t *testing.T) { transport := client.GetTransport() - if a, e := transport.MaxIdleConns, awshttp.DefaultHTTPTransportMaxIdleConns; a != e { - t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) - } - if a, e := transport.MaxIdleConnsPerHost, awshttp.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { - t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) - } - if a, e := transport.IdleConnTimeout, awshttp.DefaultHTTPTransportIdleConnTimeout; a != e { - t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) - } - if a, e := transport.TLSHandshakeTimeout, awshttp.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { - t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) - } - if a, e := transport.ExpectContinueTimeout, awshttp.DefaultHTTPTransportExpectContinueTimeout; a != e { - t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) - } - if !transport.ForceAttemptHTTP2 { - t.Error("expected ForceAttemptHTTP2 to be true, got false") - } - if transport.DisableKeepAlives { - t.Error("expected DisableKeepAlives to be false, got true") - } - - tlsConfig := transport.TLSClientConfig - if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { - t.Errorf("expected tlsConfig.MinVersion to be %d, got %d", e, a) - } - if tlsConfig.InsecureSkipVerify { - t.Error("expected InsecureSkipVerify to be false, got true") - } + test.HTTPClientConfigurationTest_basic(t, transport) } func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { @@ -57,8 +28,5 @@ func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { transport := client.GetTransport() - tlsConfig := transport.TLSClientConfig - if !tlsConfig.InsecureSkipVerify { - t.Error("expected InsecureSkipVerify to be true, got false") - } + test.HTTPClientConfigurationTest_insecureHTTPS(t, transport) } diff --git a/internal/test/http_client.go b/internal/test/http_client.go new file mode 100644 index 00000000..31a05620 --- /dev/null +++ b/internal/test/http_client.go @@ -0,0 +1,48 @@ +package test + +import ( + "crypto/tls" + "net/http" + "testing" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" +) + +func HTTPClientConfigurationTest_basic(t *testing.T, transport *http.Transport) { + if a, e := transport.MaxIdleConns, awshttp.DefaultHTTPTransportMaxIdleConns; a != e { + t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) + } + if a, e := transport.MaxIdleConnsPerHost, awshttp.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { + t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) + } + if a, e := transport.IdleConnTimeout, awshttp.DefaultHTTPTransportIdleConnTimeout; a != e { + t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) + } + if a, e := transport.TLSHandshakeTimeout, awshttp.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { + t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) + } + if a, e := transport.ExpectContinueTimeout, awshttp.DefaultHTTPTransportExpectContinueTimeout; a != e { + t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) + } + if !transport.ForceAttemptHTTP2 { + t.Error("expected ForceAttemptHTTP2 to be true, got false") + } + if transport.DisableKeepAlives { + t.Error("expected DisableKeepAlives to be false, got true") + } + + tlsConfig := transport.TLSClientConfig + if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { + t.Errorf("expected tlsConfig.MinVersion to be %d, got %d", e, a) + } + if tlsConfig.InsecureSkipVerify { + t.Error("expected InsecureSkipVerify to be false, got true") + } +} + +func HTTPClientConfigurationTest_insecureHTTPS(t *testing.T, transport *http.Transport) { + tlsConfig := transport.TLSClientConfig + if !tlsConfig.InsecureSkipVerify { + t.Error("expected InsecureSkipVerify to be true, got false") + } +} diff --git a/v2/awsv1shim/http_client_test.go b/v2/awsv1shim/http_client_test.go index 98129d97..2fa27117 100644 --- a/v2/awsv1shim/http_client_test.go +++ b/v2/awsv1shim/http_client_test.go @@ -1,12 +1,11 @@ package awsv1shim import ( - "crypto/tls" "net/http" "testing" - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" ) func TestHTTPClientConfiguration_basic(t *testing.T) { @@ -20,35 +19,7 @@ func TestHTTPClientConfiguration_basic(t *testing.T) { t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) } - if a, e := transport.MaxIdleConns, awshttp.DefaultHTTPTransportMaxIdleConns; a != e { - t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) - } - if a, e := transport.MaxIdleConnsPerHost, awshttp.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { - t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) - } - if a, e := transport.IdleConnTimeout, awshttp.DefaultHTTPTransportIdleConnTimeout; a != e { - t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) - } - if a, e := transport.TLSHandshakeTimeout, awshttp.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { - t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) - } - if a, e := transport.ExpectContinueTimeout, awshttp.DefaultHTTPTransportExpectContinueTimeout; a != e { - t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) - } - if !transport.ForceAttemptHTTP2 { - t.Error("expected ForceAttemptHTTP2 to be true, got false") - } - if transport.DisableKeepAlives { - t.Error("expected DisableKeepAlives to be false, got true") - } - - tlsConfig := transport.TLSClientConfig - if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { - t.Errorf("expected tlsConfig.MinVersion to be %d, got %d", e, a) - } - if tlsConfig.InsecureSkipVerify { - t.Error("expected InsecureSkipVerify to be false, got true") - } + test.HTTPClientConfigurationTest_basic(t, transport) } func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { @@ -64,8 +35,5 @@ func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) } - tlsConfig := transport.TLSClientConfig - if !tlsConfig.InsecureSkipVerify { - t.Error("expected InsecureSkipVerify to be true, got false") - } + test.HTTPClientConfigurationTest_insecureHTTPS(t, transport) } From dc4cbe98d4b2ac01115401967473a0cfa046bbac Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 16 Feb 2022 15:40:57 -0800 Subject: [PATCH 3/5] Adds config parameter --- aws_config.go | 11 +++++++++++ aws_config_test.go | 23 ++++++++++++++--------- internal/config/config.go | 1 + servicemocks/pem_file.go | 2 +- v2/awsv1shim/session.go | 10 ++++++++++ v2/awsv1shim/session_test.go | 24 ++++++++++++++---------- 6 files changed, 51 insertions(+), 20 deletions(-) diff --git a/aws_config.go b/aws_config.go index 8bbe394d..199d37bd 100644 --- a/aws_config.go +++ b/aws_config.go @@ -1,6 +1,7 @@ package awsbase import ( + "bytes" "context" "errors" "fmt" @@ -181,6 +182,16 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { ) } + if c.CustomCABundle != "" { + bundle, err := os.ReadFile(c.CustomCABundle) + if err != nil { + return nil, fmt.Errorf("error reading custom CA bundle %q: %w", c.CustomCABundle, err) + } + loadOptions = append(loadOptions, + config.WithCustomCABundle(bytes.NewReader(bundle)), + ) + } + if c.EC2MetadataServiceEndpoint != "" { loadOptions = append(loadOptions, config.WithEC2IMDSEndpoint(c.EC2MetadataServiceEndpoint), diff --git a/aws_config_test.go b/aws_config_test.go index c3ed1a4c..57f0c100 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -2023,6 +2023,7 @@ ec2_metadata_service_endpoint_mode = IPv4 func TestCustomCABundle(t *testing.T) { testCases := map[string]struct { Config *Config + SetConfig bool SetEnvironmentVariable bool SetSharedConfigurationFile bool ExpectTLSClientConfigRootCAsSet bool @@ -2036,15 +2037,15 @@ func TestCustomCABundle(t *testing.T) { ExpectTLSClientConfigRootCAsSet: false, }, - // "config": { - // Config: &Config{ - // AccessKey: servicemocks.MockStaticAccessKey, - // Region: "us-east-1", - // SecretKey: servicemocks.MockStaticSecretKey, - // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, - // }, - // ExpectTLSClientConfigRootCAsSet: true, - // }, + "config": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpectTLSClientConfigRootCAsSet: true, + }, "envvar": { Config: &Config{ @@ -2111,6 +2112,10 @@ func TestCustomCABundle(t *testing.T) { t.Fatalf("error creating PEM file: %s", err) } + if testCase.SetConfig { + testCase.Config.CustomCABundle = pemFile + } + if testCase.SetEnvironmentVariable { os.Setenv("AWS_CA_BUNDLE", pemFile) } diff --git a/internal/config/config.go b/internal/config/config.go index 318d9ae7..087b7b81 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,6 +8,7 @@ type Config struct { AssumeRole *AssumeRole CallerDocumentationURL string CallerName string + CustomCABundle string EC2MetadataServiceEndpoint string EC2MetadataServiceEndpointMode string HTTPProxy string diff --git a/servicemocks/pem_file.go b/servicemocks/pem_file.go index fea9672f..fde238e4 100644 --- a/servicemocks/pem_file.go +++ b/servicemocks/pem_file.go @@ -5,7 +5,7 @@ import ( ) func TempPEMFile() (string, error) { - file, err := ioutil.TempFile(".", "bundle-*.pem") + file, err := ioutil.TempFile("", "bundle-*.pem") if err != nil { return "", err } diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 0a57fa39..32493bad 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -1,6 +1,7 @@ package awsv1shim import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim + "bytes" "context" "fmt" "log" @@ -41,6 +42,7 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, if err != nil { return nil, err } + options := &session.Options{ Config: aws.Config{ Credentials: credentials.NewStaticCredentials( @@ -58,6 +60,14 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, }, } + if c.CustomCABundle != "" { + bundle, err := os.ReadFile(c.CustomCABundle) + if err != nil { + return nil, fmt.Errorf("error reading custom CA bundle %q: %w", c.CustomCABundle, err) + } + options.CustomCABundle = bytes.NewReader(bundle) + } + return options, nil } diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index d70bb767..82693727 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -1553,6 +1553,7 @@ func DualStackEndpointStateString(state endpoints.DualStackEndpointState) string func TestCustomCABundle(t *testing.T) { testCases := map[string]struct { Config *awsbase.Config + SetConfig bool SetEnvironmentVariable bool SetSharedConfigurationFile bool ExpectTLSClientConfigRootCAsSet bool @@ -1566,15 +1567,15 @@ func TestCustomCABundle(t *testing.T) { ExpectTLSClientConfigRootCAsSet: false, }, - // "config": { - // Config: &awsbase.Config{ - // AccessKey: servicemocks.MockStaticAccessKey, - // Region: "us-east-1", - // SecretKey: servicemocks.MockStaticSecretKey, - // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, - // }, - // ExpectTLSClientConfigRootCAsSet: true, - // }, + "config": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpectTLSClientConfigRootCAsSet: true, + }, "envvar": { Config: &awsbase.Config{ @@ -1636,11 +1637,14 @@ func TestCustomCABundle(t *testing.T) { pemFile, err := servicemocks.TempPEMFile() defer os.Remove(pemFile) - t.Logf("PEM file name: %s", pemFile) if err != nil { t.Fatalf("error creating PEM file: %s", err) } + if testCase.SetConfig { + testCase.Config.CustomCABundle = pemFile + } + if testCase.SetEnvironmentVariable { os.Setenv("AWS_CA_BUNDLE", pemFile) } From c4ed8f61fe0e09c966ccbf04f322373fa5cef8e3 Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 16 Feb 2022 16:19:40 -0800 Subject: [PATCH 4/5] Expands environment variables in CA bundle path --- aws_config.go | 34 ++------- aws_config_test.go | 114 ++++++++++++------------------- credentials.go | 5 +- internal/config/apn_info.go | 5 ++ internal/config/config.go | 37 ++++++---- internal/config/user_agent.go | 8 +++ internal/expand/filepath.go | 28 ++++++++ internal/expand/filepath_test.go | 67 ++++++++++++++++++ servicemocks/setup.go | 15 ++++ v2/awsv1shim/session.go | 7 +- v2/awsv1shim/session_test.go | 55 +++++++++++---- 11 files changed, 244 insertions(+), 131 deletions(-) create mode 100644 internal/expand/filepath.go create mode 100644 internal/expand/filepath_test.go diff --git a/aws_config.go b/aws_config.go index 199d37bd..090947a4 100644 --- a/aws_config.go +++ b/aws_config.go @@ -1,7 +1,6 @@ package awsbase import ( - "bytes" "context" "errors" "fmt" @@ -20,8 +19,7 @@ import ( "github.com/aws/smithy-go/middleware" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" "github.com/hashicorp/aws-sdk-go-base/v2/internal/endpoints" - "github.com/hashicorp/go-multierror" - "github.com/mitchellh/go-homedir" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" ) func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { @@ -172,9 +170,9 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { } if len(c.SharedConfigFiles) > 0 { - configFiles, err := expandFilePaths(c.SharedConfigFiles) + configFiles, err := expand.FilePaths(c.SharedConfigFiles) if err != nil { - return nil, fmt.Errorf("error expanding shared config files: %w", err) + return nil, fmt.Errorf("expanding shared config files: %w", err) } loadOptions = append( loadOptions, @@ -183,12 +181,12 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { } if c.CustomCABundle != "" { - bundle, err := os.ReadFile(c.CustomCABundle) + reader, err := c.CustomCABundleReader() if err != nil { - return nil, fmt.Errorf("error reading custom CA bundle %q: %w", c.CustomCABundle, err) + return nil, err } loadOptions = append(loadOptions, - config.WithCustomCABundle(bytes.NewReader(bundle)), + config.WithCustomCABundle(reader), ) } @@ -232,23 +230,3 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { return loadOptions, nil } - -func expandFilePaths(in []string) ([]string, error) { - var errs *multierror.Error - result := make([]string, 0, len(in)) - for _, v := range in { - p, err := expandFilePath(v) - if err != nil { - errs = multierror.Append(errs, err) - continue - } - result = append(result, p) - } - return result, errs.ErrorOrNil() -} - -func expandFilePath(in string) (s string, err error) { - e := os.ExpandEnv(in) - s, err = homedir.Expand(e) - return -} diff --git a/aws_config_test.go b/aws_config_test.go index 57f0c100..1020dbce 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "os" + "path/filepath" "reflect" "runtime" "strings" @@ -2026,6 +2027,8 @@ func TestCustomCABundle(t *testing.T) { SetConfig bool SetEnvironmentVariable bool SetSharedConfigurationFile bool + ExpandEnvVars bool + EnvironmentVariables map[string]string ExpectTLSClientConfigRootCAsSet bool }{ "no configuration": { @@ -2047,6 +2050,17 @@ func TestCustomCABundle(t *testing.T) { ExpectTLSClientConfigRootCAsSet: true, }, + "expanded config": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpandEnvVars: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + "envvar": { Config: &Config{ AccessKey: servicemocks.MockStaticAccessKey, @@ -2068,19 +2082,20 @@ func TestCustomCABundle(t *testing.T) { // ExpectTLSClientConfigRootCAsSet: true, // }, - // "config overrides envvar": { - // Config: &Config{ - // AccessKey: servicemocks.MockStaticAccessKey, - // Region: "us-east-1", - // SecretKey: servicemocks.MockStaticSecretKey, - // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, - // }, - // EnvironmentVariables: map[string]string{ - // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, - // }, - // ExpectTLSClientConfigRootCAsSet: true, - // }, + "config overrides envvar": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + EnvironmentVariables: map[string]string{ + "AWS_CA_BUNDLE": "no-such-file", + }, + ExpectTLSClientConfigRootCAsSet: true, + }, + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 // "envvar overrides shared configuration": { // Config: &Config{ // AccessKey: servicemocks.MockStaticAccessKey, @@ -2104,14 +2119,29 @@ func TestCustomCABundle(t *testing.T) { t.Run(testName, func(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) + servicemocks.RestoreEnv(oldEnv, "TMPDIR") + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } pemFile, err := servicemocks.TempPEMFile() defer os.Remove(pemFile) - t.Logf("PEM file name: %s", pemFile) if err != nil { t.Fatalf("error creating PEM file: %s", err) } + if testCase.ExpandEnvVars { + tmpdir := os.Getenv("TMPDIR") + rel, err := filepath.Rel(tmpdir, pemFile) + if err != nil { + t.Fatalf("error making path relative: %s", err) + } + t.Logf("relative: %s", rel) + pemFile = filepath.Join("$TMPDIR", rel) + t.Logf("env tempfile: %s", pemFile) + } + if testCase.SetConfig { testCase.Config.CustomCABundle = pemFile } @@ -2496,61 +2526,3 @@ func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error) return 0 * time.Second, nil } - -func TestExpandFilePath(t *testing.T) { - testcases := map[string]struct { - path string - expected string - envvars map[string]string - }{ - "filename": { - path: "file", - expected: "file", - }, - "file in current dir": { - path: "./file", - expected: "./file", - }, - "file with tilde": { - path: "~/file", - expected: "/my/home/dir/file", - envvars: map[string]string{ - "HOME": "/my/home/dir", - }, - }, - "file with envvar": { - path: "$HOME/file", - expected: "/home/dir/file", - envvars: map[string]string{ - "HOME": "/home/dir", - }, - }, - "full file in envvar": { - path: "$CONF_FILE", - expected: "/path/to/conf/file", - envvars: map[string]string{ - "CONF_FILE": "/path/to/conf/file", - }, - }, - } - - for name, testcase := range testcases { - t.Run(name, func(t *testing.T) { - oldEnv := servicemocks.StashEnv() - defer servicemocks.PopEnv(oldEnv) - - for k, v := range testcase.envvars { - os.Setenv(k, v) - } - - a, err := expandFilePath(testcase.path) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if a != testcase.expected { - t.Errorf("expected expansion to %q, got %q", testcase.expected, a) - } - }) - } -} diff --git a/credentials.go b/credentials.go index 66ee896f..61048af8 100644 --- a/credentials.go +++ b/credentials.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" ) func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProvider, error) { @@ -41,9 +42,9 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv ) } if len(c.SharedCredentialsFiles) > 0 { - credsFiles, err := expandFilePaths(c.SharedCredentialsFiles) + credsFiles, err := expand.FilePaths(c.SharedCredentialsFiles) if err != nil { - return nil, fmt.Errorf("error expanding shared credentials files: %w", err) + return nil, fmt.Errorf("expanding shared credentials files: %w", err) } loadOptions = append( loadOptions, diff --git a/internal/config/apn_info.go b/internal/config/apn_info.go index 9079fe98..3f639391 100644 --- a/internal/config/apn_info.go +++ b/internal/config/apn_info.go @@ -4,6 +4,11 @@ import ( smithyhttp "github.com/aws/smithy-go/transport/http" ) +type APNInfo struct { + PartnerName string + Products []UserAgentProduct +} + // Builds the user-agent string for APN func (apn APNInfo) BuildUserAgentString() string { builder := smithyhttp.NewUserAgentBuilder() diff --git a/internal/config/config.go b/internal/config/config.go index 087b7b81..196be84f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,13 @@ package config -import "time" +import ( + "bytes" + "fmt" + "os" + "time" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" +) type Config struct { AccessKey string @@ -31,19 +38,6 @@ type Config struct { UserAgent UserAgentProducts } -type APNInfo struct { - PartnerName string - Products []UserAgentProduct -} - -type UserAgentProduct struct { - Name string - Version string - Comment string -} - -type UserAgentProducts []UserAgentProduct - type AssumeRole struct { RoleARN string Duration time.Duration @@ -54,3 +48,18 @@ type AssumeRole struct { Tags map[string]string TransitiveTagKeys []string } + +func (c Config) CustomCABundleReader() (*bytes.Reader, error) { + if c.CustomCABundle == "" { + return nil, nil + } + bundleFile, err := expand.FilePath(c.CustomCABundle) + if err != nil { + return nil, fmt.Errorf("expanding custom CA bundle: %w", err) + } + bundle, err := os.ReadFile(bundleFile) + if err != nil { + return nil, fmt.Errorf("reading custom CA bundle: %w", err) + } + return bytes.NewReader(bundle), nil +} diff --git a/internal/config/user_agent.go b/internal/config/user_agent.go index 31f8ba9a..ffcef587 100644 --- a/internal/config/user_agent.go +++ b/internal/config/user_agent.go @@ -4,6 +4,14 @@ import ( smithyhttp "github.com/aws/smithy-go/transport/http" ) +type UserAgentProduct struct { + Name string + Version string + Comment string +} + +type UserAgentProducts []UserAgentProduct + func (ua UserAgentProducts) BuildUserAgentString() string { builder := smithyhttp.NewUserAgentBuilder() for _, p := range ua { diff --git a/internal/expand/filepath.go b/internal/expand/filepath.go new file mode 100644 index 00000000..44e2d205 --- /dev/null +++ b/internal/expand/filepath.go @@ -0,0 +1,28 @@ +package expand + +import ( + "os" + + "github.com/hashicorp/go-multierror" + "github.com/mitchellh/go-homedir" +) + +func FilePaths(in []string) ([]string, error) { + var errs *multierror.Error + result := make([]string, 0, len(in)) + for _, v := range in { + p, err := FilePath(v) + if err != nil { + errs = multierror.Append(errs, err) + continue + } + result = append(result, p) + } + return result, errs.ErrorOrNil() +} + +func FilePath(in string) (s string, err error) { + e := os.ExpandEnv(in) + s, err = homedir.Expand(e) + return +} diff --git a/internal/expand/filepath_test.go b/internal/expand/filepath_test.go new file mode 100644 index 00000000..bb0c2c02 --- /dev/null +++ b/internal/expand/filepath_test.go @@ -0,0 +1,67 @@ +package expand_test + +import ( + "os" + "testing" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" + "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" +) + +func TestExpandFilePath(t *testing.T) { + testcases := map[string]struct { + path string + expected string + envvars map[string]string + }{ + "filename": { + path: "file", + expected: "file", + }, + "file in current dir": { + path: "./file", + expected: "./file", + }, + "file with tilde": { + path: "~/file", + expected: "/my/home/dir/file", + envvars: map[string]string{ + "HOME": "/my/home/dir", + }, + }, + "file with envvar": { + path: "$HOME/file", + expected: "/home/dir/file", + envvars: map[string]string{ + "HOME": "/home/dir", + }, + }, + "full file in envvar": { + path: "$CONF_FILE", + expected: "/path/to/conf/file", + envvars: map[string]string{ + "CONF_FILE": "/path/to/conf/file", + }, + }, + } + + for name, testcase := range testcases { + t.Run(name, func(t *testing.T) { + oldEnv := servicemocks.StashEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testcase.envvars { + os.Setenv(k, v) + } + + a, err := expand.FilePath(testcase.path) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if a != testcase.expected { + t.Errorf("expected expansion to %q, got %q", testcase.expected, a) + } + }) + } +} diff --git a/servicemocks/setup.go b/servicemocks/setup.go index f0e5111f..ad4e9d5c 100644 --- a/servicemocks/setup.go +++ b/servicemocks/setup.go @@ -36,6 +36,21 @@ func PopEnv(env []string) { } } +func RestoreEnv(env []string, key string) { + for _, e := range env { + p := strings.SplitN(e, "=", 2) + k, v := p[0], "" + if k != key { + continue + } + if len(p) > 1 { + v = p[1] + } + os.Setenv(k, v) + break + } +} + // InvalidEC2MetadataEndpoint establishes a httptest server to simulate behaviour // when endpoint doesn't respond as expected func InvalidEC2MetadataEndpoint() func() { diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 32493bad..252f2cc6 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -1,7 +1,6 @@ package awsv1shim import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim - "bytes" "context" "fmt" "log" @@ -61,11 +60,11 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, } if c.CustomCABundle != "" { - bundle, err := os.ReadFile(c.CustomCABundle) + reader, err := c.CustomCABundleReader() if err != nil { - return nil, fmt.Errorf("error reading custom CA bundle %q: %w", c.CustomCABundle, err) + return nil, err } - options.CustomCABundle = bytes.NewReader(bundle) + options.CustomCABundle = reader } return options, nil diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index 82693727..00c08267 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "os" + "path/filepath" "runtime" "testing" "time" @@ -1556,6 +1557,8 @@ func TestCustomCABundle(t *testing.T) { SetConfig bool SetEnvironmentVariable bool SetSharedConfigurationFile bool + ExpandEnvVars bool + EnvironmentVariables map[string]string ExpectTLSClientConfigRootCAsSet bool }{ "no configuration": { @@ -1577,6 +1580,17 @@ func TestCustomCABundle(t *testing.T) { ExpectTLSClientConfigRootCAsSet: true, }, + "expanded config": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpandEnvVars: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + "envvar": { Config: &awsbase.Config{ AccessKey: servicemocks.MockStaticAccessKey, @@ -1598,19 +1612,20 @@ func TestCustomCABundle(t *testing.T) { // ExpectTLSClientConfigRootCAsSet: true, // }, - // "config overrides envvar": { - // Config: &awsbase.Config{ - // AccessKey: servicemocks.MockStaticAccessKey, - // Region: "us-east-1", - // SecretKey: servicemocks.MockStaticSecretKey, - // EC2MetadataServiceEndpointMode: EC2MetadataEndpointModeIPv4, - // }, - // EnvironmentVariables: map[string]string{ - // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, - // }, - // ExpectTLSClientConfigRootCAsSet: true, - // }, + "config overrides envvar": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + EnvironmentVariables: map[string]string{ + "AWS_CA_BUNDLE": "no-such-file", + }, + ExpectTLSClientConfigRootCAsSet: true, + }, + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 // "envvar overrides shared configuration": { // Config: &awsbase.Config{ // AccessKey: servicemocks.MockStaticAccessKey, @@ -1634,6 +1649,11 @@ func TestCustomCABundle(t *testing.T) { t.Run(testName, func(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) + servicemocks.RestoreEnv(oldEnv, "TMPDIR") + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } pemFile, err := servicemocks.TempPEMFile() defer os.Remove(pemFile) @@ -1641,6 +1661,17 @@ func TestCustomCABundle(t *testing.T) { t.Fatalf("error creating PEM file: %s", err) } + if testCase.ExpandEnvVars { + tmpdir := os.Getenv("TMPDIR") + rel, err := filepath.Rel(tmpdir, pemFile) + if err != nil { + t.Fatalf("error making path relative: %s", err) + } + t.Logf("relative: %s", rel) + pemFile = filepath.Join("$TMPDIR", rel) + t.Logf("env tempfile: %s", pemFile) + } + if testCase.SetConfig { testCase.Config.CustomCABundle = pemFile } From 8ae03cd044fcf99cc77e6a69250ccebaa0c175ac Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Wed, 16 Feb 2022 16:38:55 -0800 Subject: [PATCH 5/5] GitHub Actions doesn't define `TMPDIR` --- aws_config_test.go | 8 +++++++- servicemocks/setup.go | 15 --------------- v2/awsv1shim/session_test.go | 8 +++++++- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/aws_config_test.go b/aws_config_test.go index 1020dbce..fef8e36b 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -2119,12 +2119,18 @@ func TestCustomCABundle(t *testing.T) { t.Run(testName, func(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) - servicemocks.RestoreEnv(oldEnv, "TMPDIR") for k, v := range testCase.EnvironmentVariables { os.Setenv(k, v) } + tempdir, err := ioutil.TempDir("", "temp") + if err != nil { + t.Fatalf("error creating temp dir: %s", err) + } + defer os.Remove(tempdir) + os.Setenv("TMPDIR", tempdir) + pemFile, err := servicemocks.TempPEMFile() defer os.Remove(pemFile) if err != nil { diff --git a/servicemocks/setup.go b/servicemocks/setup.go index ad4e9d5c..f0e5111f 100644 --- a/servicemocks/setup.go +++ b/servicemocks/setup.go @@ -36,21 +36,6 @@ func PopEnv(env []string) { } } -func RestoreEnv(env []string, key string) { - for _, e := range env { - p := strings.SplitN(e, "=", 2) - k, v := p[0], "" - if k != key { - continue - } - if len(p) > 1 { - v = p[1] - } - os.Setenv(k, v) - break - } -} - // InvalidEC2MetadataEndpoint establishes a httptest server to simulate behaviour // when endpoint doesn't respond as expected func InvalidEC2MetadataEndpoint() func() { diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index 00c08267..c5695ad3 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -1649,12 +1649,18 @@ func TestCustomCABundle(t *testing.T) { t.Run(testName, func(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) - servicemocks.RestoreEnv(oldEnv, "TMPDIR") for k, v := range testCase.EnvironmentVariables { os.Setenv(k, v) } + tempdir, err := ioutil.TempDir("", "temp") + if err != nil { + t.Fatalf("error creating temp dir: %s", err) + } + defer os.Remove(tempdir) + os.Setenv("TMPDIR", tempdir) + pemFile, err := servicemocks.TempPEMFile() defer os.Remove(pemFile) if err != nil {