diff --git a/.semgrep/imports.yml b/.semgrep/imports.yml index 526ab116..1609457d 100644 --- a/.semgrep/imports.yml +++ b/.semgrep/imports.yml @@ -28,4 +28,6 @@ rules: - metavariable-regex: metavariable: "$X" regex: '^"github.com/aws/aws-sdk-go-v2/.+"$' + - pattern-not: | + import ("github.com/aws/aws-sdk-go-v2/aws/transport/http") severity: ERROR diff --git a/aws_config.go b/aws_config.go index 5849f52b..090947a4 100644 --- a/aws_config.go +++ b/aws_config.go @@ -19,9 +19,7 @@ import ( "github.com/aws/smithy-go/middleware" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" "github.com/hashicorp/aws-sdk-go-base/v2/internal/endpoints" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" - "github.com/hashicorp/go-multierror" - "github.com/mitchellh/go-homedir" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" ) func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { @@ -140,7 +138,7 @@ func GetAwsAccountIDAndPartition(ctx context.Context, awsConfig aws.Config, c *C } func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { - httpClient, err := httpclient.DefaultHttpClient(c) + httpClient, err := defaultHttpClient(c) if err != nil { return nil, err } @@ -172,9 +170,9 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { } if len(c.SharedConfigFiles) > 0 { - configFiles, err := expandFilePaths(c.SharedConfigFiles) + configFiles, err := expand.FilePaths(c.SharedConfigFiles) if err != nil { - return nil, fmt.Errorf("error expanding shared config files: %w", err) + return nil, fmt.Errorf("expanding shared config files: %w", err) } loadOptions = append( loadOptions, @@ -182,6 +180,16 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { ) } + if c.CustomCABundle != "" { + reader, err := c.CustomCABundleReader() + if err != nil { + return nil, err + } + loadOptions = append(loadOptions, + config.WithCustomCABundle(reader), + ) + } + if c.EC2MetadataServiceEndpoint != "" { loadOptions = append(loadOptions, config.WithEC2IMDSEndpoint(c.EC2MetadataServiceEndpoint), @@ -222,23 +230,3 @@ func commonLoadOptions(c *Config) ([]func(*config.LoadOptions) error, error) { return loadOptions, nil } - -func expandFilePaths(in []string) ([]string, error) { - var errs *multierror.Error - result := make([]string, 0, len(in)) - for _, v := range in { - p, err := expandFilePath(v) - if err != nil { - errs = multierror.Append(errs, err) - continue - } - result = append(result, p) - } - return result, errs.ErrorOrNil() -} - -func expandFilePath(in string) (s string, err error) { - e := os.ExpandEnv(in) - s, err = homedir.Expand(e) - return -} diff --git a/aws_config_test.go b/aws_config_test.go index 7dd76c3e..fef8e36b 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -6,7 +6,9 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "os" + "path/filepath" "reflect" "runtime" "strings" @@ -2019,6 +2021,186 @@ ec2_metadata_service_endpoint_mode = IPv4 } } +func TestCustomCABundle(t *testing.T) { + testCases := map[string]struct { + Config *Config + SetConfig bool + SetEnvironmentVariable bool + SetSharedConfigurationFile bool + ExpandEnvVars bool + EnvironmentVariables map[string]string + ExpectTLSClientConfigRootCAsSet bool + }{ + "no configuration": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectTLSClientConfigRootCAsSet: false, + }, + + "config": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + "expanded config": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpandEnvVars: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + "envvar": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnvironmentVariable: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 + // "shared configuration file": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // SetSharedConfigurationFile: true, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + "config overrides envvar": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + EnvironmentVariables: map[string]string{ + "AWS_CA_BUNDLE": "no-such-file", + }, + ExpectTLSClientConfigRootCAsSet: true, + }, + + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 + // "envvar overrides shared configuration": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, + // }, + // SharedConfigurationFile: ` + // [default] + // ec2_metadata_service_endpoint_mode = IPv4 + // `, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } + + tempdir, err := ioutil.TempDir("", "temp") + if err != nil { + t.Fatalf("error creating temp dir: %s", err) + } + defer os.Remove(tempdir) + os.Setenv("TMPDIR", tempdir) + + pemFile, err := servicemocks.TempPEMFile() + defer os.Remove(pemFile) + if err != nil { + t.Fatalf("error creating PEM file: %s", err) + } + + if testCase.ExpandEnvVars { + tmpdir := os.Getenv("TMPDIR") + rel, err := filepath.Rel(tmpdir, pemFile) + if err != nil { + t.Fatalf("error making path relative: %s", err) + } + t.Logf("relative: %s", rel) + pemFile = filepath.Join("$TMPDIR", rel) + t.Logf("env tempfile: %s", pemFile) + } + + if testCase.SetConfig { + testCase.Config.CustomCABundle = pemFile + } + + if testCase.SetEnvironmentVariable { + os.Setenv("AWS_CA_BUNDLE", pemFile) + } + + if testCase.SetSharedConfigurationFile { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile( + file.Name(), + []byte(fmt.Sprintf(` +[default] +ca_bundle = %s +`, pemFile)), + 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err) + } + + type transportGetter interface { + GetTransport() *http.Transport + } + + trGetter := awsConfig.HTTPClient.(transportGetter) + tr := trGetter.GetTransport() + + if a, e := tr.TLSClientConfig.RootCAs != nil, testCase.ExpectTLSClientConfigRootCAsSet; a != e { + t.Errorf("expected(%t) CA Bundle, got: %t", e, a) + } + }) + } +} + func TestGetAwsConfigWithAccountIDAndPartition(t *testing.T) { oldEnv := servicemocks.InitSessionTestEnv() defer servicemocks.PopEnv(oldEnv) @@ -2350,61 +2532,3 @@ func (r *withNoDelay) RetryDelay(attempt int, err error) (time.Duration, error) return 0 * time.Second, nil } - -func TestExpandFilePath(t *testing.T) { - testcases := map[string]struct { - path string - expected string - envvars map[string]string - }{ - "filename": { - path: "file", - expected: "file", - }, - "file in current dir": { - path: "./file", - expected: "./file", - }, - "file with tilde": { - path: "~/file", - expected: "/my/home/dir/file", - envvars: map[string]string{ - "HOME": "/my/home/dir", - }, - }, - "file with envvar": { - path: "$HOME/file", - expected: "/home/dir/file", - envvars: map[string]string{ - "HOME": "/home/dir", - }, - }, - "full file in envvar": { - path: "$CONF_FILE", - expected: "/path/to/conf/file", - envvars: map[string]string{ - "CONF_FILE": "/path/to/conf/file", - }, - }, - } - - for name, testcase := range testcases { - t.Run(name, func(t *testing.T) { - oldEnv := servicemocks.StashEnv() - defer servicemocks.PopEnv(oldEnv) - - for k, v := range testcase.envvars { - os.Setenv(k, v) - } - - a, err := expandFilePath(testcase.path) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if a != testcase.expected { - t.Errorf("expected expansion to %q, got %q", testcase.expected, a) - } - }) - } -} diff --git a/credentials.go b/credentials.go index 66ee896f..61048af8 100644 --- a/credentials.go +++ b/credentials.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" ) func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProvider, error) { @@ -41,9 +42,9 @@ func getCredentialsProvider(ctx context.Context, c *Config) (aws.CredentialsProv ) } if len(c.SharedCredentialsFiles) > 0 { - credsFiles, err := expandFilePaths(c.SharedCredentialsFiles) + credsFiles, err := expand.FilePaths(c.SharedCredentialsFiles) if err != nil { - return nil, fmt.Errorf("error expanding shared credentials files: %w", err) + return nil, fmt.Errorf("expanding shared credentials files: %w", err) } loadOptions = append( loadOptions, diff --git a/internal/httpclient/http_client.go b/http_client.go similarity index 88% rename from internal/httpclient/http_client.go rename to http_client.go index 2741c331..6fac1421 100644 --- a/internal/httpclient/http_client.go +++ b/http_client.go @@ -1,4 +1,4 @@ -package httpclient +package awsbase import ( "fmt" @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" ) -func DefaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) { +func defaultHttpClient(c *config.Config) (*awshttp.BuildableClient, error) { var err error httpClient := awshttp.NewBuildableClient(). diff --git a/http_client_test.go b/http_client_test.go new file mode 100644 index 00000000..aac32985 --- /dev/null +++ b/http_client_test.go @@ -0,0 +1,32 @@ +package awsbase + +import ( + "testing" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" +) + +func TestHTTPClientConfiguration_basic(t *testing.T) { + client, err := defaultHttpClient(&config.Config{}) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + transport := client.GetTransport() + + test.HTTPClientConfigurationTest_basic(t, transport) +} + +func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { + client, err := defaultHttpClient(&config.Config{ + Insecure: true, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + transport := client.GetTransport() + + test.HTTPClientConfigurationTest_insecureHTTPS(t, transport) +} diff --git a/internal/config/apn_info.go b/internal/config/apn_info.go index 9079fe98..3f639391 100644 --- a/internal/config/apn_info.go +++ b/internal/config/apn_info.go @@ -4,6 +4,11 @@ import ( smithyhttp "github.com/aws/smithy-go/transport/http" ) +type APNInfo struct { + PartnerName string + Products []UserAgentProduct +} + // Builds the user-agent string for APN func (apn APNInfo) BuildUserAgentString() string { builder := smithyhttp.NewUserAgentBuilder() diff --git a/internal/config/config.go b/internal/config/config.go index 318d9ae7..196be84f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,13 @@ package config -import "time" +import ( + "bytes" + "fmt" + "os" + "time" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" +) type Config struct { AccessKey string @@ -8,6 +15,7 @@ type Config struct { AssumeRole *AssumeRole CallerDocumentationURL string CallerName string + CustomCABundle string EC2MetadataServiceEndpoint string EC2MetadataServiceEndpointMode string HTTPProxy string @@ -30,19 +38,6 @@ type Config struct { UserAgent UserAgentProducts } -type APNInfo struct { - PartnerName string - Products []UserAgentProduct -} - -type UserAgentProduct struct { - Name string - Version string - Comment string -} - -type UserAgentProducts []UserAgentProduct - type AssumeRole struct { RoleARN string Duration time.Duration @@ -53,3 +48,18 @@ type AssumeRole struct { Tags map[string]string TransitiveTagKeys []string } + +func (c Config) CustomCABundleReader() (*bytes.Reader, error) { + if c.CustomCABundle == "" { + return nil, nil + } + bundleFile, err := expand.FilePath(c.CustomCABundle) + if err != nil { + return nil, fmt.Errorf("expanding custom CA bundle: %w", err) + } + bundle, err := os.ReadFile(bundleFile) + if err != nil { + return nil, fmt.Errorf("reading custom CA bundle: %w", err) + } + return bytes.NewReader(bundle), nil +} diff --git a/internal/config/user_agent.go b/internal/config/user_agent.go index 31f8ba9a..ffcef587 100644 --- a/internal/config/user_agent.go +++ b/internal/config/user_agent.go @@ -4,6 +4,14 @@ import ( smithyhttp "github.com/aws/smithy-go/transport/http" ) +type UserAgentProduct struct { + Name string + Version string + Comment string +} + +type UserAgentProducts []UserAgentProduct + func (ua UserAgentProducts) BuildUserAgentString() string { builder := smithyhttp.NewUserAgentBuilder() for _, p := range ua { diff --git a/internal/expand/filepath.go b/internal/expand/filepath.go new file mode 100644 index 00000000..44e2d205 --- /dev/null +++ b/internal/expand/filepath.go @@ -0,0 +1,28 @@ +package expand + +import ( + "os" + + "github.com/hashicorp/go-multierror" + "github.com/mitchellh/go-homedir" +) + +func FilePaths(in []string) ([]string, error) { + var errs *multierror.Error + result := make([]string, 0, len(in)) + for _, v := range in { + p, err := FilePath(v) + if err != nil { + errs = multierror.Append(errs, err) + continue + } + result = append(result, p) + } + return result, errs.ErrorOrNil() +} + +func FilePath(in string) (s string, err error) { + e := os.ExpandEnv(in) + s, err = homedir.Expand(e) + return +} diff --git a/internal/expand/filepath_test.go b/internal/expand/filepath_test.go new file mode 100644 index 00000000..bb0c2c02 --- /dev/null +++ b/internal/expand/filepath_test.go @@ -0,0 +1,67 @@ +package expand_test + +import ( + "os" + "testing" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/expand" + "github.com/hashicorp/aws-sdk-go-base/v2/servicemocks" +) + +func TestExpandFilePath(t *testing.T) { + testcases := map[string]struct { + path string + expected string + envvars map[string]string + }{ + "filename": { + path: "file", + expected: "file", + }, + "file in current dir": { + path: "./file", + expected: "./file", + }, + "file with tilde": { + path: "~/file", + expected: "/my/home/dir/file", + envvars: map[string]string{ + "HOME": "/my/home/dir", + }, + }, + "file with envvar": { + path: "$HOME/file", + expected: "/home/dir/file", + envvars: map[string]string{ + "HOME": "/home/dir", + }, + }, + "full file in envvar": { + path: "$CONF_FILE", + expected: "/path/to/conf/file", + envvars: map[string]string{ + "CONF_FILE": "/path/to/conf/file", + }, + }, + } + + for name, testcase := range testcases { + t.Run(name, func(t *testing.T) { + oldEnv := servicemocks.StashEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testcase.envvars { + os.Setenv(k, v) + } + + a, err := expand.FilePath(testcase.path) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if a != testcase.expected { + t.Errorf("expected expansion to %q, got %q", testcase.expected, a) + } + }) + } +} diff --git a/internal/httpclient/http_client_test.go b/internal/httpclient/http_client_test.go deleted file mode 100644 index 7e9ea8a7..00000000 --- a/internal/httpclient/http_client_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package httpclient_test - -import ( - "crypto/tls" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws/transport/http" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" -) - -func TestHTTPClientConfiguration_basic(t *testing.T) { - client, err := httpclient.DefaultHttpClient(&config.Config{}) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - transport := client.GetTransport() - - if a, e := transport.MaxIdleConns, http.DefaultHTTPTransportMaxIdleConns; a != e { - t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) - } - if a, e := transport.MaxIdleConnsPerHost, http.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { - t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) - } - if a, e := transport.IdleConnTimeout, http.DefaultHTTPTransportIdleConnTimeout; a != e { - t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) - } - if a, e := transport.TLSHandshakeTimeout, http.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { - t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) - } - if a, e := transport.ExpectContinueTimeout, http.DefaultHTTPTransportExpectContinueTimeout; a != e { - t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) - } - if !transport.ForceAttemptHTTP2 { - t.Error("expected ForceAttemptHTTP2 to be true, got false") - } - - tlsConfig := transport.TLSClientConfig - if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { - t.Errorf("expected tlsConfig.MinVersion to be %d, got %d", e, a) - } - if tlsConfig.InsecureSkipVerify { - t.Error("expected InsecureSkipVerify to be false, got true") - } -} - -func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { - client, err := httpclient.DefaultHttpClient(&config.Config{ - Insecure: true, - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - transport := client.GetTransport() - - tlsConfig := transport.TLSClientConfig - if !tlsConfig.InsecureSkipVerify { - t.Error("expected InsecureSkipVerify to be true, got false") - } -} diff --git a/internal/test/http_client.go b/internal/test/http_client.go new file mode 100644 index 00000000..31a05620 --- /dev/null +++ b/internal/test/http_client.go @@ -0,0 +1,48 @@ +package test + +import ( + "crypto/tls" + "net/http" + "testing" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" +) + +func HTTPClientConfigurationTest_basic(t *testing.T, transport *http.Transport) { + if a, e := transport.MaxIdleConns, awshttp.DefaultHTTPTransportMaxIdleConns; a != e { + t.Errorf("expected MaxIdleConns to be %d, got %d", e, a) + } + if a, e := transport.MaxIdleConnsPerHost, awshttp.DefaultHTTPTransportMaxIdleConnsPerHost; a != e { + t.Errorf("expected MaxIdleConnsPerHost to be %d, got %d", e, a) + } + if a, e := transport.IdleConnTimeout, awshttp.DefaultHTTPTransportIdleConnTimeout; a != e { + t.Errorf("expected IdleConnTimeout to be %s, got %s", e, a) + } + if a, e := transport.TLSHandshakeTimeout, awshttp.DefaultHTTPTransportTLSHandleshakeTimeout; a != e { + t.Errorf("expected TLSHandshakeTimeout to be %s, got %s", e, a) + } + if a, e := transport.ExpectContinueTimeout, awshttp.DefaultHTTPTransportExpectContinueTimeout; a != e { + t.Errorf("expected ExpectContinueTimeout to be %s, got %s", e, a) + } + if !transport.ForceAttemptHTTP2 { + t.Error("expected ForceAttemptHTTP2 to be true, got false") + } + if transport.DisableKeepAlives { + t.Error("expected DisableKeepAlives to be false, got true") + } + + tlsConfig := transport.TLSClientConfig + if a, e := int(tlsConfig.MinVersion), tls.VersionTLS12; a != e { + t.Errorf("expected tlsConfig.MinVersion to be %d, got %d", e, a) + } + if tlsConfig.InsecureSkipVerify { + t.Error("expected InsecureSkipVerify to be false, got true") + } +} + +func HTTPClientConfigurationTest_insecureHTTPS(t *testing.T, transport *http.Transport) { + tlsConfig := transport.TLSClientConfig + if !tlsConfig.InsecureSkipVerify { + t.Error("expected InsecureSkipVerify to be true, got false") + } +} diff --git a/servicemocks/pem_file.go b/servicemocks/pem_file.go new file mode 100644 index 00000000..fde238e4 --- /dev/null +++ b/servicemocks/pem_file.go @@ -0,0 +1,41 @@ +package servicemocks + +import ( + "io/ioutil" +) + +func TempPEMFile() (string, error) { + file, err := ioutil.TempFile("", "bundle-*.pem") + if err != nil { + return "", err + } + defer file.Close() + + _, err = file.Write(TLSBundleCA) + if err != nil { + return "", err + } + + return file.Name(), nil +} + +var ( + // TLSBundleCA ca.crt + TLSBundleCA = []byte(`-----BEGIN CERTIFICATE----- +MIICiTCCAfKgAwIBAgIJAJ5X1olt05XjMA0GCSqGSIb3DQEBCwUAMDgxCzAJBgNV +BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD +QTAeFw0xNzAzMDkwMDAyMDZaFw0yNzAzMDcwMDAyMDZaMDgxCzAJBgNVBAYTAkdP +MQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBDQTCBnzAN +BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAw/8DN+t9XQR60jx42rsQ2WE2Dx85rb3n +GQxnKZZLNddsT8rDyxJNP18aFalbRbFlyln5fxWxZIblu9Xkm/HRhOpbSimSqo1y +uDx21NVZ1YsOvXpHby71jx3gPrrhSc/t/zikhi++6D/C6m1CiIGuiJ0GBiJxtrub +UBMXT0QtI2ECAwEAAaOBmjCBlzAdBgNVHQ4EFgQU8XG3X/YHBA6T04kdEkq6+4GV +YykwaAYDVR0jBGEwX4AU8XG3X/YHBA6T04kdEkq6+4GVYymhPKQ6MDgxCzAJBgNV +BAYTAkdPMQ8wDQYDVQQIEwZHb3BoZXIxGDAWBgNVBAoTD1Rlc3RpbmcgUk9PVCBD +QYIJAJ5X1olt05XjMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADgYEAeILv +z49+uxmPcfOZzonuOloRcpdvyjiXblYxbzz6ch8GsE7Q886FTZbvwbgLhzdwSVgG +G8WHkodDUsymVepdqAamS3f8PdCUk8xIk9mop8LgaB9Ns0/TssxDvMr3sOD2Grb3 +xyWymTWMcj6uCiEBKtnUp4rPiefcvCRYZ17/hLE= +-----END CERTIFICATE----- +`) +) diff --git a/v2/awsv1shim/go.mod b/v2/awsv1shim/go.mod index 202640c2..dd0bd8a1 100644 --- a/v2/awsv1shim/go.mod +++ b/v2/awsv1shim/go.mod @@ -5,6 +5,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.13.0 github.com/google/go-cmp v0.5.7 github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.7 + github.com/hashicorp/go-cleanhttp v0.5.2 ) go 1.16 diff --git a/v2/awsv1shim/go.sum b/v2/awsv1shim/go.sum index 509fea30..6088d0a9 100644 --- a/v2/awsv1shim/go.sum +++ b/v2/awsv1shim/go.sum @@ -32,6 +32,8 @@ github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= diff --git a/v2/awsv1shim/http_client.go b/v2/awsv1shim/http_client.go new file mode 100644 index 00000000..af18b020 --- /dev/null +++ b/v2/awsv1shim/http_client.go @@ -0,0 +1,41 @@ +package awsv1shim + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/go-cleanhttp" +) + +func defaultHttpClient(c *config.Config) (*http.Client, error) { + httpClient := cleanhttp.DefaultPooledClient() + transport := httpClient.Transport.(*http.Transport) + + transport.MaxIdleConnsPerHost = awshttp.DefaultHTTPTransportMaxIdleConnsPerHost + + tlsConfig := transport.TLSClientConfig + if tlsConfig == nil { + tlsConfig = &tls.Config{} + transport.TLSClientConfig = tlsConfig + } + tlsConfig.MinVersion = tls.VersionTLS12 + + if c.Insecure { + tlsConfig.InsecureSkipVerify = true + } + + if c.HTTPProxy != "" { + proxyUrl, err := url.Parse(c.HTTPProxy) + if err != nil { + return nil, fmt.Errorf("error parsing HTTP proxy URL: %w", err) + } + + transport.Proxy = http.ProxyURL(proxyUrl) + } + + return httpClient, nil +} diff --git a/v2/awsv1shim/http_client_test.go b/v2/awsv1shim/http_client_test.go new file mode 100644 index 00000000..2fa27117 --- /dev/null +++ b/v2/awsv1shim/http_client_test.go @@ -0,0 +1,39 @@ +package awsv1shim + +import ( + "net/http" + "testing" + + "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/test" +) + +func TestHTTPClientConfiguration_basic(t *testing.T) { + client, err := defaultHttpClient(&config.Config{}) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) + } + + test.HTTPClientConfigurationTest_basic(t, transport) +} + +func TestHTTPClientConfiguration_insecureHTTPS(t *testing.T) { + client, err := defaultHttpClient(&config.Config{ + Insecure: true, + }) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("Unexpected type for HTTP client transport: %T", client.Transport) + } + + test.HTTPClientConfigurationTest_insecureHTTPS(t, transport) +} diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 89bce01a..252f2cc6 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -4,7 +4,6 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim "context" "fmt" "log" - "net/http" "os" awsv2 "github.com/aws/aws-sdk-go-v2/aws" @@ -17,7 +16,6 @@ import ( // nosemgrep: no-sdkv2-imports-in-awsv1shim "github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr" "github.com/hashicorp/aws-sdk-go-base/v2/internal/awsconfig" "github.com/hashicorp/aws-sdk-go-base/v2/internal/constants" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/httpclient" ) // getSessionOptions attempts to return valid AWS Go SDK session authentication @@ -39,14 +37,11 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, return nil, fmt.Errorf("error resolving configuration: %w", err) } - httpClient, ok := awsC.HTTPClient.(*http.Client) - if !ok { // This is unlikely, but technically possible - client, err := httpclient.DefaultHttpClient(c) - if err != nil { - return nil, err - } - httpClient = client.Freeze().(*http.Client) + httpClient, err := defaultHttpClient(c) + if err != nil { + return nil, err } + options := &session.Options{ Config: aws.Config{ Credentials: credentials.NewStaticCredentials( @@ -64,6 +59,14 @@ func getSessionOptions(awsC *awsv2.Config, c *awsbase.Config) (*session.Options, }, } + if c.CustomCABundle != "" { + reader, err := c.CustomCABundleReader() + if err != nil { + return nil, err + } + options.CustomCABundle = reader + } + return options, nil } diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index ccc2c58b..c5695ad3 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -6,7 +6,9 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "os" + "path/filepath" "runtime" "testing" "time" @@ -1549,6 +1551,189 @@ func DualStackEndpointStateString(state endpoints.DualStackEndpointState) string return fmt.Sprintf("unknown endpoints.FIPSEndpointStateUnset (%d)", state) } +func TestCustomCABundle(t *testing.T) { + testCases := map[string]struct { + Config *awsbase.Config + SetConfig bool + SetEnvironmentVariable bool + SetSharedConfigurationFile bool + ExpandEnvVars bool + EnvironmentVariables map[string]string + ExpectTLSClientConfigRootCAsSet bool + }{ + "no configuration": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectTLSClientConfigRootCAsSet: false, + }, + + "config": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + "expanded config": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + ExpandEnvVars: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + "envvar": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetEnvironmentVariable: true, + ExpectTLSClientConfigRootCAsSet: true, + }, + + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 + // "shared configuration file": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // SetSharedConfigurationFile: true, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + + "config overrides envvar": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + SetConfig: true, + EnvironmentVariables: map[string]string{ + "AWS_CA_BUNDLE": "no-such-file", + }, + ExpectTLSClientConfigRootCAsSet: true, + }, + + // Not implemented in AWS SDK for Go v2: https://github.com/aws/aws-sdk-go-v2/issues/1589 + // "envvar overrides shared configuration": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_CA_BUNDLE": EC2MetadataEndpointModeIPv6, + // }, + // SharedConfigurationFile: ` + // [default] + // ec2_metadata_service_endpoint_mode = IPv4 + // `, + // ExpectTLSClientConfigRootCAsSet: true, + // }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } + + tempdir, err := ioutil.TempDir("", "temp") + if err != nil { + t.Fatalf("error creating temp dir: %s", err) + } + defer os.Remove(tempdir) + os.Setenv("TMPDIR", tempdir) + + pemFile, err := servicemocks.TempPEMFile() + defer os.Remove(pemFile) + if err != nil { + t.Fatalf("error creating PEM file: %s", err) + } + + if testCase.ExpandEnvVars { + tmpdir := os.Getenv("TMPDIR") + rel, err := filepath.Rel(tmpdir, pemFile) + if err != nil { + t.Fatalf("error making path relative: %s", err) + } + t.Logf("relative: %s", rel) + pemFile = filepath.Join("$TMPDIR", rel) + t.Logf("env tempfile: %s", pemFile) + } + + if testCase.SetConfig { + testCase.Config.CustomCABundle = pemFile + } + + if testCase.SetEnvironmentVariable { + os.Setenv("AWS_CA_BUNDLE", pemFile) + } + + if testCase.SetSharedConfigurationFile { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile( + file.Name(), + []byte(fmt.Sprintf(` +[default] +ca_bundle = %s +`, pemFile)), + 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("GetAwsConfig() returned error: %s", err) + } + actualSession, err := GetSession(&awsConfig, testCase.Config) + if err != nil { + t.Fatalf("error in GetSession() '%[1]T': %[1]s", err) + } + + roundTripper := actualSession.Config.HTTPClient.Transport + tr, ok := roundTripper.(*http.Transport) + if !ok { + t.Fatalf("Unexpected type for HTTP client transport: %T", roundTripper) + } + + if a, e := tr.TLSClientConfig.RootCAs != nil, testCase.ExpectTLSClientConfigRootCAsSet; a != e { + t.Errorf("expected(%t) CA Bundle, got: %t", e, a) + } + }) + } +} + func TestSessionRetryHandlers(t *testing.T) { const maxRetries = 25