Skip to content
Merged
40 changes: 21 additions & 19 deletions awsauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
setOptionalEndpoint(cfg)
sess, err := session.NewSession(cfg)
if err != nil {
return "", "", fmt.Errorf("error creating EC2 Metadata session: %s", err)
return "", "", fmt.Errorf("error creating EC2 Metadata session: %w", err)
}

metadataClient := ec2metadata.New(sess)
Expand All @@ -88,7 +88,7 @@ func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
// We can end up here if there's an issue with the instance metadata service
// or if we're getting credentials from AdRoll's Hologram (in which case IAMInfo will
// error out).
err = fmt.Errorf("failed getting account information via EC2 Metadata IAM information: %s", err)
err = fmt.Errorf("failed getting account information via EC2 Metadata IAM information: %w", err)
log.Printf("[DEBUG] %s", err)
return "", "", err
}
Expand All @@ -111,7 +111,7 @@ func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, e
return "", "", nil
}
}
err = fmt.Errorf("failed getting account information via iam:GetUser: %s", err)
err = fmt.Errorf("failed getting account information via iam:GetUser: %w", err)
log.Printf("[DEBUG] %s", err)
return "", "", err
}
Expand All @@ -134,7 +134,7 @@ func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string,
MaxItems: aws.Int64(int64(1)),
})
if err != nil {
err = fmt.Errorf("failed getting account information via iam:ListRoles: %s", err)
err = fmt.Errorf("failed getting account information via iam:ListRoles: %w", err)
log.Printf("[DEBUG] %s", err)
return "", "", err
}
Expand All @@ -155,7 +155,7 @@ func GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsconn *sts.STS) (string,

output, err := stsconn.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return "", "", fmt.Errorf("error calling sts:GetCallerIdentity: %s", err)
return "", "", fmt.Errorf("error calling sts:GetCallerIdentity: %w", err)
}

if output == nil || output.Arn == nil {
Expand Down Expand Up @@ -184,16 +184,17 @@ func GetCredentialsFromSession(c *Config) (*awsCredentials.Credentials, error) {
var sess *session.Session
var err error
if c.Profile == "" {
sess, err = session.NewSession()
sess, err = session.NewSession(&aws.Config{EndpointResolver: c.EndpointResolver()})
if err != nil {
return nil, ErrNoValidCredentialSources
}
} else {
options := &session.Options{
Config: aws.Config{
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(0),
Region: aws.String(c.Region),
EndpointResolver: c.EndpointResolver(),
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(0),
Region: aws.String(c.Region),
},
}
options.Profile = c.Profile
Expand All @@ -204,7 +205,7 @@ func GetCredentialsFromSession(c *Config) (*awsCredentials.Credentials, error) {
if IsAWSErr(err, "NoCredentialProviders", "") {
return nil, ErrNoValidCredentialSources
}
return nil, fmt.Errorf("Error creating AWS session: %s", err)
return nil, fmt.Errorf("Error creating AWS session: %w", err)
}
}

Expand Down Expand Up @@ -276,7 +277,7 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
ec2Session, err := session.NewSession(cfg)

if err != nil {
return nil, fmt.Errorf("error creating EC2 Metadata session: %s", err)
return nil, fmt.Errorf("error creating EC2 Metadata session: %w", err)
}

metadataClient := ec2metadata.New(ec2Session)
Expand Down Expand Up @@ -305,7 +306,7 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
return nil, err
}
} else {
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %w", err)
}
} else {
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
Expand All @@ -322,16 +323,17 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
c.AssumeRoleARN, c.AssumeRoleSessionName, c.AssumeRoleExternalID, c.AssumeRolePolicy)

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
Credentials: creds,
EndpointResolver: c.EndpointResolver(),
Region: aws.String(c.Region),
MaxRetries: aws.Int(c.MaxRetries),
HTTPClient: cleanhttp.DefaultClient(),
}

assumeRoleSession, err := session.NewSession(awsConfig)

if err != nil {
return nil, fmt.Errorf("error creating assume role session: %s", err)
return nil, fmt.Errorf("error creating assume role session: %w", err)
}

stsclient := sts.New(assumeRoleSession)
Expand All @@ -354,7 +356,7 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
assumeRoleCreds := awsCredentials.NewChainCredentials(providers)
_, err = assumeRoleCreds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
if IsAWSErr(err, "NoCredentialProviders", "") {
return nil, fmt.Errorf("The role %q cannot be assumed.\n\n"+
" There are a number of possible causes of this - the most common are:\n"+
" * The credentials used in order to assume the role are invalid\n"+
Expand All @@ -363,7 +365,7 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
c.AssumeRoleARN)
}

return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %w", err)
}

return assumeRoleCreds, nil
Expand Down
12 changes: 6 additions & 6 deletions awsauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,14 +518,14 @@ func TestAWSGetCredentials_shouldIAM(t *testing.T) {
if err != nil {
t.Fatalf("Error gettings creds: %s", err)
}
if v.AccessKeyID != "somekey" {
t.Fatalf("AccessKeyID mismatch, expected: (somekey), got (%s)", v.AccessKeyID)
if expected, actual := "Ec2MetadataAccessKey", v.AccessKeyID; expected != actual {
t.Fatalf("expected access key (%s), got: %s", expected, actual)
}
if v.SecretAccessKey != "somesecret" {
t.Fatalf("SecretAccessKey mismatch, expected: (somesecret), got (%s)", v.SecretAccessKey)
if expected, actual := "Ec2MetadataSecretKey", v.SecretAccessKey; expected != actual {
t.Fatalf("expected secret key (%s), got: %s", expected, actual)
}
if v.SessionToken != "sometoken" {
t.Fatalf("SessionToken mismatch, expected: (sometoken), got (%s)", v.SessionToken)
if expected, actual := "Ec2MetadataSessionToken", v.SessionToken; expected != actual {
t.Fatalf("expected session token (%s), got: %s", expected, actual)
}
}

Expand Down
13 changes: 5 additions & 8 deletions awserr.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package awsbase

import (
"errors"
"strings"

"github.com/aws/aws-sdk-go/aws/awserr"
Expand All @@ -11,17 +12,13 @@ import (
// * Error.Code() matches code
// * Error.Message() contains message
func IsAWSErr(err error, code string, message string) bool {
awsErr, ok := err.(awserr.Error)
var awsErr awserr.Error

if !ok {
return false
if errors.As(err, &awsErr) {
return awsErr.Code() == code && strings.Contains(awsErr.Message(), message)
}

if awsErr.Code() != code {
return false
}

return strings.Contains(awsErr.Message(), message)
return false
}

// IsAWSErrExtended returns true if the error matches all these conditions:
Expand Down
66 changes: 66 additions & 0 deletions awserr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package awsbase

import (
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
Expand Down Expand Up @@ -100,6 +101,71 @@ func TestIsAwsErr(t *testing.T) {
Err: awserr.New("TestCode", "TestMessage", nil),
Message: "TestMessage",
},
{
Name: "wrapped other error",
Err: fmt.Errorf("test: %w", errors.New("test")),
},
{
Name: "wrapped other error code",
Err: fmt.Errorf("test: %w", errors.New("test")),
Code: "test",
},
{
Name: "wrapped other error message",
Err: fmt.Errorf("test: %w", errors.New("test")),
Message: "test",
},
{
Name: "wrapped other error code and message",
Err: fmt.Errorf("test: %w", errors.New("test")),
Code: "test",
Message: "test",
},
{
Name: "wrapped awserr error matching code and no message",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Code: "TestCode",
Expected: true,
},
{
Name: "wrapped awserr error matching code and matching message exact",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Code: "TestCode",
Message: "TestMessage",
Expected: true,
},
{
Name: "wrapped awserr error matching code and matching message contains",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Code: "TestCode",
Message: "Message",
Expected: true,
},
{
Name: "wrapped awserr error matching code and non-matching message",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Code: "TestCode",
Message: "NotMatching",
},
{
Name: "wrapped awserr error no code",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
},
{
Name: "wrapped awserr error no code and matching message exact",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Message: "TestMessage",
},
{
Name: "wrapped awserr error non-matching code",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Code: "NotMatching",
},
{
Name: "wrapped awserr error non-matching code and message exact",
Err: fmt.Errorf("test: %w", awserr.New("TestCode", "TestMessage", nil)),
Message: "TestMessage",
},
}

for _, testCase := range testCases {
Expand Down
46 changes: 46 additions & 0 deletions endpoints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package awsbase

import (
"log"
"os"

"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
)

func (c *Config) EndpointResolver() endpoints.Resolver {
resolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
// Ensure we pass all existing information (e.g. SigningRegion) and
// only override the URL, otherwise a MissingRegion error can occur
// when aws.Config.Region is not defined.
resolvedEndpoint, err := endpoints.DefaultResolver().EndpointFor(service, region, optFns...)

if err != nil {
return resolvedEndpoint, err
}

switch service {
case ec2metadata.ServiceName:
if endpoint := os.Getenv("AWS_METADATA_URL"); endpoint != "" {
log.Printf("[INFO] Setting custom EC2 metadata endpoint: %s", endpoint)
resolvedEndpoint.URL = endpoint
}
case iam.ServiceName:
if endpoint := c.IamEndpoint; endpoint != "" {
log.Printf("[INFO] Setting custom IAM endpoint: %s", endpoint)
resolvedEndpoint.URL = endpoint
}
case sts.ServiceName:
if endpoint := c.StsEndpoint; endpoint != "" {
log.Printf("[INFO] Setting custom STS endpoint: %s", endpoint)
resolvedEndpoint.URL = endpoint
}
}

return resolvedEndpoint, nil
}

return endpoints.ResolverFunc(resolver)
}
Loading